Add ratelimit for phone number verification (#1354)
# What this PR does ## Which issue(s) this PR fixes ## Checklist - [x] Tests updated - [x] `CHANGELOG.md` updated --------- Co-authored-by: Joey Orlando <joey.orlando@grafana.com>
This commit is contained in:
parent
adac88f1c0
commit
61fdcfdc72
9 changed files with 250 additions and 59 deletions
|
|
@ -5,6 +5,12 @@ All notable changes to this project will be documented in this file.
|
|||
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
|
||||
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
||||
|
||||
## Unreleased
|
||||
|
||||
### Changed
|
||||
|
||||
- Add ratelimits for phone verification
|
||||
|
||||
## v1.1.26 (2023-02-20)
|
||||
|
||||
### Fixed
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from django.core.cache import cache
|
||||
from django.core.exceptions import ObjectDoesNotExist
|
||||
from django.urls import reverse
|
||||
from django.utils import timezone
|
||||
|
|
@ -13,6 +14,12 @@ from apps.base.models import UserNotificationPolicy
|
|||
from apps.user_management.models.user import default_working_hours
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def clear_cache():
|
||||
# Ratelimit keys are stored in cache, clean to prevent ratelimits
|
||||
cache.clear()
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_update_user(
|
||||
make_organization,
|
||||
|
|
@ -653,7 +660,6 @@ def test_admin_can_verify_own_phone(
|
|||
make_user_auth_headers,
|
||||
):
|
||||
_, user, token = make_organization_and_user_with_plugin_token(role=LegacyAccessControlRole.ADMIN)
|
||||
|
||||
client = APIClient()
|
||||
url = reverse("api-internal:user-verify-number", kwargs={"pk": user.public_primary_key})
|
||||
|
||||
|
|
@ -1499,3 +1505,88 @@ def test_check_availability_other_user(make_organization_and_user_with_plugin_to
|
|||
response = client.get(url, **make_user_auth_headers(user, token))
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
|
||||
|
||||
@patch("apps.twilioapp.phone_manager.PhoneManager.send_verification_code", return_value=Mock())
|
||||
@patch("apps.twilioapp.phone_manager.PhoneManager.verify_phone_number", return_value=(True, None))
|
||||
@patch(
|
||||
"apps.api.throttlers.GetPhoneVerificationCodeThrottlerPerUser.get_throttle_limits",
|
||||
return_value=(1, 10 * 60),
|
||||
)
|
||||
@patch("apps.api.throttlers.VerifyPhoneNumberThrottlerPerUser.get_throttle_limits", return_value=(1, 10 * 60))
|
||||
@pytest.mark.django_db
|
||||
def test_phone_number_verification_flow_ratelimit_per_user(
|
||||
mock_verification_start,
|
||||
mocked_verification_check,
|
||||
mocked_get_phone_verification_code_get_throttle_limits,
|
||||
mocked_get_phone_verify_phone_number_limits,
|
||||
make_organization_and_user_with_plugin_token,
|
||||
make_user_auth_headers,
|
||||
):
|
||||
_, user, token = make_organization_and_user_with_plugin_token()
|
||||
|
||||
client = APIClient()
|
||||
url = reverse("api-internal:user-get-verification-code", kwargs={"pk": user.public_primary_key})
|
||||
|
||||
# first get_verification_code request is succesfull
|
||||
response = client.get(url, format="json", **make_user_auth_headers(user, token))
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
|
||||
# second get_verification_code request is ratelimited
|
||||
response = client.get(url, format="json", **make_user_auth_headers(user, token))
|
||||
assert response.status_code == status.HTTP_429_TOO_MANY_REQUESTS
|
||||
|
||||
url = reverse("api-internal:user-verify-number", kwargs={"pk": user.public_primary_key})
|
||||
|
||||
# first verify_number request is succesfull, because it uses different ratelimit scope
|
||||
response = client.put(f"{url}?token=12345", format="json", **make_user_auth_headers(user, token))
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
|
||||
url = reverse("api-internal:user-verify-number", kwargs={"pk": user.public_primary_key})
|
||||
|
||||
# second verify_number request is succesfull, because it ratelimited
|
||||
response = client.put(f"{url}?token=12345", format="json", **make_user_auth_headers(user, token))
|
||||
assert response.status_code == status.HTTP_429_TOO_MANY_REQUESTS
|
||||
|
||||
|
||||
@patch("apps.twilioapp.phone_manager.PhoneManager.send_verification_code", return_value=Mock())
|
||||
@patch("apps.twilioapp.phone_manager.PhoneManager.verify_phone_number", return_value=(True, None))
|
||||
@patch(
|
||||
"apps.api.throttlers.GetPhoneVerificationCodeThrottlerPerOrg.get_throttle_limits",
|
||||
return_value=(1, 10 * 60),
|
||||
)
|
||||
@patch("apps.api.throttlers.VerifyPhoneNumberThrottlerPerOrg.get_throttle_limits", return_value=(1, 10 * 60))
|
||||
@pytest.mark.django_db
|
||||
def test_phone_number_verification_flow_ratelimit_per_org(
|
||||
mock_verification_start,
|
||||
mocked_verification_check,
|
||||
mocked_get_phone_verification_code_get_throttle_limits,
|
||||
mocked_get_phone_verify_phone_number_limits,
|
||||
make_organization_and_user_with_plugin_token,
|
||||
make_user_auth_headers,
|
||||
make_user_for_organization,
|
||||
):
|
||||
"""
|
||||
This test is checks per-org ratelimits for phone verification flow.
|
||||
It makes two get_verification_code and two verify_number requests from different users and expect that second call will be ratelimited.
|
||||
"""
|
||||
org, user, token = make_organization_and_user_with_plugin_token()
|
||||
second_user = make_user_for_organization(org)
|
||||
|
||||
client = APIClient()
|
||||
|
||||
url = reverse("api-internal:user-get-verification-code", kwargs={"pk": user.public_primary_key})
|
||||
response = client.get(url, format="json", **make_user_auth_headers(user, token))
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
|
||||
url = reverse("api-internal:user-get-verification-code", kwargs={"pk": second_user.public_primary_key})
|
||||
response = client.get(url, format="json", **make_user_auth_headers(second_user, token))
|
||||
assert response.status_code == status.HTTP_429_TOO_MANY_REQUESTS
|
||||
|
||||
url = reverse("api-internal:user-verify-number", kwargs={"pk": user.public_primary_key})
|
||||
response = client.put(f"{url}?token=12345", format="json", **make_user_auth_headers(user, token))
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
|
||||
url = reverse("api-internal:user-verify-number", kwargs={"pk": second_user.public_primary_key})
|
||||
response = client.put(f"{url}?token=12345", format="json", **make_user_auth_headers(second_user, token))
|
||||
assert response.status_code == status.HTTP_429_TOO_MANY_REQUESTS
|
||||
|
|
|
|||
|
|
@ -1 +1,8 @@
|
|||
from .demo_alert_throttler import DemoAlertThrottler # noqa: F401
|
||||
from .phone_verification_throttler import ( # noqa: F401
|
||||
GetPhoneVerificationCodeThrottlerPerOrg,
|
||||
GetPhoneVerificationCodeThrottlerPerUser,
|
||||
VerifyPhoneNumberThrottlerPerOrg,
|
||||
VerifyPhoneNumberThrottlerPerUser,
|
||||
)
|
||||
from .test_call_throttler import TestCallThrottler # noqa: F401
|
||||
|
|
|
|||
49
engine/apps/api/throttlers/phone_verification_throttler.py
Normal file
49
engine/apps/api/throttlers/phone_verification_throttler.py
Normal file
|
|
@ -0,0 +1,49 @@
|
|||
from common.api_helpers.custom_rate_scoped_throttler import CustomRateScopedThrottler
|
||||
|
||||
|
||||
class GetPhoneVerificationCodeThrottlerPerUser(CustomRateScopedThrottler):
|
||||
def get_scope(self):
|
||||
return "get_phone_verification_code_per_user"
|
||||
|
||||
def get_throttle_limits(self):
|
||||
return 5, 10 * 60
|
||||
|
||||
|
||||
class VerifyPhoneNumberThrottlerPerUser(CustomRateScopedThrottler):
|
||||
def get_scope(self):
|
||||
return "verify_phone_number_per_user"
|
||||
|
||||
def get_throttle_limits(self):
|
||||
return 50, 10 * 60
|
||||
|
||||
|
||||
class GetPhoneVerificationCodeThrottlerPerOrg(CustomRateScopedThrottler):
|
||||
def get_scope(self):
|
||||
return "get_phone_verification_code_per_org"
|
||||
|
||||
def get_throttle_limits(self):
|
||||
return 50, 10 * 60
|
||||
|
||||
def get_cache_key(self, request, view):
|
||||
if request.user.is_authenticated:
|
||||
ident = request.user.organization.pk
|
||||
else:
|
||||
ident = self.get_ident(request)
|
||||
|
||||
return self.cache_format % {"scope": self.scope, "ident": ident}
|
||||
|
||||
|
||||
class VerifyPhoneNumberThrottlerPerOrg(CustomRateScopedThrottler):
|
||||
def get_scope(self):
|
||||
return "verify_phone_number_per_org"
|
||||
|
||||
def get_throttle_limits(self):
|
||||
return 50, 10 * 60
|
||||
|
||||
def get_cache_key(self, request, view):
|
||||
if request.user.is_authenticated:
|
||||
ident = request.user.organization.pk
|
||||
else:
|
||||
ident = self.get_ident(request)
|
||||
|
||||
return self.cache_format % {"scope": self.scope, "ident": ident}
|
||||
6
engine/apps/api/throttlers/test_call_throttler.py
Normal file
6
engine/apps/api/throttlers/test_call_throttler.py
Normal file
|
|
@ -0,0 +1,6 @@
|
|||
from rest_framework.throttling import UserRateThrottle
|
||||
|
||||
|
||||
class TestCallThrottler(UserRateThrottle):
|
||||
scope = "make_test_call"
|
||||
rate = "5/m"
|
||||
|
|
@ -24,6 +24,13 @@ from apps.api.permissions import (
|
|||
)
|
||||
from apps.api.serializers.team import TeamSerializer
|
||||
from apps.api.serializers.user import FilterUserSerializer, UserHiddenFieldsSerializer, UserSerializer
|
||||
from apps.api.throttlers import (
|
||||
GetPhoneVerificationCodeThrottlerPerOrg,
|
||||
GetPhoneVerificationCodeThrottlerPerUser,
|
||||
TestCallThrottler,
|
||||
VerifyPhoneNumberThrottlerPerOrg,
|
||||
VerifyPhoneNumberThrottlerPerUser,
|
||||
)
|
||||
from apps.auth_token.auth import PluginAuthentication
|
||||
from apps.auth_token.constants import SCHEDULE_EXPORT_TOKEN_NAME
|
||||
from apps.auth_token.models import UserScheduleExportAuthToken
|
||||
|
|
@ -279,7 +286,11 @@ class UserView(
|
|||
def timezone_options(self, request):
|
||||
return Response(pytz.common_timezones)
|
||||
|
||||
@action(detail=True, methods=["get"])
|
||||
@action(
|
||||
detail=True,
|
||||
methods=["get"],
|
||||
throttle_classes=[GetPhoneVerificationCodeThrottlerPerUser, GetPhoneVerificationCodeThrottlerPerOrg],
|
||||
)
|
||||
def get_verification_code(self, request, pk):
|
||||
user = self.get_object()
|
||||
phone_manager = PhoneManager(user)
|
||||
|
|
@ -289,7 +300,11 @@ class UserView(
|
|||
return Response(status=status.HTTP_400_BAD_REQUEST)
|
||||
return Response(status=status.HTTP_200_OK)
|
||||
|
||||
@action(detail=True, methods=["put"])
|
||||
@action(
|
||||
detail=True,
|
||||
methods=["put"],
|
||||
throttle_classes=[VerifyPhoneNumberThrottlerPerUser, VerifyPhoneNumberThrottlerPerOrg],
|
||||
)
|
||||
def verify_number(self, request, pk):
|
||||
target_user = self.get_object()
|
||||
code = request.query_params.get("token", None)
|
||||
|
|
@ -327,7 +342,7 @@ class UserView(
|
|||
)
|
||||
return Response(status=status.HTTP_200_OK)
|
||||
|
||||
@action(detail=True, methods=["post"])
|
||||
@action(detail=True, methods=["post"], throttle_classes=[TestCallThrottler])
|
||||
def make_test_call(self, request, pk):
|
||||
user = self.get_object()
|
||||
phone_number = user.verified_phone_number
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from unittest.mock import patch
|
||||
from unittest.mock import PropertyMock, patch
|
||||
|
||||
import pytest
|
||||
from django.core.cache import cache
|
||||
|
|
@ -7,27 +7,25 @@ from rest_framework import status
|
|||
from rest_framework.test import APIClient
|
||||
|
||||
|
||||
@patch("apps.public_api.throttlers.user_throttle.UserThrottle.get_throttle_limits")
|
||||
@pytest.mark.django_db
|
||||
def test_throttling(mocked_throttle_limits, make_organization_and_user_with_token):
|
||||
MAX_REQUESTS = 1
|
||||
PERIOD = 360
|
||||
def test_throttling(make_organization_and_user_with_token):
|
||||
with patch("apps.public_api.throttlers.user_throttle.UserThrottle.rate", new_callable=PropertyMock) as mocked_rate:
|
||||
mocked_rate.return_value = "1/m"
|
||||
|
||||
_, _, token = make_organization_and_user_with_token()
|
||||
cache.clear()
|
||||
_, _, token = make_organization_and_user_with_token()
|
||||
cache.clear()
|
||||
|
||||
client = APIClient()
|
||||
client = APIClient()
|
||||
|
||||
mocked_throttle_limits.return_value = MAX_REQUESTS, PERIOD
|
||||
url = reverse("api-public:alert_groups-list")
|
||||
url = reverse("api-public:alert_groups-list")
|
||||
|
||||
response = client.get(url, format="json", HTTP_AUTHORIZATION=f"{token}")
|
||||
response = client.get(url, format="json", HTTP_AUTHORIZATION=f"{token}")
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
|
||||
response = client.get(url, format="json", HTTP_AUTHORIZATION=f"{token}")
|
||||
response = client.get(url, format="json", HTTP_AUTHORIZATION=f"{token}")
|
||||
|
||||
assert response.status_code == status.HTTP_429_TOO_MANY_REQUESTS
|
||||
assert response.status_code == status.HTTP_429_TOO_MANY_REQUESTS
|
||||
|
||||
# make sure RateLimitHeadersMixin used
|
||||
assert response.has_header("RateLimit-Reset")
|
||||
# make sure RateLimitHeadersMixin used
|
||||
assert response.has_header("RateLimit-Reset")
|
||||
|
|
|
|||
|
|
@ -2,42 +2,5 @@ from rest_framework.throttling import UserRateThrottle
|
|||
|
||||
|
||||
class UserThrottle(UserRateThrottle):
|
||||
"""
|
||||
__init__ and allow_request are overridden because we want rate 300/5m,
|
||||
but default rate parser implementation doesn't allow to specify length of period (only m, d, etc.)
|
||||
(See SimpleRateThrottle.parse_rate)
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.num_requests, self.duration = self.get_throttle_limits()
|
||||
|
||||
def get_throttle_limits(self):
|
||||
"""
|
||||
This method exits for speed up tests.
|
||||
:return tuple requests/seconds
|
||||
"""
|
||||
return 300, 60
|
||||
|
||||
def allow_request(self, request, view):
|
||||
"""
|
||||
Implement the check to see if the request should be throttled.
|
||||
|
||||
On success calls `throttle_success`.
|
||||
On failure calls `throttle_failure`.
|
||||
"""
|
||||
|
||||
self.key = self.get_cache_key(request, view)
|
||||
if self.key is None:
|
||||
return True
|
||||
|
||||
self.history = self.cache.get(self.key, [])
|
||||
self.now = self.timer()
|
||||
|
||||
# Drop any requests from the history which have now passed the
|
||||
# throttle duration
|
||||
while self.history and self.history[-1] <= self.now - self.duration:
|
||||
self.history.pop()
|
||||
if len(self.history) >= self.num_requests:
|
||||
return self.throttle_failure()
|
||||
return self.throttle_success()
|
||||
scope = "public_api"
|
||||
rate = "300/m"
|
||||
|
|
|
|||
56
engine/common/api_helpers/custom_rate_scoped_throttler.py
Normal file
56
engine/common/api_helpers/custom_rate_scoped_throttler.py
Normal file
|
|
@ -0,0 +1,56 @@
|
|||
from rest_framework.throttling import SimpleRateThrottle
|
||||
|
||||
|
||||
class CustomRateScopedThrottler(SimpleRateThrottle):
|
||||
"""
|
||||
Abstract class to create throttlers with custom amount of seconds and custom scope.
|
||||
The unique cache key will be generated by concatenating the
|
||||
user id of the request, and the scope from get_scope() method.
|
||||
|
||||
Should not be used directly.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.scope = self.get_scope()
|
||||
self.num_requests, self.duration = self.get_throttle_limits()
|
||||
|
||||
def get_throttle_limits(self):
|
||||
"""
|
||||
:return tuple requests/seconds
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_scope(self):
|
||||
"""
|
||||
:return ratelimit scope
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def allow_request(self, request, view):
|
||||
"""
|
||||
Overriden allow_request method.
|
||||
The difference is that overriden method doesn't check rate property.
|
||||
"""
|
||||
|
||||
self.key = self.get_cache_key(request, view)
|
||||
if self.key is None:
|
||||
return True
|
||||
|
||||
self.history = self.cache.get(self.key, [])
|
||||
self.now = self.timer()
|
||||
|
||||
# Drop any requests from the history which have now passed the
|
||||
# throttle duration
|
||||
while self.history and self.history[-1] <= self.now - self.duration:
|
||||
self.history.pop()
|
||||
if len(self.history) >= self.num_requests:
|
||||
return self.throttle_failure()
|
||||
return self.throttle_success()
|
||||
|
||||
def get_cache_key(self, request, view):
|
||||
if request.user.is_authenticated:
|
||||
ident = request.user.pk
|
||||
else:
|
||||
ident = self.get_ident(request)
|
||||
|
||||
return self.cache_format % {"scope": self.scope, "ident": ident}
|
||||
Loading…
Add table
Reference in a new issue