From 357b5c47c6034c1fb556dbdbf5540fb49fcab8b4 Mon Sep 17 00:00:00 2001 From: Michael Derynck Date: Fri, 8 Nov 2024 19:42:11 -0700 Subject: [PATCH 01/12] Limit slack block text length when rendering alert group timeline (#5246) # What this PR does Limit length of text in block being posted to slack when showing alert group timeline. ## Which issue(s) this PR closes Related to [issue link here] ## Checklist - [x] Unit, integration, and e2e (if applicable) tests updated - [x] Documentation added (or `pr:no public docs` PR label added if not required) - [x] Added the relevant release notes label (see labels prefixed w/ `release:`). These labels dictate how your PR will show up in the autogenerated release notes. --- engine/apps/slack/scenarios/alertgroup_timeline.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/engine/apps/slack/scenarios/alertgroup_timeline.py b/engine/apps/slack/scenarios/alertgroup_timeline.py index 08f74b88..7ca3a56f 100644 --- a/engine/apps/slack/scenarios/alertgroup_timeline.py +++ b/engine/apps/slack/scenarios/alertgroup_timeline.py @@ -2,6 +2,7 @@ import typing from apps.api.permissions import RBACPermission from apps.slack.chatops_proxy_routing import make_private_metadata +from apps.slack.constants import BLOCK_SECTION_TEXT_MAX_SIZE from apps.slack.scenarios import scenario_step from apps.slack.scenarios.slack_renderer import AlertGroupLogSlackRenderer from apps.slack.types import ( @@ -47,9 +48,13 @@ class OpenAlertGroupTimelineDialogStep(AlertGroupActionsMixin, scenario_step.Sce future_log_report = AlertGroupLogSlackRenderer.render_alert_group_future_log_report_text(alert_group) blocks: typing.List[Block.Section] = [] if past_log_report: - blocks.append({"type": "section", "text": {"type": "mrkdwn", "text": past_log_report}}) + blocks.append( + {"type": "section", "text": {"type": "mrkdwn", "text": past_log_report[:BLOCK_SECTION_TEXT_MAX_SIZE]}} + ) if future_log_report: - blocks.append({"type": "section", "text": {"type": "mrkdwn", "text": future_log_report}}) + blocks.append( + {"type": "section", "text": {"type": "mrkdwn", "text": future_log_report[:BLOCK_SECTION_TEXT_MAX_SIZE]}} + ) view: ModalView = { "blocks": blocks, From df6bb69d29c7c496d70aaa583a7046787dfe68bd Mon Sep 17 00:00:00 2001 From: Dominik Broj Date: Tue, 12 Nov 2024 16:48:47 +0100 Subject: [PATCH 02/12] fix: disable accessControlOnCall for Grafana 11.3 (#5245) # What this PR does Disable accessControlOnCall for Grafana 11.3 ## Checklist - [ ] Unit, integration, and e2e (if applicable) tests updated - [x] Documentation added (or `pr:no public docs` PR label added if not required) - [x] Added the relevant release notes label (see labels prefixed w/ `release:`). These labels dictate how your PR will show up in the autogenerated release notes. --- .github/workflows/linting-and-tests.yml | 1 + Tiltfile | 24 ++++++++++++++++++++++-- dev/helm-local.yml | 3 +++ helm/oncall/values.yaml | 3 +++ 4 files changed, 29 insertions(+), 2 deletions(-) diff --git a/.github/workflows/linting-and-tests.yml b/.github/workflows/linting-and-tests.yml index fc43b572..23688595 100644 --- a/.github/workflows/linting-and-tests.yml +++ b/.github/workflows/linting-and-tests.yml @@ -244,6 +244,7 @@ jobs: grafana_version: - 10.3.0 - 11.2.0 + - latest fail-fast: false with: grafana_version: ${{ matrix.grafana_version }} diff --git a/Tiltfile b/Tiltfile index 26442416..00d7ec41 100644 --- a/Tiltfile +++ b/Tiltfile @@ -32,12 +32,23 @@ def plugin_json(): return plugin_file return 'NOT_A_PLUGIN' +def extra_grafana_ini(): + return { + 'feature_toggles': { + 'accessControlOnCall': 'false' + } + } + def extra_env(): return { "GF_APP_URL": grafana_url, "GF_SERVER_ROOT_URL": grafana_url, "GF_FEATURE_TOGGLES_ENABLE": "externalServiceAccounts", - "ONCALL_API_URL": "http://oncall-dev-engine:8080" + "ONCALL_API_URL": "http://oncall-dev-engine:8080", + + # Enables managed service accounts for plugin authentication in Grafana >= 11.3 + # https://grafana.com/docs/grafana/latest/setup-grafana/configure-grafana/#managed_service_accounts_enabled + "GF_AUTH_MANAGED_SERVICE_ACCOUNTS_ENABLED": "true", } def extra_deps(): @@ -132,7 +143,16 @@ def load_grafana(): "GF_APP_URL": grafana_url, # older versions of grafana need this "GF_SERVER_ROOT_URL": grafana_url, "GF_FEATURE_TOGGLES_ENABLE": "externalServiceAccounts", - "ONCALL_API_URL": "http://oncall-dev-engine:8080" + "ONCALL_API_URL": "http://oncall-dev-engine:8080", + + # Enables managed service accounts for plugin authentication in Grafana >= 11.3 + # https://grafana.com/docs/grafana/latest/setup-grafana/configure-grafana/#managed_service_accounts_enabled + "GF_AUTH_MANAGED_SERVICE_ACCOUNTS_ENABLED": "true", + }, + extra_grafana_ini={ + "feature_toggles": { + "accessControlOnCall": "false" + } }, ) # --- GRAFANA END ---- diff --git a/dev/helm-local.yml b/dev/helm-local.yml index 33a28790..8655df43 100644 --- a/dev/helm-local.yml +++ b/dev/helm-local.yml @@ -47,6 +47,8 @@ externalGrafana: grafana: enabled: false grafana.ini: + feature_toggles: + accessControlOnCall: false server: domain: localhost:3000 root_url: "%(protocol)s://%(domain)s" @@ -71,6 +73,7 @@ grafana: value: oncallpassword env: GF_FEATURE_TOGGLES_ENABLE: externalServiceAccounts + GF_AUTH_MANAGED_SERVICE_ACCOUNTS_ENABLED: true GF_SECURITY_ADMIN_PASSWORD: oncall GF_SECURITY_ADMIN_USER: oncall GF_PLUGINS_ALLOW_LOADING_UNSIGNED_PLUGINS: grafana-oncall-app diff --git a/helm/oncall/values.yaml b/helm/oncall/values.yaml index 8ca59a26..826e0a5b 100644 --- a/helm/oncall/values.yaml +++ b/helm/oncall/values.yaml @@ -639,6 +639,9 @@ grafana: serve_from_sub_path: true feature_toggles: enable: externalServiceAccounts + accessControlOnCall: false + env: + GF_AUTH_MANAGED_SERVICE_ACCOUNTS_ENABLED: true persistence: enabled: true # Disable psp as PodSecurityPolicy is deprecated in v1.21+, unavailable in v1.25+ From 9338cff0ef36661cdd1440724d8d163ad27fdc65 Mon Sep 17 00:00:00 2001 From: Michael Derynck Date: Thu, 14 Nov 2024 09:19:30 -0700 Subject: [PATCH 03/12] fix: disable accessControlonCall for Grafana 11.3 in docker compose (#5255) # What this PR does Disable accessControlOnCall for Grafana 11.3 in docker compose Similar to https://github.com/grafana/oncall/pull/5245 ## Checklist - [ ] Unit, integration, and e2e (if applicable) tests updated - [x] Documentation added (or `pr:no public docs` PR label added if not required) - [x] Added the relevant release notes label (see labels prefixed w/ `release:`). These labels dictate how your PR will show up in the autogenerated release notes. --- docker-compose-developer.yml | 1 + docker-compose-mysql-rabbitmq.yml | 10 ++++++++++ docker-compose.yml | 10 ++++++++++ 3 files changed, 21 insertions(+) diff --git a/docker-compose-developer.yml b/docker-compose-developer.yml index b751ab1e..ee668df7 100644 --- a/docker-compose-developer.yml +++ b/docker-compose-developer.yml @@ -324,6 +324,7 @@ services: GF_PLUGINS_ALLOW_LOADING_UNSIGNED_PLUGINS: grafana-oncall-app GF_FEATURE_TOGGLES_ENABLE: externalServiceAccounts ONCALL_API_URL: http://host.docker.internal:8080 + GF_AUTH_MANAGED_SERVICE_ACCOUNTS_ENABLED: true env_file: - ./dev/.env.${DB}.dev ports: diff --git a/docker-compose-mysql-rabbitmq.yml b/docker-compose-mysql-rabbitmq.yml index f587902e..60b320e8 100644 --- a/docker-compose-mysql-rabbitmq.yml +++ b/docker-compose-mysql-rabbitmq.yml @@ -144,6 +144,7 @@ services: GF_SECURITY_ADMIN_PASSWORD: ${GRAFANA_PASSWORD:-admin} GF_PLUGINS_ALLOW_LOADING_UNSIGNED_PLUGINS: grafana-oncall-app GF_INSTALL_PLUGINS: grafana-oncall-app + GF_AUTH_MANAGED_SERVICE_ACCOUNTS_ENABLED: true deploy: resources: limits: @@ -156,7 +157,16 @@ services: condition: service_healthy profiles: - with_grafana + configs: + - source: grafana.ini + target: /etc/grafana/grafana.ini volumes: dbdata: rabbitmqdata: + +configs: + grafana.ini: + content: | + [feature_toggles] + accessControlOnCall = false diff --git a/docker-compose.yml b/docker-compose.yml index b115199f..c54c2fb3 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -94,6 +94,7 @@ services: GF_SECURITY_ADMIN_PASSWORD: ${GRAFANA_PASSWORD:-admin} GF_PLUGINS_ALLOW_LOADING_UNSIGNED_PLUGINS: grafana-oncall-app GF_INSTALL_PLUGINS: grafana-oncall-app + GF_AUTH_MANAGED_SERVICE_ACCOUNTS_ENABLED: true volumes: - grafana_data:/var/lib/grafana deploy: @@ -103,9 +104,18 @@ services: cpus: "0.5" profiles: - with_grafana + configs: + - source: grafana.ini + target: /etc/grafana/grafana.ini volumes: grafana_data: prometheus_data: oncall_data: redis_data: + +configs: + grafana.ini: + content: | + [feature_toggles] + accessControlOnCall = false From 208db9cdb7a45a35867949aa3e97a8fbd59bb02a Mon Sep 17 00:00:00 2001 From: Salvatore Giordano Date: Fri, 15 Nov 2024 11:29:00 +0100 Subject: [PATCH 04/12] remove add_stack_slug_to_message_title utility from push notification titles (#5258) # What this PR does We noticed that the backend was adding the stack name to the notification title only on Android. We thought it makes sense to add the stack name only if the user has more than 1 stack connected, but that's not doable right now since the backend doesn't know how many stacks are connected in the app. Also we took a look at the analytics for the app and basically 95% of the users have only 1 stack connected. This pr removes the stack name from the notifications title. If in the future we think it makes sense to add it conditionally based on the number of stacks we can open another pr, but given the very little amount of users with more than 1 stack I think this is not needed. ## Checklist - [x] Unit, integration, and e2e (if applicable) tests updated - [x] Documentation added (or `pr:no public docs` PR label added if not required) - [x] Added the relevant release notes label (see labels prefixed w/ `release:`). These labels dictate how your PR will show up in the autogenerated release notes. --- engine/apps/mobile_app/demo_push.py | 4 ++-- .../apps/mobile_app/tasks/going_oncall_notification.py | 9 ++------- engine/apps/mobile_app/tasks/new_alert_group.py | 9 ++------- engine/apps/mobile_app/tasks/new_shift_swap_request.py | 9 ++------- .../tests/tasks/test_going_oncall_notification.py | 3 +-- .../tests/tasks/test_new_shift_swap_request.py | 7 ++----- engine/apps/mobile_app/tests/test_demo_push.py | 7 +++---- 7 files changed, 14 insertions(+), 34 deletions(-) diff --git a/engine/apps/mobile_app/demo_push.py b/engine/apps/mobile_app/demo_push.py index 19daca5b..01194c14 100644 --- a/engine/apps/mobile_app/demo_push.py +++ b/engine/apps/mobile_app/demo_push.py @@ -8,7 +8,7 @@ from firebase_admin.messaging import APNSPayload, Aps, ApsAlert, CriticalSound, from apps.mobile_app.exceptions import DeviceNotSet from apps.mobile_app.types import FCMMessageData, MessageType, Platform -from apps.mobile_app.utils import add_stack_slug_to_message_title, construct_fcm_message, send_push_notification +from apps.mobile_app.utils import construct_fcm_message, send_push_notification from apps.user_management.models import User if typing.TYPE_CHECKING: @@ -47,7 +47,7 @@ def _get_test_escalation_fcm_message(user: User, device_to_notify: "FCMDevice", apns_sound_name = mobile_app_user_settings.get_notification_sound_name(message_type, Platform.IOS) fcm_message_data: FCMMessageData = { - "title": add_stack_slug_to_message_title(get_test_push_title(critical), user.organization), + "title": get_test_push_title(critical), "orgName": user.organization.stack_slug, # Pass user settings, so the Android app can use them to play the correct sound and volume "default_notification_sound_name": mobile_app_user_settings.get_notification_sound_name( diff --git a/engine/apps/mobile_app/tasks/going_oncall_notification.py b/engine/apps/mobile_app/tasks/going_oncall_notification.py index 214fa19d..34fd4160 100644 --- a/engine/apps/mobile_app/tasks/going_oncall_notification.py +++ b/engine/apps/mobile_app/tasks/going_oncall_notification.py @@ -12,12 +12,7 @@ from django.utils import timezone from firebase_admin.messaging import APNSPayload, Aps, ApsAlert, CriticalSound, Message from apps.mobile_app.types import FCMMessageData, MessageType, Platform -from apps.mobile_app.utils import ( - MAX_RETRIES, - add_stack_slug_to_message_title, - construct_fcm_message, - send_push_notification, -) +from apps.mobile_app.utils import MAX_RETRIES, construct_fcm_message, send_push_notification from apps.schedules.models.on_call_schedule import OnCallSchedule, ScheduleEvent from apps.user_management.models import User from common.cache import ensure_cache_key_allocates_to_the_same_hash_slot @@ -82,7 +77,7 @@ def _get_fcm_message( notification_subtitle = _get_notification_subtitle(schedule, schedule_event, mobile_app_user_settings) data: FCMMessageData = { - "title": add_stack_slug_to_message_title(notification_title, user.organization), + "title": notification_title, "subtitle": notification_subtitle, "orgName": user.organization.stack_slug, "info_notification_sound_name": mobile_app_user_settings.get_notification_sound_name( diff --git a/engine/apps/mobile_app/tasks/new_alert_group.py b/engine/apps/mobile_app/tasks/new_alert_group.py index e33e9111..2b759f5f 100644 --- a/engine/apps/mobile_app/tasks/new_alert_group.py +++ b/engine/apps/mobile_app/tasks/new_alert_group.py @@ -8,12 +8,7 @@ from firebase_admin.messaging import APNSPayload, Aps, ApsAlert, CriticalSound, from apps.alerts.models import AlertGroup from apps.mobile_app.alert_rendering import get_push_notification_subtitle, get_push_notification_title from apps.mobile_app.types import FCMMessageData, MessageType, Platform -from apps.mobile_app.utils import ( - MAX_RETRIES, - add_stack_slug_to_message_title, - construct_fcm_message, - send_push_notification, -) +from apps.mobile_app.utils import MAX_RETRIES, construct_fcm_message, send_push_notification from apps.user_management.models import User from common.custom_celery_tasks import shared_dedicated_queue_retry_task @@ -46,7 +41,7 @@ def _get_fcm_message(alert_group: AlertGroup, user: User, device_to_notify: "FCM apns_sound_name = mobile_app_user_settings.get_notification_sound_name(message_type, Platform.IOS) fcm_message_data: FCMMessageData = { - "title": add_stack_slug_to_message_title(alert_title, alert_group.channel.organization), + "title": alert_title, "subtitle": alert_subtitle, "orgId": alert_group.channel.organization.public_primary_key, "orgName": alert_group.channel.organization.stack_slug, diff --git a/engine/apps/mobile_app/tasks/new_shift_swap_request.py b/engine/apps/mobile_app/tasks/new_shift_swap_request.py index a6d49c8b..3ab71674 100644 --- a/engine/apps/mobile_app/tasks/new_shift_swap_request.py +++ b/engine/apps/mobile_app/tasks/new_shift_swap_request.py @@ -10,12 +10,7 @@ from django.utils import timezone from firebase_admin.messaging import APNSPayload, Aps, ApsAlert, CriticalSound, Message from apps.mobile_app.types import FCMMessageData, MessageType, Platform -from apps.mobile_app.utils import ( - MAX_RETRIES, - add_stack_slug_to_message_title, - construct_fcm_message, - send_push_notification, -) +from apps.mobile_app.utils import MAX_RETRIES, construct_fcm_message, send_push_notification from apps.schedules.models import ShiftSwapRequest from apps.user_management.models import User from common.custom_celery_tasks import shared_dedicated_queue_retry_task @@ -121,7 +116,7 @@ def _get_fcm_message( route = f"/schedules/{shift_swap_request.schedule.public_primary_key}/ssrs/{shift_swap_request.public_primary_key}" data: FCMMessageData = { - "title": add_stack_slug_to_message_title(notification_title, user.organization), + "title": notification_title, "subtitle": notification_subtitle, "orgName": user.organization.stack_slug, "route": route, diff --git a/engine/apps/mobile_app/tests/tasks/test_going_oncall_notification.py b/engine/apps/mobile_app/tests/tasks/test_going_oncall_notification.py index 2541d507..051e4ffb 100644 --- a/engine/apps/mobile_app/tests/tasks/test_going_oncall_notification.py +++ b/engine/apps/mobile_app/tests/tasks/test_going_oncall_notification.py @@ -18,7 +18,6 @@ from apps.mobile_app.tasks.going_oncall_notification import ( conditionally_send_going_oncall_push_notifications_for_schedule, ) from apps.mobile_app.types import MessageType, Platform -from apps.mobile_app.utils import add_stack_slug_to_message_title from apps.schedules.models import OnCallScheduleCalendar, OnCallScheduleICal, OnCallScheduleWeb from apps.schedules.models.on_call_schedule import ScheduleEvent @@ -228,7 +227,7 @@ def test_get_fcm_message( maus = MobileAppUserSettings.objects.create(user=user, time_zone=user_tz) data = { - "title": add_stack_slug_to_message_title(mock_notification_title, organization), + "title": mock_notification_title, "subtitle": mock_notification_subtitle, "orgName": organization.stack_slug, "info_notification_sound_name": maus.get_notification_sound_name(MessageType.INFO, Platform.ANDROID), diff --git a/engine/apps/mobile_app/tests/tasks/test_new_shift_swap_request.py b/engine/apps/mobile_app/tests/tasks/test_new_shift_swap_request.py index 452b9895..f77674f8 100644 --- a/engine/apps/mobile_app/tests/tasks/test_new_shift_swap_request.py +++ b/engine/apps/mobile_app/tests/tasks/test_new_shift_swap_request.py @@ -19,7 +19,6 @@ from apps.mobile_app.tasks.new_shift_swap_request import ( notify_shift_swap_requests, notify_user_about_shift_swap_request, ) -from apps.mobile_app.utils import add_stack_slug_to_message_title from apps.schedules.models import CustomOnCallShift, OnCallScheduleWeb, ShiftSwapRequest from apps.user_management.models import User from apps.user_management.models.user import default_working_hours @@ -288,7 +287,7 @@ def test_notify_user_about_shift_swap_request( message: Message = mock_send_push_notification.call_args.args[1] assert message.data["type"] == "oncall.info" - assert message.data["title"] == add_stack_slug_to_message_title("New shift swap request", organization) + assert message.data["title"] == "New shift swap request" assert message.data["subtitle"] == "John Doe, Test Schedule" assert ( message.data["route"] @@ -487,9 +486,7 @@ def test_notify_beneficiary_about_taken_shift_swap_request( message: Message = mock_send_push_notification.call_args.args[1] assert message.data["type"] == "oncall.info" - assert message.data["title"] == add_stack_slug_to_message_title( - "Your shift swap request has been taken", organization - ) + assert message.data["title"] == "Your shift swap request has been taken" assert message.data["subtitle"] == schedule_name assert ( message.data["route"] diff --git a/engine/apps/mobile_app/tests/test_demo_push.py b/engine/apps/mobile_app/tests/test_demo_push.py index 769691f7..abf5f6eb 100644 --- a/engine/apps/mobile_app/tests/test_demo_push.py +++ b/engine/apps/mobile_app/tests/test_demo_push.py @@ -2,7 +2,6 @@ import pytest from apps.mobile_app.demo_push import _get_test_escalation_fcm_message, get_test_push_title from apps.mobile_app.models import FCMDevice, MobileAppUserSettings -from apps.mobile_app.utils import add_stack_slug_to_message_title @pytest.mark.django_db @@ -34,7 +33,7 @@ def test_test_escalation_fcm_message_user_settings( # Check expected test push content assert message.apns.payload.aps.badge is None assert message.apns.payload.aps.alert.title == get_test_push_title(critical=False) - assert message.data["title"] == add_stack_slug_to_message_title(get_test_push_title(critical=False), organization) + assert message.data["title"] == get_test_push_title(critical=False) assert message.data["type"] == "oncall.message" @@ -68,7 +67,7 @@ def test_escalation_fcm_message_user_settings_critical( # Check expected test push content assert message.apns.payload.aps.badge is None assert message.apns.payload.aps.alert.title == get_test_push_title(critical=True) - assert message.data["title"] == add_stack_slug_to_message_title(get_test_push_title(critical=True), organization) + assert message.data["title"] == get_test_push_title(critical=True) assert message.data["type"] == "oncall.critical_message" @@ -94,4 +93,4 @@ def test_escalation_fcm_message_user_settings_critical_override_dnd_disabled( # Check expected test push content assert message.apns.payload.aps.badge is None assert message.apns.payload.aps.alert.title == get_test_push_title(critical=True) - assert message.data["title"] == add_stack_slug_to_message_title(get_test_push_title(critical=True), organization) + assert message.data["title"] == get_test_push_title(critical=True) From 10dc454c7b61a1bc98d6313a56e654c090c0abcc Mon Sep 17 00:00:00 2001 From: Vadim Stepanov Date: Mon, 18 Nov 2024 09:44:32 +0000 Subject: [PATCH 05/12] Inbound email improvements (#5259) # What this PR does * Allows to use multiple inbound email ESPs at the same time by setting the `INBOUND_EMAIL_ESP` env variable to `amazon_ses,mailgun` for example * Adds a new ESP `amazon_ses_validated` that performs SNS message vaildation (`django-anymail` doesn't implement it: [comment](https://github.com/anymail/django-anymail/blob/35383c7140289e82b39ada5980077898aa07d18d/anymail/webhooks/amazon_ses.py#L107-L108)) ## Which issue(s) this PR closes Related to https://github.com/grafana/oncall-private/issues/2905 ## Checklist - [x] Unit, integration, and e2e (if applicable) tests updated - [x] Documentation added (or `pr:no public docs` PR label added if not required) - [x] Added the relevant release notes label (see labels prefixed w/ `release:`). These labels dictate how your PR will show up in the autogenerated release notes. --- engine/apps/email/inbound.py | 83 ++-- engine/apps/email/tests/test_inbound_email.py | 450 ++++++++++++++++++ .../apps/email/validate_amazon_sns_message.py | 99 ++++ engine/settings/base.py | 1 + 4 files changed, 600 insertions(+), 33 deletions(-) create mode 100644 engine/apps/email/validate_amazon_sns_message.py diff --git a/engine/apps/email/inbound.py b/engine/apps/email/inbound.py index 1780f00c..185234c5 100644 --- a/engine/apps/email/inbound.py +++ b/engine/apps/email/inbound.py @@ -1,27 +1,42 @@ import logging +from functools import cached_property from typing import Optional, TypedDict -from anymail.exceptions import AnymailInvalidAddress, AnymailWebhookValidationFailure +from anymail.exceptions import AnymailAPIError, AnymailInvalidAddress, AnymailWebhookValidationFailure from anymail.inbound import AnymailInboundMessage from anymail.signals import AnymailInboundEvent from anymail.webhooks import amazon_ses, mailgun, mailjet, mandrill, postal, postmark, sendgrid, sparkpost from django.http import HttpResponse, HttpResponseNotAllowed from django.utils import timezone from rest_framework import status -from rest_framework.request import Request from rest_framework.response import Response from rest_framework.views import APIView from apps.base.utils import live_settings +from apps.email.validate_amazon_sns_message import validate_amazon_sns_message from apps.integrations.mixins import AlertChannelDefiningMixin from apps.integrations.tasks import create_alert logger = logging.getLogger(__name__) +class AmazonSESValidatedInboundWebhookView(amazon_ses.AmazonSESInboundWebhookView): + # disable "Your Anymail webhooks are insecure and open to anyone on the web." warning + warn_if_no_basic_auth = False + + def validate_request(self, request): + """Add SNS message validation to Amazon SES inbound webhook view, which is not implemented in Anymail.""" + + super().validate_request(request) + sns_message = self._parse_sns_message(request) + if not validate_amazon_sns_message(sns_message): + raise AnymailWebhookValidationFailure("SNS message validation failed") + + # {: (, ), ...} INBOUND_EMAIL_ESP_OPTIONS = { "amazon_ses": (amazon_ses.AmazonSESInboundWebhookView, None), + "amazon_ses_validated": (AmazonSESValidatedInboundWebhookView, None), "mailgun": (mailgun.MailgunInboundWebhookView, "webhook_signing_key"), "mailjet": (mailjet.MailjetInboundWebhookView, "webhook_secret"), "mandrill": (mandrill.MandrillCombinedWebhookView, "webhook_key"), @@ -62,38 +77,33 @@ class InboundEmailWebhookView(AlertChannelDefiningMixin, APIView): return super().dispatch(request, alert_channel_key=integration_token) def post(self, request): - timestamp = timezone.now().isoformat() - for message in self.get_messages_from_esp_request(request): - payload = self.get_alert_payload_from_email_message(message) - create_alert.delay( - title=payload["subject"], - message=payload["message"], - alert_receive_channel_pk=request.alert_receive_channel.pk, - image_url=None, - link_to_upstream_details=None, - integration_unique_data=None, - raw_request_data=payload, - received_at=timestamp, - ) - + payload = self.get_alert_payload_from_email_message(self.message) + create_alert.delay( + title=payload["subject"], + message=payload["message"], + alert_receive_channel_pk=request.alert_receive_channel.pk, + image_url=None, + link_to_upstream_details=None, + integration_unique_data=None, + raw_request_data=payload, + received_at=timezone.now().isoformat(), + ) return Response("OK", status=status.HTTP_200_OK) def get_integration_token_from_request(self, request) -> Optional[str]: - messages = self.get_messages_from_esp_request(request) - if not messages: + if not self.message: return None - message = messages[0] # First try envelope_recipient field. # According to AnymailInboundMessage it's provided not by all ESPs. - if message.envelope_recipient: - recipients = message.envelope_recipient.split(",") + if self.message.envelope_recipient: + recipients = self.message.envelope_recipient.split(",") for recipient in recipients: # if there is more than one recipient, the first matching the expected domain will be used try: token, domain = recipient.strip().split("@") except ValueError: logger.error( - f"get_integration_token_from_request: envelope_recipient field has unexpected format: {message.envelope_recipient}" + f"get_integration_token_from_request: envelope_recipient field has unexpected format: {self.message.envelope_recipient}" ) continue if domain == live_settings.INBOUND_EMAIL_DOMAIN: @@ -113,20 +123,27 @@ class InboundEmailWebhookView(AlertChannelDefiningMixin, APIView): # return cc.address.split("@")[0] return None - def get_messages_from_esp_request(self, request: Request) -> list[AnymailInboundMessage]: - view_class, secret_name = INBOUND_EMAIL_ESP_OPTIONS[live_settings.INBOUND_EMAIL_ESP] + @cached_property + def message(self) -> AnymailInboundMessage | None: + esps = live_settings.INBOUND_EMAIL_ESP.split(",") + for esp in esps: + view_class, secret_name = INBOUND_EMAIL_ESP_OPTIONS[esp] - kwargs = {secret_name: live_settings.INBOUND_EMAIL_WEBHOOK_SECRET} if secret_name else {} - view = view_class(**kwargs) + kwargs = {secret_name: live_settings.INBOUND_EMAIL_WEBHOOK_SECRET} if secret_name else {} + view = view_class(**kwargs) - try: - view.run_validators(request) - events = view.parse_events(request) - except AnymailWebhookValidationFailure as e: - logger.info(f"get_messages_from_esp_request: inbound email webhook validation failed: {e}") - return [] + try: + view.run_validators(self.request) + events = view.parse_events(self.request) + except (AnymailWebhookValidationFailure, AnymailAPIError) as e: + logger.info(f"inbound email webhook validation failed for ESP {esp}: {e}") + continue - return [event.message for event in events if isinstance(event, AnymailInboundEvent)] + messages = [event.message for event in events if isinstance(event, AnymailInboundEvent)] + if messages: + return messages[0] + + return None def check_inbound_email_settings_set(self): """ diff --git a/engine/apps/email/tests/test_inbound_email.py b/engine/apps/email/tests/test_inbound_email.py index 81a76e92..35bccd10 100644 --- a/engine/apps/email/tests/test_inbound_email.py +++ b/engine/apps/email/tests/test_inbound_email.py @@ -1,13 +1,295 @@ +import datetime +import hashlib +import hmac import json +from base64 import b64encode from textwrap import dedent +from unittest.mock import ANY, Mock, patch import pytest from anymail.inbound import AnymailInboundMessage +from cryptography import x509 +from cryptography.hazmat.primitives import hashes, serialization +from cryptography.hazmat.primitives.asymmetric import padding, rsa +from cryptography.x509 import CertificateBuilder, NameOID +from django.conf import settings from django.urls import reverse from rest_framework import status from rest_framework.test import APIClient +from apps.alerts.models import AlertReceiveChannel from apps.email.inbound import InboundEmailWebhookView +from apps.integrations.tasks import create_alert + +PRIVATE_KEY = rsa.generate_private_key( + public_exponent=65537, + key_size=2048, +) +ISSUER_NAME = x509.Name( + [ + x509.NameAttribute(NameOID.COUNTRY_NAME, "US"), + x509.NameAttribute(NameOID.STATE_OR_PROVINCE_NAME, "Test"), + x509.NameAttribute(NameOID.LOCALITY_NAME, "Test"), + x509.NameAttribute(NameOID.ORGANIZATION_NAME, "Amazon"), + x509.NameAttribute(NameOID.COMMON_NAME, "Test"), + ] +) +CERTIFICATE = ( + CertificateBuilder() + .subject_name(ISSUER_NAME) + .issuer_name(ISSUER_NAME) + .public_key(PRIVATE_KEY.public_key()) + .serial_number(x509.random_serial_number()) + .not_valid_before(datetime.datetime.now() - datetime.timedelta(days=1)) + .not_valid_after(datetime.datetime.now() + datetime.timedelta(days=10)) + .sign(PRIVATE_KEY, hashes.SHA256()) + .public_bytes(serialization.Encoding.PEM) +) +AMAZON_SNS_TOPIC_ARN = "arn:aws:sns:us-east-2:123456789012:test" +SIGNING_CERT_URL = "https://sns.us-east-2.amazonaws.com/SimpleNotificationService-example.pem" + + +def _sns_inbound_email_payload_and_headers(sender_email, to_email, subject, message): + content = ( + f"From: Sender Name <{sender_email}>\n" + f"To: {to_email}\n" + f"Subject: {subject}\n" + "Date: Tue, 5 Nov 2024 16:05:39 +0000\n" + "Message-ID: \n\n" + f"{message}\r\n" + ) + + message = { + "notificationType": "Received", + "mail": { + "timestamp": "2024-11-05T16:05:52.387Z", + "source": sender_email, + "messageId": "example-message-id-5678", + "destination": [to_email], + "headersTruncated": False, + "headers": [ + {"name": "Return-Path", "value": f"<{sender_email}>"}, + { + "name": "Received", + "value": ( + f"from mail.example.com (mail.example.com [203.0.113.1]) " + f"by inbound-smtp.us-east-2.amazonaws.com with SMTP id example-id " + f"for {to_email}; Tue, 05 Nov 2024 16:05:52 +0000 (UTC)" + ), + }, + {"name": "X-SES-Spam-Verdict", "value": "PASS"}, + {"name": "X-SES-Virus-Verdict", "value": "PASS"}, + { + "name": "Received-SPF", + "value": ( + "pass (spfCheck: domain of example.com designates 203.0.113.1 as permitted sender) " + f"client-ip=203.0.113.1; envelope-from={sender_email}; helo=mail.example.com;" + ), + }, + { + "name": "Authentication-Results", + "value": ( + "amazonses.com; spf=pass (spfCheck: domain of example.com designates 203.0.113.1 as permitted sender) " + f"client-ip=203.0.113.1; envelope-from={sender_email}; helo=mail.example.com; " + "dkim=pass header.i=@example.com; dmarc=pass header.from=example.com;" + ), + }, + {"name": "X-SES-RECEIPT", "value": "example-receipt-data"}, + {"name": "X-SES-DKIM-SIGNATURE", "value": "example-dkim-signature"}, + { + "name": "Received", + "value": ( + f"by mail.example.com with SMTP id example-id for <{to_email}>; " + "Tue, 05 Nov 2024 08:05:52 -0800 (PST)" + ), + }, + { + "name": "DKIM-Signature", + "value": ( + "v=1; a=rsa-sha256; c=relaxed/relaxed; d=example.com; s=default; t=1234567890; " + "bh=examplehash; h=From:To:Subject:Date:Message-ID; b=example-signature" + ), + }, + {"name": "X-Google-DKIM-Signature", "value": "example-google-dkim-signature"}, + {"name": "X-Gm-Message-State", "value": "example-message-state"}, + {"name": "X-Google-Smtp-Source", "value": "example-smtp-source"}, + { + "name": "X-Received", + "value": "by 2002:a17:example with SMTP id example-id; Tue, 05 Nov 2024 08:05:50 -0800 (PST)", + }, + {"name": "MIME-Version", "value": "1.0"}, + {"name": "From", "value": f"Sender Name <{sender_email}>"}, + {"name": "Date", "value": "Tue, 5 Nov 2024 16:05:39 +0000"}, + {"name": "Message-ID", "value": ""}, + {"name": "Subject", "value": subject}, + {"name": "To", "value": to_email}, + { + "name": "Content-Type", + "value": 'multipart/alternative; boundary="00000000000036b9f706262c9312"', + }, + ], + "commonHeaders": { + "returnPath": sender_email, + "from": [f"Sender Name <{sender_email}>"], + "date": "Tue, 5 Nov 2024 16:05:39 +0000", + "to": [to_email], + "messageId": "", + "subject": subject, + }, + }, + "receipt": { + "timestamp": "2024-11-05T16:05:52.387Z", + "processingTimeMillis": 638, + "recipients": [to_email], + "spamVerdict": {"status": "PASS"}, + "virusVerdict": {"status": "PASS"}, + "spfVerdict": {"status": "PASS"}, + "dkimVerdict": {"status": "PASS"}, + "dmarcVerdict": {"status": "PASS"}, + "action": { + "type": "SNS", + "topicArn": "arn:aws:sns:us-east-2:123456789012:test", + "encoding": "BASE64", + }, + }, + "content": b64encode(content.encode()).decode(), + } + + payload = { + "Type": "Notification", + "MessageId": "example-message-id-1234", + "TopicArn": AMAZON_SNS_TOPIC_ARN, + "Subject": "Amazon SES Email Receipt Notification", + "Message": json.dumps(message), + "Timestamp": "2024-11-05T16:05:53.041Z", + "SignatureVersion": "1", + "SigningCertURL": SIGNING_CERT_URL, + "UnsubscribeURL": ( + "https://sns.us-east-2.amazonaws.com/?Action=Unsubscribe&SubscriptionArn=" + "arn:aws:sns:us-east-2:123456789012:test:example-subscription-id" + ), + } + # Sign the payload + canonical_message = "".join( + f"{key}\n{payload[key]}\n" for key in ("Message", "MessageId", "Subject", "Timestamp", "TopicArn", "Type") + ) + signature = PRIVATE_KEY.sign( + canonical_message.encode(), + padding.PKCS1v15(), + hashes.SHA1(), + ) + payload["Signature"] = b64encode(signature).decode() + + headers = { + "X-Amz-Sns-Message-Type": "Notification", + "X-Amz-Sns-Message-Id": "example-message-id-1234", + } + return payload, headers + + +def _mailgun_inbound_email_payload(sender_email, to_email, subject, message): + timestamp, token = "1731341416", "example-token" + signature = hmac.new( + key=settings.INBOUND_EMAIL_WEBHOOK_SECRET.encode("ascii"), + msg="{}{}".format(timestamp, token).encode("ascii"), + digestmod=hashlib.sha256, + ).hexdigest() + + return { + "Content-Type": 'multipart/alternative; boundary="000000000000267130626a556e5"', + "Date": "Mon, 11 Nov 2024 16:10:03 +0000", + "Dkim-Signature": ( + "v=1; a=rsa-sha256; c=relaxed/relaxed; d=example.com; s=default; " + "t=1731341415; x=1731946215; darn=example.com; " + "h=to:subject:message-id:date:from:mime-version:from:to:cc:subject " + ":date:message-id:reply-to; bh=examplebh; b=exampleb" + ), + "From": f"Sender Name <{sender_email}>", + "Message-Id": "", + "Mime-Version": "1.0", + "Received": ( + f"by mail.example.com with SMTP id example-id for <{to_email}>; " "Mon, 11 Nov 2024 08:10:15 -0800 (PST)" + ), + "Subject": subject, + "To": to_email, + "X-Envelope-From": sender_email, + "X-Gm-Message-State": "example-message-state", + "X-Google-Dkim-Signature": ( + "v=1; a=rsa-sha256; c=relaxed/relaxed; d=1e100.net; s=20230601; " + "t=1731341415; x=1731946215; " + "h=to:subject:message-id:date:from:mime-version:x-gm-message-state " + ":from:to:cc:subject:date:message-id:reply-to; bh=examplebh; b=exampleb" + ), + "X-Google-Smtp-Source": "example-smtp-source", + "X-Mailgun-Incoming": "Yes", + "X-Received": "by 2002:a17:example with SMTP id example-id; Mon, 11 Nov 2024 08:10:14 -0800 (PST)", + "body-html": f'
{message}
\r\n', + "body-plain": f"{message}\r\n", + "from": f"Sender Name <{sender_email}>", + "message-headers": json.dumps( + [ + ["X-Mailgun-Incoming", "Yes"], + ["X-Envelope-From", sender_email], + [ + "Received", + ( + "from mail.example.com (mail.example.com [203.0.113.1]) " + "by example.com with SMTP id example-id; " + "Mon, 11 Nov 2024 16:10:15 GMT" + ), + ], + [ + "Received", + ( + f"by mail.example.com with SMTP id example-id for <{to_email}>; " + "Mon, 11 Nov 2024 08:10:15 -0800 (PST)" + ), + ], + [ + "Dkim-Signature", + ( + "v=1; a=rsa-sha256; c=relaxed/relaxed; d=example.com; s=default; " + "t=1731341415; x=1731946215; darn=example.com; " + "h=to:subject:message-id:date:from:mime-version:from:to:cc:subject " + ":date:message-id:reply-to; bh=examplebh; b=exampleb" + ), + ], + [ + "X-Google-Dkim-Signature", + ( + "v=1; a=rsa-sha256; c=relaxed/relaxed; d=1e100.net; s=20230601; " + "t=1731341415; x=1731946215; " + "h=to:subject:message-id:date:from:mime-version:x-gm-message-state " + ":from:to:cc:subject:date:message-id:reply-to; bh=examplebh; b=exampleb" + ), + ], + ["X-Gm-Message-State", "example-message-state"], + ["X-Google-Smtp-Source", "example-smtp-source"], + [ + "X-Received", + "by 2002:a17:example with SMTP id example-id; Mon, 11 Nov 2024 08:10:14 -0800 (PST)", + ], + ["Mime-Version", "1.0"], + ["From", f"Sender Name <{sender_email}>"], + ["Date", "Mon, 11 Nov 2024 16:10:03 +0000"], + ["Message-Id", ""], + ["Subject", subject], + ["To", to_email], + [ + "Content-Type", + 'multipart/alternative; boundary="000000000000267130626a556e5"', + ], + ] + ), + "recipient": to_email, + "sender": sender_email, + "signature": signature, + "stripped-html": f'
{message}
\n', + "stripped-text": f"{message}\n", + "subject": subject, + "timestamp": timestamp, + "token": token, + } @pytest.mark.parametrize( @@ -141,3 +423,171 @@ def test_get_sender_from_email_message(sender_value, expected_result): view = InboundEmailWebhookView() result = view.get_sender_from_email_message(email) assert result == expected_result + + +@patch.object(create_alert, "delay") +@pytest.mark.django_db +def test_amazon_ses_pass(create_alert_mock, settings, make_organization, make_alert_receive_channel): + settings.INBOUND_EMAIL_ESP = "amazon_ses,mailgun" + settings.INBOUND_EMAIL_DOMAIN = "inbound.example.com" + settings.INBOUND_EMAIL_WEBHOOK_SECRET = "secret" + + organization = make_organization() + alert_receive_channel = make_alert_receive_channel( + organization, + integration=AlertReceiveChannel.INTEGRATION_INBOUND_EMAIL, + token="test-token", + ) + + sender_email = "sender@example.com" + to_email = "test-token@inbound.example.com" + subject = "Test email" + message = "This is a test email message body." + sns_payload, sns_headers = _sns_inbound_email_payload_and_headers( + sender_email=sender_email, + to_email=to_email, + subject=subject, + message=message, + ) + + client = APIClient() + response = client.post( + reverse("integrations:inbound_email_webhook"), + data=sns_payload, + headers=sns_headers, + format="json", + ) + + assert response.status_code == status.HTTP_200_OK + create_alert_mock.assert_called_once_with( + title=subject, + message=message, + alert_receive_channel_pk=alert_receive_channel.pk, + image_url=None, + link_to_upstream_details=None, + integration_unique_data=None, + raw_request_data={ + "subject": subject, + "message": message, + "sender": sender_email, + }, + received_at=ANY, + ) + + +@patch("requests.get", return_value=Mock(content=CERTIFICATE)) +@patch.object(create_alert, "delay") +@pytest.mark.django_db +def test_amazon_ses_validated_pass( + mock_create_alert, mock_requests_get, settings, make_organization, make_alert_receive_channel +): + settings.INBOUND_EMAIL_ESP = "amazon_ses_validated,mailgun" + settings.INBOUND_EMAIL_DOMAIN = "inbound.example.com" + settings.INBOUND_EMAIL_WEBHOOK_SECRET = "secret" + settings.INBOUND_EMAIL_AMAZON_SNS_TOPIC_ARN = AMAZON_SNS_TOPIC_ARN + + organization = make_organization() + alert_receive_channel = make_alert_receive_channel( + organization, + integration=AlertReceiveChannel.INTEGRATION_INBOUND_EMAIL, + token="test-token", + ) + + sender_email = "sender@example.com" + to_email = "test-token@inbound.example.com" + subject = "Test email" + message = "This is a test email message body." + sns_payload, sns_headers = _sns_inbound_email_payload_and_headers( + sender_email=sender_email, + to_email=to_email, + subject=subject, + message=message, + ) + + client = APIClient() + response = client.post( + reverse("integrations:inbound_email_webhook"), + data=sns_payload, + headers=sns_headers, + format="json", + ) + + assert response.status_code == status.HTTP_200_OK + mock_create_alert.assert_called_once_with( + title=subject, + message=message, + alert_receive_channel_pk=alert_receive_channel.pk, + image_url=None, + link_to_upstream_details=None, + integration_unique_data=None, + raw_request_data={ + "subject": subject, + "message": message, + "sender": sender_email, + }, + received_at=ANY, + ) + + mock_requests_get.assert_called_once_with(SIGNING_CERT_URL, timeout=5) + + +@patch.object(create_alert, "delay") +@pytest.mark.django_db +def test_mailgun_pass(create_alert_mock, settings, make_organization, make_alert_receive_channel): + settings.INBOUND_EMAIL_ESP = "amazon_ses,mailgun" + settings.INBOUND_EMAIL_DOMAIN = "inbound.example.com" + settings.INBOUND_EMAIL_WEBHOOK_SECRET = "secret" + + organization = make_organization() + alert_receive_channel = make_alert_receive_channel( + organization, + integration=AlertReceiveChannel.INTEGRATION_INBOUND_EMAIL, + token="test-token", + ) + + sender_email = "sender@example.com" + to_email = "test-token@inbound.example.com" + subject = "Test email" + message = "This is a test email message body." + + mailgun_payload = _mailgun_inbound_email_payload( + sender_email=sender_email, + to_email=to_email, + subject=subject, + message=message, + ) + + client = APIClient() + response = client.post( + reverse("integrations:inbound_email_webhook"), + data=mailgun_payload, + format="multipart", + ) + + assert response.status_code == status.HTTP_200_OK + create_alert_mock.assert_called_once_with( + title=subject, + message=message, + alert_receive_channel_pk=alert_receive_channel.pk, + image_url=None, + link_to_upstream_details=None, + integration_unique_data=None, + raw_request_data={ + "subject": subject, + "message": message, + "sender": sender_email, + }, + received_at=ANY, + ) + + +@pytest.mark.django_db +def test_multiple_esps_fail(settings): + settings.INBOUND_EMAIL_ESP = "amazon_ses,mailgun" + settings.INBOUND_EMAIL_DOMAIN = "example.com" + settings.INBOUND_EMAIL_WEBHOOK_SECRET = "secret" + + client = APIClient() + response = client.post(reverse("integrations:inbound_email_webhook"), data={}) + + assert response.status_code == status.HTTP_400_BAD_REQUEST diff --git a/engine/apps/email/validate_amazon_sns_message.py b/engine/apps/email/validate_amazon_sns_message.py new file mode 100644 index 00000000..f3d2aec4 --- /dev/null +++ b/engine/apps/email/validate_amazon_sns_message.py @@ -0,0 +1,99 @@ +import logging +import re +from base64 import b64decode +from urllib.parse import urlparse + +import requests +from cryptography.exceptions import InvalidSignature +from cryptography.hazmat.primitives.asymmetric.padding import PKCS1v15 +from cryptography.hazmat.primitives.hashes import SHA1, SHA256 +from cryptography.x509 import NameOID, load_pem_x509_certificate +from django.conf import settings + +logger = logging.getLogger(__name__) + +HOST_PATTERN = re.compile(r"^sns\.[a-zA-Z0-9\-]{3,}\.amazonaws\.com(\.cn)?$") +REQUIRED_KEYS = ( + "Message", + "MessageId", + "Timestamp", + "TopicArn", + "Type", + "Signature", + "SigningCertURL", + "SignatureVersion", +) +SIGNING_KEYS_NOTIFICATION = ("Message", "MessageId", "Subject", "Timestamp", "TopicArn", "Type") +SIGNING_KEYS_SUBSCRIPTION = ("Message", "MessageId", "SubscribeURL", "Timestamp", "Token", "TopicArn", "Type") + + +def validate_amazon_sns_message(message: dict) -> bool: + """ + Validate an AWS SNS message. Based on: + - https://docs.aws.amazon.com/sns/latest/dg/sns-verify-signature-of-message.html + - https://github.com/aws/aws-js-sns-message-validator/blob/a6ba4d646dc60912653357660301f3b25f94d686/index.js + - https://github.com/aws/aws-php-sns-message-validator/blob/3cee0fc1aee5538e1bd677654b09fad811061d0b/src/MessageValidator.php + """ + + # Check if the message has all the required keys + if not all(key in message for key in REQUIRED_KEYS): + logger.warning("Missing required keys in the message, got: %s", message.keys()) + return False + + # Check TopicArn + if message["TopicArn"] != settings.INBOUND_EMAIL_AMAZON_SNS_TOPIC_ARN: + logger.warning("Invalid TopicArn: %s", message["TopicArn"]) + return False + + # Construct the canonical message + if message["Type"] == "Notification": + signing_keys = SIGNING_KEYS_NOTIFICATION + elif message["Type"] in ("SubscriptionConfirmation", "UnsubscribeConfirmation"): + signing_keys = SIGNING_KEYS_SUBSCRIPTION + else: + logger.warning("Invalid message type: %s", message["Type"]) + return False + canonical_message = "".join(f"{key}\n{message[key]}\n" for key in signing_keys if key in message).encode() + + # Check if SigningCertURL is a valid SNS URL + signing_cert_url = message["SigningCertURL"] + parsed_url = urlparse(signing_cert_url) + if ( + parsed_url.scheme != "https" + or not HOST_PATTERN.match(parsed_url.netloc) + or not parsed_url.path.endswith(".pem") + ): + logger.warning("Invalid SigningCertURL: %s", signing_cert_url) + return False + + # Fetch the certificate + try: + response = requests.get(signing_cert_url, timeout=5) + response.raise_for_status() + certificate_bytes = response.content + except requests.RequestException as e: + logger.warning("Failed to fetch the certificate from %s: %s", signing_cert_url, e) + return False + + # Verify the certificate issuer + certificate = load_pem_x509_certificate(certificate_bytes) + if certificate.issuer.get_attributes_for_oid(NameOID.ORGANIZATION_NAME)[0].value != "Amazon": + logger.warning("Invalid certificate issuer: %s", certificate.issuer) + return False + + # Verify the signature + signature = b64decode(message["Signature"]) + if message["SignatureVersion"] == "1": + hash_algorithm = SHA1() + elif message["SignatureVersion"] == "2": + hash_algorithm = SHA256() + else: + logger.warning("Invalid SignatureVersion: %s", message["SignatureVersion"]) + return False + try: + certificate.public_key().verify(signature, canonical_message, PKCS1v15(), hash_algorithm) + except InvalidSignature: + logger.warning("Invalid signature") + return False + + return True diff --git a/engine/settings/base.py b/engine/settings/base.py index 5b6eba8f..25ef7dc1 100644 --- a/engine/settings/base.py +++ b/engine/settings/base.py @@ -867,6 +867,7 @@ if FEATURE_EMAIL_INTEGRATION_ENABLED: INBOUND_EMAIL_ESP = os.getenv("INBOUND_EMAIL_ESP") INBOUND_EMAIL_DOMAIN = os.getenv("INBOUND_EMAIL_DOMAIN") INBOUND_EMAIL_WEBHOOK_SECRET = os.getenv("INBOUND_EMAIL_WEBHOOK_SECRET") +INBOUND_EMAIL_AMAZON_SNS_TOPIC_ARN = os.getenv("INBOUND_EMAIL_AMAZON_SNS_TOPIC_ARN") INSTALLED_ONCALL_INTEGRATIONS = [ # Featured From 5fbc3d058ca8ef0febe15e688307a1cab419e7fe Mon Sep 17 00:00:00 2001 From: Vadim Stepanov Date: Mon, 18 Nov 2024 12:09:05 +0000 Subject: [PATCH 06/12] Inbound email improvements (continued) (#5263) follow up to https://github.com/grafana/oncall/pull/5259: * Auto confirm SNS subsriptions for ESP `amazon_ses_validated` * Add a couple of tests for SNS message validation (try with wrong SNS topic ARN, try with wrong singature) --- engine/apps/email/inbound.py | 11 +- engine/apps/email/tests/test_inbound_email.py | 148 +++++++++++++----- 2 files changed, 115 insertions(+), 44 deletions(-) diff --git a/engine/apps/email/inbound.py b/engine/apps/email/inbound.py index 185234c5..6c86e194 100644 --- a/engine/apps/email/inbound.py +++ b/engine/apps/email/inbound.py @@ -2,6 +2,7 @@ import logging from functools import cached_property from typing import Optional, TypedDict +import requests from anymail.exceptions import AnymailAPIError, AnymailInvalidAddress, AnymailWebhookValidationFailure from anymail.inbound import AnymailInboundMessage from anymail.signals import AnymailInboundEvent @@ -26,12 +27,14 @@ class AmazonSESValidatedInboundWebhookView(amazon_ses.AmazonSESInboundWebhookVie def validate_request(self, request): """Add SNS message validation to Amazon SES inbound webhook view, which is not implemented in Anymail.""" - - super().validate_request(request) - sns_message = self._parse_sns_message(request) - if not validate_amazon_sns_message(sns_message): + if not validate_amazon_sns_message(self._parse_sns_message(request)): raise AnymailWebhookValidationFailure("SNS message validation failed") + def auto_confirm_sns_subscription(self, sns_message): + """This method is called after validate_request, so we can be sure that the message is valid.""" + response = requests.get(sns_message["SubscribeURL"]) + response.raise_for_status() + # {: (, ), ...} INBOUND_EMAIL_ESP_OPTIONS = { diff --git a/engine/apps/email/tests/test_inbound_email.py b/engine/apps/email/tests/test_inbound_email.py index 35bccd10..252b5292 100644 --- a/engine/apps/email/tests/test_inbound_email.py +++ b/engine/apps/email/tests/test_inbound_email.py @@ -47,6 +47,10 @@ CERTIFICATE = ( ) AMAZON_SNS_TOPIC_ARN = "arn:aws:sns:us-east-2:123456789012:test" SIGNING_CERT_URL = "https://sns.us-east-2.amazonaws.com/SimpleNotificationService-example.pem" +SENDER_EMAIL = "sender@example.com" +TO_EMAIL = "test-token@inbound.example.com" +SUBJECT = "Test email" +MESSAGE = "This is a test email message body." def _sns_inbound_email_payload_and_headers(sender_email, to_email, subject, message): @@ -439,15 +443,11 @@ def test_amazon_ses_pass(create_alert_mock, settings, make_organization, make_al token="test-token", ) - sender_email = "sender@example.com" - to_email = "test-token@inbound.example.com" - subject = "Test email" - message = "This is a test email message body." sns_payload, sns_headers = _sns_inbound_email_payload_and_headers( - sender_email=sender_email, - to_email=to_email, - subject=subject, - message=message, + sender_email=SENDER_EMAIL, + to_email=TO_EMAIL, + subject=SUBJECT, + message=MESSAGE, ) client = APIClient() @@ -460,16 +460,16 @@ def test_amazon_ses_pass(create_alert_mock, settings, make_organization, make_al assert response.status_code == status.HTTP_200_OK create_alert_mock.assert_called_once_with( - title=subject, - message=message, + title=SUBJECT, + message=MESSAGE, alert_receive_channel_pk=alert_receive_channel.pk, image_url=None, link_to_upstream_details=None, integration_unique_data=None, raw_request_data={ - "subject": subject, - "message": message, - "sender": sender_email, + "subject": SUBJECT, + "message": MESSAGE, + "sender": SENDER_EMAIL, }, received_at=ANY, ) @@ -493,15 +493,11 @@ def test_amazon_ses_validated_pass( token="test-token", ) - sender_email = "sender@example.com" - to_email = "test-token@inbound.example.com" - subject = "Test email" - message = "This is a test email message body." sns_payload, sns_headers = _sns_inbound_email_payload_and_headers( - sender_email=sender_email, - to_email=to_email, - subject=subject, - message=message, + sender_email=SENDER_EMAIL, + to_email=TO_EMAIL, + subject=SUBJECT, + message=MESSAGE, ) client = APIClient() @@ -514,16 +510,16 @@ def test_amazon_ses_validated_pass( assert response.status_code == status.HTTP_200_OK mock_create_alert.assert_called_once_with( - title=subject, - message=message, + title=SUBJECT, + message=MESSAGE, alert_receive_channel_pk=alert_receive_channel.pk, image_url=None, link_to_upstream_details=None, integration_unique_data=None, raw_request_data={ - "subject": subject, - "message": message, - "sender": sender_email, + "subject": SUBJECT, + "message": MESSAGE, + "sender": SENDER_EMAIL, }, received_at=ANY, ) @@ -531,6 +527,83 @@ def test_amazon_ses_validated_pass( mock_requests_get.assert_called_once_with(SIGNING_CERT_URL, timeout=5) +@patch("requests.get", return_value=Mock(content=CERTIFICATE)) +@patch.object(create_alert, "delay") +@pytest.mark.django_db +def test_amazon_ses_validated_fail_wrong_sns_topic_arn( + mock_create_alert, mock_requests_get, settings, make_organization, make_alert_receive_channel +): + settings.INBOUND_EMAIL_ESP = "amazon_ses_validated,mailgun" + settings.INBOUND_EMAIL_DOMAIN = "inbound.example.com" + settings.INBOUND_EMAIL_WEBHOOK_SECRET = "secret" + settings.INBOUND_EMAIL_AMAZON_SNS_TOPIC_ARN = "arn:aws:sns:us-east-2:123456789013:test" + + organization = make_organization() + make_alert_receive_channel( + organization, + integration=AlertReceiveChannel.INTEGRATION_INBOUND_EMAIL, + token="test-token", + ) + + sns_payload, sns_headers = _sns_inbound_email_payload_and_headers( + sender_email=SENDER_EMAIL, + to_email=TO_EMAIL, + subject=SUBJECT, + message=MESSAGE, + ) + + client = APIClient() + response = client.post( + reverse("integrations:inbound_email_webhook"), + data=sns_payload, + headers=sns_headers, + format="json", + ) + + assert response.status_code == status.HTTP_400_BAD_REQUEST + mock_create_alert.assert_not_called() + mock_requests_get.assert_not_called() + + +@patch("requests.get", return_value=Mock(content=CERTIFICATE)) +@patch.object(create_alert, "delay") +@pytest.mark.django_db +def test_amazon_ses_validated_fail_wrong_signature( + mock_create_alert, mock_requests_get, settings, make_organization, make_alert_receive_channel +): + settings.INBOUND_EMAIL_ESP = "amazon_ses_validated,mailgun" + settings.INBOUND_EMAIL_DOMAIN = "inbound.example.com" + settings.INBOUND_EMAIL_WEBHOOK_SECRET = "secret" + settings.INBOUND_EMAIL_AMAZON_SNS_TOPIC_ARN = AMAZON_SNS_TOPIC_ARN + + organization = make_organization() + make_alert_receive_channel( + organization, + integration=AlertReceiveChannel.INTEGRATION_INBOUND_EMAIL, + token="test-token", + ) + + sns_payload, sns_headers = _sns_inbound_email_payload_and_headers( + sender_email=SENDER_EMAIL, + to_email=TO_EMAIL, + subject=SUBJECT, + message=MESSAGE, + ) + sns_payload["Signature"] = "invalid-signature" + + client = APIClient() + response = client.post( + reverse("integrations:inbound_email_webhook"), + data=sns_payload, + headers=sns_headers, + format="json", + ) + + assert response.status_code == status.HTTP_400_BAD_REQUEST + mock_create_alert.assert_not_called() + mock_requests_get.assert_called_once_with(SIGNING_CERT_URL, timeout=5) + + @patch.object(create_alert, "delay") @pytest.mark.django_db def test_mailgun_pass(create_alert_mock, settings, make_organization, make_alert_receive_channel): @@ -545,16 +618,11 @@ def test_mailgun_pass(create_alert_mock, settings, make_organization, make_alert token="test-token", ) - sender_email = "sender@example.com" - to_email = "test-token@inbound.example.com" - subject = "Test email" - message = "This is a test email message body." - mailgun_payload = _mailgun_inbound_email_payload( - sender_email=sender_email, - to_email=to_email, - subject=subject, - message=message, + sender_email=SENDER_EMAIL, + to_email=TO_EMAIL, + subject=SUBJECT, + message=MESSAGE, ) client = APIClient() @@ -566,16 +634,16 @@ def test_mailgun_pass(create_alert_mock, settings, make_organization, make_alert assert response.status_code == status.HTTP_200_OK create_alert_mock.assert_called_once_with( - title=subject, - message=message, + title=SUBJECT, + message=MESSAGE, alert_receive_channel_pk=alert_receive_channel.pk, image_url=None, link_to_upstream_details=None, integration_unique_data=None, raw_request_data={ - "subject": subject, - "message": message, - "sender": sender_email, + "subject": SUBJECT, + "message": MESSAGE, + "sender": SENDER_EMAIL, }, received_at=ANY, ) From 0c811e0249cb1c6d8b91e9443646149b92d4c478 Mon Sep 17 00:00:00 2001 From: Matias Bordese Date: Mon, 18 Nov 2024 17:29:23 -0300 Subject: [PATCH 07/12] fix: update `next_shifts_per_user` to only list users with upcoming shifts (#5264) Related to https://github.com/grafana/irm/issues/343 --- engine/apps/api/tests/test_schedules.py | 40 +++++++++++++++++++++---- engine/apps/api/views/schedule.py | 18 ++++++----- 2 files changed, 45 insertions(+), 13 deletions(-) diff --git a/engine/apps/api/tests/test_schedules.py b/engine/apps/api/tests/test_schedules.py index 4a29dc9d..8efcb6b2 100644 --- a/engine/apps/api/tests/test_schedules.py +++ b/engine/apps/api/tests/test_schedules.py @@ -1442,8 +1442,9 @@ def test_next_shifts_per_user( ("B", "UTC"), ("C", None), ("D", "America/Montevideo"), + ("E", None), ) - user_a, user_b, user_c, user_d = ( + user_a, user_b, user_c, user_d, user_e = ( make_user_for_organization(organization, username=i, _timezone=tz) for i, tz in users ) @@ -1469,8 +1470,7 @@ def test_next_shifts_per_user( ) on_call_shift.add_rolling_users([[user]]) - # override in the past: 17-18 / D - # won't be listed, but user D will still be included in the response + # override in the past, won't be listed: 17-18 / D override_data = { "start": tomorrow - timezone.timedelta(days=3), "rotation_start": tomorrow - timezone.timedelta(days=3), @@ -1483,6 +1483,7 @@ def test_next_shifts_per_user( override.add_rolling_users([[user_d]]) # override: 17-18 / C + # this is before C's shift, so it will be listed as upcoming override_data = { "start": tomorrow + timezone.timedelta(hours=17), "rotation_start": tomorrow + timezone.timedelta(hours=17), @@ -1494,11 +1495,26 @@ def test_next_shifts_per_user( ) override.add_rolling_users([[user_c]]) + # override: 17-18 / E + fifteend_days_later = tomorrow + timezone.timedelta(days=15) + override_data = { + "start": fifteend_days_later + timezone.timedelta(hours=17), + "rotation_start": fifteend_days_later + timezone.timedelta(hours=17), + "duration": timezone.timedelta(hours=1), + "schedule": schedule, + } + override = make_on_call_shift( + organization=organization, shift_type=CustomOnCallShift.TYPE_OVERRIDE, **override_data + ) + override.add_rolling_users([[user_e]]) + # final schedule: 7-12: B, 15-16: A, 16-17: B, 17-18: C (override), 18-20: C schedule.refresh_ical_final_schedule() url = reverse("api-internal:schedule-next-shifts-per-user", kwargs={"pk": schedule.public_primary_key}) - response = client.get(url, format="json", **make_user_auth_headers(admin, token)) + + # check for users with shifts in the next week + response = client.get(url + "?days=7", format="json", **make_user_auth_headers(admin, token)) assert response.status_code == status.HTTP_200_OK expected = { @@ -1517,13 +1533,27 @@ def test_next_shifts_per_user( tomorrow + timezone.timedelta(hours=18), user_c.timezone, ), - user_d.public_primary_key: (None, None, user_d.timezone), } returned_data = { u: (ev.get("start"), ev.get("end"), ev.get("user_timezone")) for u, ev in response.data["users"].items() } assert returned_data == expected + # by default it will check for shifts in the next 45 days + response = client.get(url, format="json", **make_user_auth_headers(admin, token)) + assert response.status_code == status.HTTP_200_OK + + # include user E with the override + expected[user_e.public_primary_key] = ( + fifteend_days_later + timezone.timedelta(hours=17), + fifteend_days_later + timezone.timedelta(hours=18), + user_e.timezone, + ) + returned_data = { + u: (ev.get("start"), ev.get("end"), ev.get("user_timezone")) for u, ev in response.data["users"].items() + } + assert returned_data == expected + @pytest.mark.django_db def test_next_shifts_per_user_ical_schedule_using_emails( diff --git a/engine/apps/api/views/schedule.py b/engine/apps/api/views/schedule.py index 78635290..e30aa8cb 100644 --- a/engine/apps/api/views/schedule.py +++ b/engine/apps/api/views/schedule.py @@ -388,20 +388,22 @@ class ScheduleView( @action(detail=True, methods=["get"]) def next_shifts_per_user(self, request, pk): """Return next shift for users in schedule.""" + days = self.request.query_params.get("days") + days = int(days) if days else 30 now = timezone.now() - datetime_end = now + datetime.timedelta(days=30) + datetime_end = now + datetime.timedelta(days=days) schedule = self.get_object(annotate=False) + users = {} events = schedule.final_events(now, datetime_end) - - # include user TZ information for every user - users = {u.public_primary_key: {"user_timezone": u.timezone} for u in schedule.related_users()} + users_tz = {u.public_primary_key: u.timezone for u in schedule.related_users()} added_users = set() for e in events: - user = e["users"][0]["pk"] if e["users"] else None - if user is not None and user not in added_users and user in users and e["end"] > now: - users[user].update(e) - added_users.add(user) + user_ppk = e["users"][0]["pk"] if e["users"] else None + if user_ppk is not None and user_ppk not in users and user_ppk in users_tz and e["end"] > now: + users[user_ppk] = e + users[user_ppk]["user_timezone"] = users_tz[user_ppk] + added_users.add(user_ppk) result = {"users": users} return Response(result, status=status.HTTP_200_OK) From 2bcbac8454904ae8e0e8783d41b32cb19ad2d7eb Mon Sep 17 00:00:00 2001 From: Matias Bordese Date: Tue, 19 Nov 2024 09:52:23 -0300 Subject: [PATCH 08/12] Enable service account token auth for public API (#5254) Related to https://github.com/grafana/oncall-private/issues/2826 Continuing work started in https://github.com/grafana/oncall/pull/5211, this adds support for Grafana service accounts tokens for API authentication (except alert group actions which will still require a user behind). Next steps would be updating the go client and the terraform provider to allow service account token auth for OnCall resources. Following proposal 1.1 from [doc](https://docs.google.com/document/d/1I3nFbsUEkiNPphBXT-kWefIeramTY71qqZ1OA06Kmls/edit?usp=sharing). --- ...065_alertreceivechannel_service_account.py | 20 ++ .../alerts/models/alert_receive_channel.py | 14 +- engine/apps/api/permissions.py | 1 + engine/apps/auth_token/auth.py | 43 +--- .../auth_token/grafana/grafana_auth_token.py | 6 + .../migrations/0007_serviceaccounttoken.py | 29 +++ engine/apps/auth_token/models/__init__.py | 1 + .../models/service_account_token.py | 110 +++++++++ engine/apps/auth_token/tests/helpers.py | 18 ++ .../auth_token/tests/test_grafana_auth.py | 229 +++++++++++++++++- engine/apps/grafana_plugin/helpers/client.py | 3 + .../public_api/serializers/integrations.py | 5 +- .../public_api/tests/test_alert_groups.py | 34 +++ .../public_api/tests/test_integrations.py | 44 ++++ .../public_api/tests/test_rbac_permissions.py | 104 ++++++++ .../public_api/tests/test_resolution_notes.py | 6 +- engine/apps/public_api/views/alert_groups.py | 25 +- engine/apps/public_api/views/alerts.py | 4 +- .../public_api/views/escalation_chains.py | 4 +- .../public_api/views/escalation_policies.py | 4 +- engine/apps/public_api/views/integrations.py | 4 +- .../apps/public_api/views/on_call_shifts.py | 4 +- engine/apps/public_api/views/organizations.py | 4 +- engine/apps/public_api/views/routes.py | 4 +- engine/apps/public_api/views/schedules.py | 8 +- engine/apps/public_api/views/shift_swap.py | 4 +- .../apps/public_api/views/slack_channels.py | 4 +- engine/apps/public_api/views/teams.py | 4 +- engine/apps/public_api/views/user_groups.py | 4 +- engine/apps/public_api/views/users.py | 8 +- engine/apps/public_api/views/webhooks.py | 4 +- .../migrations/0027_serviceaccount.py | 26 ++ .../apps/user_management/models/__init__.py | 1 + .../user_management/models/service_account.py | 55 +++++ .../apps/user_management/tests/factories.py | 10 +- engine/conftest.py | 36 ++- engine/engine/middlewares.py | 6 +- 37 files changed, 816 insertions(+), 74 deletions(-) create mode 100644 engine/apps/alerts/migrations/0065_alertreceivechannel_service_account.py create mode 100644 engine/apps/auth_token/migrations/0007_serviceaccounttoken.py create mode 100644 engine/apps/auth_token/models/service_account_token.py create mode 100644 engine/apps/auth_token/tests/helpers.py create mode 100644 engine/apps/user_management/migrations/0027_serviceaccount.py create mode 100644 engine/apps/user_management/models/service_account.py diff --git a/engine/apps/alerts/migrations/0065_alertreceivechannel_service_account.py b/engine/apps/alerts/migrations/0065_alertreceivechannel_service_account.py new file mode 100644 index 00000000..306d8a04 --- /dev/null +++ b/engine/apps/alerts/migrations/0065_alertreceivechannel_service_account.py @@ -0,0 +1,20 @@ +# Generated by Django 4.2.15 on 2024-11-12 13:13 + +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + + dependencies = [ + ('user_management', '0027_serviceaccount'), + ('alerts', '0064_migrate_resolutionnoteslackmessage_slack_channel_id'), + ] + + operations = [ + migrations.AddField( + model_name='alertreceivechannel', + name='service_account', + field=models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='alert_receive_channels', to='user_management.serviceaccount'), + ), + ] diff --git a/engine/apps/alerts/models/alert_receive_channel.py b/engine/apps/alerts/models/alert_receive_channel.py index 4fd926ac..a8cb1494 100644 --- a/engine/apps/alerts/models/alert_receive_channel.py +++ b/engine/apps/alerts/models/alert_receive_channel.py @@ -234,6 +234,13 @@ class AlertReceiveChannel(IntegrationOptionsMixin, MaintainableObject): author = models.ForeignKey( "user_management.User", on_delete=models.SET_NULL, related_name="alert_receive_channels", blank=True, null=True ) + service_account = models.ForeignKey( + "user_management.ServiceAccount", + on_delete=models.SET_NULL, + related_name="alert_receive_channels", + blank=True, + null=True, + ) team = models.ForeignKey( "user_management.Team", on_delete=models.SET_NULL, @@ -764,15 +771,16 @@ def listen_for_alertreceivechannel_model_save( from apps.heartbeat.models import IntegrationHeartBeat if created: - write_resource_insight_log(instance=instance, author=instance.author, event=EntityEvent.CREATED) + author = instance.author or instance.service_account + write_resource_insight_log(instance=instance, author=author, event=EntityEvent.CREATED) default_filter = ChannelFilter(alert_receive_channel=instance, filtering_term=None, is_default=True) default_filter.save() - write_resource_insight_log(instance=default_filter, author=instance.author, event=EntityEvent.CREATED) + write_resource_insight_log(instance=default_filter, author=author, event=EntityEvent.CREATED) TEN_MINUTES = 600 # this is timeout for cloud heartbeats if instance.is_available_for_integration_heartbeat: heartbeat = IntegrationHeartBeat.objects.create(alert_receive_channel=instance, timeout_seconds=TEN_MINUTES) - write_resource_insight_log(instance=heartbeat, author=instance.author, event=EntityEvent.CREATED) + write_resource_insight_log(instance=heartbeat, author=author, event=EntityEvent.CREATED) metrics_add_integrations_to_cache([instance], instance.organization) diff --git a/engine/apps/api/permissions.py b/engine/apps/api/permissions.py index 852506a1..d9dad6b3 100644 --- a/engine/apps/api/permissions.py +++ b/engine/apps/api/permissions.py @@ -18,6 +18,7 @@ if typing.TYPE_CHECKING: RBAC_PERMISSIONS_ATTR = "rbac_permissions" RBAC_OBJECT_PERMISSIONS_ATTR = "rbac_object_permissions" + ViewSetOrAPIView = typing.Union[ViewSet, APIView] diff --git a/engine/apps/auth_token/auth.py b/engine/apps/auth_token/auth.py index dc6ccf7a..3a7e25d6 100644 --- a/engine/apps/auth_token/auth.py +++ b/engine/apps/auth_token/auth.py @@ -9,7 +9,6 @@ from rest_framework import exceptions from rest_framework.authentication import BaseAuthentication, get_authorization_header from rest_framework.request import Request -from apps.api.permissions import GrafanaAPIPermissions, LegacyAccessControlRole from apps.grafana_plugin.helpers.gcom import check_token from apps.grafana_plugin.sync_data import SyncPermission, SyncUser from apps.user_management.exceptions import OrganizationDeletedException, OrganizationMovedException @@ -20,13 +19,13 @@ from settings.base import SELF_HOSTED_SETTINGS from .constants import GOOGLE_OAUTH2_AUTH_TOKEN_NAME, SCHEDULE_EXPORT_TOKEN_NAME, SLACK_AUTH_TOKEN_NAME from .exceptions import InvalidToken -from .grafana.grafana_auth_token import get_service_account_token_permissions from .models import ( ApiAuthToken, GoogleOAuth2Token, IntegrationBacksyncAuthToken, PluginAuthToken, ScheduleExportAuthToken, + ServiceAccountToken, SlackAuthToken, UserScheduleExportAuthToken, ) @@ -336,8 +335,8 @@ class UserScheduleExportAuthentication(BaseAuthentication): return auth_token.user, auth_token +X_GRAFANA_URL = "X-Grafana-URL" X_GRAFANA_INSTANCE_ID = "X-Grafana-Instance-ID" -GRAFANA_SA_PREFIX = "glsa_" class GrafanaServiceAccountAuthentication(BaseAuthentication): @@ -345,7 +344,7 @@ class GrafanaServiceAccountAuthentication(BaseAuthentication): auth = get_authorization_header(request).decode("utf-8") if not auth: raise exceptions.AuthenticationFailed("Invalid token.") - if not auth.startswith(GRAFANA_SA_PREFIX): + if not auth.startswith(ServiceAccountToken.GRAFANA_SA_PREFIX): return None organization = self.get_organization(request) @@ -359,6 +358,13 @@ class GrafanaServiceAccountAuthentication(BaseAuthentication): return self.authenticate_credentials(organization, auth) def get_organization(self, request): + grafana_url = request.headers.get(X_GRAFANA_URL) + if grafana_url: + organization = Organization.objects.filter(grafana_url=grafana_url).first() + if not organization: + raise exceptions.AuthenticationFailed("Invalid Grafana URL.") + return organization + if settings.LICENSE == settings.CLOUD_LICENSE_NAME: instance_id = request.headers.get(X_GRAFANA_INSTANCE_ID) if not instance_id: @@ -370,36 +376,13 @@ class GrafanaServiceAccountAuthentication(BaseAuthentication): return Organization.objects.filter(org_slug=org_slug, stack_slug=instance_slug).first() def authenticate_credentials(self, organization, token): - permissions = get_service_account_token_permissions(organization, token) - if not permissions: + try: + user, auth_token = ServiceAccountToken.validate_token(organization, token) + except InvalidToken: raise exceptions.AuthenticationFailed("Invalid token.") - role = LegacyAccessControlRole.NONE - if not organization.is_rbac_permissions_enabled: - role = self.determine_role_from_permissions(permissions) - - user = User( - organization_id=organization.pk, - name="Grafana Service Account", - username="grafana_service_account", - role=role, - permissions=GrafanaAPIPermissions.construct_permissions(permissions.keys()), - ) - - auth_token = ApiAuthToken(organization=organization, user=user, name="Grafana Service Account") - return user, auth_token - # Using default permissions as proxies for roles since we cannot explicitly get role from the service account token - def determine_role_from_permissions(self, permissions): - if "plugins:write" in permissions: - return LegacyAccessControlRole.ADMIN - if "dashboards:write" in permissions: - return LegacyAccessControlRole.EDITOR - if "dashboards:read" in permissions: - return LegacyAccessControlRole.VIEWER - return LegacyAccessControlRole.NONE - class IntegrationBacksyncAuthentication(BaseAuthentication): model = IntegrationBacksyncAuthToken diff --git a/engine/apps/auth_token/grafana/grafana_auth_token.py b/engine/apps/auth_token/grafana/grafana_auth_token.py index 07bae644..6576e417 100644 --- a/engine/apps/auth_token/grafana/grafana_auth_token.py +++ b/engine/apps/auth_token/grafana/grafana_auth_token.py @@ -46,3 +46,9 @@ def get_service_account_token_permissions(organization: Organization, token: str grafana_api_client = GrafanaAPIClient(api_url=organization.grafana_url, api_token=token) permissions, _ = grafana_api_client.get_service_account_token_permissions() return permissions + + +def get_service_account_details(organization: Organization, token: str) -> typing.Dict[str, typing.List[str]]: + grafana_api_client = GrafanaAPIClient(api_url=organization.grafana_url, api_token=token) + user_data, _ = grafana_api_client.get_current_user() + return user_data diff --git a/engine/apps/auth_token/migrations/0007_serviceaccounttoken.py b/engine/apps/auth_token/migrations/0007_serviceaccounttoken.py new file mode 100644 index 00000000..920b9ada --- /dev/null +++ b/engine/apps/auth_token/migrations/0007_serviceaccounttoken.py @@ -0,0 +1,29 @@ +# Generated by Django 4.2.15 on 2024-11-12 13:13 + +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + + dependencies = [ + ('user_management', '0027_serviceaccount'), + ('auth_token', '0006_googleoauth2token'), + ] + + operations = [ + migrations.CreateModel( + name='ServiceAccountToken', + fields=[ + ('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('token_key', models.CharField(db_index=True, max_length=8)), + ('digest', models.CharField(max_length=128)), + ('created_at', models.DateTimeField(auto_now_add=True)), + ('revoked_at', models.DateTimeField(null=True)), + ('service_account', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='tokens', to='user_management.serviceaccount')), + ], + options={ + 'unique_together': {('token_key', 'service_account', 'digest')}, + }, + ), + ] diff --git a/engine/apps/auth_token/models/__init__.py b/engine/apps/auth_token/models/__init__.py index 272adbda..42cc60c5 100644 --- a/engine/apps/auth_token/models/__init__.py +++ b/engine/apps/auth_token/models/__init__.py @@ -4,5 +4,6 @@ from .google_oauth2_token import GoogleOAuth2Token # noqa: F401 from .integration_backsync_auth_token import IntegrationBacksyncAuthToken # noqa: F401 from .plugin_auth_token import PluginAuthToken # noqa: F401 from .schedule_export_auth_token import ScheduleExportAuthToken # noqa: F401 +from .service_account_token import ServiceAccountToken # noqa: F401 from .slack_auth_token import SlackAuthToken # noqa: F401 from .user_schedule_export_auth_token import UserScheduleExportAuthToken # noqa: F401 diff --git a/engine/apps/auth_token/models/service_account_token.py b/engine/apps/auth_token/models/service_account_token.py new file mode 100644 index 00000000..716dc55d --- /dev/null +++ b/engine/apps/auth_token/models/service_account_token.py @@ -0,0 +1,110 @@ +import binascii +from hmac import compare_digest + +from django.db import models + +from apps.api.permissions import GrafanaAPIPermissions, LegacyAccessControlRole +from apps.auth_token import constants +from apps.auth_token.crypto import hash_token_string +from apps.auth_token.exceptions import InvalidToken +from apps.auth_token.grafana.grafana_auth_token import ( + get_service_account_details, + get_service_account_token_permissions, +) +from apps.auth_token.models import BaseAuthToken +from apps.user_management.models import ServiceAccount, ServiceAccountUser + + +class ServiceAccountTokenManager(models.Manager): + def get_queryset(self): + return super().get_queryset().select_related("service_account__organization") + + +class ServiceAccountToken(BaseAuthToken): + GRAFANA_SA_PREFIX = "glsa_" + + objects = ServiceAccountTokenManager() + + service_account: "ServiceAccount" + service_account = models.ForeignKey(ServiceAccount, on_delete=models.CASCADE, related_name="tokens") + + class Meta: + unique_together = ("token_key", "service_account", "digest") + + @property + def organization(self): + return self.service_account.organization + + @classmethod + def validate_token(cls, organization, token): + # require RBAC enabled to allow service account auth + if not organization.is_rbac_permissions_enabled: + raise InvalidToken + + # Grafana API request: get permissions and confirm token is valid + permissions = get_service_account_token_permissions(organization, token) + if not permissions: + # NOTE: a token can be disabled/re-enabled (not setting as revoked in oncall DB for now) + raise InvalidToken + + # check if we have already seen this token + validated_token = None + service_account = None + prefix_length = len(cls.GRAFANA_SA_PREFIX) + token_key = token[prefix_length : prefix_length + constants.TOKEN_KEY_LENGTH] + try: + hashable_token = binascii.hexlify(token.encode()).decode() + digest = hash_token_string(hashable_token) + except (TypeError, binascii.Error): + raise InvalidToken + for existing_token in cls.objects.filter(service_account__organization=organization, token_key=token_key): + if compare_digest(digest, existing_token.digest): + validated_token = existing_token + service_account = existing_token.service_account + break + + if not validated_token: + # if it didn't match an existing token, create a new one + # make request to Grafana API api/user using token + service_account_data = get_service_account_details(organization, token) + if not service_account_data: + # Grafana versions < 11.3 return 403 trying to get user details with service account token + # use some default values + service_account_data = { + "login": "grafana_service_account", + "uid": None, # "service-account:7" + } + + grafana_id = 0 # default to zero for old Grafana versions (to keep service account unique) + if service_account_data["uid"] is not None: + # extract service account Grafana ID + try: + grafana_id = int(service_account_data["uid"].split(":")[-1]) + except ValueError: + pass + + # get or create service account + service_account, _ = ServiceAccount.objects.get_or_create( + organization=organization, + grafana_id=grafana_id, + defaults={ + "login": service_account_data["login"], + }, + ) + # create token + validated_token, _ = cls.objects.get_or_create( + service_account=service_account, + token_key=token_key, + digest=digest, + ) + + user = ServiceAccountUser( + organization=organization, + service_account=service_account, + username=service_account.username, + public_primary_key=service_account.public_primary_key, + role=LegacyAccessControlRole.NONE, + permissions=GrafanaAPIPermissions.construct_permissions(permissions.keys()), + ) + + return user, validated_token diff --git a/engine/apps/auth_token/tests/helpers.py b/engine/apps/auth_token/tests/helpers.py new file mode 100644 index 00000000..bcecce6f --- /dev/null +++ b/engine/apps/auth_token/tests/helpers.py @@ -0,0 +1,18 @@ +import json + +import httpretty + + +def setup_service_account_api_mocks(organization, perms=None, user_data=None, perms_status=200, user_status=200): + # requires enabling httpretty + if perms is None: + perms = {} + mock_response = httpretty.Response(status=perms_status, body=json.dumps(perms)) + perms_url = f"{organization.grafana_url}/api/access-control/user/permissions" + httpretty.register_uri(httpretty.GET, perms_url, responses=[mock_response]) + + if user_data is None: + user_data = {"login": "some-login", "uid": "service-account:42"} + mock_response = httpretty.Response(status=user_status, body=json.dumps(user_data)) + user_url = f"{organization.grafana_url}/api/user" + httpretty.register_uri(httpretty.GET, user_url, responses=[mock_response]) diff --git a/engine/apps/auth_token/tests/test_grafana_auth.py b/engine/apps/auth_token/tests/test_grafana_auth.py index 5b78636c..3a8ec56c 100644 --- a/engine/apps/auth_token/tests/test_grafana_auth.py +++ b/engine/apps/auth_token/tests/test_grafana_auth.py @@ -1,11 +1,16 @@ import typing from unittest.mock import patch +import httpretty import pytest from rest_framework import exceptions from rest_framework.test import APIRequestFactory -from apps.auth_token.auth import GRAFANA_SA_PREFIX, X_GRAFANA_INSTANCE_ID, GrafanaServiceAccountAuthentication +from apps.api.permissions import LegacyAccessControlRole +from apps.auth_token.auth import X_GRAFANA_INSTANCE_ID, GrafanaServiceAccountAuthentication +from apps.auth_token.models import ServiceAccountToken +from apps.auth_token.tests.helpers import setup_service_account_api_mocks +from apps.user_management.models import ServiceAccountUser from settings.base import CLOUD_LICENSE_NAME, OPEN_SOURCE_LICENSE_NAME, SELF_HOSTED_SETTINGS @@ -53,7 +58,7 @@ def test_grafana_authentication_cloud_inputs(make_organization, settings): mock.assert_called_once_with(organization, token) -def check_common_inputs() -> (dict[str, typing.Any], str): +def check_common_inputs() -> tuple[dict[str, typing.Any], str]: request = APIRequestFactory().get("/") with pytest.raises(exceptions.AuthenticationFailed): GrafanaServiceAccountAuthentication().authenticate(request) @@ -65,7 +70,7 @@ def check_common_inputs() -> (dict[str, typing.Any], str): result = GrafanaServiceAccountAuthentication().authenticate(request) assert result is None - token = f"{GRAFANA_SA_PREFIX}xyz" + token = f"{ServiceAccountToken.GRAFANA_SA_PREFIX}xyz" headers = { "HTTP_AUTHORIZATION": token, } @@ -74,3 +79,221 @@ def check_common_inputs() -> (dict[str, typing.Any], str): GrafanaServiceAccountAuthentication().authenticate(request) return headers, token + + +@pytest.mark.django_db +@httpretty.activate(verbose=True, allow_net_connect=False) +def test_grafana_authentication_missing_org(): + token = f"{ServiceAccountToken.GRAFANA_SA_PREFIX}xyz" + headers = { + "HTTP_AUTHORIZATION": token, + } + request = APIRequestFactory().get("/", **headers) + + with pytest.raises(exceptions.AuthenticationFailed) as exc: + GrafanaServiceAccountAuthentication().authenticate(request) + assert exc.value.detail == "Invalid organization." + + +@pytest.mark.django_db +@httpretty.activate(verbose=True, allow_net_connect=False) +def test_grafana_authentication_invalid_grafana_url(): + token = f"{ServiceAccountToken.GRAFANA_SA_PREFIX}xyz" + headers = { + "HTTP_AUTHORIZATION": token, + "HTTP_X_GRAFANA_URL": "http://grafana.test", # no org for this URL + } + request = APIRequestFactory().get("/", **headers) + + with pytest.raises(exceptions.AuthenticationFailed) as exc: + GrafanaServiceAccountAuthentication().authenticate(request) + assert exc.value.detail == "Invalid Grafana URL." + + +@pytest.mark.django_db +@httpretty.activate(verbose=True, allow_net_connect=False) +def test_grafana_authentication_rbac_disabled_fails(make_organization): + organization = make_organization(grafana_url="http://grafana.test") + if organization.is_rbac_permissions_enabled: + return + + token = f"{ServiceAccountToken.GRAFANA_SA_PREFIX}xyz" + headers = { + "HTTP_AUTHORIZATION": token, + "HTTP_X_GRAFANA_URL": organization.grafana_url, + } + request = APIRequestFactory().get("/", **headers) + + with pytest.raises(exceptions.AuthenticationFailed) as exc: + GrafanaServiceAccountAuthentication().authenticate(request) + assert exc.value.detail == "Invalid token." + + +@pytest.mark.django_db +@httpretty.activate(verbose=True, allow_net_connect=False) +def test_grafana_authentication_permissions_call_fails(make_organization): + organization = make_organization(grafana_url="http://grafana.test") + if not organization.is_rbac_permissions_enabled: + return + + token = f"{ServiceAccountToken.GRAFANA_SA_PREFIX}xyz" + headers = { + "HTTP_AUTHORIZATION": token, + "HTTP_X_GRAFANA_URL": organization.grafana_url, + } + request = APIRequestFactory().get("/", **headers) + + # setup Grafana API responses + # permissions endpoint returns a 401 + setup_service_account_api_mocks(organization, perms_status=401) + + with pytest.raises(exceptions.AuthenticationFailed) as exc: + GrafanaServiceAccountAuthentication().authenticate(request) + assert exc.value.detail == "Invalid token." + + last_request = httpretty.last_request() + assert last_request.method == "GET" + expected_url = f"{organization.grafana_url}/api/access-control/user/permissions" + assert last_request.url == expected_url + # the request uses the given token + assert last_request.headers["Authorization"] == f"Bearer {token}" + + +@pytest.mark.django_db +@httpretty.activate(verbose=True, allow_net_connect=False) +def test_grafana_authentication_existing_token( + make_organization, make_service_account_for_organization, make_token_for_service_account +): + organization = make_organization(grafana_url="http://grafana.test") + if not organization.is_rbac_permissions_enabled: + return + service_account = make_service_account_for_organization(organization) + token_string = "glsa_the-token" + token = make_token_for_service_account(service_account, token_string) + + headers = { + "HTTP_AUTHORIZATION": token_string, + "HTTP_X_GRAFANA_URL": organization.grafana_url, + } + request = APIRequestFactory().get("/", **headers) + + # setup Grafana API responses + setup_service_account_api_mocks(organization, {"some-perm": "value"}) + + user, auth_token = GrafanaServiceAccountAuthentication().authenticate(request) + + assert isinstance(user, ServiceAccountUser) + assert user.service_account == service_account + assert user.public_primary_key == service_account.public_primary_key + assert user.username == service_account.username + assert user.role == LegacyAccessControlRole.NONE + assert auth_token == token + + last_request = httpretty.last_request() + assert last_request.method == "GET" + expected_url = f"{organization.grafana_url}/api/access-control/user/permissions" + assert last_request.url == expected_url + # the request uses the given token + assert last_request.headers["Authorization"] == f"Bearer {token_string}" + + +@pytest.mark.django_db +@httpretty.activate(verbose=True, allow_net_connect=False) +def test_grafana_authentication_token_created(make_organization): + organization = make_organization(grafana_url="http://grafana.test") + if not organization.is_rbac_permissions_enabled: + return + token_string = "glsa_the-token" + + headers = { + "HTTP_AUTHORIZATION": token_string, + "HTTP_X_GRAFANA_URL": organization.grafana_url, + } + request = APIRequestFactory().get("/", **headers) + + # setup Grafana API responses + permissions = {"some-perm": "value"} + user_data = {"login": "some-login", "uid": "service-account:42"} + setup_service_account_api_mocks(organization, permissions, user_data) + + user, auth_token = GrafanaServiceAccountAuthentication().authenticate(request) + + assert isinstance(user, ServiceAccountUser) + service_account = user.service_account + assert service_account.organization == organization + assert user.public_primary_key == service_account.public_primary_key + assert user.username == service_account.username + assert service_account.grafana_id == 42 + assert service_account.login == "some-login" + assert user.role == LegacyAccessControlRole.NONE + assert user.permissions == [{"action": p} for p in permissions] + assert auth_token.service_account == user.service_account + + perms_request, user_request = httpretty.latest_requests() + for req in (perms_request, user_request): + assert req.method == "GET" + assert req.headers["Authorization"] == f"Bearer {token_string}" + perms_url = f"{organization.grafana_url}/api/access-control/user/permissions" + assert perms_request.url == perms_url + user_url = f"{organization.grafana_url}/api/user" + assert user_request.url == user_url + + +@pytest.mark.django_db +@httpretty.activate(verbose=True, allow_net_connect=False) +def test_grafana_authentication_token_created_older_grafana(make_organization): + organization = make_organization(grafana_url="http://grafana.test") + if not organization.is_rbac_permissions_enabled: + return + token_string = "glsa_the-token" + + headers = { + "HTTP_AUTHORIZATION": token_string, + "HTTP_X_GRAFANA_URL": organization.grafana_url, + } + request = APIRequestFactory().get("/", **headers) + + # setup Grafana API responses + permissions = {"some-perm": "value"} + # User API fails for older Grafana versions + setup_service_account_api_mocks(organization, permissions, user_status=400) + + user, auth_token = GrafanaServiceAccountAuthentication().authenticate(request) + + assert isinstance(user, ServiceAccountUser) + service_account = user.service_account + assert service_account.organization == organization + # use fallback data + assert service_account.grafana_id == 0 + assert service_account.login == "grafana_service_account" + assert auth_token.service_account == user.service_account + + +@pytest.mark.django_db +@httpretty.activate(verbose=True, allow_net_connect=False) +def test_grafana_authentication_token_reuse_service_account(make_organization, make_service_account_for_organization): + organization = make_organization(grafana_url="http://grafana.test") + if not organization.is_rbac_permissions_enabled: + return + service_account = make_service_account_for_organization(organization) + token_string = "glsa_the-token" + + headers = { + "HTTP_AUTHORIZATION": token_string, + "HTTP_X_GRAFANA_URL": organization.grafana_url, + } + request = APIRequestFactory().get("/", **headers) + + # setup Grafana API responses + permissions = {"some-perm": "value"} + user_data = { + "login": service_account.login, + "uid": f"service-account:{service_account.grafana_id}", + } + setup_service_account_api_mocks(organization, permissions, user_data) + + user, auth_token = GrafanaServiceAccountAuthentication().authenticate(request) + + assert isinstance(user, ServiceAccountUser) + assert user.service_account == service_account + assert auth_token.service_account == service_account diff --git a/engine/apps/grafana_plugin/helpers/client.py b/engine/apps/grafana_plugin/helpers/client.py index 2beafa8b..17d1cabd 100644 --- a/engine/apps/grafana_plugin/helpers/client.py +++ b/engine/apps/grafana_plugin/helpers/client.py @@ -315,6 +315,9 @@ class GrafanaAPIClient(APIClient): def get_grafana_irm_plugin_settings(self) -> APIClientResponse["GrafanaAPIClient.Types.PluginSettings"]: return self.get_grafana_plugin_settings(PluginID.IRM) + def get_current_user(self) -> APIClientResponse[typing.Dict[str, typing.List[str]]]: + return self.api_get("api/user") + def get_service_account(self, login: str) -> APIClientResponse["GrafanaAPIClient.Types.ServiceAccountResponse"]: return self.api_get(f"api/serviceaccounts/search?query={login}") diff --git a/engine/apps/public_api/serializers/integrations.py b/engine/apps/public_api/serializers/integrations.py index b16aeb54..0cbf4605 100644 --- a/engine/apps/public_api/serializers/integrations.py +++ b/engine/apps/public_api/serializers/integrations.py @@ -7,6 +7,7 @@ from apps.alerts.grafana_alerting_sync_manager.grafana_alerting_sync import Graf from apps.alerts.models import AlertReceiveChannel from apps.base.messaging import get_messaging_backends from apps.integrations.legacy_prefix import has_legacy_prefix, remove_legacy_prefix +from apps.user_management.models import ServiceAccountUser from common.api_helpers.custom_fields import TeamPrimaryKeyRelatedField from common.api_helpers.exceptions import BadRequest from common.api_helpers.mixins import PHONE_CALL, SLACK, SMS, TELEGRAM, WEB, EagerLoadingMixin @@ -123,11 +124,13 @@ class IntegrationSerializer(EagerLoadingMixin, serializers.ModelSerializer, Main connection_error = GrafanaAlertingSyncManager.check_for_connection_errors(organization) if connection_error: raise serializers.ValidationError(connection_error) + user = self.context["request"].user with transaction.atomic(): try: instance = AlertReceiveChannel.create( **validated_data, - author=self.context["request"].user, + author=user if not isinstance(user, ServiceAccountUser) else None, + service_account=user.service_account if isinstance(user, ServiceAccountUser) else None, organization=organization, ) except AlertReceiveChannel.DuplicateDirectPagingError: diff --git a/engine/apps/public_api/tests/test_alert_groups.py b/engine/apps/public_api/tests/test_alert_groups.py index 71421cd3..e3cc872e 100644 --- a/engine/apps/public_api/tests/test_alert_groups.py +++ b/engine/apps/public_api/tests/test_alert_groups.py @@ -1,5 +1,6 @@ from unittest.mock import patch +import httpretty import pytest from django.urls import reverse from django.utils import timezone @@ -9,6 +10,8 @@ from rest_framework.test import APIClient from apps.alerts.constants import ActionSource from apps.alerts.models import AlertGroup, AlertReceiveChannel from apps.alerts.tasks import delete_alert_group, wipe +from apps.api import permissions +from apps.auth_token.tests.helpers import setup_service_account_api_mocks def construct_expected_response_from_alert_groups(alert_groups): @@ -736,3 +739,34 @@ def test_alert_group_unsilence( assert alert_group.silenced == silenced assert response.status_code == status_code assert response_msg == response.json()["detail"] + + +@pytest.mark.django_db +@httpretty.activate(verbose=True, allow_net_connect=False) +def test_actions_disabled_for_service_accounts( + make_organization, + make_service_account_for_organization, + make_token_for_service_account, + make_escalation_chain, +): + organization = make_organization(grafana_url="http://grafana.test") + service_account = make_service_account_for_organization(organization) + token_string = "glsa_token" + make_token_for_service_account(service_account, token_string) + make_escalation_chain(organization) + + perms = { + permissions.RBACPermission.Permissions.ALERT_GROUPS_WRITE.value: ["*"], + } + setup_service_account_api_mocks(organization, perms=perms) + + client = APIClient() + disabled_actions = ["acknowledge", "unacknowledge", "resolve", "unresolve", "silence", "unsilence"] + for action in disabled_actions: + url = reverse(f"api-public:alert_groups-{action}", kwargs={"pk": "ABCDEFG"}) + response = client.post( + url, + HTTP_AUTHORIZATION=f"{token_string}", + HTTP_X_GRAFANA_URL=organization.grafana_url, + ) + assert response.status_code == status.HTTP_403_FORBIDDEN diff --git a/engine/apps/public_api/tests/test_integrations.py b/engine/apps/public_api/tests/test_integrations.py index b021df33..796942eb 100644 --- a/engine/apps/public_api/tests/test_integrations.py +++ b/engine/apps/public_api/tests/test_integrations.py @@ -1,9 +1,12 @@ +import httpretty import pytest from django.urls import reverse from rest_framework import status from rest_framework.test import APIClient from apps.alerts.models import AlertReceiveChannel +from apps.api import permissions +from apps.auth_token.tests.helpers import setup_service_account_api_mocks from apps.base.tests.messaging_backend import TestOnlyBackend TEST_MESSAGING_BACKEND_FIELD = TestOnlyBackend.backend_id.lower() @@ -104,6 +107,47 @@ def test_create_integration( assert response.status_code == status.HTTP_201_CREATED +@pytest.mark.django_db +@httpretty.activate(verbose=True, allow_net_connect=False) +def test_create_integration_via_service_account( + make_organization, + make_service_account_for_organization, + make_token_for_service_account, + make_escalation_chain, +): + organization = make_organization(grafana_url="http://grafana.test") + service_account = make_service_account_for_organization(organization) + token_string = "glsa_token" + make_token_for_service_account(service_account, token_string) + make_escalation_chain(organization) + + perms = { + permissions.RBACPermission.Permissions.INTEGRATIONS_WRITE.value: ["*"], + } + setup_service_account_api_mocks(organization, perms) + + client = APIClient() + data_for_create = { + "type": "grafana", + "name": "grafana_created", + "team_id": None, + } + url = reverse("api-public:integrations-list") + response = client.post( + url, + data=data_for_create, + format="json", + HTTP_AUTHORIZATION=f"{token_string}", + HTTP_X_GRAFANA_URL=organization.grafana_url, + ) + if not organization.is_rbac_permissions_enabled: + assert response.status_code == status.HTTP_403_FORBIDDEN + else: + assert response.status_code == status.HTTP_201_CREATED + integration = AlertReceiveChannel.objects.get(public_primary_key=response.data["id"]) + assert integration.service_account == service_account + + @pytest.mark.django_db def test_integration_name_uniqueness( make_organization_and_user_with_token, diff --git a/engine/apps/public_api/tests/test_rbac_permissions.py b/engine/apps/public_api/tests/test_rbac_permissions.py index 9829550d..95154ab4 100644 --- a/engine/apps/public_api/tests/test_rbac_permissions.py +++ b/engine/apps/public_api/tests/test_rbac_permissions.py @@ -1,5 +1,7 @@ +import json from unittest.mock import patch +import httpretty import pytest from django.urls import reverse from rest_framework import status @@ -9,6 +11,13 @@ from rest_framework.test import APIClient from apps.api.permissions import GrafanaAPIPermission, LegacyAccessControlRole, get_most_authorized_role from apps.public_api.urls import router +VIEWS_REQUIRING_USER_AUTH = ( + "EscalationView", + "PersonalNotificationView", + "MakeCallView", + "SendSMSView", +) + @pytest.mark.parametrize( "rbac_enabled,role,give_perm", @@ -96,3 +105,98 @@ def test_rbac_permissions( with patch(method_path, return_value=success): response = client.generic(path=url, method=http_method, HTTP_AUTHORIZATION=token) assert response.status_code == expected + + +@pytest.mark.parametrize( + "rbac_enabled,role,give_perm", + [ + # rbac disabled: auth is disabled + (False, LegacyAccessControlRole.ADMIN, None), + # rbac enabled: having role None, check the perm is required + (True, LegacyAccessControlRole.NONE, False), + (True, LegacyAccessControlRole.NONE, True), + ], +) +@pytest.mark.django_db +@httpretty.activate(verbose=True, allow_net_connect=False) +def test_service_account_auth( + make_organization, + make_service_account_for_organization, + make_token_for_service_account, + rbac_enabled, + role, + give_perm, +): + # APIView default actions + # (name, http method, detail-based) + default_actions = { + "create": ("post", False), + "list": ("get", False), + "retrieve": ("get", True), + "update": ("put", True), + "partial_update": ("patch", True), + "destroy": ("delete", True), + } + + organization = make_organization(grafana_url="http://grafana.test") + service_account = make_service_account_for_organization(organization) + token_string = "glsa_token" + make_token_for_service_account(service_account, token_string) + + if organization.is_rbac_permissions_enabled != rbac_enabled: + # skip if the organization's rbac_enabled is not the expected by the test + return + + client = APIClient() + # check all actions for all public API viewsets + for _, viewset, _basename in router.registry: + if viewset.__name__ == "ActionView": + # old actions (webhooks) are deprecated, no RBAC or service account support + continue + for viewset_method_name, required_perms in viewset.rbac_permissions.items(): + # setup Grafana API permissions response + if rbac_enabled: + permissions = {"perm": "value"} + expected = status.HTTP_403_FORBIDDEN + if give_perm: + permissions = {perm.value: "value" for perm in required_perms} + expected = status.HTTP_200_OK + mock_response = httpretty.Response(status=200, body=json.dumps(permissions)) + perms_url = f"{organization.grafana_url}/api/access-control/user/permissions" + httpretty.register_uri(httpretty.GET, perms_url, responses=[mock_response]) + else: + # service account auth is disabled + expected = status.HTTP_403_FORBIDDEN + + # iterate over all viewset actions, making an API request for each, + # using the user's token and confirming the response status code + if viewset_method_name in default_actions: + http_method, detail = default_actions[viewset_method_name] + else: + action_method = getattr(viewset, viewset_method_name) + http_method = list(action_method.mapping.keys())[0] + detail = action_method.detail + + method_path = f"{viewset.__module__}.{viewset.__name__}.{viewset_method_name}" + success = Response(status=status.HTTP_200_OK) + kwargs = {"pk": "NONEXISTENT"} if detail else None + if viewset_method_name in default_actions and detail: + url = reverse(f"api-public:{_basename}-detail", kwargs=kwargs) + elif viewset_method_name in default_actions and not detail: + url = reverse(f"api-public:{_basename}-list", kwargs=kwargs) + else: + name = viewset_method_name.replace("_", "-") + url = reverse(f"api-public:{_basename}-{name}", kwargs=kwargs) + + with patch(method_path, return_value=success): + headers = { + "HTTP_AUTHORIZATION": token_string, + "HTTP_X_GRAFANA_URL": organization.grafana_url, + } + response = client.generic(path=url, method=http_method, **headers) + assert ( + response.status_code == expected + if viewset.__name__ not in VIEWS_REQUIRING_USER_AUTH + # user-specific APIs do not support service account auth + else status.HTTP_403_FORBIDDEN + ) diff --git a/engine/apps/public_api/tests/test_resolution_notes.py b/engine/apps/public_api/tests/test_resolution_notes.py index c3a89a1d..7a730e18 100644 --- a/engine/apps/public_api/tests/test_resolution_notes.py +++ b/engine/apps/public_api/tests/test_resolution_notes.py @@ -6,8 +6,8 @@ from rest_framework import status from rest_framework.test import APIClient from apps.alerts.models import ResolutionNote -from apps.auth_token.auth import GRAFANA_SA_PREFIX, ApiTokenAuthentication, GrafanaServiceAccountAuthentication -from apps.auth_token.models import ApiAuthToken +from apps.auth_token.auth import ApiTokenAuthentication, GrafanaServiceAccountAuthentication +from apps.auth_token.models import ApiAuthToken, ServiceAccountToken @pytest.mark.django_db @@ -366,7 +366,7 @@ def test_create_resolution_note_grafana_auth(make_organization_and_user, make_al mock_api_key_auth.assert_called_once() assert response.status_code == status.HTTP_403_FORBIDDEN - token = f"{GRAFANA_SA_PREFIX}123" + token = f"{ServiceAccountToken.GRAFANA_SA_PREFIX}123" # GrafanaServiceAccountAuthentication handle invalid token with patch( "apps.auth_token.auth.ApiTokenAuthentication.authenticate", wraps=api_token_auth.authenticate diff --git a/engine/apps/public_api/views/alert_groups.py b/engine/apps/public_api/views/alert_groups.py index d4f4a302..fc5d01d0 100644 --- a/engine/apps/public_api/views/alert_groups.py +++ b/engine/apps/public_api/views/alert_groups.py @@ -12,12 +12,13 @@ from apps.alerts.models import AlertGroup, AlertReceiveChannel from apps.alerts.tasks import delete_alert_group, wipe from apps.api.label_filtering import parse_label_query from apps.api.permissions import RBACPermission -from apps.auth_token.auth import ApiTokenAuthentication +from apps.auth_token.auth import ApiTokenAuthentication, GrafanaServiceAccountAuthentication from apps.public_api.constants import VALID_DATE_FOR_DELETE_INCIDENT from apps.public_api.helpers import is_valid_group_creation_date, team_has_slack_token_for_deleting from apps.public_api.serializers import AlertGroupSerializer from apps.public_api.throttlers.user_throttle import UserThrottle -from common.api_helpers.exceptions import BadRequest +from apps.user_management.models import ServiceAccountUser +from common.api_helpers.exceptions import BadRequest, Forbidden from common.api_helpers.filters import ( NO_TEAM_VALUE, ByTeamModelFieldFilterMixin, @@ -57,7 +58,7 @@ class AlertGroupView( mixins.DestroyModelMixin, GenericViewSet, ): - authentication_classes = (ApiTokenAuthentication,) + authentication_classes = (GrafanaServiceAccountAuthentication, ApiTokenAuthentication) permission_classes = (IsAuthenticated, RBACPermission) rbac_permissions = { @@ -170,6 +171,9 @@ class AlertGroupView( @action(methods=["post"], detail=True) def acknowledge(self, request, pk): + if isinstance(request.user, ServiceAccountUser): + raise Forbidden(detail="Service accounts are not allowed to acknowledge alert groups") + alert_group = self.get_object() if alert_group.acknowledged: @@ -189,6 +193,9 @@ class AlertGroupView( @action(methods=["post"], detail=True) def unacknowledge(self, request, pk): + if isinstance(request.user, ServiceAccountUser): + raise Forbidden(detail="Service accounts are not allowed to unacknowledge alert groups") + alert_group = self.get_object() if not alert_group.acknowledged: @@ -208,6 +215,9 @@ class AlertGroupView( @action(methods=["post"], detail=True) def resolve(self, request, pk): + if isinstance(request.user, ServiceAccountUser): + raise Forbidden(detail="Service accounts are not allowed to resolve alert groups") + alert_group = self.get_object() if alert_group.resolved: @@ -225,6 +235,9 @@ class AlertGroupView( @action(methods=["post"], detail=True) def unresolve(self, request, pk): + if isinstance(request.user, ServiceAccountUser): + raise Forbidden(detail="Service accounts are not allowed to unresolve alert groups") + alert_group = self.get_object() if not alert_group.resolved: @@ -241,6 +254,9 @@ class AlertGroupView( @action(methods=["post"], detail=True) def silence(self, request, pk=None): + if isinstance(request.user, ServiceAccountUser): + raise Forbidden(detail="Service accounts are not allowed to silence alert groups") + alert_group = self.get_object() delay = request.data.get("delay") @@ -267,6 +283,9 @@ class AlertGroupView( @action(methods=["post"], detail=True) def unsilence(self, request, pk=None): + if isinstance(request.user, ServiceAccountUser): + raise Forbidden(detail="Service accounts are not allowed to unsilence alert groups") + alert_group = self.get_object() if not alert_group.silenced: diff --git a/engine/apps/public_api/views/alerts.py b/engine/apps/public_api/views/alerts.py index b96d51c5..0f3d1d46 100644 --- a/engine/apps/public_api/views/alerts.py +++ b/engine/apps/public_api/views/alerts.py @@ -7,7 +7,7 @@ from rest_framework.viewsets import GenericViewSet from apps.alerts.models import Alert from apps.api.permissions import RBACPermission -from apps.auth_token.auth import ApiTokenAuthentication +from apps.auth_token.auth import ApiTokenAuthentication, GrafanaServiceAccountAuthentication from apps.public_api.serializers.alerts import AlertSerializer from apps.public_api.throttlers.user_throttle import UserThrottle from common.api_helpers.mixins import RateLimitHeadersMixin @@ -19,7 +19,7 @@ class AlertFilter(filters.FilterSet): class AlertView(RateLimitHeadersMixin, mixins.ListModelMixin, GenericViewSet): - authentication_classes = (ApiTokenAuthentication,) + authentication_classes = (GrafanaServiceAccountAuthentication, ApiTokenAuthentication) permission_classes = (IsAuthenticated, RBACPermission) rbac_permissions = { diff --git a/engine/apps/public_api/views/escalation_chains.py b/engine/apps/public_api/views/escalation_chains.py index 84bb7162..52a1cc44 100644 --- a/engine/apps/public_api/views/escalation_chains.py +++ b/engine/apps/public_api/views/escalation_chains.py @@ -5,7 +5,7 @@ from rest_framework.viewsets import ModelViewSet from apps.alerts.models import EscalationChain from apps.api.permissions import RBACPermission -from apps.auth_token.auth import ApiTokenAuthentication +from apps.auth_token.auth import ApiTokenAuthentication, GrafanaServiceAccountAuthentication from apps.public_api.serializers import EscalationChainSerializer from apps.public_api.throttlers.user_throttle import UserThrottle from common.api_helpers.filters import ByTeamFilter @@ -15,7 +15,7 @@ from common.insight_log import EntityEvent, write_resource_insight_log class EscalationChainView(RateLimitHeadersMixin, ModelViewSet): - authentication_classes = (ApiTokenAuthentication,) + authentication_classes = (GrafanaServiceAccountAuthentication, ApiTokenAuthentication) permission_classes = (IsAuthenticated, RBACPermission) rbac_permissions = { diff --git a/engine/apps/public_api/views/escalation_policies.py b/engine/apps/public_api/views/escalation_policies.py index ddbaeae8..e91e52f4 100644 --- a/engine/apps/public_api/views/escalation_policies.py +++ b/engine/apps/public_api/views/escalation_policies.py @@ -5,7 +5,7 @@ from rest_framework.viewsets import ModelViewSet from apps.alerts.models import EscalationPolicy from apps.api.permissions import RBACPermission -from apps.auth_token.auth import ApiTokenAuthentication +from apps.auth_token.auth import ApiTokenAuthentication, GrafanaServiceAccountAuthentication from apps.public_api.serializers import EscalationPolicySerializer, EscalationPolicyUpdateSerializer from apps.public_api.throttlers.user_throttle import UserThrottle from common.api_helpers.mixins import RateLimitHeadersMixin, UpdateSerializerMixin @@ -14,7 +14,7 @@ from common.insight_log import EntityEvent, write_resource_insight_log class EscalationPolicyView(RateLimitHeadersMixin, UpdateSerializerMixin, ModelViewSet): - authentication_classes = (ApiTokenAuthentication,) + authentication_classes = (GrafanaServiceAccountAuthentication, ApiTokenAuthentication) permission_classes = (IsAuthenticated, RBACPermission) rbac_permissions = { diff --git a/engine/apps/public_api/views/integrations.py b/engine/apps/public_api/views/integrations.py index 26c55224..e8ec9a85 100644 --- a/engine/apps/public_api/views/integrations.py +++ b/engine/apps/public_api/views/integrations.py @@ -5,7 +5,7 @@ from rest_framework.viewsets import ModelViewSet from apps.alerts.models import AlertReceiveChannel from apps.api.permissions import RBACPermission -from apps.auth_token.auth import ApiTokenAuthentication +from apps.auth_token.auth import ApiTokenAuthentication, GrafanaServiceAccountAuthentication from apps.public_api.serializers import IntegrationSerializer, IntegrationUpdateSerializer from apps.public_api.throttlers.user_throttle import UserThrottle from common.api_helpers.exceptions import BadRequest @@ -24,7 +24,7 @@ class IntegrationView( MaintainableObjectMixin, ModelViewSet, ): - authentication_classes = (ApiTokenAuthentication,) + authentication_classes = (GrafanaServiceAccountAuthentication, ApiTokenAuthentication) permission_classes = (IsAuthenticated, RBACPermission) rbac_permissions = { diff --git a/engine/apps/public_api/views/on_call_shifts.py b/engine/apps/public_api/views/on_call_shifts.py index e825ea35..2e091e94 100644 --- a/engine/apps/public_api/views/on_call_shifts.py +++ b/engine/apps/public_api/views/on_call_shifts.py @@ -5,7 +5,7 @@ from rest_framework.permissions import IsAuthenticated from rest_framework.viewsets import ModelViewSet from apps.api.permissions import RBACPermission -from apps.auth_token.auth import ApiTokenAuthentication +from apps.auth_token.auth import ApiTokenAuthentication, GrafanaServiceAccountAuthentication from apps.public_api.serializers import CustomOnCallShiftSerializer, CustomOnCallShiftUpdateSerializer from apps.public_api.throttlers.user_throttle import UserThrottle from apps.schedules.models import CustomOnCallShift @@ -16,7 +16,7 @@ from common.insight_log import EntityEvent, write_resource_insight_log class CustomOnCallShiftView(RateLimitHeadersMixin, UpdateSerializerMixin, ModelViewSet): - authentication_classes = (ApiTokenAuthentication,) + authentication_classes = (GrafanaServiceAccountAuthentication, ApiTokenAuthentication) permission_classes = (IsAuthenticated, RBACPermission) rbac_permissions = { diff --git a/engine/apps/public_api/views/organizations.py b/engine/apps/public_api/views/organizations.py index 1df2f63a..473d79de 100644 --- a/engine/apps/public_api/views/organizations.py +++ b/engine/apps/public_api/views/organizations.py @@ -3,7 +3,7 @@ from rest_framework.settings import api_settings from rest_framework.viewsets import ReadOnlyModelViewSet from apps.api.permissions import RBACPermission -from apps.auth_token.auth import ApiTokenAuthentication +from apps.auth_token.auth import ApiTokenAuthentication, GrafanaServiceAccountAuthentication from apps.public_api.serializers import OrganizationSerializer from apps.public_api.throttlers.user_throttle import UserThrottle from apps.user_management.models import Organization @@ -15,7 +15,7 @@ class OrganizationView( RateLimitHeadersMixin, ReadOnlyModelViewSet, ): - authentication_classes = (ApiTokenAuthentication,) + authentication_classes = (GrafanaServiceAccountAuthentication, ApiTokenAuthentication) permission_classes = (IsAuthenticated, RBACPermission) rbac_permissions = { diff --git a/engine/apps/public_api/views/routes.py b/engine/apps/public_api/views/routes.py index 79461527..19ddc105 100644 --- a/engine/apps/public_api/views/routes.py +++ b/engine/apps/public_api/views/routes.py @@ -7,7 +7,7 @@ from rest_framework.viewsets import ModelViewSet from apps.alerts.models import ChannelFilter from apps.api.permissions import RBACPermission -from apps.auth_token.auth import ApiTokenAuthentication +from apps.auth_token.auth import ApiTokenAuthentication, GrafanaServiceAccountAuthentication from apps.public_api.serializers import ChannelFilterSerializer, ChannelFilterUpdateSerializer from apps.public_api.throttlers.user_throttle import UserThrottle from common.api_helpers.exceptions import BadRequest @@ -17,7 +17,7 @@ from common.insight_log import EntityEvent, write_resource_insight_log class ChannelFilterView(RateLimitHeadersMixin, UpdateSerializerMixin, ModelViewSet): - authentication_classes = (ApiTokenAuthentication,) + authentication_classes = (GrafanaServiceAccountAuthentication, ApiTokenAuthentication) permission_classes = (IsAuthenticated, RBACPermission) rbac_permissions = { diff --git a/engine/apps/public_api/views/schedules.py b/engine/apps/public_api/views/schedules.py index 6dcca6fd..5960ad48 100644 --- a/engine/apps/public_api/views/schedules.py +++ b/engine/apps/public_api/views/schedules.py @@ -9,7 +9,11 @@ from rest_framework.views import Response from rest_framework.viewsets import ModelViewSet from apps.api.permissions import RBACPermission -from apps.auth_token.auth import ApiTokenAuthentication, ScheduleExportAuthentication +from apps.auth_token.auth import ( + ApiTokenAuthentication, + GrafanaServiceAccountAuthentication, + ScheduleExportAuthentication, +) from apps.public_api.custom_renderers import CalendarRenderer from apps.public_api.serializers import PolymorphicScheduleSerializer, PolymorphicScheduleUpdateSerializer from apps.public_api.serializers.schedules_base import FinalShiftQueryParamsSerializer @@ -28,7 +32,7 @@ logger = logging.getLogger(__name__) class OnCallScheduleChannelView(RateLimitHeadersMixin, UpdateSerializerMixin, ModelViewSet): - authentication_classes = (ApiTokenAuthentication,) + authentication_classes = (GrafanaServiceAccountAuthentication, ApiTokenAuthentication) permission_classes = (IsAuthenticated, RBACPermission) rbac_permissions = { diff --git a/engine/apps/public_api/views/shift_swap.py b/engine/apps/public_api/views/shift_swap.py index 07f978e5..c46c1419 100644 --- a/engine/apps/public_api/views/shift_swap.py +++ b/engine/apps/public_api/views/shift_swap.py @@ -10,7 +10,7 @@ from rest_framework.serializers import BaseSerializer from apps.api.permissions import AuthenticatedRequest, RBACPermission from apps.api.views.shift_swap import BaseShiftSwapViewSet -from apps.auth_token.auth import ApiTokenAuthentication +from apps.auth_token.auth import ApiTokenAuthentication, GrafanaServiceAccountAuthentication from apps.public_api.throttlers.user_throttle import UserThrottle from apps.schedules.models import ShiftSwapRequest from apps.user_management.models import User @@ -23,7 +23,7 @@ logger = logging.getLogger(__name__) class ShiftSwapViewSet(RateLimitHeadersMixin, BaseShiftSwapViewSet): # set authentication and permission classes - authentication_classes = (ApiTokenAuthentication,) + authentication_classes = (GrafanaServiceAccountAuthentication, ApiTokenAuthentication) permission_classes = (IsAuthenticated, RBACPermission) rbac_permissions = { diff --git a/engine/apps/public_api/views/slack_channels.py b/engine/apps/public_api/views/slack_channels.py index 77581f3d..35f38402 100644 --- a/engine/apps/public_api/views/slack_channels.py +++ b/engine/apps/public_api/views/slack_channels.py @@ -3,7 +3,7 @@ from rest_framework.permissions import IsAuthenticated from rest_framework.viewsets import GenericViewSet from apps.api.permissions import RBACPermission -from apps.auth_token.auth import ApiTokenAuthentication +from apps.auth_token.auth import ApiTokenAuthentication, GrafanaServiceAccountAuthentication from apps.public_api.serializers.slack_channel import SlackChannelSerializer from apps.public_api.throttlers.user_throttle import UserThrottle from apps.slack.models import SlackChannel @@ -12,7 +12,7 @@ from common.api_helpers.paginators import FiftyPageSizePaginator class SlackChannelView(RateLimitHeadersMixin, mixins.ListModelMixin, GenericViewSet): - authentication_classes = (ApiTokenAuthentication,) + authentication_classes = (GrafanaServiceAccountAuthentication, ApiTokenAuthentication) permission_classes = (IsAuthenticated, RBACPermission) rbac_permissions = { diff --git a/engine/apps/public_api/views/teams.py b/engine/apps/public_api/views/teams.py index 490e74ef..6d399bad 100644 --- a/engine/apps/public_api/views/teams.py +++ b/engine/apps/public_api/views/teams.py @@ -3,7 +3,7 @@ from rest_framework.mixins import ListModelMixin, RetrieveModelMixin from rest_framework.permissions import IsAuthenticated from apps.api.permissions import RBACPermission -from apps.auth_token.auth import ApiTokenAuthentication +from apps.auth_token.auth import ApiTokenAuthentication, GrafanaServiceAccountAuthentication from apps.public_api.serializers.teams import TeamSerializer from apps.public_api.tf_sync import is_request_from_terraform, sync_teams_on_tf_request from apps.public_api.throttlers.user_throttle import UserThrottle @@ -14,7 +14,7 @@ from common.api_helpers.paginators import FiftyPageSizePaginator class TeamView(PublicPrimaryKeyMixin, RetrieveModelMixin, ListModelMixin, viewsets.GenericViewSet): serializer_class = TeamSerializer - authentication_classes = (ApiTokenAuthentication,) + authentication_classes = (GrafanaServiceAccountAuthentication, ApiTokenAuthentication) permission_classes = (IsAuthenticated, RBACPermission) rbac_permissions = { diff --git a/engine/apps/public_api/views/user_groups.py b/engine/apps/public_api/views/user_groups.py index ced7f626..bb1dac7f 100644 --- a/engine/apps/public_api/views/user_groups.py +++ b/engine/apps/public_api/views/user_groups.py @@ -3,7 +3,7 @@ from rest_framework.permissions import IsAuthenticated from rest_framework.viewsets import GenericViewSet from apps.api.permissions import RBACPermission -from apps.auth_token.auth import ApiTokenAuthentication +from apps.auth_token.auth import ApiTokenAuthentication, GrafanaServiceAccountAuthentication from apps.public_api.serializers.user_groups import UserGroupSerializer from apps.public_api.throttlers.user_throttle import UserThrottle from apps.slack.models import SlackUserGroup @@ -12,7 +12,7 @@ from common.api_helpers.paginators import FiftyPageSizePaginator class UserGroupView(RateLimitHeadersMixin, mixins.ListModelMixin, GenericViewSet): - authentication_classes = (ApiTokenAuthentication,) + authentication_classes = (GrafanaServiceAccountAuthentication, ApiTokenAuthentication) permission_classes = (IsAuthenticated, RBACPermission) rbac_permissions = { diff --git a/engine/apps/public_api/views/users.py b/engine/apps/public_api/views/users.py index 97315fe2..129096e5 100644 --- a/engine/apps/public_api/views/users.py +++ b/engine/apps/public_api/views/users.py @@ -6,7 +6,11 @@ from rest_framework.views import Response from rest_framework.viewsets import ReadOnlyModelViewSet from apps.api.permissions import LegacyAccessControlRole, RBACPermission -from apps.auth_token.auth import ApiTokenAuthentication, UserScheduleExportAuthentication +from apps.auth_token.auth import ( + ApiTokenAuthentication, + GrafanaServiceAccountAuthentication, + UserScheduleExportAuthentication, +) from apps.public_api.custom_renderers import CalendarRenderer from apps.public_api.serializers import FastUserSerializer, UserSerializer from apps.public_api.tf_sync import is_request_from_terraform, sync_users_on_tf_request @@ -35,7 +39,7 @@ class UserFilter(filters.FilterSet): class UserView(RateLimitHeadersMixin, ShortSerializerMixin, ReadOnlyModelViewSet): - authentication_classes = (ApiTokenAuthentication,) + authentication_classes = (GrafanaServiceAccountAuthentication, ApiTokenAuthentication) permission_classes = (IsAuthenticated, RBACPermission) rbac_permissions = { diff --git a/engine/apps/public_api/views/webhooks.py b/engine/apps/public_api/views/webhooks.py index 8f75148b..b1a6a47b 100644 --- a/engine/apps/public_api/views/webhooks.py +++ b/engine/apps/public_api/views/webhooks.py @@ -6,7 +6,7 @@ from rest_framework.response import Response from rest_framework.viewsets import ModelViewSet from apps.api.permissions import RBACPermission -from apps.auth_token.auth import ApiTokenAuthentication +from apps.auth_token.auth import ApiTokenAuthentication, GrafanaServiceAccountAuthentication from apps.public_api.serializers.webhooks import ( WebhookCreateSerializer, WebhookResponseSerializer, @@ -21,7 +21,7 @@ from common.insight_log import EntityEvent, write_resource_insight_log class WebhooksView(RateLimitHeadersMixin, UpdateSerializerMixin, ModelViewSet): - authentication_classes = (ApiTokenAuthentication,) + authentication_classes = (GrafanaServiceAccountAuthentication, ApiTokenAuthentication) permission_classes = (IsAuthenticated, RBACPermission) rbac_permissions = { diff --git a/engine/apps/user_management/migrations/0027_serviceaccount.py b/engine/apps/user_management/migrations/0027_serviceaccount.py new file mode 100644 index 00000000..dc9e520b --- /dev/null +++ b/engine/apps/user_management/migrations/0027_serviceaccount.py @@ -0,0 +1,26 @@ +# Generated by Django 4.2.15 on 2024-11-12 13:13 + +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + + dependencies = [ + ('user_management', '0026_auto_20241017_1919'), + ] + + operations = [ + migrations.CreateModel( + name='ServiceAccount', + fields=[ + ('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('grafana_id', models.PositiveIntegerField()), + ('login', models.CharField(max_length=300)), + ('organization', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='service_accounts', to='user_management.organization')), + ], + options={ + 'unique_together': {('grafana_id', 'organization')}, + }, + ), + ] diff --git a/engine/apps/user_management/models/__init__.py b/engine/apps/user_management/models/__init__.py index e2bcd4c7..2fd5a9aa 100644 --- a/engine/apps/user_management/models/__init__.py +++ b/engine/apps/user_management/models/__init__.py @@ -1,4 +1,5 @@ from .user import User # noqa: F401, isort: skip from .organization import Organization # noqa: F401 from .region import Region # noqa: F401 +from .service_account import ServiceAccount, ServiceAccountUser # noqa: F401 from .team import Team # noqa: F401 diff --git a/engine/apps/user_management/models/service_account.py b/engine/apps/user_management/models/service_account.py new file mode 100644 index 00000000..5082f7b9 --- /dev/null +++ b/engine/apps/user_management/models/service_account.py @@ -0,0 +1,55 @@ +from dataclasses import dataclass +from typing import List + +from django.db import models + +from apps.user_management.models import Organization + + +@dataclass +class ServiceAccountUser: + """Authenticated service account in public API requests.""" + + service_account: "ServiceAccount" + organization: "Organization" # required for insight logs interface + username: str # required for insight logs interface + public_primary_key: str # required for insight logs interface + role: str # required for permissions check + permissions: List[str] # required for permissions check + + @property + def id(self): + return self.service_account.id + + @property + def pk(self): + return self.service_account.id + + @property + def organization_id(self): + return self.organization.id + + @property + def is_authenticated(self): + return True + + +class ServiceAccount(models.Model): + organization: "Organization" + + grafana_id = models.PositiveIntegerField() + organization = models.ForeignKey(Organization, on_delete=models.CASCADE, related_name="service_accounts") + login = models.CharField(max_length=300) + + class Meta: + unique_together = ("grafana_id", "organization") + + @property + def username(self): + # required for insight logs interface + return self.login + + @property + def public_primary_key(self): + # required for insight logs interface + return f"service-account:{self.grafana_id}" diff --git a/engine/apps/user_management/tests/factories.py b/engine/apps/user_management/tests/factories.py index ccfbb858..a33aefac 100644 --- a/engine/apps/user_management/tests/factories.py +++ b/engine/apps/user_management/tests/factories.py @@ -1,6 +1,6 @@ import factory -from apps.user_management.models import Organization, Region, Team, User +from apps.user_management.models import Organization, Region, ServiceAccount, Team, User from common.utils import UniqueFaker @@ -41,3 +41,11 @@ class RegionFactory(factory.DjangoModelFactory): class Meta: model = Region + + +class ServiceAccountFactory(factory.DjangoModelFactory): + grafana_id = UniqueFaker("pyint") + login = UniqueFaker("user_name") + + class Meta: + model = ServiceAccount diff --git a/engine/conftest.py b/engine/conftest.py index a95383dd..0b66e3ad 100644 --- a/engine/conftest.py +++ b/engine/conftest.py @@ -1,3 +1,4 @@ +import binascii import datetime import json import os @@ -46,11 +47,14 @@ from apps.api.permissions import ( LegacyAccessControlRole, RBACPermission, ) +from apps.auth_token import constants as auth_token_constants +from apps.auth_token.crypto import hash_token_string from apps.auth_token.models import ( ApiAuthToken, GoogleOAuth2Token, IntegrationBacksyncAuthToken, PluginAuthToken, + ServiceAccountToken, SlackAuthToken, ) from apps.base.models.user_notification_policy_log_record import ( @@ -102,7 +106,13 @@ from apps.telegram.tests.factories import ( TelegramVerificationCodeFactory, ) from apps.user_management.models.user import User, listen_for_user_model_save -from apps.user_management.tests.factories import OrganizationFactory, RegionFactory, TeamFactory, UserFactory +from apps.user_management.tests.factories import ( + OrganizationFactory, + RegionFactory, + ServiceAccountFactory, + TeamFactory, + UserFactory, +) from apps.webhooks.presets.preset_options import WebhookPresetOptions from apps.webhooks.tests.factories import CustomWebhookFactory, WebhookResponseFactory from apps.webhooks.tests.test_webhook_presets import ( @@ -252,6 +262,30 @@ def make_user_for_organization(make_user): return _make_user_for_organization +@pytest.fixture +def make_service_account_for_organization(make_user): + def _make_service_account_for_organization(organization, **kwargs): + return ServiceAccountFactory(organization=organization, **kwargs) + + return _make_service_account_for_organization + + +@pytest.fixture +def make_token_for_service_account(): + def _make_token_for_service_account(service_account, token_string): + prefix_length = len(ServiceAccountToken.GRAFANA_SA_PREFIX) + token_key = token_string[prefix_length : prefix_length + auth_token_constants.TOKEN_KEY_LENGTH] + hashable_token = binascii.hexlify(token_string.encode()).decode() + digest = hash_token_string(hashable_token) + return ServiceAccountToken.objects.create( + service_account=service_account, + token_key=token_key, + digest=digest, + ) + + return _make_token_for_service_account + + @pytest.fixture def make_token_for_organization(): def _make_token_for_organization(organization): diff --git a/engine/engine/middlewares.py b/engine/engine/middlewares.py index c3da3c4c..0173323b 100644 --- a/engine/engine/middlewares.py +++ b/engine/engine/middlewares.py @@ -28,9 +28,13 @@ class RequestTimeLoggingMiddleware(MiddlewareMixin): ) if hasattr(request, "user") and request.user and request.user.id and hasattr(request.user, "organization"): user_id = request.user.id + if hasattr(request.user, "service_account"): + message += f"service_account_id={user_id} " + else: + message += f"user_id={user_id} " org_id = request.user.organization.id org_slug = request.user.organization.org_slug - message += f"user_id={user_id} org_id={org_id} org_slug={org_slug} " + message += f"org_id={org_id} org_slug={org_slug} " if request.path.startswith("/integrations/v1"): split_path = request.path.split("/") integration_type = split_path[3] From 1bd30b3cf8e3bd6804da7ae323c73e85781aa8b4 Mon Sep 17 00:00:00 2001 From: Joey Orlando Date: Tue, 19 Nov 2024 14:23:48 -0500 Subject: [PATCH 09/12] chore: remove deprecated `AlertGroupPostMortem` model + recently refactored/deprecated slack channel related columns (#5240) # What this PR does - `AlertGroupPostMortem` has no references in the codebase.. I stumbled across it while working on https://github.com/grafana/oncall/pull/5224 and decided to just remove it - Removing old Slack channel related `VARCHAR` columns; these were refactored to foreign key references to `slack_slackchannel` table in following PRs: - https://github.com/grafana/oncall/pull/5224 - https://github.com/grafana/oncall/pull/5199 - https://github.com/grafana/oncall/pull/5191 ## Checklist - [x] Unit, integration, and e2e (if applicable) tests updated - [x] Documentation added (or `pr:no public docs` PR label added if not required) - [x] Added the relevant release notes label (see labels prefixed w/ `release:`). These labels dictate how your PR will show up in the autogenerated release notes. --- .../migrations/0001_squashed_initial.py | 2 +- ...hannelfilter__slack_channel_id_and_more.py | 26 ++++++++++ engine/apps/alerts/models/channel_filter.py | 4 -- engine/apps/alerts/models/resolution_note.py | 50 +++---------------- .../0020_remove_oncallschedule_channel.py | 19 +++++++ .../apps/schedules/models/on_call_schedule.py | 2 - ...ove_organization_general_log_channel_id.py | 19 +++++++ .../user_management/models/organization.py | 3 -- 8 files changed, 73 insertions(+), 52 deletions(-) create mode 100644 engine/apps/alerts/migrations/0066_remove_channelfilter__slack_channel_id_and_more.py create mode 100644 engine/apps/schedules/migrations/0020_remove_oncallschedule_channel.py create mode 100644 engine/apps/user_management/migrations/0028_remove_organization_general_log_channel_id.py diff --git a/engine/apps/alerts/migrations/0001_squashed_initial.py b/engine/apps/alerts/migrations/0001_squashed_initial.py index 0c96d7d4..8426d263 100644 --- a/engine/apps/alerts/migrations/0001_squashed_initial.py +++ b/engine/apps/alerts/migrations/0001_squashed_initial.py @@ -119,7 +119,7 @@ class Migration(migrations.Migration): name='AlertGroupPostmortem', fields=[ ('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), - ('public_primary_key', models.CharField(default=apps.alerts.models.resolution_note.generate_public_primary_key_for_alert_group_postmortem, max_length=20, unique=True, validators=[django.core.validators.MinLengthValidator(13)])), + ('public_primary_key', models.CharField(max_length=20, unique=True, validators=[django.core.validators.MinLengthValidator(13)])), ('created_at', models.DateTimeField(auto_now_add=True)), ('last_modified', models.DateTimeField(auto_now=True)), ('text', models.TextField(default=None, max_length=3000, null=True)), diff --git a/engine/apps/alerts/migrations/0066_remove_channelfilter__slack_channel_id_and_more.py b/engine/apps/alerts/migrations/0066_remove_channelfilter__slack_channel_id_and_more.py new file mode 100644 index 00000000..03c5f534 --- /dev/null +++ b/engine/apps/alerts/migrations/0066_remove_channelfilter__slack_channel_id_and_more.py @@ -0,0 +1,26 @@ +# Generated by Django 4.2.16 on 2024-11-06 21:11 + +from django.db import migrations +import django_migration_linter as linter + + +class Migration(migrations.Migration): + + dependencies = [ + ('alerts', '0065_alertreceivechannel_service_account'), + ] + + operations = [ + linter.IgnoreMigration(), + migrations.RemoveField( + model_name='channelfilter', + name='_slack_channel_id', + ), + migrations.RemoveField( + model_name='resolutionnoteslackmessage', + name='_slack_channel_id', + ), + migrations.DeleteModel( + name='AlertGroupPostmortem', + ), + ] diff --git a/engine/apps/alerts/models/channel_filter.py b/engine/apps/alerts/models/channel_filter.py index f7cb302f..3ea2ea8b 100644 --- a/engine/apps/alerts/models/channel_filter.py +++ b/engine/apps/alerts/models/channel_filter.py @@ -69,9 +69,6 @@ class ChannelFilter(OrderedModel): notify_in_slack = models.BooleanField(null=True, default=True) notify_in_telegram = models.BooleanField(null=True, default=False) - - # TODO: remove _slack_channel_id in future release - _slack_channel_id = models.CharField(max_length=100, null=True, default=None) slack_channel = models.ForeignKey( "slack.SlackChannel", null=True, @@ -79,7 +76,6 @@ class ChannelFilter(OrderedModel): on_delete=models.SET_NULL, related_name="+", ) - telegram_channel = models.ForeignKey( "telegram.TelegramToOrganizationConnector", on_delete=models.SET_NULL, diff --git a/engine/apps/alerts/models/resolution_note.py b/engine/apps/alerts/models/resolution_note.py index e2f3586a..90e65166 100644 --- a/engine/apps/alerts/models/resolution_note.py +++ b/engine/apps/alerts/models/resolution_note.py @@ -14,20 +14,7 @@ from common.utils import clean_markup if typing.TYPE_CHECKING: from apps.alerts.models import AlertGroup from apps.slack.models import SlackChannel - - -def generate_public_primary_key_for_alert_group_postmortem(): - prefix = "P" - new_public_primary_key = generate_public_primary_key(prefix) - - failure_counter = 0 - while AlertGroupPostmortem.objects.filter(public_primary_key=new_public_primary_key).exists(): - new_public_primary_key = increase_public_primary_key_length( - failure_counter=failure_counter, prefix=prefix, model_name="AlertGroupPostmortem" - ) - failure_counter += 1 - - return new_public_primary_key + from apps.user_management.models import User def generate_public_primary_key_for_resolution_note(): @@ -75,9 +62,6 @@ class ResolutionNoteSlackMessage(models.Model): related_name="added_resolution_note_slack_messages", ) text = models.TextField(max_length=3000, default=None, null=True) - - # TODO: remove _slack_channel_id in future release - _slack_channel_id = models.CharField(max_length=100, null=True, default=None) slack_channel = models.ForeignKey( "slack.SlackChannel", null=True, @@ -85,7 +69,6 @@ class ResolutionNoteSlackMessage(models.Model): on_delete=models.SET_NULL, related_name="+", ) - ts = models.CharField(max_length=100, null=True, default=None) thread_ts = models.CharField(max_length=100, null=True, default=None) permalink = models.CharField(max_length=250, null=True, default=None) @@ -130,6 +113,7 @@ class ResolutionNoteQueryset(models.QuerySet): class ResolutionNote(models.Model): alert_group: "AlertGroup" + author: typing.Optional["User"] resolution_note_slack_message: typing.Optional[ResolutionNoteSlackMessage] objects = ResolutionNoteQueryset.as_manager() @@ -213,29 +197,11 @@ class ResolutionNote(models.Model): return result - def author_verbal(self, mention): + def author_verbal(self, mention: bool) -> str: """ - Postmortems to resolution notes included migrating AlertGroupPostmortem to ResolutionNotes. - But AlertGroupPostmortem has no author field. So this method was introduces as workaround. + Postmortems to resolution notes included migrating `AlertGroupPostmortem` to `ResolutionNote`s. + But `AlertGroupPostmortem` has no author field. So this method was introduced as a workaround. + + (see git history for more details on what `AlertGroupPostmortem` was) """ - if self.author is not None: - return self.author.get_username_with_slack_verbal(mention) - else: - return "" - - -class AlertGroupPostmortem(models.Model): - public_primary_key = models.CharField( - max_length=20, - validators=[MinLengthValidator(settings.PUBLIC_PRIMARY_KEY_MIN_LENGTH + 1)], - unique=True, - default=generate_public_primary_key_for_alert_group_postmortem, - ) - alert_group = models.ForeignKey( - "alerts.AlertGroup", - on_delete=models.CASCADE, - related_name="postmortem_text", - ) - created_at = models.DateTimeField(auto_now_add=True) - last_modified = models.DateTimeField(auto_now=True) - text = models.TextField(max_length=3000, default=None, null=True) + return "" if self.author is None else self.author.get_username_with_slack_verbal(mention) diff --git a/engine/apps/schedules/migrations/0020_remove_oncallschedule_channel.py b/engine/apps/schedules/migrations/0020_remove_oncallschedule_channel.py new file mode 100644 index 00000000..e4d19138 --- /dev/null +++ b/engine/apps/schedules/migrations/0020_remove_oncallschedule_channel.py @@ -0,0 +1,19 @@ +# Generated by Django 4.2.16 on 2024-11-06 21:13 + +from django.db import migrations +import django_migration_linter as linter + + +class Migration(migrations.Migration): + + dependencies = [ + ('schedules', '0019_auto_20241021_1735'), + ] + + operations = [ + linter.IgnoreMigration(), + migrations.RemoveField( + model_name='oncallschedule', + name='channel', + ), + ] diff --git a/engine/apps/schedules/models/on_call_schedule.py b/engine/apps/schedules/models/on_call_schedule.py index 544ec847..e57cf4bc 100644 --- a/engine/apps/schedules/models/on_call_schedule.py +++ b/engine/apps/schedules/models/on_call_schedule.py @@ -209,8 +209,6 @@ class OnCallSchedule(PolymorphicModel): name = models.CharField(max_length=200) - # TODO: drop this field in a subsequent release, this has been migrated to slack_channel field - channel = models.CharField(max_length=100, null=True, default=None) slack_channel = models.ForeignKey( "slack.SlackChannel", null=True, diff --git a/engine/apps/user_management/migrations/0028_remove_organization_general_log_channel_id.py b/engine/apps/user_management/migrations/0028_remove_organization_general_log_channel_id.py new file mode 100644 index 00000000..6d415bdb --- /dev/null +++ b/engine/apps/user_management/migrations/0028_remove_organization_general_log_channel_id.py @@ -0,0 +1,19 @@ +# Generated by Django 4.2.16 on 2024-11-06 21:11 + +from django.db import migrations +import django_migration_linter as linter + + +class Migration(migrations.Migration): + + dependencies = [ + ('user_management', '0027_serviceaccount'), + ] + + operations = [ + linter.IgnoreMigration(), + migrations.RemoveField( + model_name='organization', + name='general_log_channel_id', + ), + ] diff --git a/engine/apps/user_management/models/organization.py b/engine/apps/user_management/models/organization.py index aac0aeae..2fbeefca 100644 --- a/engine/apps/user_management/models/organization.py +++ b/engine/apps/user_management/models/organization.py @@ -162,9 +162,6 @@ class Organization(MaintainableObject): slack_team_identity = models.ForeignKey( "slack.SlackTeamIdentity", on_delete=models.PROTECT, null=True, default=None, related_name="organizations" ) - - # TODO: drop this field in a subsequent release, this has been migrated to default_slack_channel field - general_log_channel_id = models.CharField(max_length=100, null=True, default=None) default_slack_channel = models.ForeignKey( "slack.SlackChannel", null=True, From 2024ee7f78aee69e7b7a4d7371c31cb593b8f3ee Mon Sep 17 00:00:00 2001 From: Michael Derynck Date: Tue, 19 Nov 2024 15:23:15 -0700 Subject: [PATCH 10/12] feat: Auto retry escalation on failed audit (#5265) # What this PR does Automatically retries escalation when alert groups fail auditing. This is the same effect as the continue_escalation command without any of the extra arguments. ## Checklist - [x] Unit, integration, and e2e (if applicable) tests updated - [x] Documentation added (or `pr:no public docs` PR label added if not required) - [x] Added the relevant release notes label (see labels prefixed w/ `release:`). These labels dictate how your PR will show up in the autogenerated release notes. --- .../alerts/tasks/check_escalation_finished.py | 41 +++++- .../test_check_escalation_finished_task.py | 123 ++++++++++++++++++ engine/settings/base.py | 2 + 3 files changed, 165 insertions(+), 1 deletion(-) diff --git a/engine/apps/alerts/tasks/check_escalation_finished.py b/engine/apps/alerts/tasks/check_escalation_finished.py index 9f3fb62d..8ae6d814 100644 --- a/engine/apps/alerts/tasks/check_escalation_finished.py +++ b/engine/apps/alerts/tasks/check_escalation_finished.py @@ -2,7 +2,9 @@ import datetime import typing import requests +from celery import uuid as celery_uuid from django.conf import settings +from django.core.cache import cache from django.db.models import Avg, F, Max, Q from django.utils import timezone @@ -174,6 +176,42 @@ def check_personal_notifications_task() -> None: task_logger.info(f"personal_notifications_triggered={triggered} personal_notifications_completed={completed}") +# Retries an alert group that has failed auditing if it is within the retry limit +# Returns whether an alert group escalation is being retried +def retry_audited_alert_group(alert_group) -> bool: + cache_key = f"audited-alert-group-retry-count-{alert_group.id}" + retry_count = cache.get(cache_key, 0) + if retry_count >= settings.AUDITED_ALERT_GROUP_MAX_RETRIES: + task_logger.info(f"Not retrying audited alert_group={alert_group.id} max retries exceeded.") + return False + + if alert_group.is_silenced_for_period: + task_logger.info(f"Not retrying audited alert_group={alert_group.id} as it is silenced.") + return False + + if not alert_group.escalation_snapshot: + task_logger.info(f"Not retrying audited alert_group={alert_group.id} as its escalation snapshot is empty.") + return False + + retry_count += 1 + cache.set(cache_key, retry_count, timeout=3600) + + task_id = celery_uuid() + alert_group.active_escalation_id = task_id + alert_group.save(update_fields=["active_escalation_id"]) + + from apps.alerts.tasks import escalate_alert_group + + escalate_alert_group.apply_async( + args=(alert_group.pk,), + immutable=True, + task_id=task_id, + eta=alert_group.next_step_eta, + ) + task_logger.info(f"Retrying audited alert_group={alert_group.id} attempt={retry_count}") + return True + + @shared_log_exception_on_failure_task def check_escalation_finished_task() -> None: """ @@ -221,7 +259,8 @@ def check_escalation_finished_task() -> None: try: audit_alert_group_escalation(alert_group) except AlertGroupEscalationPolicyExecutionAuditException: - alert_group_ids_that_failed_audit.append(str(alert_group.id)) + if not retry_audited_alert_group(alert_group): + alert_group_ids_that_failed_audit.append(str(alert_group.id)) failed_alert_groups_count = len(alert_group_ids_that_failed_audit) success_ratio = ( diff --git a/engine/apps/alerts/tests/test_check_escalation_finished_task.py b/engine/apps/alerts/tests/test_check_escalation_finished_task.py index 8aa5cbbd..229fabff 100644 --- a/engine/apps/alerts/tests/test_check_escalation_finished_task.py +++ b/engine/apps/alerts/tests/test_check_escalation_finished_task.py @@ -6,12 +6,14 @@ from django.test import override_settings from django.utils import timezone from apps.alerts.models import EscalationPolicy +from apps.alerts.tasks import escalate_alert_group from apps.alerts.tasks.check_escalation_finished import ( AlertGroupEscalationPolicyExecutionAuditException, audit_alert_group_escalation, check_alert_group_personal_notifications_task, check_escalation_finished_task, check_personal_notifications_task, + retry_audited_alert_group, send_alert_group_escalation_auditor_task_heartbeat, ) from apps.base.models import UserNotificationPolicy, UserNotificationPolicyLogRecord @@ -580,3 +582,124 @@ def test_check_escalation_finished_task_calls_audit_alert_group_personal_notific check_personal_notifications_task() assert "personal_notifications_triggered=6 personal_notifications_completed=2" in caplog.text + + +@patch("apps.alerts.tasks.check_escalation_finished.audit_alert_group_escalation") +@patch("apps.alerts.tasks.check_escalation_finished.retry_audited_alert_group") +@patch("apps.alerts.tasks.check_escalation_finished.send_alert_group_escalation_auditor_task_heartbeat") +@pytest.mark.django_db +def test_invoke_retry_from_check_escalation_finished_task( + mocked_send_alert_group_escalation_auditor_task_heartbeat, + mocked_retry_audited_alert_group, + mocked_audit_alert_group_escalation, + make_organization_and_user, + make_alert_receive_channel, + make_alert_group_that_started_at_specific_date, +): + organization, _ = make_organization_and_user() + alert_receive_channel = make_alert_receive_channel(organization) + + # Pass audit (should not be counted in final message or go to retry function) + alert_group1 = make_alert_group_that_started_at_specific_date(alert_receive_channel, received_delta=1) + # Fail audit but not retrying (should be counted in final message) + alert_group2 = make_alert_group_that_started_at_specific_date(alert_receive_channel, received_delta=5) + # Fail audit but retry (should not be counted in final message) + alert_group3 = make_alert_group_that_started_at_specific_date(alert_receive_channel, received_delta=10) + + def _mocked_audit_alert_group_escalation(alert_group): + if alert_group.id == alert_group2.id or alert_group.id == alert_group3.id: + raise AlertGroupEscalationPolicyExecutionAuditException(f"{alert_group2.id} failed audit") + + mocked_audit_alert_group_escalation.side_effect = _mocked_audit_alert_group_escalation + + def _mocked_retry_audited_alert_group(alert_group): + if alert_group.id == alert_group2.id: + return False + return True + + mocked_retry_audited_alert_group.side_effect = _mocked_retry_audited_alert_group + + with pytest.raises(AlertGroupEscalationPolicyExecutionAuditException) as exc: + check_escalation_finished_task() + + error_msg = str(exc.value) + + assert "The following alert group id(s) failed auditing:" in error_msg + assert str(alert_group1.id) not in error_msg + assert str(alert_group2.id) in error_msg + assert str(alert_group3.id) not in error_msg + + assert mocked_retry_audited_alert_group.call_count == 2 + mocked_send_alert_group_escalation_auditor_task_heartbeat.assert_not_called() + + +@patch.object(escalate_alert_group, "apply_async") +@override_settings(AUDITED_ALERT_GROUP_MAX_RETRIES=1) +@pytest.mark.django_db +def test_retry_audited_alert_group( + mocked_escalate_alert_group, + make_organization_and_user, + make_user_for_organization, + make_user_notification_policy, + make_escalation_chain, + make_escalation_policy, + make_channel_filter, + make_alert_receive_channel, + make_alert_group_that_started_at_specific_date, +): + organization, user = make_organization_and_user() + make_user_notification_policy( + user=user, + step=UserNotificationPolicy.Step.NOTIFY, + notify_by=UserNotificationPolicy.NotificationChannel.SLACK, + ) + + alert_receive_channel = make_alert_receive_channel(organization) + escalation_chain = make_escalation_chain(organization) + channel_filter = make_channel_filter(alert_receive_channel, escalation_chain=escalation_chain) + notify_to_multiple_users_step = make_escalation_policy( + escalation_chain=channel_filter.escalation_chain, + escalation_policy_step=EscalationPolicy.STEP_NOTIFY_MULTIPLE_USERS, + ) + notify_to_multiple_users_step.notify_to_users_queue.set([user]) + + alert_group1 = make_alert_group_that_started_at_specific_date(alert_receive_channel, channel_filter=channel_filter) + alert_group1.raw_escalation_snapshot = alert_group1.build_raw_escalation_snapshot() + alert_group1.raw_escalation_snapshot["last_active_escalation_policy_order"] = 1 + alert_group1.save() + + # Retry should occur + is_retrying = retry_audited_alert_group(alert_group1) + assert is_retrying + mocked_escalate_alert_group.assert_called() + mocked_escalate_alert_group.reset_mock() + + # No retry as attempts == max + is_retrying = retry_audited_alert_group(alert_group1) + assert not is_retrying + mocked_escalate_alert_group.assert_not_called() + mocked_escalate_alert_group.reset_mock() + + alert_group2 = make_alert_group_that_started_at_specific_date(alert_receive_channel, channel_filter=channel_filter) + # No retry because no escalation snapshot + is_retrying = retry_audited_alert_group(alert_group2) + assert not is_retrying + mocked_escalate_alert_group.assert_not_called() + mocked_escalate_alert_group.reset_mock() + + alert_group3 = make_alert_group_that_started_at_specific_date( + alert_receive_channel, + channel_filter=channel_filter, + silenced=True, + silenced_at=timezone.now(), + silenced_by_user=user, + silenced_until=(now + timezone.timedelta(hours=1)), + ) + alert_group3.raw_escalation_snapshot = alert_group1.build_raw_escalation_snapshot() + alert_group3.raw_escalation_snapshot["last_active_escalation_policy_order"] = 1 + alert_group3.save() + + # No retry because alert group silenced + is_retrying = retry_audited_alert_group(alert_group3) + assert not is_retrying + mocked_escalate_alert_group.assert_not_called() diff --git a/engine/settings/base.py b/engine/settings/base.py index 25ef7dc1..2b3cc971 100644 --- a/engine/settings/base.py +++ b/engine/settings/base.py @@ -988,3 +988,5 @@ CUSTOM_RATELIMITS = getenv_custom_ratelimit("CUSTOM_RATELIMITS", default={}) SYNC_V2_MAX_TASKS = getenv_integer("SYNC_V2_MAX_TASKS", 6) SYNC_V2_PERIOD_SECONDS = getenv_integer("SYNC_V2_PERIOD_SECONDS", 240) SYNC_V2_BATCH_SIZE = getenv_integer("SYNC_V2_BATCH_SIZE", 500) + +AUDITED_ALERT_GROUP_MAX_RETRIES = getenv_integer("AUDITED_ALERT_GROUP_MAX_RETRIES", 1) From 336b924a0811e1209e6543ca5caaa769519958c0 Mon Sep 17 00:00:00 2001 From: Jack Baldry Date: Wed, 20 Nov 2024 10:05:03 +0000 Subject: [PATCH 11/12] Fix first heading level (#5269) --- docs/sources/configure/jinja2-templating/_index.md | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/docs/sources/configure/jinja2-templating/_index.md b/docs/sources/configure/jinja2-templating/_index.md index 68837858..6cc158f7 100644 --- a/docs/sources/configure/jinja2-templating/_index.md +++ b/docs/sources/configure/jinja2-templating/_index.md @@ -23,8 +23,7 @@ refs: destination: /docs/grafana-cloud/alerting-and-irm/oncall/configure/integrations/references/webhook/ --- - -## Configure templates +# Configure templates Grafana OnCall integrates with your monitoring systems using webhooks with JSON payloads. By default, these webhooks deliver raw JSON payloads. From fda05a6cc43b509b8aed651355a659e558a1aca9 Mon Sep 17 00:00:00 2001 From: Joey Orlando Date: Wed, 20 Nov 2024 11:17:04 -0500 Subject: [PATCH 12/12] chore: remove deprecated `slack_channel` and `heartbeat` integration types (#5270) # What this PR does See [Slack discussion](https://raintank-corp.slack.com/archives/C06K1MQ07GS/p1732110700877869) for more context ## Checklist - [x] Unit, integration, and e2e (if applicable) tests updated - [x] Documentation added (or `pr:no public docs` PR label added if not required) - [x] Added the relevant release notes label (see labels prefixed w/ `release:`). These labels dictate how your PR will show up in the autogenerated release notes. --- dev/helm-local.yml | 2 +- .../alerts/models/alert_receive_channel.py | 18 +++----- ...rtbeat_actual_check_up_task_id_and_more.py | 23 ++++++++++ engine/apps/heartbeat/models.py | 10 ----- engine/apps/heartbeat/tasks.py | 6 --- engine/apps/heartbeat/tests/factories.py | 2 - engine/apps/integrations/tasks.py | 5 +-- .../public_api/tests/test_integrations.py | 1 - .../apps/slack/alert_group_slack_service.py | 5 +-- .../apps/slack/scenarios/distribute_alerts.py | 18 +------- engine/config_integrations/heartbeat.py | 29 ------------ engine/config_integrations/slack_channel.py | 44 ------------------- engine/settings/base.py | 2 - engine/settings/celery_task_routes.py | 1 - 14 files changed, 33 insertions(+), 133 deletions(-) create mode 100644 engine/apps/heartbeat/migrations/0003_remove_integrationheartbeat_actual_check_up_task_id_and_more.py delete mode 100644 engine/config_integrations/heartbeat.py delete mode 100644 engine/config_integrations/slack_channel.py diff --git a/dev/helm-local.yml b/dev/helm-local.yml index 8655df43..770a5dfb 100644 --- a/dev/helm-local.yml +++ b/dev/helm-local.yml @@ -39,7 +39,7 @@ engine: replicaCount: 1 celery: replicaCount: 1 - worker_beat_enabled: false + worker_beat_enabled: true externalGrafana: url: http://grafana:3000 diff --git a/engine/apps/alerts/models/alert_receive_channel.py b/engine/apps/alerts/models/alert_receive_channel.py index a8cb1494..7a351d2a 100644 --- a/engine/apps/alerts/models/alert_receive_channel.py +++ b/engine/apps/alerts/models/alert_receive_channel.py @@ -525,29 +525,21 @@ class AlertReceiveChannel(IntegrationOptionsMixin, MaintainableObject): ) @property - def short_name_with_maintenance_status(self): - if self.maintenance_mode is not None: - return ( - self.short_name + f" *[ on " - f"{AlertReceiveChannel.MAINTENANCE_MODE_CHOICES[self.maintenance_mode][1]}" - f" :construction: ]*" - ) - else: - return self.short_name - - @property - def created_name(self): + def created_name(self) -> str: return f"{self.get_integration_display()} {self.smile_code}" @property def web_link(self) -> str: return UIURLBuilder(self.organization).integration_detail(self.public_primary_key) + @property + def is_maintenace_integration(self) -> bool: + return self.integration == AlertReceiveChannel.INTEGRATION_MAINTENANCE + @property def integration_url(self) -> str | None: if self.integration in [ AlertReceiveChannel.INTEGRATION_MANUAL, - AlertReceiveChannel.INTEGRATION_SLACK_CHANNEL, AlertReceiveChannel.INTEGRATION_INBOUND_EMAIL, AlertReceiveChannel.INTEGRATION_MAINTENANCE, ]: diff --git a/engine/apps/heartbeat/migrations/0003_remove_integrationheartbeat_actual_check_up_task_id_and_more.py b/engine/apps/heartbeat/migrations/0003_remove_integrationheartbeat_actual_check_up_task_id_and_more.py new file mode 100644 index 00000000..e50d915e --- /dev/null +++ b/engine/apps/heartbeat/migrations/0003_remove_integrationheartbeat_actual_check_up_task_id_and_more.py @@ -0,0 +1,23 @@ +# Generated by Django 4.2.16 on 2024-11-20 15:39 + +from django.db import migrations +import django_migration_linter as linter + + +class Migration(migrations.Migration): + + dependencies = [ + ('heartbeat', '0002_delete_heartbeat'), + ] + + operations = [ + linter.IgnoreMigration(), + migrations.RemoveField( + model_name='integrationheartbeat', + name='actual_check_up_task_id', + ), + migrations.RemoveField( + model_name='integrationheartbeat', + name='last_checkup_task_time', + ), + ] diff --git a/engine/apps/heartbeat/models.py b/engine/apps/heartbeat/models.py index 0c0084bd..4688cc71 100644 --- a/engine/apps/heartbeat/models.py +++ b/engine/apps/heartbeat/models.py @@ -48,16 +48,6 @@ class IntegrationHeartBeat(models.Model): Stores the latest received heartbeat signal time """ - last_checkup_task_time = models.DateTimeField(default=None, null=True) - """ - Deprecated. This field is not used. TODO: remove it - """ - - actual_check_up_task_id = models.CharField(max_length=100) - """ - Deprecated. Stored the latest scheduled `integration_heartbeat_checkup` task id. TODO: remove it - """ - previous_alerted_state_was_life = models.BooleanField(default=True) """ Last status of the heartbeat. Determines if integration was alive on latest checkup diff --git a/engine/apps/heartbeat/tasks.py b/engine/apps/heartbeat/tasks.py index 7939290e..e9d26c57 100644 --- a/engine/apps/heartbeat/tasks.py +++ b/engine/apps/heartbeat/tasks.py @@ -105,12 +105,6 @@ def check_heartbeats() -> str: return f"Found {expired_count} expired and {restored_count} restored heartbeats" -@shared_dedicated_queue_retry_task() -def integration_heartbeat_checkup(heartbeat_id: int) -> None: - """Deprecated. TODO: Remove this task after this task cleared from queue""" - pass - - @shared_dedicated_queue_retry_task() def process_heartbeat_task(alert_receive_channel_pk): IntegrationHeartBeat.objects.filter( diff --git a/engine/apps/heartbeat/tests/factories.py b/engine/apps/heartbeat/tests/factories.py index 5e69db9d..40011255 100644 --- a/engine/apps/heartbeat/tests/factories.py +++ b/engine/apps/heartbeat/tests/factories.py @@ -4,7 +4,5 @@ from apps.heartbeat.models import IntegrationHeartBeat class IntegrationHeartBeatFactory(factory.DjangoModelFactory): - actual_check_up_task_id = "none" - class Meta: model = IntegrationHeartBeat diff --git a/engine/apps/integrations/tasks.py b/engine/apps/integrations/tasks.py index 45f3e04f..91f6a7d4 100644 --- a/engine/apps/integrations/tasks.py +++ b/engine/apps/integrations/tasks.py @@ -31,10 +31,7 @@ def create_alertmanager_alerts(alert_receive_channel_pk, alert, is_demo=False, r from apps.alerts.models import Alert, AlertReceiveChannel alert_receive_channel = AlertReceiveChannel.objects_with_deleted.get(pk=alert_receive_channel_pk) - if ( - alert_receive_channel.deleted_at is not None - or alert_receive_channel.integration == AlertReceiveChannel.INTEGRATION_MAINTENANCE - ): + if alert_receive_channel.deleted_at is not None or alert_receive_channel.is_maintenace_integration: logger.info("AlertReceiveChannel alert ignored if deleted/maintenance") return diff --git a/engine/apps/public_api/tests/test_integrations.py b/engine/apps/public_api/tests/test_integrations.py index 796942eb..9a4e29c6 100644 --- a/engine/apps/public_api/tests/test_integrations.py +++ b/engine/apps/public_api/tests/test_integrations.py @@ -903,7 +903,6 @@ def test_get_list_integrations_link_and_inbound_email( if integration_type in [ AlertReceiveChannel.INTEGRATION_MANUAL, - AlertReceiveChannel.INTEGRATION_SLACK_CHANNEL, AlertReceiveChannel.INTEGRATION_MAINTENANCE, ]: assert integration_link is None diff --git a/engine/apps/slack/alert_group_slack_service.py b/engine/apps/slack/alert_group_slack_service.py index 9bb9510b..ed614305 100644 --- a/engine/apps/slack/alert_group_slack_service.py +++ b/engine/apps/slack/alert_group_slack_service.py @@ -35,9 +35,8 @@ class AlertGroupSlackService: self._slack_client = SlackClient(slack_team_identity) def update_alert_group_slack_message(self, alert_group: "AlertGroup") -> None: - from apps.alerts.models import AlertReceiveChannel - logger.info(f"Update message for alert_group {alert_group.pk}") + try: self._slack_client.chat_update( channel=alert_group.slack_message.channel_id, @@ -47,7 +46,7 @@ class AlertGroupSlackService: ) logger.info(f"Message has been updated for alert_group {alert_group.pk}") except SlackAPIRatelimitError as e: - if alert_group.channel.integration != AlertReceiveChannel.INTEGRATION_MAINTENANCE: + if not alert_group.channel.is_maintenace_integration: if not alert_group.channel.is_rate_limited_in_slack: alert_group.channel.start_send_rate_limit_message_task(e.retry_after) logger.info( diff --git a/engine/apps/slack/scenarios/distribute_alerts.py b/engine/apps/slack/scenarios/distribute_alerts.py index 3a7090e3..3d3c1a60 100644 --- a/engine/apps/slack/scenarios/distribute_alerts.py +++ b/engine/apps/slack/scenarios/distribute_alerts.py @@ -141,22 +141,6 @@ class AlertShootingStep(scenario_step.ScenarioStep): channel_id=channel_id, ) - # If alert was made out of a message: - if alert_group.channel.integration == AlertReceiveChannel.INTEGRATION_SLACK_CHANNEL: - channel = json.loads(alert.integration_unique_data)["channel"] - result = self._slack_client.chat_postMessage( - channel=channel, - thread_ts=json.loads(alert.integration_unique_data)["ts"], - text=":rocket: <{}|Incident registered!>".format(alert_group.slack_message.permalink), - team=slack_team_identity, - ) - alert_group.slack_messages.create( - slack_id=result["ts"], - organization=alert_group.channel.organization, - _slack_team_identity=self.slack_team_identity, - channel_id=channel, - ) - alert.delivered = True except SlackAPITokenError: alert_group.reason_to_skip_escalation = AlertGroup.ACCOUNT_INACTIVE @@ -172,7 +156,7 @@ class AlertShootingStep(scenario_step.ScenarioStep): logger.info("Not delivering alert due to channel is archived.") except SlackAPIRatelimitError as e: # don't rate limit maintenance alert - if alert_group.channel.integration != AlertReceiveChannel.INTEGRATION_MAINTENANCE: + if not alert_group.channel.is_maintenace_integration: alert_group.reason_to_skip_escalation = AlertGroup.RATE_LIMITED alert_group.save(update_fields=["reason_to_skip_escalation"]) alert_group.channel.start_send_rate_limit_message_task(e.retry_after) diff --git a/engine/config_integrations/heartbeat.py b/engine/config_integrations/heartbeat.py deleted file mode 100644 index 60699c45..00000000 --- a/engine/config_integrations/heartbeat.py +++ /dev/null @@ -1,29 +0,0 @@ -# Main -enabled = True -title = "Heartbeat" -slug = "heartbeat" -short_description = None -description = None -is_displayed_on_web = False -is_featured = False -is_able_to_autoresolve = True -is_demo_alert_enabled = False - -description = None - -# Default templates -slack_title = """\ -*<{{ grafana_oncall_link }}|#{{ grafana_oncall_incident_id }} {{ payload.get("title", "Title undefined (check Slack Title Template)") }}>* via {{ integration_name }} -{% if source_link %} - (*<{{ source_link }}|source>*) -{%- endif %}""" - -grouping_id = """\ -{{ payload.get("id", "") }}{{ payload.get("user_defined_id", "") }} -""" - -resolve_condition = '{{ payload.get("is_resolve", False) == True }}' - -acknowledge_condition = None - -example_payload = None diff --git a/engine/config_integrations/slack_channel.py b/engine/config_integrations/slack_channel.py deleted file mode 100644 index 05021935..00000000 --- a/engine/config_integrations/slack_channel.py +++ /dev/null @@ -1,44 +0,0 @@ -# Main -enabled = True -title = "Slack Channel" -slug = "slack_channel" -short_description = None -description = None -is_displayed_on_web = False -is_featured = False -is_able_to_autoresolve = False -is_demo_alert_enabled = False - -description = None - -# Default templates -slack_title = """\ -{% if source_link -%} -*<{{ source_link }}|<#{{ payload.get("channel", "") }}>>* -{%- else -%} -<#{{ payload.get("channel", "") }}> -{%- endif %}""" - -web_title = """\ -{% if source_link -%} -[#{{ grafana_oncall_incident_id }}]{{ source_link }}) <#{{ payload.get("channel", "") }}>>* -{%- else -%} -*#{{ grafana_oncall_incident_id }}* <#{{ payload.get("channel", "") }}> -{%- endif %}""" - -telegram_title = """\ -{% if source_link -%} -#{{ grafana_oncall_incident_id }} {{ payload.get("channel", "") }} -{%- else -%} -*#{{ grafana_oncall_incident_id }}* <#{{ payload.get("channel", "") }}> -{%- endif %}""" - -grouping_id = '{{ payload.get("ts", "") }}' - -resolve_condition = None - -acknowledge_condition = None - -source_link = '{{ payload.get("amixr_mixin", {}).get("permalink", "")}}' - -example_payload = None diff --git a/engine/settings/base.py b/engine/settings/base.py index 2b3cc971..0f73c8d5 100644 --- a/engine/settings/base.py +++ b/engine/settings/base.py @@ -878,11 +878,9 @@ INSTALLED_ONCALL_INTEGRATIONS = [ "config_integrations.formatted_webhook", "config_integrations.kapacitor", "config_integrations.elastalert", - "config_integrations.heartbeat", "config_integrations.inbound_email", "config_integrations.maintenance", "config_integrations.manual", - "config_integrations.slack_channel", "config_integrations.zabbix", "config_integrations.direct_paging", # Actually it's Grafana 8 integration. diff --git a/engine/settings/celery_task_routes.py b/engine/settings/celery_task_routes.py index 04a8ffa4..7ef62121 100644 --- a/engine/settings/celery_task_routes.py +++ b/engine/settings/celery_task_routes.py @@ -12,7 +12,6 @@ CELERY_TASK_ROUTES = { "common.oncall_gateway.tasks.delete_oncall_connector_async": {"queue": "default"}, "common.oncall_gateway.tasks.create_slack_connector_async_v2": {"queue": "default"}, "common.oncall_gateway.tasks.delete_slack_connector_async_v2": {"queue": "default"}, - "apps.heartbeat.tasks.integration_heartbeat_checkup": {"queue": "default"}, "apps.heartbeat.tasks.process_heartbeat_task": {"queue": "default"}, "apps.labels.tasks.update_labels_cache": {"queue": "default"}, "apps.labels.tasks.update_instances_labels_cache": {"queue": "default"},