Use dataclass methods in custom ratelimits and fix tests (#5036)
# What this PR does Follow up PR for https://github.com/grafana/oncall/pull/5004 Tests haven’t caught a bug, so the method and the tests are fixed ## Which issue(s) this PR closes Related to [issue link here] <!-- *Note*: If you want the issue to be auto-closed once the PR is merged, change "Related to" to "Closes" in the line above. If you have more than one GitHub issue that this PR closes, be sure to preface each issue link with a [closing keyword](https://docs.github.com/en/get-started/writing-on-github/working-with-advanced-formatting/using-keywords-in-issues-and-pull-requests#linking-a-pull-request-to-an-issue). This ensures that the issue(s) are auto-closed once the PR has been merged. --> ## Checklist - [ ] Unit, integration, and e2e (if applicable) tests updated - [ ] Documentation added (or `pr:no public docs` PR label added if not required) - [ ] Added the relevant release notes label (see labels prefixed w/ `release:`). These labels dictate how your PR will show up in the autogenerated release notes.
This commit is contained in:
parent
61902d5889
commit
c7a7a3f81a
6 changed files with 30 additions and 12 deletions
|
|
@ -47,11 +47,11 @@ def get_rate_limit(group, request):
|
|||
|
||||
if group == RATELIMIT_INTEGRATION_GROUP_NAME:
|
||||
if organization_id in custom_ratelimits:
|
||||
return custom_ratelimits[organization_id]["integration"]
|
||||
return custom_ratelimits[organization_id].integration
|
||||
return RATELIMIT_INTEGRATION
|
||||
elif group == RATELIMIT_TEAM_GROUP_NAME:
|
||||
if organization_id in custom_ratelimits:
|
||||
return custom_ratelimits[organization_id]["organization"]
|
||||
return custom_ratelimits[organization_id].organization
|
||||
return RATELIMIT_TEAM
|
||||
else:
|
||||
raise Exception("Unknown group")
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
import json
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
|
|
@ -10,6 +9,7 @@ from rest_framework import status
|
|||
from apps.alerts.models import AlertReceiveChannel
|
||||
from apps.integrations.mixins import IntegrationRateLimitMixin
|
||||
from apps.integrations.mixins.ratelimit_mixin import RATELIMIT_INTEGRATION
|
||||
from common.api_helpers.custom_ratelimit import load_custom_ratelimits
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
|
|
@ -197,7 +197,7 @@ def test_custom_throttling(make_organization, make_alert_receive_channel):
|
|||
+ '": {"integration": "2/m","organization": "3/m","public_api": "1/m"}}'
|
||||
)
|
||||
|
||||
with override_settings(CUSTOM_RATELIMITS=json.loads(CUSTOM_RATELIMITS_STR)):
|
||||
with override_settings(CUSTOM_RATELIMITS=load_custom_ratelimits(CUSTOM_RATELIMITS_STR)):
|
||||
client = Client()
|
||||
|
||||
# Organization without custom ratelimit should use default ratelimit
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
import json
|
||||
from unittest.mock import PropertyMock, patch
|
||||
|
||||
import pytest
|
||||
|
|
@ -8,6 +7,8 @@ from django.urls import reverse
|
|||
from rest_framework import status
|
||||
from rest_framework.test import APIClient
|
||||
|
||||
from common.api_helpers.custom_ratelimit import load_custom_ratelimits
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_throttling(make_organization_and_user_with_token):
|
||||
|
|
@ -45,7 +46,7 @@ def test_custom_throttling(make_organization_and_user_with_token):
|
|||
+ '": {"integration": "10/5m","organization": "15/5m","public_api": "1/m"}}'
|
||||
)
|
||||
|
||||
with override_settings(CUSTOM_RATELIMITS=json.loads(CUSTOM_RATELIMITS_STR)):
|
||||
with override_settings(CUSTOM_RATELIMITS=load_custom_ratelimits(CUSTOM_RATELIMITS_STR)):
|
||||
client = APIClient()
|
||||
|
||||
url = reverse("api-public:alert_groups-list")
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ class CustomRateUserThrottler(UserRateThrottle):
|
|||
custom_ratelimits = settings.CUSTOM_RATELIMITS
|
||||
organization_id = str(request.user.organization_id)
|
||||
if organization_id in custom_ratelimits:
|
||||
self.rate = custom_ratelimits[organization_id]["public_api"]
|
||||
self.rate = custom_ratelimits[organization_id].public_api
|
||||
self.num_requests, self.duration = self.parse_rate(self.rate)
|
||||
|
||||
return super().allow_request(request, view)
|
||||
|
|
|
|||
|
|
@ -1,3 +1,6 @@
|
|||
import json
|
||||
import os
|
||||
import typing
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
|
|
@ -6,3 +9,19 @@ class CustomRateLimit:
|
|||
integration: str
|
||||
organization: str
|
||||
public_api: str
|
||||
|
||||
|
||||
def getenv_custom_ratelimit(variable_name: str, default: dict) -> typing.Dict[str, CustomRateLimit]:
|
||||
custom_ratelimits_str = os.environ.get(variable_name)
|
||||
if custom_ratelimits_str is None:
|
||||
return default
|
||||
value = load_custom_ratelimits(custom_ratelimits_str)
|
||||
return value
|
||||
|
||||
|
||||
def load_custom_ratelimits(custom_ratelimits_str: str) -> typing.Dict[str, CustomRateLimit]:
|
||||
custom_ratelimits_dict = json.loads(custom_ratelimits_str)
|
||||
# Convert the parsed JSON into a dictionary of RateLimit dataclasses
|
||||
custom_ratelimits = {key: CustomRateLimit(**value) for key, value in custom_ratelimits_dict.items()}
|
||||
|
||||
return custom_ratelimits
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ from random import randrange
|
|||
from celery.schedules import crontab
|
||||
from firebase_admin import credentials, initialize_app
|
||||
|
||||
from common.api_helpers.custom_ratelimit import CustomRateLimit
|
||||
from common.api_helpers.custom_ratelimit import getenv_custom_ratelimit
|
||||
from common.utils import getenv_boolean, getenv_integer, getenv_list
|
||||
|
||||
VERSION = "dev-oss"
|
||||
|
|
@ -971,10 +971,8 @@ ACKNOWLEDGE_REMINDER_TASK_EXPIRY_DAYS = os.environ.get("ACKNOWLEDGE_REMINDER_TAS
|
|||
# CUSTOM_RATELIMITS={"1": {"integration": "10/5m", "organization": "15/5m", "public_api": "10/5m"}}
|
||||
# Where, "1" is the pk of the organization
|
||||
|
||||
# Load the environment variable and parse it into a dictionary, falling back to an empty dictionary if not set.
|
||||
CUSTOM_RATELIMITS: typing.Dict[str, CustomRateLimit] = json.loads(os.getenv("CUSTOM_RATELIMITS", "{}"))
|
||||
# Convert the parsed JSON into a dictionary of RateLimit dataclasses
|
||||
CUSTOM_RATELIMITS = {key: CustomRateLimit(**value) for key, value in CUSTOM_RATELIMITS.items()}
|
||||
# Load the environment variable and parse it into a dictionary of custom ralimits, falling back to an empty dictionary if not set.
|
||||
CUSTOM_RATELIMITS = getenv_custom_ratelimit("CUSTOM_RATELIMITS", default={})
|
||||
|
||||
SYNC_V2_MAX_TASKS = getenv_integer("SYNC_V2_MAX_TASKS", 6)
|
||||
SYNC_V2_PERIOD_SECONDS = getenv_integer("SYNC_V2_PERIOD_SECONDS", 240)
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue