Fix broken openapi schema + add integration test (#3364)

# Which issue(s) this PR fixes

- Fix issue that was causing our openapi schema to return HTTP 500 + add
an integration test which fetches the `.yaml` schema and validates that
the endpoint returns HTTP 200 (should hopefully prevent this from
happening again).
- add a few more type hints along the way

## Checklist

- [x] Unit, integration, and e2e (if applicable) tests updated
- [x] Documentation added (or `pr:no public docs` PR label added if not
required)
- [x] `CHANGELOG.md` updated (or `pr:no changelog` PR label added if not
required)
This commit is contained in:
Joey Orlando 2023-11-16 07:15:05 -05:00 committed by GitHub
parent 607e87c6c2
commit 77cb381366
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 125 additions and 73 deletions

View file

@ -179,6 +179,7 @@ class AlertGroupSlackRenderingMixin:
class AlertGroup(AlertGroupSlackRenderingMixin, EscalationSnapshotMixin, models.Model):
acknowledged_by_user: typing.Optional["User"]
alerts: "RelatedManager['Alert']"
dependent_alert_groups: "RelatedManager['AlertGroup']"
channel: "AlertReceiveChannel"
@ -187,7 +188,9 @@ class AlertGroup(AlertGroupSlackRenderingMixin, EscalationSnapshotMixin, models.
resolution_notes: "RelatedManager['ResolutionNote']"
resolution_note_slack_messages: "RelatedManager['ResolutionNoteSlackMessage']"
resolved_by_alert: typing.Optional["Alert"]
resolved_by_user: typing.Optional["User"]
root_alert_group: typing.Optional["AlertGroup"]
silenced_by_user: typing.Optional["User"]
slack_log_message: typing.Optional["SlackMessage"]
slack_messages: "RelatedManager['SlackMessage']"
users: "RelatedManager['User']"

View file

@ -1,21 +1,22 @@
import datetime
import logging
import typing
from django.core.cache import cache
from django.utils import timezone
from drf_spectacular.utils import extend_schema_field, inline_serializer
from rest_framework import serializers
from apps.alerts.incident_appearance.renderers.classic_markdown_renderer import AlertGroupClassicMarkdownRenderer
from apps.alerts.incident_appearance.renderers.web_renderer import AlertGroupWebRenderer
from apps.alerts.models import AlertGroup
from apps.alerts.models.alert_group import PagedUser
from common.api_helpers.custom_fields import TeamPrimaryKeyRelatedField
from common.api_helpers.mixins import EagerLoadingMixin
from .alert import AlertSerializer
from .alert_receive_channel import FastAlertReceiveChannelSerializer
from .alerts_field_cache_buster_mixin import AlertsFieldCacheBusterMixin
from .user import FastUserSerializer, PagedUserSerializer, UserShortSerializer
from .user import FastUserSerializer, UserShortSerializer
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
@ -90,7 +91,7 @@ class ShortAlertGroupSerializer(AlertGroupFieldsCacheSerializerMixin, serializer
},
)
)
def get_render_for_web(self, obj):
def get_render_for_web(self, obj: "AlertGroup"):
last_alert = obj.alerts.last()
if last_alert is None:
return {}
@ -102,7 +103,9 @@ class ShortAlertGroupSerializer(AlertGroupFieldsCacheSerializerMixin, serializer
)
class AlertGroupListSerializer(EagerLoadingMixin, AlertGroupFieldsCacheSerializerMixin, serializers.ModelSerializer):
class AlertGroupListSerializer(
EagerLoadingMixin, AlertGroupFieldsCacheSerializerMixin, serializers.ModelSerializer[AlertGroup]
):
pk = serializers.CharField(read_only=True, source="public_primary_key")
alert_receive_channel = FastAlertReceiveChannelSerializer(source="channel")
status = serializers.ReadOnlyField()
@ -116,7 +119,6 @@ class AlertGroupListSerializer(EagerLoadingMixin, AlertGroupFieldsCacheSerialize
alerts_count = serializers.IntegerField(read_only=True)
render_for_web = serializers.SerializerMethodField()
render_for_classic_markdown = serializers.SerializerMethodField()
labels = AlertGroupLabelSerializer(many=True, read_only=True)
@ -158,7 +160,6 @@ class AlertGroupListSerializer(EagerLoadingMixin, AlertGroupFieldsCacheSerialize
"silenced_until",
"related_users",
"render_for_web",
"render_for_classic_markdown",
"dependent_alert_groups",
"root_alert_group",
"status",
@ -179,7 +180,7 @@ class AlertGroupListSerializer(EagerLoadingMixin, AlertGroupFieldsCacheSerialize
},
)
)
def get_render_for_web(self, obj):
def get_render_for_web(self, obj: "AlertGroup"):
if not obj.last_alert:
return {}
return AlertGroupFieldsCacheSerializerMixin.get_or_set_web_template_field(
@ -189,21 +190,12 @@ class AlertGroupListSerializer(EagerLoadingMixin, AlertGroupFieldsCacheSerialize
AlertGroupWebRenderer,
)
def get_render_for_classic_markdown(self, obj):
"""Deprecated. TODO: remove"""
if not obj.last_alert:
return {}
return AlertGroupFieldsCacheSerializerMixin.get_or_set_web_template_field(
obj,
obj.last_alert,
AlertGroupFieldsCacheSerializerMixin.RENDER_FOR_CLASSIC_MARKDOWN_FIELD_NAME,
AlertGroupClassicMarkdownRenderer,
)
@extend_schema_field(UserShortSerializer(many=True))
def get_related_users(self, obj):
users_ids = set()
users = []
def get_related_users(self, obj: "AlertGroup"):
from apps.user_management.models import User
users_ids: typing.Set[str] = set()
users: typing.List[User] = []
# add resolved and acknowledged by_user explicitly because logs are already prefetched
# when def acknowledge/resolve are called in view.
@ -241,7 +233,7 @@ class AlertGroupSerializer(AlertGroupListSerializer):
"paged_users",
]
def get_last_alert_at(self, obj) -> datetime.datetime:
def get_last_alert_at(self, obj: "AlertGroup") -> datetime.datetime:
last_alert = obj.alerts.last()
if not last_alert:
@ -250,7 +242,7 @@ class AlertGroupSerializer(AlertGroupListSerializer):
return last_alert.created_at
@extend_schema_field(AlertSerializer(many=True))
def get_limited_alerts(self, obj):
def get_limited_alerts(self, obj: "AlertGroup"):
"""
Overriding default alerts because there are alert_groups with thousands of them.
It's just too slow, we need to cut here.
@ -258,6 +250,5 @@ class AlertGroupSerializer(AlertGroupListSerializer):
alerts = obj.alerts.order_by("-pk")[:100]
return AlertSerializer(alerts, many=True).data
@extend_schema_field(PagedUserSerializer(many=True))
def get_paged_users(self, obj):
def get_paged_users(self, obj: "AlertGroup") -> typing.List[PagedUser]:
return obj.get_paged_users()

View file

@ -1,3 +1,4 @@
import typing
from collections import OrderedDict
from django.conf import settings
@ -39,7 +40,9 @@ class IntegrationAlertGroupLabelsSerializer(serializers.Serializer):
inheritable = serializers.DictField(child=serializers.BooleanField())
class AlertReceiveChannelSerializer(EagerLoadingMixin, LabelsSerializerMixin, serializers.ModelSerializer):
class AlertReceiveChannelSerializer(
EagerLoadingMixin, LabelsSerializerMixin, serializers.ModelSerializer[AlertReceiveChannel]
):
id = serializers.CharField(read_only=True, source="public_primary_key")
integration_url = serializers.ReadOnlyField()
alert_count = serializers.SerializerMethodField()
@ -163,12 +166,12 @@ class AlertReceiveChannelSerializer(EagerLoadingMixin, LabelsSerializerMixin, se
except AlertReceiveChannel.DuplicateDirectPagingError:
raise BadRequest(detail=AlertReceiveChannel.DuplicateDirectPagingError.DETAIL)
def get_instructions(self, obj):
def get_instructions(self, obj: "AlertReceiveChannel"):
# Deprecated, kept for api-backward compatibility
return ""
# MethodFields are used instead of relevant properties because of properties hit db on each instance in queryset
def get_default_channel_filter(self, obj):
def get_default_channel_filter(self, obj: "AlertReceiveChannel"):
for filter in obj.channel_filters.all():
if filter.is_default:
return filter.public_primary_key
@ -192,29 +195,29 @@ class AlertReceiveChannelSerializer(EagerLoadingMixin, LabelsSerializerMixin, se
else:
raise serializers.ValidationError(detail="Integration with this name already exists")
def get_heartbeat(self, obj):
def get_heartbeat(self, obj: "AlertReceiveChannel"):
try:
heartbeat = obj.integration_heartbeat
except ObjectDoesNotExist:
return None
return IntegrationHeartBeatSerializer(heartbeat).data
def get_allow_delete(self, obj):
def get_allow_delete(self, obj: "AlertReceiveChannel"):
return True
def get_alert_count(self, obj):
def get_alert_count(self, obj: "AlertReceiveChannel"):
return 0
def get_alert_groups_count(self, obj):
def get_alert_groups_count(self, obj: "AlertReceiveChannel"):
return 0
def get_routes_count(self, obj) -> int:
def get_routes_count(self, obj: "AlertReceiveChannel") -> int:
return obj.channel_filters.count()
def get_is_legacy(self, obj) -> bool:
def get_is_legacy(self, obj: "AlertReceiveChannel") -> bool:
return has_legacy_prefix(obj.integration)
def get_connected_escalations_chains_count(self, obj) -> int:
def get_connected_escalations_chains_count(self, obj: "AlertReceiveChannel") -> int:
return (
ChannelFilter.objects.filter(alert_receive_channel=obj, escalation_chain__isnull=False)
.values("escalation_chain")
@ -228,7 +231,7 @@ class AlertReceiveChannelUpdateSerializer(AlertReceiveChannelSerializer):
read_only_fields = [*AlertReceiveChannelSerializer.Meta.read_only_fields, "integration"]
class FastAlertReceiveChannelSerializer(serializers.ModelSerializer):
class FastAlertReceiveChannelSerializer(serializers.ModelSerializer[AlertReceiveChannel]):
id = serializers.CharField(read_only=True, source="public_primary_key")
integration = serializers.CharField(read_only=True)
deleted = serializers.SerializerMethodField()
@ -237,27 +240,30 @@ class FastAlertReceiveChannelSerializer(serializers.ModelSerializer):
model = AlertReceiveChannel
fields = ["id", "integration", "verbal_name", "deleted"]
def get_deleted(self, obj):
def get_deleted(self, obj: "AlertReceiveChannel") -> bool:
return obj.deleted_at is not None
class FilterAlertReceiveChannelSerializer(serializers.ModelSerializer):
value = serializers.SerializerMethodField()
class FilterAlertReceiveChannelSerializer(serializers.ModelSerializer[AlertReceiveChannel]):
# don't use get_value as the method name, otherwise this will override the get_value method on
# serializers.ModelSerializer, which may cause unexpected behavior (+ this violates the "Lisov substition
# principle" which mypy complains about)
value = serializers.SerializerMethodField(method_name="_get_value")
display_name = serializers.SerializerMethodField()
class Meta:
model = AlertReceiveChannel
fields = ["value", "display_name", "integration_url"]
def get_value(self, obj):
def _get_value(self, obj: "AlertReceiveChannel"):
return obj.public_primary_key
def get_display_name(self, obj):
def get_display_name(self, obj: "AlertReceiveChannel"):
display_name = obj.verbal_name or AlertReceiveChannel.INTEGRATION_CHOICES[obj.integration][1]
return display_name
class AlertReceiveChannelTemplatesSerializer(EagerLoadingMixin, serializers.ModelSerializer):
class AlertReceiveChannelTemplatesSerializer(EagerLoadingMixin, serializers.ModelSerializer[AlertReceiveChannel]):
id = serializers.CharField(read_only=True, source="public_primary_key")
payload_example = SerializerMethodField()
@ -273,7 +279,7 @@ class AlertReceiveChannelTemplatesSerializer(EagerLoadingMixin, serializers.Mode
]
extra_kwargs = {"integration": {"required": True}}
def get_payload_example(self, obj):
def get_payload_example(self, obj: "AlertReceiveChannel"):
from apps.alerts.models import AlertGroup
if "alert_group_id" in self.context["request"].query_params:
@ -290,7 +296,7 @@ class AlertReceiveChannelTemplatesSerializer(EagerLoadingMixin, serializers.Mode
except AttributeError:
return None
def get_is_based_on_alertmanager(self, obj):
def get_is_based_on_alertmanager(self, obj: "AlertReceiveChannel"):
return obj.based_on_alertmanager
# Override method to pass field_name directly in set_value to handle None values for WritableSerializerField
@ -366,7 +372,7 @@ class AlertReceiveChannelTemplatesSerializer(EagerLoadingMixin, serializers.Mode
set_value(ret, [field_name], value)
return errors
def to_representation(self, obj):
def to_representation(self, obj: "AlertReceiveChannel"):
ret = super().to_representation(obj)
core_templates = self._get_core_templates(obj)
@ -378,7 +384,7 @@ class AlertReceiveChannelTemplatesSerializer(EagerLoadingMixin, serializers.Mode
return ret
def _get_messaging_backend_templates(self, obj):
def _get_messaging_backend_templates(self, obj: "AlertReceiveChannel"):
"""Return additional messaging backend templates if any."""
templates = {}
for backend_id, backend in get_messaging_backends():
@ -397,7 +403,7 @@ class AlertReceiveChannelTemplatesSerializer(EagerLoadingMixin, serializers.Mode
templates[f"{field_name}_is_default"] = is_default
return templates
def _get_core_templates(self, obj):
def _get_core_templates(self, obj: "AlertReceiveChannel"):
core_templates = {}
for template_name in self.core_templates_names:
@ -410,10 +416,9 @@ class AlertReceiveChannelTemplatesSerializer(EagerLoadingMixin, serializers.Mode
return core_templates
@property
def core_templates_names(self):
def core_templates_names(self) -> typing.List[str]:
"""
core_templates_names returns names of templates introduced before messaging backends system with respect to
enabled integrations.
returns names of templates introduced before messaging backends system with respect to enabled integrations.
"""
core_templates = [
"web_title_template",
@ -427,21 +432,16 @@ class AlertReceiveChannelTemplatesSerializer(EagerLoadingMixin, serializers.Mode
"acknowledge_condition_template",
]
slack_integration_required_templates = [
"slack_title_template",
"slack_message_template",
"slack_image_url_template",
]
telegram_integration_required_templates = [
"telegram_title_template",
"telegram_message_template",
"telegram_image_url_template",
]
apppend = []
if settings.FEATURE_SLACK_INTEGRATION_ENABLED:
core_templates += slack_integration_required_templates
core_templates += [
"slack_title_template",
"slack_message_template",
"slack_image_url_template",
]
if settings.FEATURE_TELEGRAM_INTEGRATION_ENABLED:
core_templates += telegram_integration_required_templates
return apppend + core_templates
core_templates += [
"telegram_title_template",
"telegram_message_template",
"telegram_image_url_template",
]
return core_templates

View file

@ -5,8 +5,7 @@ from django.core.cache import cache
class AlertsFieldCacheBusterMixin:
RENDER_FOR_WEB_FIELD_NAME = "render_for_web"
RENDER_FOR_CLASSIC_MARKDOWN_FIELD_NAME = "render_for_classic_markdown"
ALL_FIELD_NAMES = [RENDER_FOR_WEB_FIELD_NAME, RENDER_FOR_CLASSIC_MARKDOWN_FIELD_NAME]
ALL_FIELD_NAMES = [RENDER_FOR_WEB_FIELD_NAME]
@classmethod
def calculate_cache_key(cls, field_name: str, obj: typing.Any) -> str:

View file

@ -0,0 +1,17 @@
import pytest
import yaml
from django.urls import reverse
from rest_framework import status
from rest_framework.test import APIClient
@pytest.mark.django_db
def test_fetching_the_openapi_schema_works(settings, reload_urls):
settings.DRF_SPECTACULAR_ENABLED = True
reload_urls()
client = APIClient()
response = client.get(reverse("schema"))
assert response.status_code == status.HTTP_200_OK
assert yaml.safe_load(response.content)["info"]["title"] == settings.SPECTACULAR_SETTINGS["TITLE"]

View file

@ -5,7 +5,7 @@ from django.db.models import Count, Max, Q
from django.utils import timezone
from django_filters import rest_framework as filters
from django_filters.widgets import RangeWidget
from drf_spectacular.utils import extend_schema, inline_serializer
from drf_spectacular.utils import extend_schema, extend_schema_view, inline_serializer
from rest_framework import mixins, serializers, status, viewsets
from rest_framework.decorators import action
from rest_framework.exceptions import NotFound
@ -267,6 +267,12 @@ class AlertGroupTeamFilteringMixin(TeamFilteringMixin):
return Response(data={"error_code": "wrong_team"}, status=status.HTTP_403_FORBIDDEN)
@extend_schema_view(
list=extend_schema(description="Fetch a list of alert groups"),
retrieve=extend_schema(description="Fetch a single alert group"),
destroy=extend_schema(description="Delete an alert group"),
preview_template=extend_schema(description="Preview a template for an alert group"),
)
class AlertGroupView(
PreviewTemplateMixin,
AlertGroupTeamFilteringMixin,
@ -290,8 +296,6 @@ class AlertGroupView(
"filters": [RBACPermission.Permissions.ALERT_GROUPS_READ],
"silence_options": [RBACPermission.Permissions.ALERT_GROUPS_READ],
"bulk_action_options": [RBACPermission.Permissions.ALERT_GROUPS_READ],
"create": [RBACPermission.Permissions.ALERT_GROUPS_WRITE],
"update": [RBACPermission.Permissions.ALERT_GROUPS_WRITE],
"destroy": [RBACPermission.Permissions.ALERT_GROUPS_WRITE],
"acknowledge": [RBACPermission.Permissions.ALERT_GROUPS_WRITE],
"unacknowledge": [RBACPermission.Permissions.ALERT_GROUPS_WRITE],
@ -478,7 +482,9 @@ class AlertGroupView(
@extend_schema(responses=inline_serializer(name="AlertGroupStats", fields={"count": serializers.IntegerField()}))
@action(detail=False)
def stats(self, *args, **kwargs):
"""Return number of alert groups capped at 100001"""
"""
Return number of alert groups capped at 100001
"""
MAX_COUNT = 100001
alert_groups = self.filter_queryset(self.get_queryset())[:MAX_COUNT]
count = alert_groups.count()
@ -491,6 +497,9 @@ class AlertGroupView(
@action(methods=["post"], detail=True)
def acknowledge(self, request, pk):
"""
Acknowledge an alert group
"""
alert_group = self.get_object()
if alert_group.is_maintenance_incident:
raise BadRequest(detail="Can't acknowledge maintenance alert group")
@ -502,6 +511,9 @@ class AlertGroupView(
@action(methods=["post"], detail=True)
def unacknowledge(self, request, pk):
"""
Unacknowledge an alert group
"""
alert_group = self.get_object()
if alert_group.is_maintenance_incident:
raise BadRequest(detail="Can't unacknowledge maintenance alert group")
@ -521,6 +533,9 @@ class AlertGroupView(
@action(methods=["post"], detail=True)
def resolve(self, request, pk):
"""
Resolve an alert group
"""
alert_group = self.get_object()
organization = self.request.user.organization
@ -563,6 +578,9 @@ class AlertGroupView(
@action(methods=["post"], detail=True)
def unresolve(self, request, pk):
"""
Unresolve an alert group
"""
alert_group = self.get_object()
if alert_group.is_maintenance_incident:
raise BadRequest(detail="Can't unresolve maintenance alert group")
@ -603,6 +621,9 @@ class AlertGroupView(
@action(methods=["post"], detail=True)
def unattach(self, request, pk=None):
"""
Unattach an alert group that is already attached to another alert group
"""
alert_group = self.get_object()
if alert_group.is_maintenance_incident:
raise BadRequest(detail="Can't unattach maintenance alert group")
@ -614,6 +635,9 @@ class AlertGroupView(
@action(methods=["post"], detail=True)
def silence(self, request, pk=None):
"""
Silence an alert group for a specified delay
"""
alert_group = self.get_object()
delay = request.data.get("delay")
@ -635,6 +659,9 @@ class AlertGroupView(
)
@action(methods=["get"], detail=False)
def silence_options(self, request):
"""
Retrieve a list of valid silence options
"""
data = [
{"value": value, "display_name": display_name} for value, display_name in AlertGroup.SILENCE_DELAY_OPTIONS
]
@ -642,6 +669,9 @@ class AlertGroupView(
@action(methods=["post"], detail=True)
def unsilence(self, request, pk=None):
"""
Unsilence a silenced alert group
"""
alert_group = self.get_object()
if not alert_group.silenced:
@ -662,6 +692,9 @@ class AlertGroupView(
@action(methods=["post"], detail=True)
def unpage_user(self, request, pk=None):
"""
Remove a user that was directly paged for the alert group
"""
organization = request.auth.organization
from_user = request.user
alert_group = self.get_object()
@ -681,6 +714,9 @@ class AlertGroupView(
@action(methods=["get"], detail=False)
def filters(self, request):
"""
Retrieve a list of valid filter options that can be used to filter alert groups
"""
filter_name = request.query_params.get("search", None)
api_root = "/api/internal/v1/"
@ -780,6 +816,9 @@ class AlertGroupView(
@action(methods=["post"], detail=False)
def bulk_action(self, request):
"""
Perform a bulk action on a list of alert groups
"""
alert_group_public_pks = self.request.data.get("alert_group_pks", [])
action_with_incidents = self.request.data.get("action", None)
delay = self.request.data.get("delay")
@ -807,6 +846,9 @@ class AlertGroupView(
@action(methods=["get"], detail=False)
def bulk_action_options(self, request):
"""
Retrieve a list of valid bulk action options
"""
return Response(
[{"value": action_name, "display_name": action_name} for action_name in AlertGroup.BULK_ACTIONS]
)

View file

@ -57,5 +57,5 @@ urllib3==1.26.18
prometheus_client==0.16.0
lxml==4.9.2
babel==2.12.1
drf-spectacular==0.26.2
drf-spectacular==0.26.5
grpcio==1.57.0