diff --git a/engine/apps/integrations/tests/test_integration_backsync.py b/engine/apps/integrations/tests/test_integration_backsync.py new file mode 100644 index 00000000..a30b8702 --- /dev/null +++ b/engine/apps/integrations/tests/test_integration_backsync.py @@ -0,0 +1,62 @@ +from unittest.mock import PropertyMock, patch + +import pytest +from django.core.cache import cache +from django.urls import reverse +from rest_framework import status +from rest_framework.test import APIClient + + +@pytest.mark.django_db +def test_integration_backsync_endpoint( + make_organization, + make_alert_receive_channel, + make_token_for_integration, +): + organization = make_organization() + alert_receive_channel = make_alert_receive_channel(organization=organization) + _, token = make_token_for_integration(alert_receive_channel, organization) + + client = APIClient() + url = reverse("integrations:integration_backsync") + + response = client.post(url, format="json", HTTP_AUTHORIZATION=token) + assert response.status_code == status.HTTP_200_OK + + +@pytest.mark.django_db +def test_integration_backsync_endpoint_wrong_token( + make_organization, + make_alert_receive_channel, +): + client = APIClient() + url = reverse("integrations:integration_backsync") + response = client.post(url, format="json", HTTP_AUTHORIZATION="randomtesttoken") + assert response.status_code == status.HTTP_403_FORBIDDEN + + +@pytest.mark.django_db +def test_integration_backsync_endpoint_throttling( + make_organization, + make_alert_receive_channel, + make_token_for_integration, +): + organization = make_organization() + alert_receive_channel = make_alert_receive_channel(organization=organization) + _, token = make_token_for_integration(alert_receive_channel, organization) + + client = APIClient() + url = reverse("integrations:integration_backsync") + cache.clear() + + with patch( + "apps.integrations.throttlers.integration_backsync_throttler.BacksyncRateThrottle.rate", + new_callable=PropertyMock, + ) as mocked_rate: + mocked_rate.return_value = "1/m" + + response = client.post(url, format="json", HTTP_AUTHORIZATION=token) + assert response.status_code == status.HTTP_200_OK + + response = client.post(url, format="json", HTTP_AUTHORIZATION=f"{token}") + assert response.status_code == status.HTTP_429_TOO_MANY_REQUESTS diff --git a/engine/apps/integrations/throttlers/__init__.py b/engine/apps/integrations/throttlers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/engine/apps/integrations/throttlers/integration_backsync_throttler.py b/engine/apps/integrations/throttlers/integration_backsync_throttler.py new file mode 100644 index 00000000..dacfd258 --- /dev/null +++ b/engine/apps/integrations/throttlers/integration_backsync_throttler.py @@ -0,0 +1,14 @@ +from rest_framework.throttling import SimpleRateThrottle + + +class BacksyncRateThrottle(SimpleRateThrottle): + """ + Integration backsync rate limit + """ + + scope = "backsync" + rate = "300/m" + + def get_cache_key(self, request, view): + ident = request.auth.alert_receive_channel.pk + return self.cache_format % {"scope": self.scope, "ident": ident} diff --git a/engine/apps/integrations/urls.py b/engine/apps/integrations/urls.py index 9186f98c..f0170510 100644 --- a/engine/apps/integrations/urls.py +++ b/engine/apps/integrations/urls.py @@ -11,6 +11,7 @@ from .views import ( AmazonSNS, GrafanaAlertingAPIView, GrafanaAPIView, + IntegrationBacksyncAPIView, IntegrationHeartBeatAPIView, UniversalAPIView, ) @@ -32,6 +33,8 @@ urlpatterns = [ path("alertmanager//", AlertManagerAPIView.as_view(), name="alertmanager"), path("amazon_sns//", AmazonSNS.as_view(), name="amazon_sns"), path("//", UniversalAPIView.as_view(), name="universal"), + # integration backsync + path("backsync/", IntegrationBacksyncAPIView.as_view(), name="integration_backsync"), ] if settings.FEATURE_INBOUND_EMAIL_ENABLED: diff --git a/engine/apps/integrations/views.py b/engine/apps/integrations/views.py index 003734f2..c51e8783 100644 --- a/engine/apps/integrations/views.py +++ b/engine/apps/integrations/views.py @@ -8,10 +8,12 @@ from django.utils import timezone from django.utils.decorators import method_decorator from django.views.decorators.csrf import csrf_exempt from django_sns_view.views import SNSEndpoint +from rest_framework.permissions import IsAuthenticated from rest_framework.response import Response from rest_framework.views import APIView from apps.alerts.models import AlertReceiveChannel +from apps.auth_token.auth import IntegrationBacksyncAuthentication from apps.heartbeat.tasks import process_heartbeat_task from apps.integrations.legacy_prefix import has_legacy_prefix from apps.integrations.mixins import ( @@ -22,6 +24,7 @@ from apps.integrations.mixins import ( is_ratelimit_ignored, ) from apps.integrations.tasks import create_alert, create_alertmanager_alerts +from apps.integrations.throttlers.integration_backsync_throttler import BacksyncRateThrottle from apps.user_management.exceptions import OrganizationDeletedException, OrganizationMovedException from common.api_helpers.utils import create_engine_url @@ -353,3 +356,16 @@ class IntegrationHeartBeatAPIView(AlertChannelDefiningMixin, IntegrationHeartBea process_heartbeat_task.apply_async( (alert_receive_channel.pk,), ) + + +class IntegrationBacksyncAPIView(APIView): + authentication_classes = (IntegrationBacksyncAuthentication,) + permission_classes = (IsAuthenticated,) + throttle_classes = (BacksyncRateThrottle,) + + def post(self, request): + alert_receive_channel = request.auth.alert_receive_channel + integration_backsync_func = getattr(alert_receive_channel.config, "integration_backsync", None) + if integration_backsync_func: + integration_backsync_func(alert_receive_channel, request.data) + return Response(status=200)