oncall-engine/engine/apps/integrations/tests/test_ratelimit.py
Vadim Stepanov b7e2dc14f8
Fix ratelimit bug (#4108)
# What this PR does

Fixes a bug in the ratelimit logic when integration-specific ratelimit
429s are still counted towards the organization-wide ratelimit.

## Which issue(s) this PR closes

Related to https://github.com/grafana/support-escalations/issues/9579

## Checklist

- [x] Unit, integration, and e2e (if applicable) tests updated
- [x] Documentation added (or `pr:no public docs` PR label added if not
required)
- [x] Added the relevant release notes label (see labels prefixed w/
`release:`). These labels dictate how your PR will
    show up in the autogenerated release notes.
2024-03-26 17:20:05 +00:00

152 lines
5.2 KiB
Python

from unittest import mock
import pytest
from django.core.cache import cache
from django.test import Client
from django.urls import reverse
from rest_framework import status
from apps.alerts.models import AlertReceiveChannel
from apps.integrations.mixins import IntegrationRateLimitMixin
from apps.integrations.mixins.ratelimit_mixin import RATELIMIT_INTEGRATION
@pytest.fixture(autouse=True)
def clear_cache():
# Ratelimit keys are stored in cache. Clean it before and after every test to make them idempotent.
cache.clear()
@mock.patch("ratelimit.utils._split_rate", return_value=(1, 60))
@mock.patch("apps.integrations.tasks.create_alert.apply_async", return_value=None)
@pytest.mark.django_db
def test_ratelimit_alerts_per_integration(
mocked_task,
mocked_rate,
make_organization,
make_alert_receive_channel,
):
organization = make_organization()
integration = make_alert_receive_channel(organization, integration=AlertReceiveChannel.INTEGRATION_WEBHOOK)
url = reverse(
"integrations:universal",
kwargs={"integration_type": AlertReceiveChannel.INTEGRATION_WEBHOOK, "alert_channel_key": integration.token},
)
c = Client()
response = c.post(url, data={"message": "This is the test alert from amixr"})
assert response.status_code == 200
response = c.post(url, data={"message": "This is the test alert from amixr"})
assert response.status_code == 429
assert mocked_task.call_count == 1
@mock.patch("ratelimit.utils._split_rate", return_value=(1, 60))
@mock.patch("apps.integrations.tasks.create_alert.apply_async", return_value=None)
@pytest.mark.django_db
def test_ratelimit_alerts_per_team(
mocked_task,
mocked_rate,
make_organization,
make_alert_receive_channel,
):
organization = make_organization()
integration_1 = make_alert_receive_channel(organization, integration=AlertReceiveChannel.INTEGRATION_WEBHOOK)
url_1 = reverse(
"integrations:universal",
kwargs={"integration_type": AlertReceiveChannel.INTEGRATION_WEBHOOK, "alert_channel_key": integration_1.token},
)
integration_2 = make_alert_receive_channel(organization, integration=AlertReceiveChannel.INTEGRATION_WEBHOOK)
url_2 = reverse(
"integrations:universal",
kwargs={"integration_type": AlertReceiveChannel.INTEGRATION_WEBHOOK, "alert_channel_key": integration_2.token},
)
c = Client()
response = c.post(url_1, data={"message": "This is the test alert from amixr"})
assert response.status_code == 200
response = c.post(url_2, data={"message": "This is the test alert from amixr"})
assert response.status_code == 429
assert mocked_task.call_count == 1
@mock.patch("ratelimit.utils._split_rate", return_value=(1, 60))
@mock.patch("apps.heartbeat.tasks.process_heartbeat_task.apply_async", return_value=None)
@pytest.mark.django_db
def test_ratelimit_integration_heartbeats(
mocked_task,
mocked_rate,
make_organization,
make_alert_receive_channel,
):
organization = make_organization()
integration = make_alert_receive_channel(organization, integration=AlertReceiveChannel.INTEGRATION_WEBHOOK)
url = reverse("integrations:webhook_heartbeat", kwargs={"alert_channel_key": integration.token})
c = Client()
response = c.post(url)
assert response.status_code == 200
response = c.post(url)
assert response.status_code == 429
response = c.get(url)
assert response.status_code == 429
# mocking rate limits to 1/m per integration and 3/m per organization
@mock.patch("ratelimit.utils._split_rate", new=lambda rate: (1, 60) if rate == RATELIMIT_INTEGRATION else (3, 60))
@pytest.mark.django_db
def test_ratelimit_integration_and_organization(
make_organization,
make_alert_receive_channel,
):
organization = make_organization()
integrations = [
make_alert_receive_channel(organization, integration=AlertReceiveChannel.INTEGRATION_WEBHOOK) for _ in range(4)
]
urls = [
reverse(
"integrations:universal",
kwargs={
"integration_type": AlertReceiveChannel.INTEGRATION_WEBHOOK,
"alert_channel_key": integration.token,
},
)
for integration in integrations
]
client = Client()
response = client.post(urls[0])
assert response.status_code == status.HTTP_200_OK
response = client.post(urls[0])
assert response.status_code == status.HTTP_429_TOO_MANY_REQUESTS
assert response.content.decode() == IntegrationRateLimitMixin.TEXT_INTEGRATION.format(
integration=integrations[0].verbal_name
)
response = client.post(urls[1])
assert response.status_code == status.HTTP_200_OK
response = client.post(urls[1])
assert response.status_code == status.HTTP_429_TOO_MANY_REQUESTS
assert response.content.decode() == IntegrationRateLimitMixin.TEXT_INTEGRATION.format(
integration=integrations[1].verbal_name
)
response = client.post(urls[2])
assert response.status_code == status.HTTP_200_OK
response = client.post(urls[3])
assert response.status_code == status.HTTP_429_TOO_MANY_REQUESTS
assert response.content.decode() == IntegrationRateLimitMixin.TEXT_WORKSPACE