Add custom ratelimits per org (#5004)
# What this PR does This PR refactors Throttling for public API and integrations API and allows to specify organization ratelimits. ## Which issue(s) this PR closes Related to [issue link here] <!-- *Note*: If you want the issue to be auto-closed once the PR is merged, change "Related to" to "Closes" in the line above. If you have more than one GitHub issue that this PR closes, be sure to preface each issue link with a [closing keyword](https://docs.github.com/en/get-started/writing-on-github/working-with-advanced-formatting/using-keywords-in-issues-and-pull-requests#linking-a-pull-request-to-an-issue). This ensures that the issue(s) are auto-closed once the PR has been merged. --> ## Checklist - [x] Unit, integration, and e2e (if applicable) tests updated - [ ] Documentation added (or `pr:no public docs` PR label added if not required) - [x] Added the relevant release notes label (see labels prefixed w/ `release:`). These labels dictate how your PR will show up in the autogenerated release notes.
This commit is contained in:
parent
dd6d2ab161
commit
c718863bd8
15 changed files with 267 additions and 143 deletions
|
|
@ -1,5 +1,5 @@
|
|||
from unittest import mock
|
||||
from unittest.mock import Mock, patch
|
||||
from unittest.mock import Mock, PropertyMock, patch
|
||||
|
||||
import pytest
|
||||
from django.core.cache import cache
|
||||
|
|
@ -1775,17 +1775,10 @@ def test_invalid_working_hours(
|
|||
|
||||
@patch("apps.phone_notifications.phone_backend.PhoneBackend.send_verification_sms", return_value=Mock())
|
||||
@patch("apps.phone_notifications.phone_backend.PhoneBackend.verify_phone_number", return_value=True)
|
||||
@patch(
|
||||
"apps.api.throttlers.GetPhoneVerificationCodeThrottlerPerUser.get_throttle_limits",
|
||||
return_value=(1, 10 * 60),
|
||||
)
|
||||
@patch("apps.api.throttlers.VerifyPhoneNumberThrottlerPerUser.get_throttle_limits", return_value=(1, 10 * 60))
|
||||
@pytest.mark.django_db
|
||||
def test_phone_number_verification_flow_ratelimit_per_user(
|
||||
mock_verification_start,
|
||||
mocked_verification_check,
|
||||
mocked_get_phone_verification_code_get_throttle_limits,
|
||||
mocked_get_phone_verify_phone_number_limits,
|
||||
make_organization_and_user_with_plugin_token,
|
||||
make_user_auth_headers,
|
||||
):
|
||||
|
|
@ -1794,40 +1787,44 @@ def test_phone_number_verification_flow_ratelimit_per_user(
|
|||
client = APIClient()
|
||||
url = reverse("api-internal:user-get-verification-code", kwargs={"pk": user.public_primary_key})
|
||||
|
||||
# first get_verification_code request is succesfull
|
||||
response = client.get(url, format="json", **make_user_auth_headers(user, token))
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
with patch(
|
||||
"apps.api.throttlers.GetPhoneVerificationCodeThrottlerPerUser.rate",
|
||||
new_callable=PropertyMock,
|
||||
) as mocked_rate:
|
||||
mocked_rate.return_value = "1/10m"
|
||||
# first get_verification_code request is succesfull
|
||||
response = client.get(url, format="json", **make_user_auth_headers(user, token))
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
|
||||
# second get_verification_code request is ratelimited
|
||||
response = client.get(url, format="json", **make_user_auth_headers(user, token))
|
||||
assert response.status_code == status.HTTP_429_TOO_MANY_REQUESTS
|
||||
# second get_verification_code request is ratelimited
|
||||
response = client.get(url, format="json", **make_user_auth_headers(user, token))
|
||||
assert response.status_code == status.HTTP_429_TOO_MANY_REQUESTS
|
||||
|
||||
url = reverse("api-internal:user-verify-number", kwargs={"pk": user.public_primary_key})
|
||||
|
||||
# first verify_number request is succesfull, because it uses different ratelimit scope
|
||||
response = client.put(f"{url}?token=12345", format="json", **make_user_auth_headers(user, token))
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
with patch(
|
||||
"apps.api.throttlers.VerifyPhoneNumberThrottlerPerUser.rate",
|
||||
new_callable=PropertyMock,
|
||||
) as mocked_rate:
|
||||
mocked_rate.return_value = "1/10m"
|
||||
|
||||
url = reverse("api-internal:user-verify-number", kwargs={"pk": user.public_primary_key})
|
||||
# first verify_number request is succesfull, because it uses different ratelimit scope
|
||||
response = client.put(f"{url}?token=12345", format="json", **make_user_auth_headers(user, token))
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
|
||||
# second verify_number request is succesfull, because it ratelimited
|
||||
response = client.put(f"{url}?token=12345", format="json", **make_user_auth_headers(user, token))
|
||||
assert response.status_code == status.HTTP_429_TOO_MANY_REQUESTS
|
||||
url = reverse("api-internal:user-verify-number", kwargs={"pk": user.public_primary_key})
|
||||
|
||||
# second verify_number request is succesfull, because it ratelimited
|
||||
response = client.put(f"{url}?token=12345", format="json", **make_user_auth_headers(user, token))
|
||||
assert response.status_code == status.HTTP_429_TOO_MANY_REQUESTS
|
||||
|
||||
|
||||
@patch("apps.phone_notifications.phone_backend.PhoneBackend.send_verification_sms", return_value=Mock())
|
||||
@patch("apps.phone_notifications.phone_backend.PhoneBackend.verify_phone_number", return_value=True)
|
||||
@patch(
|
||||
"apps.api.throttlers.GetPhoneVerificationCodeThrottlerPerOrg.get_throttle_limits",
|
||||
return_value=(1, 10 * 60),
|
||||
)
|
||||
@patch("apps.api.throttlers.VerifyPhoneNumberThrottlerPerOrg.get_throttle_limits", return_value=(1, 10 * 60))
|
||||
@pytest.mark.django_db
|
||||
def test_phone_number_verification_flow_ratelimit_per_org(
|
||||
mock_verification_start,
|
||||
mocked_verification_check,
|
||||
mocked_get_phone_verification_code_get_throttle_limits,
|
||||
mocked_get_phone_verify_phone_number_limits,
|
||||
make_organization_and_user_with_plugin_token,
|
||||
make_user_auth_headers,
|
||||
make_user_for_organization,
|
||||
|
|
@ -1841,21 +1838,33 @@ def test_phone_number_verification_flow_ratelimit_per_org(
|
|||
|
||||
client = APIClient()
|
||||
|
||||
url = reverse("api-internal:user-get-verification-code", kwargs={"pk": user.public_primary_key})
|
||||
response = client.get(url, format="json", **make_user_auth_headers(user, token))
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
with patch(
|
||||
"apps.api.throttlers.GetPhoneVerificationCodeThrottlerPerOrg.rate",
|
||||
new_callable=PropertyMock,
|
||||
) as mocked_rate:
|
||||
mocked_rate.return_value = "1/10m"
|
||||
|
||||
url = reverse("api-internal:user-get-verification-code", kwargs={"pk": second_user.public_primary_key})
|
||||
response = client.get(url, format="json", **make_user_auth_headers(second_user, token))
|
||||
assert response.status_code == status.HTTP_429_TOO_MANY_REQUESTS
|
||||
url = reverse("api-internal:user-get-verification-code", kwargs={"pk": user.public_primary_key})
|
||||
response = client.get(url, format="json", **make_user_auth_headers(user, token))
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
|
||||
url = reverse("api-internal:user-verify-number", kwargs={"pk": user.public_primary_key})
|
||||
response = client.put(f"{url}?token=12345", format="json", **make_user_auth_headers(user, token))
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
url = reverse("api-internal:user-get-verification-code", kwargs={"pk": second_user.public_primary_key})
|
||||
response = client.get(url, format="json", **make_user_auth_headers(second_user, token))
|
||||
assert response.status_code == status.HTTP_429_TOO_MANY_REQUESTS
|
||||
|
||||
url = reverse("api-internal:user-verify-number", kwargs={"pk": second_user.public_primary_key})
|
||||
response = client.put(f"{url}?token=12345", format="json", **make_user_auth_headers(second_user, token))
|
||||
assert response.status_code == status.HTTP_429_TOO_MANY_REQUESTS
|
||||
with patch(
|
||||
"apps.api.throttlers.VerifyPhoneNumberThrottlerPerOrg.rate",
|
||||
new_callable=PropertyMock,
|
||||
) as mocked_rate:
|
||||
mocked_rate.return_value = "1/10m"
|
||||
|
||||
url = reverse("api-internal:user-verify-number", kwargs={"pk": user.public_primary_key})
|
||||
response = client.put(f"{url}?token=12345", format="json", **make_user_auth_headers(user, token))
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
|
||||
url = reverse("api-internal:user-verify-number", kwargs={"pk": second_user.public_primary_key})
|
||||
response = client.put(f"{url}?token=12345", format="json", **make_user_auth_headers(second_user, token))
|
||||
assert response.status_code == status.HTTP_429_TOO_MANY_REQUESTS
|
||||
|
||||
|
||||
@patch("apps.phone_notifications.phone_backend.PhoneBackend.send_verification_sms", return_value=Mock())
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
from rest_framework.throttling import UserRateThrottle
|
||||
from common.api_helpers.custom_rate_scoped_throttler import CustomRateUserThrottler
|
||||
|
||||
|
||||
class DemoAlertThrottler(UserRateThrottle):
|
||||
class DemoAlertThrottler(CustomRateUserThrottler):
|
||||
scope = "send_demo_alert"
|
||||
rate = "30/m"
|
||||
|
|
|
|||
|
|
@ -1,49 +1,21 @@
|
|||
from common.api_helpers.custom_rate_scoped_throttler import CustomRateScopedThrottler
|
||||
from common.api_helpers.custom_rate_scoped_throttler import CustomRateOrganizationThrottler, CustomRateUserThrottler
|
||||
|
||||
|
||||
class GetPhoneVerificationCodeThrottlerPerUser(CustomRateScopedThrottler):
|
||||
def get_scope(self):
|
||||
return "get_phone_verification_code_per_user"
|
||||
|
||||
def get_throttle_limits(self):
|
||||
return 5, 10 * 60
|
||||
class GetPhoneVerificationCodeThrottlerPerUser(CustomRateUserThrottler):
|
||||
rate = "5/10m"
|
||||
scope = "get_phone_verification_code_per_user"
|
||||
|
||||
|
||||
class VerifyPhoneNumberThrottlerPerUser(CustomRateScopedThrottler):
|
||||
def get_scope(self):
|
||||
return "verify_phone_number_per_user"
|
||||
|
||||
def get_throttle_limits(self):
|
||||
return 50, 10 * 60
|
||||
class VerifyPhoneNumberThrottlerPerUser(CustomRateUserThrottler):
|
||||
rate = "50/10m"
|
||||
scope = "verify_phone_number_per_user"
|
||||
|
||||
|
||||
class GetPhoneVerificationCodeThrottlerPerOrg(CustomRateScopedThrottler):
|
||||
def get_scope(self):
|
||||
return "get_phone_verification_code_per_org"
|
||||
|
||||
def get_throttle_limits(self):
|
||||
return 50, 10 * 60
|
||||
|
||||
def get_cache_key(self, request, view):
|
||||
if request.user.is_authenticated:
|
||||
ident = request.user.organization.pk
|
||||
else:
|
||||
ident = self.get_ident(request)
|
||||
|
||||
return self.cache_format % {"scope": self.scope, "ident": ident}
|
||||
class GetPhoneVerificationCodeThrottlerPerOrg(CustomRateOrganizationThrottler):
|
||||
rate = "50/10m"
|
||||
scope = "get_phone_verification_code_per_org"
|
||||
|
||||
|
||||
class VerifyPhoneNumberThrottlerPerOrg(CustomRateScopedThrottler):
|
||||
def get_scope(self):
|
||||
return "verify_phone_number_per_org"
|
||||
|
||||
def get_throttle_limits(self):
|
||||
return 50, 10 * 60
|
||||
|
||||
def get_cache_key(self, request, view):
|
||||
if request.user.is_authenticated:
|
||||
ident = request.user.organization.pk
|
||||
else:
|
||||
ident = self.get_ident(request)
|
||||
|
||||
return self.cache_format % {"scope": self.scope, "ident": ident}
|
||||
class VerifyPhoneNumberThrottlerPerOrg(CustomRateOrganizationThrottler):
|
||||
rate = "50/10m"
|
||||
scope = "verify_phone_number_per_org"
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from rest_framework.throttling import UserRateThrottle
|
||||
from common.api_helpers.custom_rate_scoped_throttler import CustomRateUserThrottler
|
||||
|
||||
|
||||
class TestCallThrottler(UserRateThrottle):
|
||||
class TestCallThrottler(CustomRateUserThrottler):
|
||||
"""
|
||||
set a __test__ = False attribute in classes that pytest should ignore otherwise we end up getting the following:
|
||||
PytestCollectionWarning: cannot collect test class 'TestCallThrottler' because it has a __init__ constructor
|
||||
|
|
@ -13,7 +13,7 @@ class TestCallThrottler(UserRateThrottle):
|
|||
rate = "5/m"
|
||||
|
||||
|
||||
class TestPushThrottler(UserRateThrottle):
|
||||
class TestPushThrottler(CustomRateUserThrottler):
|
||||
"""
|
||||
set a __test__ = False attribute in classes that pytest should ignore otherwise we end up getting the following:
|
||||
PytestCollectionWarning: cannot collect test class 'TestPushThrottler' because it has a __init__ constructor
|
||||
|
|
|
|||
|
|
@ -3,5 +3,6 @@ from .browsable_instruction_mixin import BrowsableInstructionMixin # noqa: F401
|
|||
from .ratelimit_mixin import ( # noqa: F401
|
||||
IntegrationHeartBeatRateLimitMixin,
|
||||
IntegrationRateLimitMixin,
|
||||
RateLimitMixin,
|
||||
is_ratelimit_ignored,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ import logging
|
|||
from abc import ABC, abstractmethod
|
||||
from functools import wraps
|
||||
|
||||
from django.conf import settings
|
||||
from django.core.cache import cache
|
||||
from django.http import HttpRequest, HttpResponse
|
||||
from django.views import View
|
||||
|
|
@ -16,6 +17,8 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
RATELIMIT_INTEGRATION = "300/5m"
|
||||
RATELIMIT_TEAM = "900/5m"
|
||||
RATELIMIT_INTEGRATION_GROUP_NAME = "integration"
|
||||
RATELIMIT_TEAM_GROUP_NAME = "team"
|
||||
RATELIMIT_REASON_INTEGRATION = "channel"
|
||||
RATELIMIT_REASON_TEAM = "team"
|
||||
INTEGRATION_TOKEN_TO_IGNORE_KEY = "integration_tokens_to_ignore_ratelimit"
|
||||
|
|
@ -30,13 +33,30 @@ def get_rate_limit_per_channel_key(_, request):
|
|||
return str(request.alert_receive_channel.pk)
|
||||
|
||||
|
||||
def get_rate_limit_per_team_key(_, request):
|
||||
def get_rate_limit_per_organization_key(_, request):
|
||||
"""
|
||||
Rate limiting based on AlertReceiveChannel's team PK
|
||||
"""
|
||||
return str(request.alert_receive_channel.organization_id)
|
||||
|
||||
|
||||
def get_rate_limit(group, request):
|
||||
custom_ratelimits = settings.CUSTOM_RATELIMITS
|
||||
|
||||
organization_id = str(request.alert_receive_channel.organization_id)
|
||||
|
||||
if group == RATELIMIT_INTEGRATION_GROUP_NAME:
|
||||
if organization_id in custom_ratelimits:
|
||||
return custom_ratelimits[organization_id]["integration"]
|
||||
return RATELIMIT_INTEGRATION
|
||||
elif group == RATELIMIT_TEAM_GROUP_NAME:
|
||||
if organization_id in custom_ratelimits:
|
||||
return custom_ratelimits[organization_id]["organization"]
|
||||
return RATELIMIT_TEAM
|
||||
else:
|
||||
raise Exception("Unknown group")
|
||||
|
||||
|
||||
def ratelimit(group=None, key=None, rate=None, method=ALL, block=False, reason=None):
|
||||
"""
|
||||
This decorator is an updated version of:
|
||||
|
|
@ -171,7 +191,11 @@ class IntegrationHeartBeatRateLimitMixin(RateLimitMixin, View):
|
|||
block=True, # use block=True so integration rate limit 429s are not counted towards the team rate limit
|
||||
)
|
||||
@ratelimit(
|
||||
key=get_rate_limit_per_team_key, rate=RATELIMIT_TEAM, group="team", reason=RATELIMIT_REASON_TEAM, block=True
|
||||
key=get_rate_limit_per_organization_key,
|
||||
rate=RATELIMIT_TEAM,
|
||||
group="team",
|
||||
reason=RATELIMIT_REASON_TEAM,
|
||||
block=True,
|
||||
)
|
||||
def execute_rate_limit(self, *args, **kwargs):
|
||||
pass
|
||||
|
|
@ -201,13 +225,17 @@ class IntegrationRateLimitMixin(RateLimitMixin, View):
|
|||
|
||||
@ratelimit(
|
||||
key=get_rate_limit_per_channel_key,
|
||||
rate=RATELIMIT_INTEGRATION,
|
||||
group="integration",
|
||||
rate=get_rate_limit,
|
||||
group=RATELIMIT_INTEGRATION_GROUP_NAME,
|
||||
reason=RATELIMIT_REASON_INTEGRATION,
|
||||
block=True, # use block=True so integration rate limit 429s are not counted towards the team rate limit
|
||||
)
|
||||
@ratelimit(
|
||||
key=get_rate_limit_per_team_key, rate=RATELIMIT_TEAM, group="team", reason=RATELIMIT_REASON_TEAM, block=True
|
||||
key=get_rate_limit_per_organization_key,
|
||||
rate=get_rate_limit,
|
||||
group=RATELIMIT_TEAM_GROUP_NAME,
|
||||
reason=RATELIMIT_REASON_TEAM,
|
||||
block=True,
|
||||
)
|
||||
def execute_rate_limit(self, *args, **kwargs):
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -1,8 +1,9 @@
|
|||
import json
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
from django.core.cache import cache
|
||||
from django.test import Client
|
||||
from django.test import Client, override_settings
|
||||
from django.urls import reverse
|
||||
from rest_framework import status
|
||||
|
||||
|
|
@ -35,9 +36,9 @@ def test_ratelimit_alerts_per_integration(
|
|||
|
||||
c = Client()
|
||||
|
||||
response = c.post(url, data={"message": "This is the test alert from amixr"})
|
||||
response = c.post(url, data={"message": "This is the test alert"})
|
||||
assert response.status_code == 200
|
||||
response = c.post(url, data={"message": "This is the test alert from amixr"})
|
||||
response = c.post(url, data={"message": "This is the test alert"})
|
||||
assert response.status_code == 429
|
||||
|
||||
assert mocked_task.call_count == 1
|
||||
|
|
@ -150,3 +151,82 @@ def test_ratelimit_integration_and_organization(
|
|||
response = client.post(urls[3])
|
||||
assert response.status_code == status.HTTP_429_TOO_MANY_REQUESTS
|
||||
assert response.content.decode() == IntegrationRateLimitMixin.TEXT_WORKSPACE
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_custom_throttling(make_organization, make_alert_receive_channel):
|
||||
organization_with_custom_ratelimit = make_organization()
|
||||
integration_with_custom_ratelimit = make_alert_receive_channel(
|
||||
organization_with_custom_ratelimit, integration=AlertReceiveChannel.INTEGRATION_WEBHOOK
|
||||
)
|
||||
url_with_custom_ratelimit = reverse(
|
||||
"integrations:universal",
|
||||
kwargs={
|
||||
"integration_type": AlertReceiveChannel.INTEGRATION_WEBHOOK,
|
||||
"alert_channel_key": integration_with_custom_ratelimit.token,
|
||||
},
|
||||
)
|
||||
|
||||
integration_with_custom_ratelimit_2 = make_alert_receive_channel(
|
||||
organization_with_custom_ratelimit, integration=AlertReceiveChannel.INTEGRATION_WEBHOOK
|
||||
)
|
||||
url_with_custom_ratelimit_2 = reverse(
|
||||
"integrations:universal",
|
||||
kwargs={
|
||||
"integration_type": AlertReceiveChannel.INTEGRATION_WEBHOOK,
|
||||
"alert_channel_key": integration_with_custom_ratelimit_2.token,
|
||||
},
|
||||
)
|
||||
|
||||
organization_with_default_ratelimit = make_organization()
|
||||
integration_with_default_ratelimit = make_alert_receive_channel(
|
||||
organization_with_default_ratelimit, integration=AlertReceiveChannel.INTEGRATION_WEBHOOK
|
||||
)
|
||||
url_with_default_ratelimit = reverse(
|
||||
"integrations:universal",
|
||||
kwargs={
|
||||
"integration_type": AlertReceiveChannel.INTEGRATION_WEBHOOK,
|
||||
"alert_channel_key": integration_with_default_ratelimit.token,
|
||||
},
|
||||
)
|
||||
cache.clear()
|
||||
|
||||
CUSTOM_RATELIMITS_STR = (
|
||||
'{"'
|
||||
+ str(organization_with_custom_ratelimit.pk)
|
||||
+ '": {"integration": "2/m","organization": "3/m","public_api": "1/m"}}'
|
||||
)
|
||||
|
||||
with override_settings(CUSTOM_RATELIMITS=json.loads(CUSTOM_RATELIMITS_STR)):
|
||||
client = Client()
|
||||
|
||||
# Organization without custom ratelimit should use default ratelimit
|
||||
for _ in range(5):
|
||||
response = client.post(url_with_default_ratelimit)
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
|
||||
# Organization with custom ratelimit will be ratelimited after 2 requests because of integration rate limit
|
||||
response = client.post(url_with_custom_ratelimit)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
|
||||
response = client.post(url_with_custom_ratelimit)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
|
||||
response = client.post(url_with_custom_ratelimit)
|
||||
|
||||
assert response.status_code == status.HTTP_429_TOO_MANY_REQUESTS
|
||||
assert response.content.decode() == IntegrationRateLimitMixin.TEXT_INTEGRATION.format(
|
||||
integration=integration_with_custom_ratelimit.verbal_name
|
||||
)
|
||||
|
||||
# Organization with custom ratelimit will be ratelimited after 3 requests because of organization rate limit
|
||||
response = client.post(url_with_custom_ratelimit_2)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
|
||||
response = client.post(url_with_custom_ratelimit_2)
|
||||
|
||||
assert response.status_code == status.HTTP_429_TOO_MANY_REQUESTS
|
||||
assert response.content.decode() == IntegrationRateLimitMixin.TEXT_WORKSPACE
|
||||
|
|
|
|||
|
|
@ -6,19 +6,19 @@ from firebase_admin.messaging import AndroidConfig, APNSConfig, APNSPayload, Aps
|
|||
from rest_framework import status
|
||||
from rest_framework.permissions import IsAuthenticated
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.throttling import UserRateThrottle
|
||||
from rest_framework.views import APIView
|
||||
|
||||
from apps.auth_token.auth import ApiTokenAuthentication
|
||||
from apps.mobile_app.models import FCMDevice
|
||||
from apps.mobile_app.utils import send_message_to_fcm_device
|
||||
from common.api_helpers.custom_rate_scoped_throttler import CustomRateUserThrottler
|
||||
from common.custom_celery_tasks import shared_dedicated_queue_retry_task
|
||||
|
||||
task_logger = get_task_logger(__name__)
|
||||
task_logger.setLevel(logging.DEBUG)
|
||||
|
||||
|
||||
class FCMRelayThrottler(UserRateThrottle):
|
||||
class FCMRelayThrottler(CustomRateUserThrottler):
|
||||
scope = "fcm_relay"
|
||||
rate = "300/m"
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,9 @@
|
|||
import json
|
||||
from unittest.mock import PropertyMock, patch
|
||||
|
||||
import pytest
|
||||
from django.core.cache import cache
|
||||
from django.test import override_settings
|
||||
from django.urls import reverse
|
||||
from rest_framework import status
|
||||
from rest_framework.test import APIClient
|
||||
|
|
@ -29,3 +31,38 @@ def test_throttling(make_organization_and_user_with_token):
|
|||
|
||||
# make sure RateLimitHeadersMixin used
|
||||
assert response.has_header("RateLimit-Reset")
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_custom_throttling(make_organization_and_user_with_token):
|
||||
organization_with_custom_ratelimit, _, token_with_custom_ratelimit = make_organization_and_user_with_token()
|
||||
_, _, token_with_default_ratelimit = make_organization_and_user_with_token()
|
||||
cache.clear()
|
||||
|
||||
CUSTOM_RATELIMITS_STR = (
|
||||
'{"'
|
||||
+ str(organization_with_custom_ratelimit.pk)
|
||||
+ '": {"integration": "10/5m","organization": "15/5m","public_api": "1/m"}}'
|
||||
)
|
||||
|
||||
with override_settings(CUSTOM_RATELIMITS=json.loads(CUSTOM_RATELIMITS_STR)):
|
||||
client = APIClient()
|
||||
|
||||
url = reverse("api-public:alert_groups-list")
|
||||
|
||||
# Organization without custom ratelimit should use default ratelimit
|
||||
for _ in range(5):
|
||||
response = client.get(url, format="json", HTTP_AUTHORIZATION=f"{token_with_default_ratelimit}")
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
|
||||
# Organization with custom ratelimit will be ratelimited after 1 request
|
||||
response = client.get(url, format="json", HTTP_AUTHORIZATION=f"{token_with_custom_ratelimit}")
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
|
||||
response = client.get(url, format="json", HTTP_AUTHORIZATION=f"{token_with_custom_ratelimit}")
|
||||
|
||||
assert response.status_code == status.HTTP_429_TOO_MANY_REQUESTS
|
||||
|
||||
# make sure RateLimitHeadersMixin used
|
||||
assert response.has_header("RateLimit-Reset")
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
from rest_framework.throttling import UserRateThrottle
|
||||
from common.api_helpers.custom_rate_scoped_throttler import CustomRateUserThrottler
|
||||
|
||||
|
||||
class InfoThrottler(UserRateThrottle):
|
||||
class InfoThrottler(CustomRateUserThrottler):
|
||||
scope = "info"
|
||||
rate = "100/m"
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
from rest_framework.throttling import UserRateThrottle
|
||||
from common.api_helpers.custom_rate_scoped_throttler import CustomRateUserThrottler
|
||||
|
||||
|
||||
class PhoneNotificationThrottler(UserRateThrottle):
|
||||
class PhoneNotificationThrottler(CustomRateUserThrottler):
|
||||
scope = "phone_notification"
|
||||
rate = "60/m"
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
from rest_framework.throttling import UserRateThrottle
|
||||
from common.api_helpers.custom_rate_scoped_throttler import CustomRateUserThrottler
|
||||
|
||||
|
||||
class UserThrottle(UserRateThrottle):
|
||||
class UserThrottle(CustomRateUserThrottler):
|
||||
scope = "public_api"
|
||||
rate = "300/m"
|
||||
|
|
|
|||
|
|
@ -1,55 +1,32 @@
|
|||
from rest_framework.throttling import SimpleRateThrottle
|
||||
from django.conf import settings
|
||||
from ratelimit.utils import _split_rate
|
||||
from rest_framework.throttling import UserRateThrottle
|
||||
|
||||
|
||||
class CustomRateScopedThrottler(SimpleRateThrottle):
|
||||
"""
|
||||
Abstract class to create throttlers with custom amount of seconds and custom scope.
|
||||
The unique cache key will be generated by concatenating the
|
||||
user id of the request, and the scope from get_scope() method.
|
||||
class CustomRateUserThrottler(UserRateThrottle):
|
||||
""" """
|
||||
|
||||
Should not be used directly.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.scope = self.get_scope()
|
||||
self.num_requests, self.duration = self.get_throttle_limits()
|
||||
|
||||
def get_throttle_limits(self):
|
||||
"""
|
||||
:return tuple requests/seconds
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_scope(self):
|
||||
"""
|
||||
:return ratelimit scope
|
||||
"""
|
||||
raise NotImplementedError
|
||||
def parse_rate(self, rate):
|
||||
"Use django ratelimit format to parse rate, i.e. '30/1m', instead of '30/m'"
|
||||
return _split_rate(rate)
|
||||
|
||||
def allow_request(self, request, view):
|
||||
"""
|
||||
Overriden allow_request method.
|
||||
The difference is that overriden method doesn't check rate property.
|
||||
"""
|
||||
# Override default rate limit, if organization id is specified in CUSTOM_RATELIMITS
|
||||
custom_ratelimits = settings.CUSTOM_RATELIMITS
|
||||
organization_id = str(request.user.organization_id)
|
||||
if organization_id in custom_ratelimits:
|
||||
self.rate = custom_ratelimits[organization_id]["public_api"]
|
||||
self.num_requests, self.duration = self.parse_rate(self.rate)
|
||||
|
||||
self.key = self.get_cache_key(request, view)
|
||||
if self.key is None:
|
||||
return True
|
||||
return super().allow_request(request, view)
|
||||
|
||||
self.history = self.cache.get(self.key, [])
|
||||
self.now = self.timer()
|
||||
|
||||
# Drop any requests from the history which have now passed the
|
||||
# throttle duration
|
||||
while self.history and self.history[-1] <= self.now - self.duration:
|
||||
self.history.pop()
|
||||
if len(self.history) >= self.num_requests:
|
||||
return self.throttle_failure()
|
||||
return self.throttle_success()
|
||||
class CustomRateOrganizationThrottler(CustomRateUserThrottler):
|
||||
scope = "organization"
|
||||
|
||||
def get_cache_key(self, request, view):
|
||||
if request.user.is_authenticated:
|
||||
ident = request.user.pk
|
||||
ident = request.user.organization.pk
|
||||
else:
|
||||
ident = self.get_ident(request)
|
||||
|
||||
|
|
|
|||
8
engine/common/api_helpers/custom_ratelimit.py
Normal file
8
engine/common/api_helpers/custom_ratelimit.py
Normal file
|
|
@ -0,0 +1,8 @@
|
|||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class CustomRateLimit:
|
||||
integration: str
|
||||
organization: str
|
||||
public_api: str
|
||||
|
|
@ -7,6 +7,7 @@ from random import randrange
|
|||
from celery.schedules import crontab
|
||||
from firebase_admin import credentials, initialize_app
|
||||
|
||||
from common.api_helpers.custom_ratelimit import CustomRateLimit
|
||||
from common.utils import getenv_boolean, getenv_integer, getenv_list
|
||||
|
||||
VERSION = "dev-oss"
|
||||
|
|
@ -964,6 +965,17 @@ DETACHED_INTEGRATIONS_SERVER = getenv_boolean("DETACHED_INTEGRATIONS_SERVER", de
|
|||
|
||||
ACKNOWLEDGE_REMINDER_TASK_EXPIRY_DAYS = os.environ.get("ACKNOWLEDGE_REMINDER_TASK_EXPIRY_DAYS", default=14)
|
||||
|
||||
# The CUSTOM_RATELIMITS environment variable is expected to be a JSON string that defines rate limits
|
||||
# for different levels (e.g., integration, organization, public API).
|
||||
# Example of CUSTOM_RATELIMITS in environment variable:
|
||||
# CUSTOM_RATELIMITS={"1": {"integration": "10/5m", "organization": "15/5m", "public_api": "10/5m"}}
|
||||
# Where, "1" is the pk of the organization
|
||||
|
||||
# Load the environment variable and parse it into a dictionary, falling back to an empty dictionary if not set.
|
||||
CUSTOM_RATELIMITS: typing.Dict[str, CustomRateLimit] = json.loads(os.getenv("CUSTOM_RATELIMITS", "{}"))
|
||||
# Convert the parsed JSON into a dictionary of RateLimit dataclasses
|
||||
CUSTOM_RATELIMITS = {key: CustomRateLimit(**value) for key, value in CUSTOM_RATELIMITS.items()}
|
||||
|
||||
SYNC_V2_MAX_TASKS = getenv_integer("SYNC_V2_MAX_TASKS", 6)
|
||||
SYNC_V2_PERIOD_SECONDS = getenv_integer("SYNC_V2_PERIOD_SECONDS", 240)
|
||||
SYNC_V2_BATCH_SIZE = getenv_integer("SYNC_V2_BATCH_SIZE", 500)
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue