diff --git a/engine/apps/integrations/mixins/ratelimit_mixin.py b/engine/apps/integrations/mixins/ratelimit_mixin.py index e3e2616e..9d8aa6c9 100644 --- a/engine/apps/integrations/mixins/ratelimit_mixin.py +++ b/engine/apps/integrations/mixins/ratelimit_mixin.py @@ -47,11 +47,11 @@ def get_rate_limit(group, request): if group == RATELIMIT_INTEGRATION_GROUP_NAME: if organization_id in custom_ratelimits: - return custom_ratelimits[organization_id]["integration"] + return custom_ratelimits[organization_id].integration return RATELIMIT_INTEGRATION elif group == RATELIMIT_TEAM_GROUP_NAME: if organization_id in custom_ratelimits: - return custom_ratelimits[organization_id]["organization"] + return custom_ratelimits[organization_id].organization return RATELIMIT_TEAM else: raise Exception("Unknown group") diff --git a/engine/apps/integrations/tests/test_ratelimit.py b/engine/apps/integrations/tests/test_ratelimit.py index 78d4be09..b8b50c53 100644 --- a/engine/apps/integrations/tests/test_ratelimit.py +++ b/engine/apps/integrations/tests/test_ratelimit.py @@ -1,4 +1,3 @@ -import json from unittest import mock import pytest @@ -10,6 +9,7 @@ from rest_framework import status from apps.alerts.models import AlertReceiveChannel from apps.integrations.mixins import IntegrationRateLimitMixin from apps.integrations.mixins.ratelimit_mixin import RATELIMIT_INTEGRATION +from common.api_helpers.custom_ratelimit import load_custom_ratelimits @pytest.fixture(autouse=True) @@ -197,7 +197,7 @@ def test_custom_throttling(make_organization, make_alert_receive_channel): + '": {"integration": "2/m","organization": "3/m","public_api": "1/m"}}' ) - with override_settings(CUSTOM_RATELIMITS=json.loads(CUSTOM_RATELIMITS_STR)): + with override_settings(CUSTOM_RATELIMITS=load_custom_ratelimits(CUSTOM_RATELIMITS_STR)): client = Client() # Organization without custom ratelimit should use default ratelimit diff --git a/engine/apps/public_api/tests/test_ratelimit.py b/engine/apps/public_api/tests/test_ratelimit.py index 8e2b7691..202e9039 100644 --- a/engine/apps/public_api/tests/test_ratelimit.py +++ b/engine/apps/public_api/tests/test_ratelimit.py @@ -1,4 +1,3 @@ -import json from unittest.mock import PropertyMock, patch import pytest @@ -8,6 +7,8 @@ from django.urls import reverse from rest_framework import status from rest_framework.test import APIClient +from common.api_helpers.custom_ratelimit import load_custom_ratelimits + @pytest.mark.django_db def test_throttling(make_organization_and_user_with_token): @@ -45,7 +46,7 @@ def test_custom_throttling(make_organization_and_user_with_token): + '": {"integration": "10/5m","organization": "15/5m","public_api": "1/m"}}' ) - with override_settings(CUSTOM_RATELIMITS=json.loads(CUSTOM_RATELIMITS_STR)): + with override_settings(CUSTOM_RATELIMITS=load_custom_ratelimits(CUSTOM_RATELIMITS_STR)): client = APIClient() url = reverse("api-public:alert_groups-list") diff --git a/engine/common/api_helpers/custom_rate_scoped_throttler.py b/engine/common/api_helpers/custom_rate_scoped_throttler.py index 888d1385..06874969 100644 --- a/engine/common/api_helpers/custom_rate_scoped_throttler.py +++ b/engine/common/api_helpers/custom_rate_scoped_throttler.py @@ -15,7 +15,7 @@ class CustomRateUserThrottler(UserRateThrottle): custom_ratelimits = settings.CUSTOM_RATELIMITS organization_id = str(request.user.organization_id) if organization_id in custom_ratelimits: - self.rate = custom_ratelimits[organization_id]["public_api"] + self.rate = custom_ratelimits[organization_id].public_api self.num_requests, self.duration = self.parse_rate(self.rate) return super().allow_request(request, view) diff --git a/engine/common/api_helpers/custom_ratelimit.py b/engine/common/api_helpers/custom_ratelimit.py index e9aced30..35563074 100644 --- a/engine/common/api_helpers/custom_ratelimit.py +++ b/engine/common/api_helpers/custom_ratelimit.py @@ -1,3 +1,6 @@ +import json +import os +import typing from dataclasses import dataclass @@ -6,3 +9,19 @@ class CustomRateLimit: integration: str organization: str public_api: str + + +def getenv_custom_ratelimit(variable_name: str, default: dict) -> typing.Dict[str, CustomRateLimit]: + custom_ratelimits_str = os.environ.get(variable_name) + if custom_ratelimits_str is None: + return default + value = load_custom_ratelimits(custom_ratelimits_str) + return value + + +def load_custom_ratelimits(custom_ratelimits_str: str) -> typing.Dict[str, CustomRateLimit]: + custom_ratelimits_dict = json.loads(custom_ratelimits_str) + # Convert the parsed JSON into a dictionary of RateLimit dataclasses + custom_ratelimits = {key: CustomRateLimit(**value) for key, value in custom_ratelimits_dict.items()} + + return custom_ratelimits diff --git a/engine/settings/base.py b/engine/settings/base.py index 47ccb77d..9f7f6418 100644 --- a/engine/settings/base.py +++ b/engine/settings/base.py @@ -7,7 +7,7 @@ from random import randrange from celery.schedules import crontab from firebase_admin import credentials, initialize_app -from common.api_helpers.custom_ratelimit import CustomRateLimit +from common.api_helpers.custom_ratelimit import getenv_custom_ratelimit from common.utils import getenv_boolean, getenv_integer, getenv_list VERSION = "dev-oss" @@ -971,10 +971,8 @@ ACKNOWLEDGE_REMINDER_TASK_EXPIRY_DAYS = os.environ.get("ACKNOWLEDGE_REMINDER_TAS # CUSTOM_RATELIMITS={"1": {"integration": "10/5m", "organization": "15/5m", "public_api": "10/5m"}} # Where, "1" is the pk of the organization -# Load the environment variable and parse it into a dictionary, falling back to an empty dictionary if not set. -CUSTOM_RATELIMITS: typing.Dict[str, CustomRateLimit] = json.loads(os.getenv("CUSTOM_RATELIMITS", "{}")) -# Convert the parsed JSON into a dictionary of RateLimit dataclasses -CUSTOM_RATELIMITS = {key: CustomRateLimit(**value) for key, value in CUSTOM_RATELIMITS.items()} +# Load the environment variable and parse it into a dictionary of custom ralimits, falling back to an empty dictionary if not set. +CUSTOM_RATELIMITS = getenv_custom_ratelimit("CUSTOM_RATELIMITS", default={}) SYNC_V2_MAX_TASKS = getenv_integer("SYNC_V2_MAX_TASKS", 6) SYNC_V2_PERIOD_SECONDS = getenv_integer("SYNC_V2_PERIOD_SECONDS", 240)