diff --git a/CHANGELOG.md b/CHANGELOG.md index 8d201ce3..f62d3311 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,8 +7,17 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## Unreleased +### Added + +- Add ability to use Grafana Service Account Tokens for OnCall API (This is only enabled for resolution_notes +endpoint currently) @mderynck ([#3189](https://github.com/grafana/oncall/pull/3189)) +- Add ability for webhook presets to mask sensitive headers @mderynck +([#3189](https://github.com/grafana/oncall/pull/3189)) + ### Fixed +- Fixed issue that blocked saving webhooks with presets if the preset is controlling the URL @mderynck +([#3189](https://github.com/grafana/oncall/pull/3189)) - User filter doesn't display current value on Alert Groups page ([1714](https://github.com/grafana/oncall/issues/1714)) - Remove displaying rotation modal for Terraform/API based schedules - Filters polishing ([3183](https://github.com/grafana/oncall/issues/3183)) diff --git a/engine/apps/api/serializers/webhook.py b/engine/apps/api/serializers/webhook.py index 2ee9e02e..dfa53a86 100644 --- a/engine/apps/api/serializers/webhook.py +++ b/engine/apps/api/serializers/webhook.py @@ -182,7 +182,9 @@ class WebhookSerializer(LabelsSerializerMixin, serializers.ModelSerializer): for controlled_field in preset_metadata.controlled_fields: if controlled_field in self.initial_data: if self.instance: - if self.initial_data[controlled_field] != getattr(self.instance, controlled_field): + if self.initial_data[controlled_field] is not None and self.initial_data[ + controlled_field + ] != getattr(self.instance, controlled_field): raise serializers.ValidationError( detail=f"{controlled_field} is controlled by preset, cannot update" ) diff --git a/engine/apps/auth_token/auth.py b/engine/apps/auth_token/auth.py index aee403f5..8982fe21 100644 --- a/engine/apps/auth_token/auth.py +++ b/engine/apps/auth_token/auth.py @@ -8,14 +8,16 @@ from rest_framework import exceptions from rest_framework.authentication import BaseAuthentication, get_authorization_header from rest_framework.request import Request -from apps.api.permissions import RBACPermission, user_is_authorized +from apps.api.permissions import GrafanaAPIPermission, LegacyAccessControlRole, RBACPermission, user_is_authorized from apps.grafana_plugin.helpers.gcom import check_token from apps.user_management.exceptions import OrganizationDeletedException, OrganizationMovedException from apps.user_management.models import User from apps.user_management.models.organization import Organization +from settings.base import SELF_HOSTED_SETTINGS from .constants import SCHEDULE_EXPORT_TOKEN_NAME, SLACK_AUTH_TOKEN_NAME from .exceptions import InvalidToken +from .grafana.grafana_auth_token import get_service_account_token_permissions from .models import ApiAuthToken, PluginAuthToken, ScheduleExportAuthToken, SlackAuthToken, UserScheduleExportAuthToken logger = logging.getLogger(__name__) @@ -262,3 +264,71 @@ class UserScheduleExportAuthentication(BaseAuthentication): raise exceptions.AuthenticationFailed("Export token is deactivated") return auth_token.user, auth_token + + +X_GRAFANA_ORG_SLUG = "X-Grafana-Org-Slug" +X_GRAFANA_INSTANCE_SLUG = "X-Grafana-Instance-Slug" +GRAFANA_SA_PREFIX = "glsa_" + + +class GrafanaServiceAccountAuthentication(BaseAuthentication): + def authenticate(self, request): + auth = get_authorization_header(request).decode("utf-8") + if not auth: + raise exceptions.AuthenticationFailed("Invalid token.") + if not auth.startswith(GRAFANA_SA_PREFIX): + return None + + organization = self.get_organization(request) + if not organization: + raise exceptions.AuthenticationFailed("Invalid organization.") + if organization.is_moved: + raise OrganizationMovedException(organization) + if organization.deleted_at: + raise OrganizationDeletedException(organization) + + return self.authenticate_credentials(organization, auth) + + def get_organization(self, request): + org_slug = SELF_HOSTED_SETTINGS["ORG_SLUG"] + instance_slug = SELF_HOSTED_SETTINGS["STACK_SLUG"] + if settings.LICENSE == settings.CLOUD_LICENSE_NAME: + org_slug = request.headers.get(X_GRAFANA_ORG_SLUG) + if not org_slug: + raise exceptions.AuthenticationFailed(f"Missing {X_GRAFANA_ORG_SLUG}") + instance_slug = request.headers.get(X_GRAFANA_INSTANCE_SLUG) + if not instance_slug: + raise exceptions.AuthenticationFailed(f"Missing {X_GRAFANA_INSTANCE_SLUG}") + + return Organization.objects.filter(org_slug=org_slug, stack_slug=instance_slug).first() + + def authenticate_credentials(self, organization, token): + permissions = get_service_account_token_permissions(organization, token) + if not permissions: + raise exceptions.AuthenticationFailed("Invalid token.") + + role = LegacyAccessControlRole.NONE + if not organization.is_rbac_permissions_enabled: + role = self.determine_role_from_permissions(permissions) + + user = User( + organization_id=organization.pk, + name="Grafana Service Account", + username="grafana_service_account", + role=role, + permissions=[GrafanaAPIPermission(action=key) for key, _ in permissions.items()], + ) + + auth_token = ApiAuthToken(organization=organization, user=user, name="Grafana Service Account") + + return user, auth_token + + # Using default permissions as proxies for roles since we cannot explicitly get role from the service account token + def determine_role_from_permissions(self, permissions): + if "plugins:write" in permissions: + return LegacyAccessControlRole.ADMIN + if "dashboards:write" in permissions: + return LegacyAccessControlRole.EDITOR + if "dashboards:read" in permissions: + return LegacyAccessControlRole.VIEWER + return LegacyAccessControlRole.NONE diff --git a/engine/apps/auth_token/exceptions.py b/engine/apps/auth_token/exceptions.py index 0ea79b0c..ddbdff94 100644 --- a/engine/apps/auth_token/exceptions.py +++ b/engine/apps/auth_token/exceptions.py @@ -1,2 +1,6 @@ class InvalidToken(Exception): pass + + +class ServiceAccountDoesNotExist(Exception): + pass diff --git a/engine/apps/auth_token/grafana/__init__.py b/engine/apps/auth_token/grafana/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/engine/apps/auth_token/grafana/grafana_auth_token.py b/engine/apps/auth_token/grafana/grafana_auth_token.py new file mode 100644 index 00000000..07bae644 --- /dev/null +++ b/engine/apps/auth_token/grafana/grafana_auth_token.py @@ -0,0 +1,48 @@ +import typing + +from apps.auth_token.exceptions import ServiceAccountDoesNotExist +from apps.grafana_plugin.helpers import GrafanaAPIClient +from apps.user_management.models import Organization + +SA_ONCALL_API_NAME = "sa-autogen-OnCall" + + +def find_service_account( + organization: Organization, service_account_name=SA_ONCALL_API_NAME +) -> typing.Optional["GrafanaAPIClient.Types.GrafanaServiceAccount"]: + grafana_api_client = GrafanaAPIClient(api_url=organization.grafana_url, api_token=organization.api_token) + response, _ = grafana_api_client.get_service_account(service_account_name) + if response and "serviceAccounts" in response and response["serviceAccounts"]: + return response["serviceAccounts"][0] + return None + + +def create_service_account( + organization: Organization, name: str, role: str +) -> GrafanaAPIClient.Types.GrafanaServiceAccount: + grafana_api_client = GrafanaAPIClient(api_url=organization.grafana_url, api_token=organization.api_token) + response, _ = grafana_api_client.create_service_account(name, role) + return response + + +def create_service_account_token( + organization: Organization, + token_name: str, + seconds_to_live=int | None, + service_account_name=SA_ONCALL_API_NAME, +) -> typing.Optional[str]: + grafana_api_client = GrafanaAPIClient(api_url=organization.grafana_url, api_token=organization.api_token) + service_account = find_service_account(organization, service_account_name) + if not service_account: + raise ServiceAccountDoesNotExist + + response, _ = grafana_api_client.create_service_account_token(service_account["id"], token_name, seconds_to_live) + if response: + return response["key"] + return None + + +def get_service_account_token_permissions(organization: Organization, token: str) -> typing.Dict[str, typing.List[str]]: + grafana_api_client = GrafanaAPIClient(api_url=organization.grafana_url, api_token=token) + permissions, _ = grafana_api_client.get_service_account_token_permissions() + return permissions diff --git a/engine/apps/auth_token/tests/test_grafana_auth.py b/engine/apps/auth_token/tests/test_grafana_auth.py new file mode 100644 index 00000000..b1fdc800 --- /dev/null +++ b/engine/apps/auth_token/tests/test_grafana_auth.py @@ -0,0 +1,80 @@ +import typing +from unittest.mock import patch + +import pytest +from rest_framework import exceptions +from rest_framework.test import APIRequestFactory + +from apps.auth_token.auth import ( + GRAFANA_SA_PREFIX, + X_GRAFANA_INSTANCE_SLUG, + X_GRAFANA_ORG_SLUG, + GrafanaServiceAccountAuthentication, +) +from settings.base import CLOUD_LICENSE_NAME, OPEN_SOURCE_LICENSE_NAME, SELF_HOSTED_SETTINGS + + +def fake_authenticate_credentials(organization, token): + pass + + +@pytest.mark.django_db +def test_grafana_authentication_oss_inputs(make_organization, settings): + settings.LICENSE = OPEN_SOURCE_LICENSE_NAME + + headers, token = check_common_inputs() + organization = make_organization( + stack_slug=SELF_HOSTED_SETTINGS["STACK_SLUG"], org_slug=SELF_HOSTED_SETTINGS["ORG_SLUG"] + ) + request = APIRequestFactory().get("/", **headers) + with patch( + "apps.auth_token.auth.GrafanaServiceAccountAuthentication.authenticate_credentials", + wraps=fake_authenticate_credentials, + ) as mock: + GrafanaServiceAccountAuthentication().authenticate(request) + mock.assert_called_once_with(organization, token) + + +@pytest.mark.django_db +def test_grafana_authentication_cloud_inputs(make_organization, settings): + settings.LICENSE = CLOUD_LICENSE_NAME + headers, token = check_common_inputs() + + test_org_slug = "test_org_123" + test_stack_slug = "test_stack_123" + headers[f"HTTP_{X_GRAFANA_ORG_SLUG}"] = test_org_slug + headers[f"HTTP_{X_GRAFANA_INSTANCE_SLUG}"] = test_stack_slug + request = APIRequestFactory().get("/", **headers) + with pytest.raises(exceptions.AuthenticationFailed): + GrafanaServiceAccountAuthentication().authenticate(request) + + organization = make_organization(stack_slug=test_stack_slug, org_slug=test_org_slug) + with patch( + "apps.auth_token.auth.GrafanaServiceAccountAuthentication.authenticate_credentials", + wraps=fake_authenticate_credentials, + ) as mock: + GrafanaServiceAccountAuthentication().authenticate(request) + mock.assert_called_once_with(organization, token) + + +def check_common_inputs() -> (dict[str, typing.Any], str): + request = APIRequestFactory().get("/") + with pytest.raises(exceptions.AuthenticationFailed): + GrafanaServiceAccountAuthentication().authenticate(request) + + headers = { + "HTTP_AUTHORIZATION": "xyz", + } + request = APIRequestFactory().get("/", **headers) + result = GrafanaServiceAccountAuthentication().authenticate(request) + assert result is None + + token = f"{GRAFANA_SA_PREFIX}xyz" + headers = { + "HTTP_AUTHORIZATION": token, + } + request = APIRequestFactory().get("/", **headers) + with pytest.raises(exceptions.AuthenticationFailed): + GrafanaServiceAccountAuthentication().authenticate(request) + + return headers, token diff --git a/engine/apps/grafana_plugin/helpers/client.py b/engine/apps/grafana_plugin/helpers/client.py index 6597dd54..d9e7faaa 100644 --- a/engine/apps/grafana_plugin/helpers/client.py +++ b/engine/apps/grafana_plugin/helpers/client.py @@ -176,9 +176,27 @@ class GrafanaAPIClient(APIClient): avatarUrl: str memberCount: int + class GrafanaServiceAccount(typing.TypedDict): + id: int + name: str + login: str + orgId: int + isDisabled: bool + role: str + tokens: int + avatarUrl: str + + class GrafanaServiceAccountToken(typing.TypedDict): + id: int + name: str + key: str + class TeamsResponse(_BaseGrafanaAPIResponse): teams: typing.List["GrafanaAPIClient.Types.GrafanaTeam"] + class ServiceAccountResponse(_BaseGrafanaAPIResponse): + serviceAccounts: typing.List["GrafanaAPIClient.Types.GrafanaServiceAccount"] + def __init__(self, api_url: str, api_token: str) -> None: super().__init__(api_url, api_token) @@ -274,6 +292,25 @@ class GrafanaAPIClient(APIClient): def get_grafana_plugin_settings(self, recipient: str) -> APIClientResponse: return self.api_get(f"api/plugins/{recipient}/settings") + def get_service_account(self, login: str) -> APIClientResponse["GrafanaAPIClient.Types.ServiceAccountResponse"]: + return self.api_get(f"api/serviceaccounts/search?query={login}") + + def create_service_account( + self, name: str, role: str + ) -> APIClientResponse["GrafanaAPIClient.Types.GrafanaServiceAccount"]: + return self.api_post("api/serviceaccounts", {"name": name, "role": role}) + + def create_service_account_token( + self, service_account_id: int, name: str, seconds_to_live=int | None + ) -> APIClientResponse["GrafanaAPIClient.Types.GrafanaServiceAccountToken"]: + token_config = {"name": name} + if seconds_to_live: + token_config["secondsToLive"] = seconds_to_live + return self.api_post(f"api/serviceaccounts/{service_account_id}/tokens", token_config) + + def get_service_account_token_permissions(self) -> APIClientResponse[typing.Dict[str, typing.List[str]]]: + return self.api_get("api/access-control/user/permissions") + class GcomAPIClient(APIClient): ACTIVE_INSTANCE_QUERY = "instances?status=active" diff --git a/engine/apps/public_api/serializers/resolution_notes.py b/engine/apps/public_api/serializers/resolution_notes.py index 6cf7d7c9..b48f3fa6 100644 --- a/engine/apps/public_api/serializers/resolution_notes.py +++ b/engine/apps/public_api/serializers/resolution_notes.py @@ -34,7 +34,8 @@ class ResolutionNoteSerializer(EagerLoadingMixin, serializers.ModelSerializer): SELECT_RELATED = ["alert_group", "resolution_note_slack_message", "author"] def create(self, validated_data): - validated_data["author"] = self.context["request"].user + if self.context["request"].user.pk: + validated_data["author"] = self.context["request"].user validated_data["source"] = ResolutionNote.Source.WEB return super().create(validated_data) diff --git a/engine/apps/public_api/tests/test_resolution_notes.py b/engine/apps/public_api/tests/test_resolution_notes.py index 7c5e3b94..2a44e622 100644 --- a/engine/apps/public_api/tests/test_resolution_notes.py +++ b/engine/apps/public_api/tests/test_resolution_notes.py @@ -1,9 +1,13 @@ +from unittest.mock import patch + import pytest from django.urls import reverse from rest_framework import status from rest_framework.test import APIClient from apps.alerts.models import ResolutionNote +from apps.auth_token.auth import GRAFANA_SA_PREFIX, ApiTokenAuthentication, GrafanaServiceAccountAuthentication +from apps.auth_token.models import ApiAuthToken @pytest.mark.django_db @@ -273,3 +277,75 @@ def test_delete_resolution_note( assert response.status_code == status.HTTP_404_NOT_FOUND assert response.data["detail"] == "Not found." + + +@pytest.mark.django_db +def test_create_resolution_note_grafana_auth(make_organization_and_user, make_alert_receive_channel, make_alert_group): + organization, user = make_organization_and_user() + client = APIClient() + + alert_receive_channel = make_alert_receive_channel(organization) + alert_group = make_alert_group(alert_receive_channel) + + url = reverse("api-public:resolution_notes-list") + + data = { + "alert_group_id": alert_group.public_primary_key, + "text": "Test Resolution Note Message", + } + + api_token_auth = ApiTokenAuthentication() + grafana_sa_auth = GrafanaServiceAccountAuthentication() + + # GrafanaServiceAccountAuthentication handles empty auth + with patch( + "apps.auth_token.auth.ApiTokenAuthentication.authenticate", wraps=api_token_auth.authenticate + ) as mock_api_key_auth, patch( + "apps.auth_token.auth.GrafanaServiceAccountAuthentication.authenticate", wraps=grafana_sa_auth.authenticate + ) as mock_grafana_auth: + response = client.post(url, data=data, format="json") + mock_grafana_auth.assert_called_once() + mock_api_key_auth.assert_not_called() + assert response.status_code == status.HTTP_403_FORBIDDEN + + token = "abc123" + # GrafanaServiceAccountAuthentication passes through api key auth + with patch( + "apps.auth_token.auth.ApiTokenAuthentication.authenticate", wraps=api_token_auth.authenticate + ) as mock_api_key_auth, patch( + "apps.auth_token.auth.GrafanaServiceAccountAuthentication.authenticate", wraps=grafana_sa_auth.authenticate + ) as mock_grafana_auth: + response = client.post(url, data=data, format="json", HTTP_AUTHORIZATION=f"{token}") + mock_grafana_auth.assert_called_once() + mock_api_key_auth.assert_called_once() + assert response.status_code == status.HTTP_403_FORBIDDEN + + token = f"{GRAFANA_SA_PREFIX}123" + # GrafanaServiceAccountAuthentication handle invalid token + with patch( + "apps.auth_token.auth.ApiTokenAuthentication.authenticate", wraps=api_token_auth.authenticate + ) as mock_api_key_auth, patch( + "apps.auth_token.auth.GrafanaServiceAccountAuthentication.authenticate", wraps=grafana_sa_auth.authenticate + ) as mock_grafana_auth: + response = client.post(url, data=data, format="json", HTTP_AUTHORIZATION=f"{token}") + mock_grafana_auth.assert_called_once() + mock_api_key_auth.assert_not_called() + assert response.status_code == status.HTTP_403_FORBIDDEN + + success_token = ApiAuthToken(organization=organization, user=user, name="Grafana Service Account") + # GrafanaServiceAccountAuthentication handle successful token + with patch( + "apps.auth_token.auth.GrafanaServiceAccountAuthentication.authenticate", return_value=(user, success_token) + ): + response = client.post(url, data=data, format="json", HTTP_AUTHORIZATION=f"{token}") + assert response.status_code == status.HTTP_201_CREATED + resolution_note = ResolutionNote.objects.get(public_primary_key=response.data["id"]) + result = { + "id": resolution_note.public_primary_key, + "alert_group_id": alert_group.public_primary_key, + "author": user.public_primary_key, + "source": resolution_note.get_source_display(), + "created_at": response.data["created_at"], + "text": data["text"], + } + assert response.data == result diff --git a/engine/apps/public_api/views/resolution_notes.py b/engine/apps/public_api/views/resolution_notes.py index f4886efa..06252aa7 100644 --- a/engine/apps/public_api/views/resolution_notes.py +++ b/engine/apps/public_api/views/resolution_notes.py @@ -5,7 +5,8 @@ from rest_framework.viewsets import ModelViewSet from apps.alerts.models import ResolutionNote from apps.alerts.tasks import send_update_resolution_note_signal -from apps.auth_token.auth import ApiTokenAuthentication +from apps.api.permissions import RBACPermission +from apps.auth_token.auth import ApiTokenAuthentication, GrafanaServiceAccountAuthentication from apps.public_api.serializers.resolution_notes import ResolutionNoteSerializer, ResolutionNoteUpdateSerializer from apps.public_api.throttlers.user_throttle import UserThrottle from common.api_helpers.mixins import RateLimitHeadersMixin, UpdateSerializerMixin @@ -13,8 +14,18 @@ from common.api_helpers.paginators import FiftyPageSizePaginator class ResolutionNoteView(RateLimitHeadersMixin, UpdateSerializerMixin, ModelViewSet): - authentication_classes = (ApiTokenAuthentication,) - permission_classes = (IsAuthenticated,) + authentication_classes = (GrafanaServiceAccountAuthentication, ApiTokenAuthentication) + permission_classes = (IsAuthenticated, RBACPermission) + + rbac_permissions = { + "metadata": [RBACPermission.Permissions.ALERT_GROUPS_READ], + "list": [RBACPermission.Permissions.ALERT_GROUPS_READ], + "retrieve": [RBACPermission.Permissions.ALERT_GROUPS_READ], + "create": [RBACPermission.Permissions.ALERT_GROUPS_WRITE], + "update": [RBACPermission.Permissions.ALERT_GROUPS_WRITE], + "partial_update": [RBACPermission.Permissions.ALERT_GROUPS_WRITE], + "destroy": [RBACPermission.Permissions.ALERT_GROUPS_WRITE], + } throttle_classes = [UserThrottle] diff --git a/engine/apps/webhooks/presets/advanced.py b/engine/apps/webhooks/presets/advanced.py index 1983943e..4e380c54 100644 --- a/engine/apps/webhooks/presets/advanced.py +++ b/engine/apps/webhooks/presets/advanced.py @@ -1,3 +1,5 @@ +import typing + from apps.webhooks.models import Webhook from apps.webhooks.presets.preset import WebhookPreset, WebhookPresetMetadata @@ -17,3 +19,6 @@ class AdvancedWebhookPreset(WebhookPreset): def override_parameters_at_runtime(self, webhook: Webhook): pass + + def get_masked_headers(self) -> typing.List[str]: + return [] diff --git a/engine/apps/webhooks/presets/preset.py b/engine/apps/webhooks/presets/preset.py index 8e946476..9ef97c81 100644 --- a/engine/apps/webhooks/presets/preset.py +++ b/engine/apps/webhooks/presets/preset.py @@ -1,3 +1,4 @@ +import typing from abc import ABC, abstractmethod from dataclasses import dataclass from typing import List @@ -34,3 +35,8 @@ class WebhookPreset(ABC): def override_parameters_at_runtime(self, webhook: Webhook): """Implement this to write parameters before the webhook is executed (These will not be persisted)""" pass + + @abstractmethod + def get_masked_headers(self) -> typing.List[str]: + """Implement this to write sensitive header data as ******** when writing to logs""" + return [] diff --git a/engine/apps/webhooks/presets/simple.py b/engine/apps/webhooks/presets/simple.py index dc1db970..ab62f6d3 100644 --- a/engine/apps/webhooks/presets/simple.py +++ b/engine/apps/webhooks/presets/simple.py @@ -1,3 +1,5 @@ +import typing + from apps.webhooks.models import Webhook from apps.webhooks.presets.preset import WebhookPreset, WebhookPresetMetadata @@ -30,3 +32,6 @@ class SimpleWebhookPreset(WebhookPreset): def override_parameters_at_runtime(self, webhook: Webhook): pass + + def get_masked_headers(self) -> typing.List[str]: + return [] diff --git a/engine/apps/webhooks/tasks/trigger_webhook.py b/engine/apps/webhooks/tasks/trigger_webhook.py index 0e43f82c..dcc0f960 100644 --- a/engine/apps/webhooks/tasks/trigger_webhook.py +++ b/engine/apps/webhooks/tasks/trigger_webhook.py @@ -96,10 +96,12 @@ def _build_payload(webhook, alert_group, user): return data -def mask_authorization_header(headers): +def mask_authorization_header(headers, header_keys_to_mask): masked_headers = headers.copy() - if "Authorization" in masked_headers: - masked_headers["Authorization"] = WEBHOOK_FIELD_PLACEHOLDER + lower_keys = set(k.lower() for k in header_keys_to_mask) + for k in headers.keys(): + if k.lower() in lower_keys: + masked_headers[k] = WEBHOOK_FIELD_PLACEHOLDER return masked_headers @@ -114,6 +116,7 @@ def make_request(webhook, alert_group, data): "webhook": webhook, "event_data": json.dumps(data), } + masked_header_keys = ["Authorization"] exception = error = None try: @@ -121,7 +124,9 @@ def make_request(webhook, alert_group, data): if webhook.preset not in WebhookPresetOptions.WEBHOOK_PRESETS: raise Exception(f"Invalid preset {webhook.preset}") else: - WebhookPresetOptions.WEBHOOK_PRESETS[webhook.preset].override_parameters_at_runtime(webhook) + preset = WebhookPresetOptions.WEBHOOK_PRESETS[webhook.preset] + preset.override_parameters_at_runtime(webhook) + masked_header_keys.extend(preset.get_masked_headers()) if not webhook.check_integration_filter(alert_group): status["request_trigger"] = NOT_FROM_SELECTED_INTEGRATION @@ -131,7 +136,7 @@ def make_request(webhook, alert_group, data): if triggered: status["url"] = webhook.build_url(data) request_kwargs = webhook.build_request_kwargs(data, raise_data_errors=True) - display_headers = mask_authorization_header(request_kwargs.get("headers", {})) + display_headers = mask_authorization_header(request_kwargs.get("headers", {}), masked_header_keys) status["request_headers"] = json.dumps(display_headers) if "json" in request_kwargs: status["request_data"] = json.dumps(request_kwargs["json"]) diff --git a/engine/apps/webhooks/tests/test_webhook_presets.py b/engine/apps/webhooks/tests/test_webhook_presets.py index 70c95151..d73f0aab 100644 --- a/engine/apps/webhooks/tests/test_webhook_presets.py +++ b/engine/apps/webhooks/tests/test_webhook_presets.py @@ -1,8 +1,11 @@ +import json +import typing from unittest.mock import patch import pytest from apps.webhooks.models import Webhook +from apps.webhooks.models.webhook import WEBHOOK_FIELD_PLACEHOLDER from apps.webhooks.presets.preset import WebhookPreset, WebhookPresetMetadata from apps.webhooks.tasks.trigger_webhook import make_request from apps.webhooks.tests.test_trigger_webhook import MockResponse @@ -14,6 +17,8 @@ TEST_WEBHOOK_LOGO = "test_logo" TEST_WEBHOOK_PRESET_DESCRIPTION = "Description of test webhook preset" TEST_WEBHOOK_PRESET_CONTROLLED_FIELDS = ["url", "http_method", "data", "authorization_header"] TEST_WEBHOOK_AUTHORIZATION_HEADER = "Test Auth header 12345" +TEST_WEBHOOK_MASK_HEADER = "X-Secret-Header" +TEST_WEBHOOK_MASK_HEADER_VALUE = "abc123" INVALID_PRESET_ID = "invalid_preset_id" @@ -34,6 +39,12 @@ class TestWebhookPreset(WebhookPreset): def override_parameters_at_runtime(self, webhook: Webhook): webhook.authorization_header = TEST_WEBHOOK_AUTHORIZATION_HEADER + webhook.headers = json.dumps( + {"Content-Type": "application/json", TEST_WEBHOOK_MASK_HEADER: TEST_WEBHOOK_MASK_HEADER_VALUE} + ) + + def get_masked_headers(self) -> typing.List[str]: + return [TEST_WEBHOOK_MASK_HEADER] @pytest.mark.django_db @@ -124,11 +135,20 @@ def test_webhook_preset_runtime_override(make_organization, webhook_preset_api_s with patch.object(webhook, "build_url"): response = MockResponse() with patch.object(webhook, "make_request", return_value=response) as mock_make_request: - triggered, webhook_status, error, exception = make_request(webhook, None, None) + triggered, webhook_status, error, exception = make_request(webhook, None, {}) + assert mock_make_request.call_args.args[1]["headers"]["Content-Type"] == "application/json" assert mock_make_request.call_args.args[1]["headers"]["Authorization"] == TEST_WEBHOOK_AUTHORIZATION_HEADER + assert ( + mock_make_request.call_args.args[1]["headers"][TEST_WEBHOOK_MASK_HEADER] + == TEST_WEBHOOK_MASK_HEADER_VALUE + ) assert triggered assert error is None assert exception is None + webhook_status_headers = json.loads(webhook_status["request_headers"]) + assert webhook_status_headers["Content-Type"] == "application/json" + assert webhook_status_headers["Authorization"] == WEBHOOK_FIELD_PLACEHOLDER + assert webhook_status_headers[TEST_WEBHOOK_MASK_HEADER] == WEBHOOK_FIELD_PLACEHOLDER webhook.refresh_from_db() assert webhook.authorization_header is None