From d0904ca405a2c83c179488bd17e261a3fe95727b Mon Sep 17 00:00:00 2001 From: Vadim Stepanov Date: Fri, 12 Jan 2024 15:11:22 +0000 Subject: [PATCH] Improve OpenAPI schema coverage (#3629) # What this PR does Improves OpenAPI schema coverage for internal API: - Fixes/Improves `alert group` and `feature` endpoints - Adds `integration` and `user` endpoints ## Which issue(s) this PR fixes https://github.com/grafana/oncall/issues/3444 ## Checklist - [x] Unit, integration, and e2e (if applicable) tests updated - [x] Documentation added (or `pr:no public docs` PR label added if not required) - [x] `CHANGELOG.md` updated (or `pr:no changelog` PR label added if not required) --- engine/apps/alerts/models/alert_group.py | 2 +- .../alerts/models/alert_receive_channel.py | 8 +- .../apps/alerts/models/maintainable_object.py | 2 +- engine/apps/api/serializers/alert.py | 11 +- engine/apps/api/serializers/alert_group.py | 39 ++- .../api/serializers/alert_receive_channel.py | 35 ++- .../api/serializers/integration_heartbeat.py | 4 +- .../api/serializers/slack_user_identity.py | 2 +- engine/apps/api/serializers/user.py | 53 ++-- engine/apps/api/tests/test_features.py | 25 +- engine/apps/api/views/alert_group.py | 228 ++++++++++-------- .../apps/api/views/alert_receive_channel.py | 156 +++++++++++- engine/apps/api/views/features.py | 52 ++-- engine/apps/api/views/labels.py | 2 + engine/apps/api/views/user.py | 101 +++++++- engine/apps/auth_token/auth.py | 17 ++ engine/apps/heartbeat/models.py | 5 +- engine/apps/mobile_app/auth.py | 13 + engine/apps/oss_installation/constants.py | 12 +- engine/apps/oss_installation/utils.py | 10 +- .../oss_installation/views/cloud_users.py | 9 +- .../apps/slack/models/slack_user_identity.py | 2 +- engine/apps/user_management/models/user.py | 2 +- engine/common/api_helpers/custom_fields.py | 3 + engine/common/api_helpers/filters.py | 10 + engine/common/api_helpers/mixins.py | 18 +- engine/common/api_helpers/paginators.py | 53 ++-- engine/engine/schema.py | 47 ++++ engine/settings/base.py | 4 +- engine/settings/prod_without_db.py | 2 +- 30 files changed, 640 insertions(+), 287 deletions(-) create mode 100644 engine/engine/schema.py diff --git a/engine/apps/alerts/models/alert_group.py b/engine/apps/alerts/models/alert_group.py index f13c6c45..d3d2479d 100644 --- a/engine/apps/alerts/models/alert_group.py +++ b/engine/apps/alerts/models/alert_group.py @@ -348,7 +348,7 @@ class AlertGroup(AlertGroupSlackRenderingMixin, EscalationSnapshotMixin, models. return self.silenced and self.silenced_until is not None @property - def status(self): + def status(self) -> int: if self.resolved: return AlertGroup.RESOLVED elif self.acknowledged: diff --git a/engine/apps/alerts/models/alert_receive_channel.py b/engine/apps/alerts/models/alert_receive_channel.py index 3a7973f4..02705ea1 100644 --- a/engine/apps/alerts/models/alert_receive_channel.py +++ b/engine/apps/alerts/models/alert_receive_channel.py @@ -412,7 +412,7 @@ class AlertReceiveChannel(IntegrationOptionsMixin, MaintainableObject): return Alert.objects.filter(group__channel=self).count() @property - def is_able_to_autoresolve(self): + def is_able_to_autoresolve(self) -> bool: return self.config.is_able_to_autoresolve @property @@ -420,7 +420,7 @@ class AlertReceiveChannel(IntegrationOptionsMixin, MaintainableObject): return self.config.is_demo_alert_enabled @property - def description(self): + def description(self) -> str | None: # TODO: AMV2: Remove this check after legacy integrations are migrated. if self.integration == AlertReceiveChannel.INTEGRATION_LEGACY_GRAFANA_ALERTING: contact_points = self.contact_points.all() @@ -496,7 +496,7 @@ class AlertReceiveChannel(IntegrationOptionsMixin, MaintainableObject): return urljoin(self.organization.web_link, f"integrations/{self.public_primary_key}") @property - def integration_url(self): + def integration_url(self) -> str | None: if self.integration in [ AlertReceiveChannel.INTEGRATION_MANUAL, AlertReceiveChannel.INTEGRATION_SLACK_CHANNEL, @@ -595,7 +595,7 @@ class AlertReceiveChannel(IntegrationOptionsMixin, MaintainableObject): # Heartbeat @property - def is_available_for_integration_heartbeat(self): + def is_available_for_integration_heartbeat(self) -> bool: return self.heartbeat_module is not None @property diff --git a/engine/apps/alerts/models/maintainable_object.py b/engine/apps/alerts/models/maintainable_object.py index 05fa7d65..a5788b00 100644 --- a/engine/apps/alerts/models/maintainable_object.py +++ b/engine/apps/alerts/models/maintainable_object.py @@ -151,7 +151,7 @@ class MaintainableObject(models.Model): ) @property - def till_maintenance_timestamp(self): + def till_maintenance_timestamp(self) -> int | None: if self.maintenance_started_at is not None and self.maintenance_duration is not None: return int((self.maintenance_started_at + self.maintenance_duration).astimezone(pytz.UTC).timestamp()) return None diff --git a/engine/apps/api/serializers/alert.py b/engine/apps/api/serializers/alert.py index 7774b946..75dfb295 100644 --- a/engine/apps/api/serializers/alert.py +++ b/engine/apps/api/serializers/alert.py @@ -1,3 +1,5 @@ +import typing + from django.core.cache import cache from django.utils import timezone from rest_framework import serializers @@ -8,6 +10,13 @@ from apps.alerts.models import Alert from .alerts_field_cache_buster_mixin import AlertsFieldCacheBusterMixin +class RenderForWeb(typing.TypedDict): + title: str + message: str + image_url: str | None + source_link: str | None + + class AlertFieldsCacheSerializerMixin(AlertsFieldCacheBusterMixin): CACHE_KEY_FORMAT_TEMPLATE = "{field_name}_alert_{object_id}" @@ -51,7 +60,7 @@ class AlertSerializer(AlertFieldsCacheSerializerMixin, serializers.ModelSerializ "created_at", ] - def get_render_for_web(self, obj): + def get_render_for_web(self, obj) -> RenderForWeb: return AlertFieldsCacheSerializerMixin.get_or_set_web_template_field( obj, AlertFieldsCacheSerializerMixin.RENDER_FOR_WEB_FIELD_NAME, diff --git a/engine/apps/api/serializers/alert_group.py b/engine/apps/api/serializers/alert_group.py index 140276d6..71d7cbf6 100644 --- a/engine/apps/api/serializers/alert_group.py +++ b/engine/apps/api/serializers/alert_group.py @@ -4,7 +4,7 @@ import typing from django.core.cache import cache from django.utils import timezone -from drf_spectacular.utils import extend_schema_field, inline_serializer +from drf_spectacular.utils import extend_schema_field from rest_framework import serializers from apps.alerts.incident_appearance.renderers.web_renderer import AlertGroupWebRenderer @@ -22,6 +22,17 @@ logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) +class RenderForWeb(typing.TypedDict): + title: str + message: str + image_url: str | None + source_link: str | None + + +class EmptyRenderForWeb(typing.TypedDict): + pass + + class AlertGroupFieldsCacheSerializerMixin(AlertsFieldCacheBusterMixin): CACHE_KEY_FORMAT_TEMPLATE = "{field_name}_alert_group_{object_id}" @@ -80,18 +91,7 @@ class ShortAlertGroupSerializer(AlertGroupFieldsCacheSerializerMixin, serializer fields = ["pk", "render_for_web", "alert_receive_channel", "inside_organization_number"] read_only_fields = ["pk", "render_for_web", "alert_receive_channel", "inside_organization_number"] - @extend_schema_field( - inline_serializer( - name="render_for_web", - fields={ - "title": serializers.CharField(), - "message": serializers.CharField(), - "image_url": serializers.CharField(), - "source_link": serializers.CharField(), - }, - ) - ) - def get_render_for_web(self, obj: "AlertGroup"): + def get_render_for_web(self, obj: "AlertGroup") -> RenderForWeb | EmptyRenderForWeb: last_alert = obj.alerts.last() if last_alert is None: return {} @@ -170,18 +170,7 @@ class AlertGroupListSerializer( "labels", ] - @extend_schema_field( - inline_serializer( - name="render_for_web", - fields={ - "title": serializers.CharField(), - "message": serializers.CharField(), - "image_url": serializers.CharField(), - "source_link": serializers.CharField(), - }, - ) - ) - def get_render_for_web(self, obj: "AlertGroup"): + def get_render_for_web(self, obj: "AlertGroup") -> RenderForWeb | EmptyRenderForWeb: if not obj.last_alert: return {} return AlertGroupFieldsCacheSerializerMixin.get_or_set_web_template_field( diff --git a/engine/apps/api/serializers/alert_receive_channel.py b/engine/apps/api/serializers/alert_receive_channel.py index 492a237e..76fb3b75 100644 --- a/engine/apps/api/serializers/alert_receive_channel.py +++ b/engine/apps/api/serializers/alert_receive_channel.py @@ -2,7 +2,6 @@ import typing from collections import OrderedDict from django.conf import settings -from django.core.exceptions import ObjectDoesNotExist from django.core.exceptions import ValidationError as DjangoValidationError from django.db.models import Q from jinja2 import TemplateSyntaxError @@ -53,17 +52,17 @@ class IntegrationAlertGroupLabels(typing.TypedDict): class CustomLabelSerializer(serializers.Serializer): """This serializer is consistent with apps.api.serializers.labels.LabelSerializer, but allows null for value ID.""" - class KeySerializer(serializers.Serializer): + class CustomLabelKeySerializer(serializers.Serializer): id = serializers.CharField() name = serializers.CharField() - class ValueSerializer(serializers.Serializer): + class CustomLabelValueSerializer(serializers.Serializer): # ID is null for templated labels. For such labels, the "name" value is a Jinja2 template. id = serializers.CharField(allow_null=True) name = serializers.CharField() - key = KeySerializer() - value = ValueSerializer() + key = CustomLabelKeySerializer() + value = CustomLabelValueSerializer() class IntegrationAlertGroupLabelsSerializer(serializers.Serializer): @@ -215,9 +214,9 @@ class AlertReceiveChannelSerializer( default_channel_filter = serializers.SerializerMethodField() instructions = serializers.SerializerMethodField() demo_alert_enabled = serializers.BooleanField(source="is_demo_alert_enabled", read_only=True) - is_based_on_alertmanager = serializers.BooleanField(source="has_alertmanager_payload_structure", read_only=True) + is_based_on_alertmanager = serializers.BooleanField(source="based_on_alertmanager", read_only=True) maintenance_till = serializers.ReadOnlyField(source="till_maintenance_timestamp") - heartbeat = serializers.SerializerMethodField() + heartbeat = IntegrationHeartBeatSerializer(read_only=True, allow_null=True, source="integration_heartbeat") allow_delete = serializers.SerializerMethodField() description_short = serializers.CharField(max_length=250, required=False, allow_null=True) demo_alert_payload = serializers.JSONField(source="config.example_payload", read_only=True) @@ -334,15 +333,16 @@ class AlertReceiveChannelSerializer( except AlertReceiveChannel.DuplicateDirectPagingError: raise BadRequest(detail=AlertReceiveChannel.DuplicateDirectPagingError.DETAIL) - def get_instructions(self, obj: "AlertReceiveChannel"): + def get_instructions(self, obj: "AlertReceiveChannel") -> str: # Deprecated, kept for api-backward compatibility return "" # MethodFields are used instead of relevant properties because of properties hit db on each instance in queryset - def get_default_channel_filter(self, obj: "AlertReceiveChannel"): + def get_default_channel_filter(self, obj: "AlertReceiveChannel") -> str | None: for filter in obj.channel_filters.all(): if filter.is_default: return filter.public_primary_key + return None @staticmethod def validate_integration(integration): @@ -367,21 +367,14 @@ class AlertReceiveChannelSerializer( else: raise serializers.ValidationError(detail="Integration with this name already exists") - def get_heartbeat(self, obj: "AlertReceiveChannel"): - try: - heartbeat = obj.integration_heartbeat - except ObjectDoesNotExist: - return None - return IntegrationHeartBeatSerializer(heartbeat).data - - def get_allow_delete(self, obj: "AlertReceiveChannel"): + def get_allow_delete(self, obj: "AlertReceiveChannel") -> bool: # don't allow deleting direct paging integrations return obj.integration != AlertReceiveChannel.INTEGRATION_DIRECT_PAGING - def get_alert_count(self, obj: "AlertReceiveChannel"): + def get_alert_count(self, obj: "AlertReceiveChannel") -> int: return 0 - def get_alert_groups_count(self, obj: "AlertReceiveChannel"): + def get_alert_groups_count(self, obj: "AlertReceiveChannel") -> int: return 0 def get_routes_count(self, obj: "AlertReceiveChannel") -> int: @@ -428,10 +421,10 @@ class FilterAlertReceiveChannelSerializer(serializers.ModelSerializer[AlertRecei model = AlertReceiveChannel fields = ["value", "display_name", "integration_url"] - def _get_value(self, obj: "AlertReceiveChannel"): + def _get_value(self, obj: "AlertReceiveChannel") -> str: return obj.public_primary_key - def get_display_name(self, obj: "AlertReceiveChannel"): + def get_display_name(self, obj: "AlertReceiveChannel") -> str: display_name = obj.verbal_name or AlertReceiveChannel.INTEGRATION_CHOICES[obj.integration][1] return display_name diff --git a/engine/apps/api/serializers/integration_heartbeat.py b/engine/apps/api/serializers/integration_heartbeat.py index 02cc294c..03f895eb 100644 --- a/engine/apps/api/serializers/integration_heartbeat.py +++ b/engine/apps/api/serializers/integration_heartbeat.py @@ -41,10 +41,10 @@ class IntegrationHeartBeatSerializer(EagerLoadingMixin, serializers.ModelSeriali {"alert_receive_channel": "Heartbeat is not available for this integration"} ) - def get_last_heartbeat_time_verbal(self, obj): + def get_last_heartbeat_time_verbal(self, obj) -> str | None: return self._last_heartbeat_time_verbal(obj) if obj.last_heartbeat_time else None - def get_instruction(self, obj): + def get_instruction(self, obj) -> str: # Deprecated. Kept for API backward compatibility. return "" diff --git a/engine/apps/api/serializers/slack_user_identity.py b/engine/apps/api/serializers/slack_user_identity.py index be184cef..7e9c842e 100644 --- a/engine/apps/api/serializers/slack_user_identity.py +++ b/engine/apps/api/serializers/slack_user_identity.py @@ -14,5 +14,5 @@ class SlackUserIdentitySerializer(serializers.ModelSerializer): fields = ["slack_login", "slack_id", "avatar", "name", "display_name"] read_only_fields = ["slack_login", "slack_id", "avatar", "name", "display_name"] - def get_display_name(self, obj): + def get_display_name(self, obj) -> str | None: return obj.profile_display_name or obj.slack_verbal diff --git a/engine/apps/api/serializers/user.py b/engine/apps/api/serializers/user.py index 2fd9fd91..c399fe26 100644 --- a/engine/apps/api/serializers/user.py +++ b/engine/apps/api/serializers/user.py @@ -9,10 +9,10 @@ from apps.api.serializers.telegram import TelegramToUserConnectorSerializer from apps.base.messaging import get_messaging_backends from apps.base.models import UserNotificationPolicy from apps.base.utils import live_settings +from apps.oss_installation.constants import CloudSyncStatus from apps.oss_installation.utils import cloud_user_identity_status from apps.schedules.ical_utils import SchedulesOnCallUsers from apps.user_management.models import User -from apps.user_management.models.user import default_working_hours from common.api_helpers.custom_fields import TeamPrimaryKeyRelatedField, TimeZoneField from common.api_helpers.mixins import EagerLoadingMixin from common.api_helpers.utils import check_phone_number_is_valid @@ -31,6 +31,26 @@ class UserPermissionSerializer(serializers.Serializer): action = serializers.CharField(read_only=True) +class NotificationChainVerbal(typing.TypedDict): + default: str + important: str + + +class WorkingHoursPeriodSerializer(serializers.Serializer): + start = serializers.CharField() + end = serializers.CharField() + + +class WorkingHoursSerializer(serializers.Serializer): + monday = serializers.ListField(child=WorkingHoursPeriodSerializer()) + tuesday = serializers.ListField(child=WorkingHoursPeriodSerializer()) + wednesday = serializers.ListField(child=WorkingHoursPeriodSerializer()) + thursday = serializers.ListField(child=WorkingHoursPeriodSerializer()) + friday = serializers.ListField(child=WorkingHoursPeriodSerializer()) + saturday = serializers.ListField(child=WorkingHoursPeriodSerializer()) + sunday = serializers.ListField(child=WorkingHoursPeriodSerializer()) + + class UserSerializer(DynamicFieldsModelSerializer, EagerLoadingMixin): pk = serializers.CharField(read_only=True, source="public_primary_key") slack_user_identity = SlackUserIdentitySerializer(read_only=True) @@ -47,6 +67,7 @@ class UserSerializer(DynamicFieldsModelSerializer, EagerLoadingMixin): avatar_full = serializers.URLField(source="avatar_full_url", read_only=True) notification_chain_verbal = serializers.SerializerMethodField() cloud_connection_status = serializers.SerializerMethodField() + working_hours = WorkingHoursSerializer(required=False) SELECT_RELATED = ["telegram_verification_code", "telegram_connection", "organization", "slack_user_identity"] @@ -82,29 +103,8 @@ class UserSerializer(DynamicFieldsModelSerializer, EagerLoadingMixin): ] def validate_working_hours(self, working_hours): - if not isinstance(working_hours, dict): - raise serializers.ValidationError("must be dict") - - # check that all days are present - if sorted(working_hours.keys()) != sorted(default_working_hours().keys()): - raise serializers.ValidationError("missing some days") - for day in working_hours: - periods = working_hours[day] - - if not isinstance(periods, list): - raise serializers.ValidationError("periods must be list") - - for period in periods: - if not isinstance(period, dict): - raise serializers.ValidationError("period must be dict") - - if sorted(period.keys()) != sorted(["start", "end"]): - raise serializers.ValidationError("'start' and 'end' fields must be present") - - if not isinstance(period["start"], str) or not isinstance(period["end"], str): - raise serializers.ValidationError("'start' and 'end' fields must be str") - + for period in working_hours[day]: try: start = time.strptime(period["start"], "%H:%M:%S") end = time.strptime(period["end"], "%H:%M:%S") @@ -113,7 +113,6 @@ class UserSerializer(DynamicFieldsModelSerializer, EagerLoadingMixin): if start >= end: raise serializers.ValidationError("'start' must be less than 'end'") - return working_hours def validate_unverified_phone_number(self, value): @@ -127,18 +126,18 @@ class UserSerializer(DynamicFieldsModelSerializer, EagerLoadingMixin): else: return None - def get_messaging_backends(self, obj: User): + def get_messaging_backends(self, obj: User) -> dict[str, dict]: serialized_data = {} supported_backends = get_messaging_backends() for backend_id, backend in supported_backends: serialized_data[backend_id] = backend.serialize_user(obj) return serialized_data - def get_notification_chain_verbal(self, obj: User): + def get_notification_chain_verbal(self, obj: User) -> NotificationChainVerbal: default, important = UserNotificationPolicy.get_short_verbals_for_user(user=obj) return {"default": " - ".join(default), "important": " - ".join(important)} - def get_cloud_connection_status(self, obj: User): + def get_cloud_connection_status(self, obj: User) -> CloudSyncStatus | None: if settings.IS_OPEN_SOURCE and live_settings.GRAFANA_CLOUD_NOTIFICATIONS_ENABLED: connector = self.context.get("connector", None) identities = self.context.get("cloud_identities", {}) diff --git a/engine/apps/api/tests/test_features.py b/engine/apps/api/tests/test_features.py index b7d529d8..33ef389f 100644 --- a/engine/apps/api/tests/test_features.py +++ b/engine/apps/api/tests/test_features.py @@ -4,14 +4,7 @@ from django.urls import reverse from rest_framework import status from rest_framework.test import APIClient -from apps.api.views.features import ( - FEATURE_GRAFANA_CLOUD_CONNECTION, - FEATURE_GRAFANA_CLOUD_NOTIFICATIONS, - FEATURE_LIVE_SETTINGS, - FEATURE_MSTEAMS, - FEATURE_SLACK, - FEATURE_TELEGRAM, -) +from apps.api.views.features import Feature @pytest.mark.django_db @@ -35,9 +28,9 @@ def test_features_view( @pytest.mark.parametrize( "feature_attr,expected_feature", [ - ("FEATURE_SLACK_INTEGRATION_ENABLED", FEATURE_SLACK), - ("FEATURE_TELEGRAM_INTEGRATION_ENABLED", FEATURE_TELEGRAM), - ("FEATURE_LIVE_SETTINGS_ENABLED", FEATURE_LIVE_SETTINGS), + ("FEATURE_SLACK_INTEGRATION_ENABLED", Feature.SLACK), + ("FEATURE_TELEGRAM_INTEGRATION_ENABLED", Feature.TELEGRAM), + ("FEATURE_LIVE_SETTINGS_ENABLED", Feature.LIVE_SETTINGS), ], ) def test_core_features_switch( @@ -76,9 +69,9 @@ def test_oss_features_enabled_in_oss_installation_by_default( response = client.get(url, format="json", **make_user_auth_headers(user, token)) assert response.status_code == status.HTTP_200_OK - assert FEATURE_GRAFANA_CLOUD_CONNECTION in response.json() - assert FEATURE_GRAFANA_CLOUD_NOTIFICATIONS in response.json() - assert FEATURE_MSTEAMS not in response.json() + assert Feature.GRAFANA_CLOUD_CONNECTION in response.json() + assert Feature.GRAFANA_CLOUD_NOTIFICATIONS in response.json() + assert Feature.MSTEAMS not in response.json() @pytest.mark.django_db @@ -93,14 +86,14 @@ def test_non_oss_features_enabled( response = client.get(url, format="json", **make_user_auth_headers(user, token)) assert response.status_code == status.HTTP_200_OK - assert FEATURE_MSTEAMS in response.json() + assert Feature.MSTEAMS in response.json() @pytest.mark.django_db @pytest.mark.parametrize( "feature_attr,expected_feature", [ - ("GRAFANA_CLOUD_NOTIFICATIONS_ENABLED", FEATURE_GRAFANA_CLOUD_NOTIFICATIONS), + ("GRAFANA_CLOUD_NOTIFICATIONS_ENABLED", Feature.GRAFANA_CLOUD_NOTIFICATIONS), ], ) def test_oss_features_switch( diff --git a/engine/apps/api/views/alert_group.py b/engine/apps/api/views/alert_group.py index 9a99d051..6955157f 100644 --- a/engine/apps/api/views/alert_group.py +++ b/engine/apps/api/views/alert_group.py @@ -4,8 +4,7 @@ from django.core.exceptions import ObjectDoesNotExist from django.db.models import Count, Max, Q from django.utils import timezone from django_filters import rest_framework as filters -from django_filters.widgets import RangeWidget -from drf_spectacular.utils import extend_schema, extend_schema_view, inline_serializer +from drf_spectacular.utils import extend_schema, inline_serializer from rest_framework import mixins, serializers, status, viewsets from rest_framework.decorators import action from rest_framework.exceptions import NotFound @@ -29,7 +28,12 @@ from apps.labels.utils import is_labels_feature_enabled from apps.mobile_app.auth import MobileAppAuthTokenAuthentication from apps.user_management.models import Team, User from common.api_helpers.exceptions import BadRequest -from common.api_helpers.filters import NO_TEAM_VALUE, DateRangeFilterMixin, ModelFieldFilterMixin +from common.api_helpers.filters import ( + NO_TEAM_VALUE, + DateRangeFilterMixin, + ModelFieldFilterMixin, + MultipleChoiceCharFilter, +) from common.api_helpers.mixins import PreviewTemplateMixin, PublicPrimaryKeyMixin, TeamFilteringMixin from common.api_helpers.paginators import AlertGroupCursorPaginator @@ -55,31 +59,6 @@ def get_user_queryset(request): return User.objects.filter(organization=request.user.organization).distinct() -class AlertGroupFilterBackend(filters.DjangoFilterBackend): - """ - See here for more context on how this works - - https://github.com/carltongibson/django-filter/discussions/1572 - https://youtu.be/e52S1SjuUeM?t=841 - """ - - def get_filterset(self, request, queryset, view): - filterset = super().get_filterset(request, queryset, view) - - filterset.form.fields["integration"].queryset = get_integration_queryset(request) - filterset.form.fields["escalation_chain"].queryset = get_escalation_chain_queryset(request) - - user_queryset = get_user_queryset(request) - - filterset.form.fields["silenced_by"].queryset = user_queryset - filterset.form.fields["acknowledged_by"].queryset = user_queryset - filterset.form.fields["resolved_by"].queryset = user_queryset - filterset.form.fields["invitees_are"].queryset = user_queryset - filterset.form.fields["involved_users_are"].queryset = user_queryset - - return filterset - - class AlertGroupFilter(DateRangeFilterMixin, ModelFieldFilterMixin, filters.FilterSet): """ Examples of possible date formats here https://docs.djangoproject.com/en/1.9/ref/settings/#datetime-input-formats @@ -87,69 +66,55 @@ class AlertGroupFilter(DateRangeFilterMixin, ModelFieldFilterMixin, filters.Filt FILTER_BY_INVOLVED_USERS_ALERT_GROUPS_CUTOFF = 1000 - started_at_gte = filters.DateTimeFilter(field_name="started_at", lookup_expr="gte") - started_at_lte = filters.DateTimeFilter(field_name="started_at", lookup_expr="lte") - resolved_at_lte = filters.DateTimeFilter(field_name="resolved_at", lookup_expr="lte") is_root = filters.BooleanFilter(field_name="root_alert_group", lookup_expr="isnull") - id__in = filters.BaseInFilter(field_name="public_primary_key", lookup_expr="in") status = filters.MultipleChoiceFilter(choices=AlertGroup.STATUS_CHOICES, method="filter_status") - started_at = filters.CharFilter(field_name="started_at", method=DateRangeFilterMixin.filter_date_range.__name__) - resolved_at = filters.CharFilter(field_name="resolved_at", method=DateRangeFilterMixin.filter_date_range.__name__) - silenced_at = filters.CharFilter(field_name="silenced_at", method=DateRangeFilterMixin.filter_date_range.__name__) - silenced_by = filters.ModelMultipleChoiceFilter( - field_name="silenced_by_user", - queryset=None, - to_field_name="public_primary_key", - method=ModelFieldFilterMixin.filter_model_field.__name__, + started_at = filters.CharFilter( + field_name="started_at", + method=DateRangeFilterMixin.filter_date_range.__name__, ) - integration = filters.ModelMultipleChoiceFilter( + resolved_at = filters.CharFilter( + field_name="resolved_at", + method=DateRangeFilterMixin.filter_date_range.__name__, + ) + integration = MultipleChoiceCharFilter( field_name="channel", - queryset=None, + queryset=get_integration_queryset, to_field_name="public_primary_key", method=ModelFieldFilterMixin.filter_model_field.__name__, ) - escalation_chain = filters.ModelMultipleChoiceFilter( + escalation_chain = MultipleChoiceCharFilter( field_name="channel_filter__escalation_chain", - queryset=None, + queryset=get_escalation_chain_queryset, to_field_name="public_primary_key", method=ModelFieldFilterMixin.filter_model_field.__name__, ) - started_at_range = filters.DateFromToRangeFilter( - field_name="started_at", widget=RangeWidget(attrs={"type": "date"}) - ) - resolved_by = filters.ModelMultipleChoiceFilter( + resolved_by = MultipleChoiceCharFilter( field_name="resolved_by_user", - queryset=None, + queryset=get_user_queryset, to_field_name="public_primary_key", method=ModelFieldFilterMixin.filter_model_field.__name__, ) - acknowledged_by = filters.ModelMultipleChoiceFilter( + acknowledged_by = MultipleChoiceCharFilter( field_name="acknowledged_by_user", - queryset=None, + queryset=get_user_queryset, to_field_name="public_primary_key", method=ModelFieldFilterMixin.filter_model_field.__name__, ) - invitees_are = filters.ModelMultipleChoiceFilter( - queryset=None, to_field_name="public_primary_key", method="filter_invitees_are" + silenced_by = MultipleChoiceCharFilter( + field_name="silenced_by_user", + queryset=get_user_queryset, + to_field_name="public_primary_key", + method=ModelFieldFilterMixin.filter_model_field.__name__, ) - involved_users_are = filters.ModelMultipleChoiceFilter( - queryset=None, to_field_name="public_primary_key", method="filter_by_involved_users" + invitees_are = MultipleChoiceCharFilter( + queryset=get_user_queryset, to_field_name="public_primary_key", method="filter_invitees_are" + ) + involved_users_are = MultipleChoiceCharFilter( + queryset=get_user_queryset, to_field_name="public_primary_key", method="filter_by_involved_users" ) with_resolution_note = filters.BooleanFilter(method="filter_with_resolution_note") mine = filters.BooleanFilter(method="filter_mine") - class Meta: - model = AlertGroup - fields = [ - "id__in", - "started_at_gte", - "started_at_lte", - "resolved_at_lte", - "is_root", - "resolved_by", - "acknowledged_by", - ] - def filter_status(self, queryset, name, value): if not value: return queryset @@ -263,12 +228,6 @@ class AlertGroupTeamFilteringMixin(TeamFilteringMixin): return Response(data={"error_code": "wrong_team"}, status=status.HTTP_403_FORBIDDEN) -@extend_schema_view( - list=extend_schema(description="Fetch a list of alert groups"), - retrieve=extend_schema(description="Fetch a single alert group"), - destroy=extend_schema(description="Delete an alert group"), - preview_template=extend_schema(description="Preview a template for an alert group"), -) class AlertGroupView( PreviewTemplateMixin, AlertGroupTeamFilteringMixin, @@ -278,6 +237,10 @@ class AlertGroupView( mixins.DestroyModelMixin, viewsets.GenericViewSet, ): + """ + Internal API endpoints for alert groups. + """ + authentication_classes = ( MobileAppAuthTokenAuthentication, PluginAuthentication, @@ -307,15 +270,12 @@ class AlertGroupView( "escalation_snapshot": [RBACPermission.Permissions.ALERT_GROUPS_READ], } - http_method_names = ["get", "post", "delete"] - + queryset = AlertGroup.objects.none() # needed for drf-spectacular introspection serializer_class = AlertGroupSerializer pagination_class = AlertGroupCursorPaginator - filter_backends = [SearchFilter, AlertGroupFilterBackend] - # search_fields = ["=public_primary_key", "=inside_organization_number", "web_title_cache"] - + filter_backends = [SearchFilter, filters.DjangoFilterBackend] filterset_class = AlertGroupFilter def get_serializer_class(self): @@ -375,7 +335,7 @@ class AlertGroupView( obj = self.enrich([obj])[0] return obj - def retrieve(self, request, pk, *args, **kwargs): + def retrieve(self, request, *args, **kwargs): """Return alert group details. It is worth mentioning that `render_after_resolve_report_json` property will return a list @@ -433,7 +393,7 @@ class AlertGroupView( - 1: web """ - return super().retrieve(request, pk, *args, **kwargs) + return super().retrieve(request, *args, **kwargs) def enrich(self, alert_groups): """ @@ -482,9 +442,12 @@ class AlertGroupView( delete_alert_group.apply_async((instance.pk, request.user.pk)) return Response(status=status.HTTP_204_NO_CONTENT) - @extend_schema(responses=inline_serializer(name="AlertGroupStats", fields={"count": serializers.IntegerField()})) - @action(detail=False) - def stats(self, *args, **kwargs): + @extend_schema( + filters=True, # filter alert groups before counting them + responses=inline_serializer(name="AlertGroupStats", fields={"count": serializers.IntegerField()}), + ) + @action(methods=["get"], detail=False) + def stats(self, request): """ Return number of alert groups capped at 100001 """ @@ -492,12 +455,9 @@ class AlertGroupView( alert_groups = self.filter_queryset(self.get_queryset())[:MAX_COUNT] count = alert_groups.count() count = f"{MAX_COUNT-1}+" if count == MAX_COUNT else str(count) - return Response( - { - "count": count, - } - ) + return Response({"count": count}) + @extend_schema(responses=AlertGroupSerializer) @action(methods=["post"], detail=True) def acknowledge(self, request, pk): """ @@ -512,6 +472,7 @@ class AlertGroupView( return Response(AlertGroupSerializer(alert_group, context={"request": self.request}).data) + @extend_schema(responses=AlertGroupSerializer) @action(methods=["post"], detail=True) def unacknowledge(self, request, pk): """ @@ -534,6 +495,12 @@ class AlertGroupView( return Response(AlertGroupSerializer(alert_group, context={"request": self.request}).data) + @extend_schema( + request=inline_serializer( + name="AlertGroupResolve", fields={"resolution_note": serializers.CharField(required=False, allow_null=True)} + ), + responses=AlertGroupSerializer, + ) @action(methods=["post"], detail=True) def resolve(self, request, pk): """ @@ -579,6 +546,7 @@ class AlertGroupView( alert_group.resolve_by_user(self.request.user, action_source=ActionSource.WEB) return Response(AlertGroupSerializer(alert_group, context={"request": self.request}).data) + @extend_schema(responses=AlertGroupSerializer) @action(methods=["post"], detail=True) def unresolve(self, request, pk): """ @@ -597,6 +565,10 @@ class AlertGroupView( alert_group.un_resolve_by_user(self.request.user, action_source=ActionSource.WEB) return Response(AlertGroupSerializer(alert_group, context={"request": self.request}).data) + @extend_schema( + request=inline_serializer(name="AlertGroupAttach", fields={"root_alert_group_pk": serializers.CharField()}), + responses=AlertGroupSerializer, + ) @action(methods=["post"], detail=True) def attach(self, request, pk=None): """ @@ -622,6 +594,7 @@ class AlertGroupView( alert_group.attach_by_user(self.request.user, root_alert_group, action_source=ActionSource.WEB) return Response(AlertGroupSerializer(alert_group, context={"request": self.request}).data) + @extend_schema(responses=AlertGroupSerializer) @action(methods=["post"], detail=True) def unattach(self, request, pk=None): """ @@ -636,6 +609,10 @@ class AlertGroupView( alert_group.un_attach_by_user(self.request.user, action_source=ActionSource.WEB) return Response(AlertGroupSerializer(alert_group, context={"request": self.request}).data) + @extend_schema( + request=inline_serializer(name="AlertGroupSilence", fields={"delay": serializers.IntegerField()}), + responses=AlertGroupSerializer, + ) @action(methods=["post"], detail=True) def silence(self, request, pk=None): """ @@ -655,9 +632,13 @@ class AlertGroupView( @extend_schema( responses=inline_serializer( - name="silence_options", - fields={"value": serializers.CharField(), "display_name": serializers.CharField()}, - many=True, + name="AlertGroupSilenceOptions", + fields={ + "value": serializers.ChoiceField(choices=[value for value, _ in AlertGroup.SILENCE_DELAY_OPTIONS]), + "display_name": serializers.ChoiceField( + choices=[display_name for _, display_name in AlertGroup.SILENCE_DELAY_OPTIONS] + ), + }, ) ) @action(methods=["get"], detail=False) @@ -670,6 +651,7 @@ class AlertGroupView( ] return Response(data) + @extend_schema(responses=AlertGroupSerializer) @action(methods=["post"], detail=True) def unsilence(self, request, pk=None): """ @@ -693,6 +675,10 @@ class AlertGroupView( return Response(AlertGroupSerializer(alert_group, context={"request": request}).data) + @extend_schema( + request=inline_serializer(name="AlertGroupUnpageUser", fields={"user_id": serializers.CharField()}), + responses=AlertGroupSerializer, + ) @action(methods=["post"], detail=True) def unpage_user(self, request, pk=None): """ @@ -715,12 +701,32 @@ class AlertGroupView( unpage_user(alert_group=alert_group, user=user, from_user=from_user) return Response(status=status.HTTP_200_OK) + @extend_schema( + responses=inline_serializer( + name="AlertGroupFilters", + fields={ + "name": serializers.CharField(), + "type": serializers.CharField(), + "href": serializers.CharField(required=False), + "global": serializers.BooleanField(required=False), + "default": serializers.JSONField(required=False), + "description": serializers.CharField(required=False), + "options": inline_serializer( + name="AlertGroupFiltersOptions", + fields={ + "value": serializers.CharField(), + "display_name": serializers.IntegerField(), + }, + ), + }, + many=True, + ) + ) @action(methods=["get"], detail=False) def filters(self, request): """ Retrieve a list of valid filter options that can be used to filter alert groups """ - filter_name = request.query_params.get("search", None) api_root = "/api/internal/v1/" now = timezone.now() @@ -779,7 +785,6 @@ class AlertGroupView( {"display_name": "silenced", "value": AlertGroup.SILENCED}, ], }, - # {'name': 'is_root', 'type': 'boolean', 'default': True}, { "name": "started_at", "type": "daterange", @@ -812,41 +817,60 @@ class AlertGroupView( } ) - if filter_name is not None: - filter_options = list(filter(lambda f: filter_name in f["name"], filter_options)) - return Response(filter_options) + @extend_schema( + request=inline_serializer( + name="AlertGroupBulkActionRequest", + fields={ + "alert_group_pks": serializers.ListField(child=serializers.CharField()), + "action": serializers.ChoiceField(choices=AlertGroup.BULK_ACTIONS), + "delay": serializers.IntegerField( + required=False, allow_null=True, help_text="only applicable for silence" + ), + }, + ) + ) @action(methods=["post"], detail=False) def bulk_action(self, request): """ Perform a bulk action on a list of alert groups """ - alert_group_public_pks = self.request.data.get("alert_group_pks", []) - action_with_incidents = self.request.data.get("action", None) + alert_group_pks = self.request.data.get("alert_group_pks", []) + action_name = self.request.data.get("action", None) delay = self.request.data.get("delay") kwargs = {} - if action_with_incidents not in AlertGroup.BULK_ACTIONS: + if action_name not in AlertGroup.BULK_ACTIONS: return Response("Unknown action", status=status.HTTP_400_BAD_REQUEST) - if action_with_incidents == AlertGroup.SILENCE: + if action_name == AlertGroup.SILENCE: if delay is None: raise BadRequest(detail="Please specify a delay for silence") kwargs["silence_delay"] = delay alert_groups = AlertGroup.objects.filter( - channel__organization=self.request.auth.organization, public_primary_key__in=alert_group_public_pks + channel__organization=self.request.auth.organization, public_primary_key__in=alert_group_pks ) kwargs["user"] = self.request.user kwargs["alert_groups"] = alert_groups - method = getattr(AlertGroup, f"bulk_{action_with_incidents}") + method = getattr(AlertGroup, f"bulk_{action_name}") method(**kwargs) return Response(status=status.HTTP_200_OK) + @extend_schema( + responses=inline_serializer( + name="AlertGroupBulkActionOptions", + fields={ + "value": serializers.ChoiceField(choices=AlertGroup.BULK_ACTIONS), + "display_name": serializers.ChoiceField(choices=AlertGroup.BULK_ACTIONS), + }, + many=True, + ) + ) @action(methods=["get"], detail=False) def bulk_action_options(self, request): """ diff --git a/engine/apps/api/views/alert_receive_channel.py b/engine/apps/api/views/alert_receive_channel.py index b0309a7a..dd8ab96b 100644 --- a/engine/apps/api/views/alert_receive_channel.py +++ b/engine/apps/api/views/alert_receive_channel.py @@ -1,7 +1,11 @@ +import typing + from django.db.models import Q from django_filters import rest_framework as filters from django_filters.rest_framework import DjangoFilterBackend -from rest_framework import status +from drf_spectacular.plumbing import resolve_type_hint +from drf_spectacular.utils import PolymorphicProxySerializer, extend_schema, extend_schema_view, inline_serializer +from rest_framework import serializers, status from rest_framework.decorators import action from rest_framework.filters import SearchFilter from rest_framework.permissions import IsAuthenticated @@ -39,6 +43,14 @@ from common.exceptions import MaintenanceCouldNotBeStartedError, TeamCanNotBeCha from common.insight_log import EntityEvent, write_resource_insight_log +class AlertReceiveChannelCounter(typing.TypedDict): + alerts_count: int + alert_groups_count: int + + +AlertReceiveChannelCounters = dict[str, AlertReceiveChannelCounter] + + class AlertReceiveChannelFilter(ByTeamModelFieldFilterMixin, filters.FilterSet): maintenance_mode = filters.MultipleChoiceFilter( choices=AlertReceiveChannel.MAINTENANCE_MODE_CHOICES, method="filter_maintenance_mode" @@ -71,6 +83,17 @@ class AlertReceiveChannelFilter(ByTeamModelFieldFilterMixin, filters.FilterSet): return queryset +@extend_schema_view( + list=extend_schema( + responses=PolymorphicProxySerializer( + component_name="AlertReceiveChannelPolymorphic", + serializers=[AlertReceiveChannelSerializer, FilterAlertReceiveChannelSerializer], + resource_type_field_name=None, + ) + ), + update=extend_schema(responses=AlertReceiveChannelUpdateSerializer), + partial_update=extend_schema(responses=AlertReceiveChannelUpdateSerializer), +) class AlertReceiveChannelView( PreviewTemplateMixin, TeamFilteringMixin, @@ -79,6 +102,10 @@ class AlertReceiveChannelView( UpdateSerializerMixin, ModelViewSet, ): + """ + Internal API endpoints for alert receive channels (integrations). + """ + authentication_classes = ( MobileAppAuthTokenAuthentication, PluginAuthentication, @@ -86,6 +113,7 @@ class AlertReceiveChannelView( permission_classes = (IsAuthenticated, RBACPermission) model = AlertReceiveChannel + queryset = AlertReceiveChannel.objects.none() # needed for drf-spectacular introspection serializer_class = AlertReceiveChannelSerializer filter_serializer_class = FilterAlertReceiveChannelSerializer update_serializer_class = AlertReceiveChannelUpdateSerializer @@ -194,6 +222,14 @@ class AlertReceiveChannelView( schedule_update_label_cache(self.model.__name__, self.request.auth.organization, ids) return page + @extend_schema( + request=inline_serializer( + name="AlertReceiveChannelSendDemoAlert", + fields={ + "demo_alert_payload": serializers.DictField(required=False, allow_null=True), + }, + ), + ) @action(detail=True, methods=["post"], throttle_classes=[DemoAlertThrottler]) def send_demo_alert(self, request, pk): instance = self.get_object() @@ -209,6 +245,19 @@ class AlertReceiveChannelView( return Response(status=status.HTTP_200_OK) + @extend_schema( + responses=inline_serializer( + name="AlertReceiveChannelIntegrationOptions", + fields={ + "value": serializers.CharField(), + "display_name": serializers.CharField(), + "short_description": serializers.CharField(), + "featured": serializers.BooleanField(), + "featured_tag_name": serializers.CharField(allow_null=True), + }, + many=True, + ) + ) @action(detail=False, methods=["get"]) def integration_options(self, request): choices = [] @@ -231,6 +280,11 @@ class AlertReceiveChannelView( choices.append(choice) return Response(featured_choices + choices) + @extend_schema( + parameters=[ + inline_serializer(name="AlertReceiveChannelChangeTeam", fields={"team_id": serializers.CharField()}) + ] + ) @action(detail=True, methods=["put"]) def change_team(self, request, pk): instance = self.get_object() @@ -249,6 +303,7 @@ class AlertReceiveChannelView( return Response() + @extend_schema(responses={status.HTTP_200_OK: resolve_type_hint(AlertReceiveChannelCounters)}) @action(methods=["get"], detail=False) def counters(self, request): queryset = self.filter_queryset(self.get_queryset(eager=False)) @@ -260,6 +315,11 @@ class AlertReceiveChannelView( } return Response(response) + @extend_schema( + # make operation_id unique, otherwise drf-spectacular will issue a warning + operation_id="alert_receive_channels_counters_per_integration_retrieve", + responses={status.HTTP_200_OK: resolve_type_hint(AlertReceiveChannelCounters)}, + ) @action(methods=["get"], detail=True, url_path="counters") def counters_per_integration(self, request, pk): alert_receive_channel = self.get_object() @@ -287,10 +347,22 @@ class AlertReceiveChannelView( except AttributeError: return None + @extend_schema( + responses=inline_serializer( + name="AlertReceiveChannelFilters", + fields={ + "name": serializers.CharField(), + "display_name": serializers.CharField(required=False), + "type": serializers.CharField(), + "href": serializers.CharField(), + "global": serializers.BooleanField(required=False), + }, + many=True, + ) + ) @action(methods=["get"], detail=False) def filters(self, request): organization = self.request.auth.organization - filter_name = request.query_params.get("search", None) api_root = "/api/internal/v1/" filter_options = [ @@ -317,11 +389,19 @@ class AlertReceiveChannelView( } ) - if filter_name is not None: - filter_options = list(filter(lambda f: filter_name in f["name"], filter_options)) - return Response(filter_options) + @extend_schema( + request=inline_serializer( + name="AlertReceiveChannelStartMaintenance", + fields={ + "mode": serializers.ChoiceField(choices=MaintainableObject.MAINTENANCE_MODE_CHOICES), + "duration": serializers.ChoiceField( + choices=MaintainableObject.maintenance_duration_options_in_seconds() + ), + }, + ), + ) @action(detail=True, methods=["post"]) def start_maintenance(self, request, pk): instance = self.get_object() @@ -355,8 +435,7 @@ class AlertReceiveChannelView( @action(detail=True, methods=["post"]) def stop_maintenance(self, request, pk): instance = self.get_object() - user = request.user - instance.force_disable_maintenance(user) + instance.force_disable_maintenance(request.user) return Response(status=status.HTTP_200_OK) @action(detail=True, methods=["post"]) @@ -394,6 +473,20 @@ class AlertReceiveChannelView( instance.save() return Response(status=status.HTTP_200_OK) + @extend_schema( + parameters=[ + inline_serializer( + name="AlertReceiveChannelValidateName", + fields={ + "verbal_name": serializers.CharField(), + }, + ) + ], + responses={ + status.HTTP_200_OK: None, + status.HTTP_409_CONFLICT: None, + }, + ) @action(detail=False, methods=["get"]) def validate_name(self, request): """ @@ -412,6 +505,17 @@ class AlertReceiveChannelView( return r + @extend_schema( + responses=inline_serializer( + name="AlertReceiveChannelConnectedContactPoints", + fields={ + "uid": serializers.CharField(), + "name": serializers.CharField(), + "contact_points": serializers.ListField(child=serializers.CharField()), + }, + many=True, + ) + ) @action(detail=True, methods=["get"]) def connected_contact_points(self, request, pk): instance = self.get_object() @@ -420,12 +524,32 @@ class AlertReceiveChannelView( contact_points = instance.grafana_alerting_sync_manager.get_connected_contact_points() return Response(contact_points) + @extend_schema( + responses=inline_serializer( + name="AlertReceiveChannelContactPoints", + fields={ + "uid": serializers.CharField(), + "name": serializers.CharField(), + "contact_points": serializers.ListField(child=serializers.CharField()), + }, + many=True, + ) + ) @action(detail=False, methods=["get"]) def contact_points(self, request): organization = request.auth.organization contact_points = GrafanaAlertingSyncManager.get_contact_points(organization) return Response(contact_points) + @extend_schema( + request=inline_serializer( + name="AlertReceiveChannelConnectContactPoint", + fields={ + "datasource_uid": serializers.CharField(), + "contact_point_name": serializers.CharField(), + }, + ), + ) @action(detail=True, methods=["post"]) def connect_contact_point(self, request, pk): instance = self.get_object() @@ -443,6 +567,15 @@ class AlertReceiveChannelView( raise BadRequest(detail=error) return Response(status=status.HTTP_200_OK) + @extend_schema( + request=inline_serializer( + name="AlertReceiveChannelCreateContactPoint", + fields={ + "datasource_uid": serializers.CharField(), + "contact_point_name": serializers.CharField(), + }, + ), + ) @action(detail=True, methods=["post"]) def create_contact_point(self, request, pk): instance = self.get_object() @@ -460,6 +593,15 @@ class AlertReceiveChannelView( raise BadRequest(detail=error) return Response(status=status.HTTP_201_CREATED) + @extend_schema( + request=inline_serializer( + name="AlertReceiveChannelDisconnectContactPoint", + fields={ + "datasource_uid": serializers.CharField(), + "contact_point_name": serializers.CharField(), + }, + ), + ) @action(detail=True, methods=["post"]) def disconnect_contact_point(self, request, pk): instance = self.get_object() diff --git a/engine/apps/api/views/features.py b/engine/apps/api/views/features.py index 1eabfb8f..0d65a560 100644 --- a/engine/apps/api/views/features.py +++ b/engine/apps/api/views/features.py @@ -1,6 +1,9 @@ +import enum + from django.conf import settings -from drf_spectacular.utils import OpenApiExample, extend_schema -from rest_framework import serializers +from drf_spectacular.plumbing import resolve_type_hint +from drf_spectacular.utils import extend_schema +from rest_framework import status from rest_framework.response import Response from rest_framework.views import APIView @@ -8,14 +11,16 @@ from apps.auth_token.auth import PluginAuthentication from apps.base.utils import live_settings from apps.labels.utils import is_labels_feature_enabled -FEATURE_MSTEAMS = "msteams" -FEATURE_SLACK = "slack" -FEATURE_TELEGRAM = "telegram" -FEATURE_LIVE_SETTINGS = "live_settings" -FEATURE_GRAFANA_CLOUD_NOTIFICATIONS = "grafana_cloud_notifications" -FEATURE_GRAFANA_CLOUD_CONNECTION = "grafana_cloud_connection" -FEATURE_GRAFANA_ALERTING_V2 = "grafana_alerting_v2" -FEATURE_LABELS = "labels" + +class Feature(enum.StrEnum): + MSTEAMS = "msteams" + SLACK = "slack" + TELEGRAM = "telegram" + LIVE_SETTINGS = "live_settings" + GRAFANA_CLOUD_NOTIFICATIONS = "grafana_cloud_notifications" + GRAFANA_CLOUD_CONNECTION = "grafana_cloud_connection" + GRAFANA_ALERTING_V2 = "grafana_alerting_v2" + LABELS = "labels" class FeaturesAPIView(APIView): @@ -26,16 +31,7 @@ class FeaturesAPIView(APIView): authentication_classes = (PluginAuthentication,) - @extend_schema( - request=None, - responses=serializers.ListField(child=serializers.CharField()), - examples=[ - OpenApiExample( - name="Example response", - value=["slack", "telegram", "grafana_cloud_connection", "live_settings", "grafana_cloud_notifications"], - ) - ], - ) + @extend_schema(responses={status.HTTP_200_OK: resolve_type_hint(list[Feature])}) def get(self, request): data = self._get_enabled_features(request) return Response(data) @@ -44,25 +40,25 @@ class FeaturesAPIView(APIView): enabled_features = [] if settings.FEATURE_SLACK_INTEGRATION_ENABLED: - enabled_features.append(FEATURE_SLACK) + enabled_features.append(Feature.SLACK) if settings.FEATURE_TELEGRAM_INTEGRATION_ENABLED: - enabled_features.append(FEATURE_TELEGRAM) + enabled_features.append(Feature.TELEGRAM) if settings.IS_OPEN_SOURCE: # Features below should be enabled only in OSS - enabled_features.append(FEATURE_GRAFANA_CLOUD_CONNECTION) + enabled_features.append(Feature.GRAFANA_CLOUD_CONNECTION) if settings.FEATURE_LIVE_SETTINGS_ENABLED: - enabled_features.append(FEATURE_LIVE_SETTINGS) + enabled_features.append(Feature.LIVE_SETTINGS) if live_settings.GRAFANA_CLOUD_NOTIFICATIONS_ENABLED: - enabled_features.append(FEATURE_GRAFANA_CLOUD_NOTIFICATIONS) + enabled_features.append(Feature.GRAFANA_CLOUD_NOTIFICATIONS) else: - enabled_features.append(FEATURE_MSTEAMS) + enabled_features.append(Feature.MSTEAMS) if settings.FEATURE_GRAFANA_ALERTING_V2_ENABLED: - enabled_features.append(FEATURE_GRAFANA_ALERTING_V2) + enabled_features.append(Feature.GRAFANA_ALERTING_V2) if is_labels_feature_enabled(self.request.auth.organization): - enabled_features.append(FEATURE_LABELS) + enabled_features.append(Feature.LABELS) return enabled_features diff --git a/engine/apps/api/views/labels.py b/engine/apps/api/views/labels.py index e2d7a559..5c0fa271 100644 --- a/engine/apps/api/views/labels.py +++ b/engine/apps/api/views/labels.py @@ -145,6 +145,8 @@ class LabelsViewSet(LabelsFeatureFlagViewSet): return super().handle_exception(exc) +# specifying a tag explicitly to avoid these endpoints being grouped with alert group endpoints +@extend_schema(tags=["alert group labels"]) class AlertGroupLabelsViewSet(LabelsFeatureFlagViewSet): """ This viewset is similar to LabelsViewSet, but it works with alert group labels. diff --git a/engine/apps/api/views/user.py b/engine/apps/api/views/user.py index effeb313..8cd517c1 100644 --- a/engine/apps/api/views/user.py +++ b/engine/apps/api/views/user.py @@ -1,4 +1,5 @@ import logging +import typing import pytz from django.conf import settings @@ -8,7 +9,9 @@ from django.urls import reverse from django.utils import timezone from django.utils.functional import cached_property from django_filters import rest_framework as filters -from rest_framework import mixins, status, viewsets +from drf_spectacular.plumbing import resolve_type_hint +from drf_spectacular.utils import PolymorphicProxySerializer, extend_schema, inline_serializer +from rest_framework import mixins, serializers, status, viewsets from rest_framework.decorators import action from rest_framework.exceptions import NotFound from rest_framework.filters import SearchFilter @@ -60,12 +63,13 @@ from apps.phone_notifications.exceptions import ( from apps.phone_notifications.phone_backend import PhoneBackend from apps.schedules.ical_utils import get_cached_oncall_users_for_multiple_schedules from apps.schedules.models import OnCallSchedule +from apps.schedules.models.on_call_schedule import ScheduleEvent from apps.telegram.client import TelegramClient from apps.telegram.models import TelegramVerificationCode from apps.user_management.models import Team, User from common.api_helpers.exceptions import Conflict from common.api_helpers.filters import ByTeamModelFieldFilterMixin, TeamModelMultipleChoiceFilter -from common.api_helpers.mixins import FilterSerializerMixin, PublicPrimaryKeyMixin +from common.api_helpers.mixins import PublicPrimaryKeyMixin from common.api_helpers.paginators import HundredPageSizePaginator from common.api_helpers.utils import create_engine_url from common.insight_log import ( @@ -86,6 +90,17 @@ UPCOMING_SHIFTS_DEFAULT_DAYS = 7 UPCOMING_SHIFTS_MAX_DAYS = 65 +class UpcomingShift(typing.TypedDict): + schedule_id: str + schedule_name: str + is_oncall: bool + current_shift: ScheduleEvent | None + next_shift: ScheduleEvent | None + + +UpcomingShifts = list[UpcomingShift] + + class CurrentUserView(APIView): authentication_classes = (MobileAppAuthTokenAuthentication, PluginAuthentication) permission_classes = (IsAuthenticated,) @@ -143,12 +158,15 @@ class UserFilter(ByTeamModelFieldFilterMixin, filters.FilterSet): class UserView( PublicPrimaryKeyMixin, - FilterSerializerMixin, mixins.RetrieveModelMixin, mixins.UpdateModelMixin, mixins.ListModelMixin, viewsets.GenericViewSet, ): + """ + Internal API endpoints for users. + """ + authentication_classes = ( MobileAppAuthTokenAuthentication, PluginAuthentication, @@ -208,7 +226,7 @@ class UserView( ], } - filter_serializer_class = FilterUserSerializer + queryset = User.objects.none() # needed for drf-spectacular introspection pagination_class = HundredPageSizePaginator @@ -264,7 +282,7 @@ class UserView( is_filters_request = query_params.get("filters", "false") == "true" if is_list_request and is_filters_request: - return self.get_filter_serializer_class() + return FilterUserSerializer elif is_list_request and self._is_currently_oncall_request(): return UserIsCurrentlyOnCallSerializer @@ -287,6 +305,13 @@ class UserView( return queryset.order_by("id") + @extend_schema( + responses=PolymorphicProxySerializer( + component_name="UserPolymorphic", + serializers=[FilterUserSerializer, UserIsCurrentlyOnCallSerializer, UserSerializer], + resource_type_field_name=None, + ) + ) def list(self, request, *args, **kwargs) -> Response: queryset = self.filter_queryset(self.get_queryset()) @@ -324,6 +349,7 @@ class UserView( serializer = self.get_serializer(queryset, many=True, context=context) return Response(serializer.data) + @extend_schema(responses=UserSerializer) def retrieve(self, request, *args, **kwargs) -> Response: context = self.get_serializer_context() @@ -345,6 +371,14 @@ class UserView( serializer = self.get_serializer(instance, context=context) return Response(serializer.data) + @extend_schema(request=UserSerializer, responses=UserSerializer) + def update(self, request, *args, **kwargs): + return super().update(request, *args, **kwargs) + + @extend_schema(request=UserSerializer, responses=UserSerializer) + def partial_update(self, request, *args, **kwargs): + return super().partial_update(request, *args, **kwargs) + def wrong_team_response(self) -> Response: """ This method returns 403 and {"error_code": "wrong_team", "owner_team": {"name", "id", "email", "avatar_url"}}. @@ -371,6 +405,7 @@ class UserView( serializer = UserSerializer(self.get_queryset().get(pk=self.request.user.pk)) return Response(serializer.data) + @extend_schema(responses={status.HTTP_200_OK: resolve_type_hint(typing.List[str])}) @action(detail=False, methods=["get"]) def timezone_options(self, request) -> Response: return Response(pytz.common_timezones) @@ -429,6 +464,7 @@ class UserView( ) return Response(status=status.HTTP_200_OK) + @extend_schema(parameters=[inline_serializer(name="UserVerifyNumber", fields={"token": serializers.CharField()})]) @action( detail=True, methods=["put"], @@ -508,6 +544,13 @@ class UserView( return Response(status=status.HTTP_200_OK) + @extend_schema( + parameters=[ + inline_serializer( + name="UserSendTestPush", fields={"critical": serializers.BooleanField(required=False, default=False)} + ) + ] + ) @action(detail=True, methods=["post"], throttle_classes=[TestPushThrottler]) def send_test_push(self, request, pk) -> Response: user = self.get_object() @@ -527,6 +570,11 @@ class UserView( ) return Response(status=status.HTTP_200_OK) + @extend_schema( + parameters=[ + inline_serializer(name="UserGetBackendVerificationCode", fields={"backend": serializers.CharField()}) + ] + ) @action(detail=True, methods=["get"]) def get_backend_verification_code(self, request, pk) -> Response: user = self.get_object() @@ -539,6 +587,15 @@ class UserView( code = backend.generate_user_verification_code(user) return Response(code) + @extend_schema( + responses=inline_serializer( + name="UserGetTelegramVerificationCode", + fields={ + "telegram_code": serializers.CharField(), + "bot_link": serializers.CharField(), + }, + ) + ) @action(detail=True, methods=["get"]) def get_telegram_verification_code(self, request, pk) -> Response: user = self.get_object() @@ -596,6 +653,9 @@ class UserView( return Response(status=status.HTTP_400_BAD_REQUEST) return Response(status=status.HTTP_200_OK) + @extend_schema( + parameters=[inline_serializer(name="UserUnlinkBackend", fields={"backend": serializers.CharField()})] + ) @action(detail=True, methods=["post"]) def unlink_backend(self, request, pk) -> Response: # TODO: insight logs support @@ -619,6 +679,15 @@ class UserView( return Response(status=status.HTTP_400_BAD_REQUEST) return Response(status=status.HTTP_200_OK) + @extend_schema( + parameters=[ + inline_serializer( + name="UserUpcomingShiftsParams", + fields={"days": serializers.IntegerField(required=False, default=UPCOMING_SHIFTS_DEFAULT_DAYS)}, + ) + ], + responses={status.HTTP_200_OK: resolve_type_hint(UpcomingShifts)}, + ) @action(detail=True, methods=["get"]) def upcoming_shifts(self, request, pk) -> Response: user = self.get_object() @@ -658,6 +727,28 @@ class UserView( return Response(upcoming, status=status.HTTP_200_OK) + @extend_schema( + methods=["get"], + responses=inline_serializer( + name="UserExportTokenGetResponse", + fields={ + "created_at": serializers.DateTimeField(), + "revoked_at": serializers.DateTimeField(allow_null=True), + "active": serializers.BooleanField(), + }, + ), + ) + @extend_schema( + methods=["post"], + responses=inline_serializer( + name="UserExportTokenPostResponse", + fields={ + "token": serializers.CharField(), + "created_at": serializers.DateTimeField(), + "export_url": serializers.CharField(), + }, + ), + ) @action(detail=True, methods=["get", "post", "delete"]) def export_token(self, request, pk) -> Response: user = self.get_object() diff --git a/engine/apps/auth_token/auth.py b/engine/apps/auth_token/auth.py index 6d6907c6..89a12560 100644 --- a/engine/apps/auth_token/auth.py +++ b/engine/apps/auth_token/auth.py @@ -4,6 +4,7 @@ from typing import Tuple from django.conf import settings from django.contrib.auth.models import AnonymousUser +from drf_spectacular.extensions import OpenApiAuthenticationExtension from rest_framework import exceptions from rest_framework.authentication import BaseAuthentication, get_authorization_header from rest_framework.request import Request @@ -142,6 +143,22 @@ class PluginAuthentication(BasePluginAuthentication): raise exceptions.AuthenticationFailed("Non-existent or anonymous user.") +class PluginAuthenticationSchema(OpenApiAuthenticationExtension): + target_class = PluginAuthentication + name = "PluginAuthentication" + + def get_security_definition(self, auto_schema): + return { + "type": "apiKey", + "in": "header", + "name": "Authorization", + "description": ( + "Additional X-Instance-Context and X-Grafana-Context headers must be set. " + "THIS WILL NOT WORK IN SWAGGER UI." + ), + } + + class GrafanaIncidentUser(AnonymousUser): @property def is_authenticated(self): diff --git a/engine/apps/heartbeat/models.py b/engine/apps/heartbeat/models.py index 9d8018df..0c0084bd 100644 --- a/engine/apps/heartbeat/models.py +++ b/engine/apps/heartbeat/models.py @@ -95,7 +95,10 @@ class IntegrationHeartBeat(models.Model): return not self.is_expired @property - def link(self) -> str: + def link(self) -> str | None: + if not self.alert_receive_channel.integration_url: + return None + return urljoin(self.alert_receive_channel.integration_url, "heartbeat/") # Insight logs diff --git a/engine/apps/mobile_app/auth.py b/engine/apps/mobile_app/auth.py index 72d0646a..6363f9f2 100644 --- a/engine/apps/mobile_app/auth.py +++ b/engine/apps/mobile_app/auth.py @@ -1,5 +1,6 @@ from typing import Optional, Tuple +from drf_spectacular.extensions import OpenApiAuthenticationExtension from rest_framework import exceptions from rest_framework.authentication import BaseAuthentication, get_authorization_header @@ -43,3 +44,15 @@ class MobileAppAuthTokenAuthentication(BaseAuthentication): return None, None return auth_token.user, auth_token + + +class MobileAppAuthTokenAuthenticationSchema(OpenApiAuthenticationExtension): + target_class = MobileAppAuthTokenAuthentication + name = "MobileAppAuthTokenAuthentication" + + def get_security_definition(self, auto_schema): + return { + "type": "apiKey", + "in": "header", + "name": "Authorization", + } diff --git a/engine/apps/oss_installation/constants.py b/engine/apps/oss_installation/constants.py index 11f3dc48..a20e86b0 100644 --- a/engine/apps/oss_installation/constants.py +++ b/engine/apps/oss_installation/constants.py @@ -1,4 +1,8 @@ -CLOUD_NOT_SYNCED = 0 -CLOUD_SYNCED_USER_NOT_FOUND = 1 -CLOUD_SYNCED_PHONE_NOT_VERIFIED = 2 -CLOUD_SYNCED_PHONE_VERIFIED = 3 +from enum import IntEnum + + +class CloudSyncStatus(IntEnum): + NOT_SYNCED = 0 + SYNCED_USER_NOT_FOUND = 1 + SYNCED_PHONE_NOT_VERIFIED = 2 + SYNCED_PHONE_VERIFIED = 3 diff --git a/engine/apps/oss_installation/utils.py b/engine/apps/oss_installation/utils.py index 7c58fe93..f2f11606 100644 --- a/engine/apps/oss_installation/utils.py +++ b/engine/apps/oss_installation/utils.py @@ -3,7 +3,7 @@ from urllib.parse import urljoin from django.utils import timezone -from apps.oss_installation import constants as oss_constants +from apps.oss_installation.constants import CloudSyncStatus from apps.schedules.ical_utils import list_users_to_notify_from_ical_for_period logger = logging.getLogger(__name__) @@ -70,15 +70,15 @@ def active_oss_users_count(): def cloud_user_identity_status(connector, identity): link = None if connector is None: - status = oss_constants.CLOUD_NOT_SYNCED + status = CloudSyncStatus.NOT_SYNCED elif identity is None: - status = oss_constants.CLOUD_SYNCED_USER_NOT_FOUND + status = CloudSyncStatus.SYNCED_USER_NOT_FOUND link = connector.cloud_url else: if identity.phone_number_verified: - status = oss_constants.CLOUD_SYNCED_PHONE_VERIFIED + status = CloudSyncStatus.SYNCED_PHONE_VERIFIED else: - status = oss_constants.CLOUD_SYNCED_PHONE_NOT_VERIFIED + status = CloudSyncStatus.SYNCED_PHONE_NOT_VERIFIED link = urljoin(connector.cloud_url, f"a/grafana-oncall-app/?page=users&p=1&id={identity.cloud_id}") return status, link diff --git a/engine/apps/oss_installation/views/cloud_users.py b/engine/apps/oss_installation/views/cloud_users.py index d62fd612..471d3a0c 100644 --- a/engine/apps/oss_installation/views/cloud_users.py +++ b/engine/apps/oss_installation/views/cloud_users.py @@ -20,12 +20,9 @@ class CloudUsersPagination(HundredPageSizePaginator): # the override ignore here is expected. The parent classes' get_paginated_response method does not # take a matched_users_count argument. This is fine in this case def get_paginated_response(self, data: PaginatedData, matched_users_count: int) -> Response: # type: ignore[override] - return Response( - { - **self._get_paginated_response_data(data), - "matched_users_count": matched_users_count, - } - ) + response = super().get_paginated_response(data) + response.data["matched_users_count"] = matched_users_count + return response class CloudUsersView(CloudUsersPagination, APIView): diff --git a/engine/apps/slack/models/slack_user_identity.py b/engine/apps/slack/models/slack_user_identity.py index ee76b32f..9f16ca5c 100644 --- a/engine/apps/slack/models/slack_user_identity.py +++ b/engine/apps/slack/models/slack_user_identity.py @@ -147,7 +147,7 @@ class SlackUserIdentity(models.Model): ) @property - def slack_verbal(self): + def slack_verbal(self) -> str | None: return ( self.profile_real_name_normalized or self.profile_real_name diff --git a/engine/apps/user_management/models/user.py b/engine/apps/user_management/models/user.py index 42a942b5..9aa70719 100644 --- a/engine/apps/user_management/models/user.py +++ b/engine/apps/user_management/models/user.py @@ -257,7 +257,7 @@ class User(models.Model): return urljoin(self.organization.grafana_url, self.avatar_url) @property - def verified_phone_number(self): + def verified_phone_number(self) -> str | None: """ Use property to highlight that _verified_phone_number should not be modified directly """ diff --git a/engine/common/api_helpers/custom_fields.py b/engine/common/api_helpers/custom_fields.py index dcfc1bb8..42ec2759 100644 --- a/engine/common/api_helpers/custom_fields.py +++ b/engine/common/api_helpers/custom_fields.py @@ -1,4 +1,5 @@ from django.core.exceptions import ObjectDoesNotExist +from drf_spectacular.utils import extend_schema_field from rest_framework import fields, serializers from rest_framework.exceptions import ValidationError from rest_framework.relations import RelatedField @@ -9,6 +10,7 @@ from common.api_helpers.exceptions import BadRequest from common.timezones import raise_exception_if_not_valid_timezone +@extend_schema_field(serializers.CharField) class OrganizationFilteredPrimaryKeyRelatedField(RelatedField): """ This field is used to filter entities by organization @@ -42,6 +44,7 @@ class OrganizationFilteredPrimaryKeyRelatedField(RelatedField): return self.display_func(instance) +@extend_schema_field(serializers.CharField) class TeamPrimaryKeyRelatedField(RelatedField): """ This field is used to get user teams diff --git a/engine/common/api_helpers/filters.py b/engine/common/api_helpers/filters.py index 7e174ae8..12e1fe41 100644 --- a/engine/common/api_helpers/filters.py +++ b/engine/common/api_helpers/filters.py @@ -3,6 +3,8 @@ from datetime import datetime from django.db.models import Q from django_filters import rest_framework as filters from django_filters.utils import handle_timezone +from drf_spectacular.utils import extend_schema_field +from rest_framework import serializers from apps.user_management.models import Team from common.api_helpers.exceptions import BadRequest @@ -48,6 +50,13 @@ class DateRangeFilterMixin: return start_date, end_date +@extend_schema_field(serializers.CharField) +class MultipleChoiceCharFilter(filters.ModelMultipleChoiceFilter): + """MultipleChoiceCharFilter with an explicit schema. Otherwise, drf-specacular may generate a wrong schema.""" + + pass + + class ModelFieldFilterMixin: def filter_model_field(self, queryset, name, value): if not value: @@ -106,6 +115,7 @@ class ByTeamFilter(ByTeamModelFieldFilterMixin, filters.FilterSet): ) +@extend_schema_field(serializers.CharField) class TeamModelMultipleChoiceFilter(filters.ModelMultipleChoiceFilter): def __init__( self, diff --git a/engine/common/api_helpers/mixins.py b/engine/common/api_helpers/mixins.py index 5e2fb8c2..7bce174e 100644 --- a/engine/common/api_helpers/mixins.py +++ b/engine/common/api_helpers/mixins.py @@ -6,7 +6,8 @@ from django.core.exceptions import ObjectDoesNotExist from django.db import models from django.db.models import Q from django.utils.functional import cached_property -from rest_framework import status +from drf_spectacular.utils import extend_schema, inline_serializer +from rest_framework import serializers, status from rest_framework.decorators import action from rest_framework.exceptions import NotFound, Throttled from rest_framework.request import Request @@ -281,6 +282,21 @@ class PreviewTemplateException(Exception): class PreviewTemplateMixin: + @extend_schema( + description="Preview template", + request=inline_serializer( + name="PreviewTemplateRequest", + fields={ + "template_body": serializers.CharField(required=False, allow_null=True), + "template_name": serializers.CharField(required=False, allow_null=True), + "payload": serializers.DictField(required=False, allow_null=True), + }, + ), + responses=inline_serializer( + name="PreviewTemplateResponse", + fields={"preview": serializers.CharField(allow_null=True)}, + ), + ) @action(methods=["post"], detail=True) def preview_template(self, request, pk): template_body = request.data.get("template_body", None) diff --git a/engine/common/api_helpers/paginators.py b/engine/common/api_helpers/paginators.py index 554c3497..3020b9d0 100644 --- a/engine/common/api_helpers/paginators.py +++ b/engine/common/api_helpers/paginators.py @@ -28,40 +28,43 @@ class BasePathPrefixedPagination(BasePagination): def paginate_queryset(self, queryset, request, view=None): request.build_absolute_uri = lambda: create_engine_url(request.get_full_path()) - - # we're setting the request object explicitly here because the way the paginate_quersey works - # between PageNumberPagination and CursorPagination is slightly different. In the latter class, - # it does not set self.request in the paginate_queryset method, whereas in the former it does. - # this leads to an issue in _get_base_paginated_response_data where the self.request would not be set - self.request = request - return super().paginate_queryset(queryset, request, view) - def _get_base_paginated_response_data(self, data: PaginatedData) -> BasePaginatedResponseData: - return { - "next": self.get_next_link(), - "previous": self.get_previous_link(), - "results": data, - "page_size": self.get_page_size(self.request), - } - class PathPrefixedPagePagination(BasePathPrefixedPagination, PageNumberPagination): - def _get_paginated_response_data(self, data: PaginatedData) -> PageBasedPaginationResponseData: - return { - **self._get_base_paginated_response_data(data), - "count": self.page.paginator.count, - "current_page_number": self.page.number, - "total_pages": self.page.paginator.num_pages, - } - def get_paginated_response(self, data: PaginatedData) -> Response: - return Response(self._get_paginated_response_data(data)) + response = super().get_paginated_response(data) + response.data.update( + { + "page_size": self.get_page_size(self.request), + "current_page_number": self.page.number, + "total_pages": self.page.paginator.num_pages, + } + ) + return response + + def get_paginated_response_schema(self, schema): + paginated_schema = super().get_paginated_response_schema(schema) + paginated_schema["properties"].update( + { + "page_size": {"type": "integer"}, + "current_page_number": {"type": "integer"}, + "total_pages": {"type": "integer"}, + } + ) + return paginated_schema class PathPrefixedCursorPagination(BasePathPrefixedPagination, CursorPagination): def get_paginated_response(self, data: PaginatedData) -> Response: - return Response(self._get_base_paginated_response_data(data)) + response = super().get_paginated_response(data) + response.data.update({"page_size": self.page_size}) + return response + + def get_paginated_response_schema(self, schema): + paginated_schema = super().get_paginated_response_schema(schema) + paginated_schema["properties"].update({"page_size": {"type": "integer"}}) + return paginated_schema class HundredPageSizePaginator(PathPrefixedPagePagination): diff --git a/engine/engine/schema.py b/engine/engine/schema.py new file mode 100644 index 00000000..e1cee88c --- /dev/null +++ b/engine/engine/schema.py @@ -0,0 +1,47 @@ +from drf_spectacular.openapi import AutoSchema +from drf_spectacular.plumbing import get_view_model + +from common.api_helpers.mixins import PublicPrimaryKeyMixin + + +class CustomAutoSchema(AutoSchema): + def _get_serializer(self): + """Makes so that extra actions (@action on viewset) don't inherit the serializer from the viewset.""" + if self._is_extra_action: + return None + return super()._get_serializer() + + def _get_paginator(self): + """Makes so that extra actions (@action on viewset) don't inherit the paginator from the viewset.""" + if self._is_extra_action: + return None + return super()._get_paginator() + + def get_filter_backends(self): + """Makes so that extra actions (@action on viewset) don't inherit the filter backends from the viewset.""" + if self._is_extra_action: + return [] + return super().get_filter_backends() + + def _resolve_path_parameters(self, variables): + """A workaround to make public primary keys appear as strings in the OpenAPI schema.""" + + parameters = super()._resolve_path_parameters(variables) + if not isinstance(self.view, PublicPrimaryKeyMixin): + return parameters + + for parameter in parameters: + if parameter["name"] == "id" and parameter["in"] == "path": + parameter["schema"]["type"] = "string" + model = get_view_model(self.view, emit_warnings=False) + model_name = model._meta.verbose_name if model else "resource" + parameter["description"] = f"A string identifying this {model_name}." + + return parameters + + @property + def _is_extra_action(self) -> bool: + try: + return self.view.action in [action.__name__ for action in self.view.get_extra_actions()] + except AttributeError: + return False diff --git a/engine/settings/base.py b/engine/settings/base.py index a80c1166..81f1bb3f 100644 --- a/engine/settings/base.py +++ b/engine/settings/base.py @@ -287,7 +287,7 @@ REST_FRAMEWORK = { "rest_framework.parsers.MultiPartParser", ), "DEFAULT_AUTHENTICATION_CLASSES": [], - "DEFAULT_SCHEMA_CLASS": "drf_spectacular.openapi.AutoSchema", + "DEFAULT_SCHEMA_CLASS": "engine.schema.CustomAutoSchema", } @@ -315,6 +315,8 @@ if SWAGGER_UI_SETTINGS_URL: SPECTACULAR_INCLUDED_PATHS = [ "/features", "/alertgroups", + "/alert_receive_channels", + "/users", "/labels", ] diff --git a/engine/settings/prod_without_db.py b/engine/settings/prod_without_db.py index 9b4e7682..f505c8f9 100644 --- a/engine/settings/prod_without_db.py +++ b/engine/settings/prod_without_db.py @@ -55,5 +55,5 @@ REST_FRAMEWORK = { ), "DEFAULT_AUTHENTICATION_CLASSES": [], "DEFAULT_RENDERER_CLASSES": ("rest_framework.renderers.JSONRenderer",), - "DEFAULT_SCHEMA_CLASS": "drf_spectacular.openapi.AutoSchema", + "DEFAULT_SCHEMA_CLASS": "engine.schema.CustomAutoSchema", }