diff --git a/CHANGELOG.md b/CHANGELOG.md index bbbcd3d4..b9c27be4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,12 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## Unreleased + +### Changed + +- Add ratelimits for phone verification + ## v1.1.26 (2023-02-20) ### Fixed diff --git a/engine/apps/api/tests/test_user.py b/engine/apps/api/tests/test_user.py index 9ad3b6d5..ad90e264 100644 --- a/engine/apps/api/tests/test_user.py +++ b/engine/apps/api/tests/test_user.py @@ -1,6 +1,7 @@ from unittest.mock import Mock, patch import pytest +from django.core.cache import cache from django.core.exceptions import ObjectDoesNotExist from django.urls import reverse from django.utils import timezone @@ -13,6 +14,12 @@ from apps.base.models import UserNotificationPolicy from apps.user_management.models.user import default_working_hours +@pytest.fixture(autouse=True) +def clear_cache(): + # Ratelimit keys are stored in cache, clean to prevent ratelimits + cache.clear() + + @pytest.mark.django_db def test_update_user( make_organization, @@ -653,7 +660,6 @@ def test_admin_can_verify_own_phone( make_user_auth_headers, ): _, user, token = make_organization_and_user_with_plugin_token(role=LegacyAccessControlRole.ADMIN) - client = APIClient() url = reverse("api-internal:user-verify-number", kwargs={"pk": user.public_primary_key}) @@ -1499,3 +1505,88 @@ def test_check_availability_other_user(make_organization_and_user_with_plugin_to response = client.get(url, **make_user_auth_headers(user, token)) assert response.status_code == status.HTTP_200_OK + + +@patch("apps.twilioapp.phone_manager.PhoneManager.send_verification_code", return_value=Mock()) +@patch("apps.twilioapp.phone_manager.PhoneManager.verify_phone_number", return_value=(True, None)) +@patch( + "apps.api.throttlers.GetPhoneVerificationCodeThrottlerPerUser.get_throttle_limits", + return_value=(1, 10 * 60), +) +@patch("apps.api.throttlers.VerifyPhoneNumberThrottlerPerUser.get_throttle_limits", return_value=(1, 10 * 60)) +@pytest.mark.django_db +def test_phone_number_verification_flow_ratelimit_per_user( + mock_verification_start, + mocked_verification_check, + mocked_get_phone_verification_code_get_throttle_limits, + mocked_get_phone_verify_phone_number_limits, + make_organization_and_user_with_plugin_token, + make_user_auth_headers, +): + _, user, token = make_organization_and_user_with_plugin_token() + + client = APIClient() + url = reverse("api-internal:user-get-verification-code", kwargs={"pk": user.public_primary_key}) + + # first get_verification_code request is succesfull + response = client.get(url, format="json", **make_user_auth_headers(user, token)) + assert response.status_code == status.HTTP_200_OK + + # second get_verification_code request is ratelimited + response = client.get(url, format="json", **make_user_auth_headers(user, token)) + assert response.status_code == status.HTTP_429_TOO_MANY_REQUESTS + + url = reverse("api-internal:user-verify-number", kwargs={"pk": user.public_primary_key}) + + # first verify_number request is succesfull, because it uses different ratelimit scope + response = client.put(f"{url}?token=12345", format="json", **make_user_auth_headers(user, token)) + assert response.status_code == status.HTTP_200_OK + + url = reverse("api-internal:user-verify-number", kwargs={"pk": user.public_primary_key}) + + # second verify_number request is succesfull, because it ratelimited + response = client.put(f"{url}?token=12345", format="json", **make_user_auth_headers(user, token)) + assert response.status_code == status.HTTP_429_TOO_MANY_REQUESTS + + +@patch("apps.twilioapp.phone_manager.PhoneManager.send_verification_code", return_value=Mock()) +@patch("apps.twilioapp.phone_manager.PhoneManager.verify_phone_number", return_value=(True, None)) +@patch( + "apps.api.throttlers.GetPhoneVerificationCodeThrottlerPerOrg.get_throttle_limits", + return_value=(1, 10 * 60), +) +@patch("apps.api.throttlers.VerifyPhoneNumberThrottlerPerOrg.get_throttle_limits", return_value=(1, 10 * 60)) +@pytest.mark.django_db +def test_phone_number_verification_flow_ratelimit_per_org( + mock_verification_start, + mocked_verification_check, + mocked_get_phone_verification_code_get_throttle_limits, + mocked_get_phone_verify_phone_number_limits, + make_organization_and_user_with_plugin_token, + make_user_auth_headers, + make_user_for_organization, +): + """ + This test is checks per-org ratelimits for phone verification flow. + It makes two get_verification_code and two verify_number requests from different users and expect that second call will be ratelimited. + """ + org, user, token = make_organization_and_user_with_plugin_token() + second_user = make_user_for_organization(org) + + client = APIClient() + + url = reverse("api-internal:user-get-verification-code", kwargs={"pk": user.public_primary_key}) + response = client.get(url, format="json", **make_user_auth_headers(user, token)) + assert response.status_code == status.HTTP_200_OK + + url = reverse("api-internal:user-get-verification-code", kwargs={"pk": second_user.public_primary_key}) + response = client.get(url, format="json", **make_user_auth_headers(second_user, token)) + assert response.status_code == status.HTTP_429_TOO_MANY_REQUESTS + + url = reverse("api-internal:user-verify-number", kwargs={"pk": user.public_primary_key}) + response = client.put(f"{url}?token=12345", format="json", **make_user_auth_headers(user, token)) + assert response.status_code == status.HTTP_200_OK + + url = reverse("api-internal:user-verify-number", kwargs={"pk": second_user.public_primary_key}) + response = client.put(f"{url}?token=12345", format="json", **make_user_auth_headers(second_user, token)) + assert response.status_code == status.HTTP_429_TOO_MANY_REQUESTS diff --git a/engine/apps/api/throttlers/__init__.py b/engine/apps/api/throttlers/__init__.py index 3d04dedf..6eb19e81 100644 --- a/engine/apps/api/throttlers/__init__.py +++ b/engine/apps/api/throttlers/__init__.py @@ -1 +1,8 @@ from .demo_alert_throttler import DemoAlertThrottler # noqa: F401 +from .phone_verification_throttler import ( # noqa: F401 + GetPhoneVerificationCodeThrottlerPerOrg, + GetPhoneVerificationCodeThrottlerPerUser, + VerifyPhoneNumberThrottlerPerOrg, + VerifyPhoneNumberThrottlerPerUser, +) +from .test_call_throttler import TestCallThrottler # noqa: F401 diff --git a/engine/apps/api/throttlers/phone_verification_throttler.py b/engine/apps/api/throttlers/phone_verification_throttler.py new file mode 100644 index 00000000..1bf6fe02 --- /dev/null +++ b/engine/apps/api/throttlers/phone_verification_throttler.py @@ -0,0 +1,49 @@ +from common.api_helpers.custom_rate_scoped_throttler import CustomRateScopedThrottler + + +class GetPhoneVerificationCodeThrottlerPerUser(CustomRateScopedThrottler): + def get_scope(self): + return "get_phone_verification_code_per_user" + + def get_throttle_limits(self): + return 5, 10 * 60 + + +class VerifyPhoneNumberThrottlerPerUser(CustomRateScopedThrottler): + def get_scope(self): + return "verify_phone_number_per_user" + + def get_throttle_limits(self): + return 50, 10 * 60 + + +class GetPhoneVerificationCodeThrottlerPerOrg(CustomRateScopedThrottler): + def get_scope(self): + return "get_phone_verification_code_per_org" + + def get_throttle_limits(self): + return 50, 10 * 60 + + def get_cache_key(self, request, view): + if request.user.is_authenticated: + ident = request.user.organization.pk + else: + ident = self.get_ident(request) + + return self.cache_format % {"scope": self.scope, "ident": ident} + + +class VerifyPhoneNumberThrottlerPerOrg(CustomRateScopedThrottler): + def get_scope(self): + return "verify_phone_number_per_org" + + def get_throttle_limits(self): + return 50, 10 * 60 + + def get_cache_key(self, request, view): + if request.user.is_authenticated: + ident = request.user.organization.pk + else: + ident = self.get_ident(request) + + return self.cache_format % {"scope": self.scope, "ident": ident} diff --git a/engine/apps/api/throttlers/test_call_throttler.py b/engine/apps/api/throttlers/test_call_throttler.py new file mode 100644 index 00000000..93517163 --- /dev/null +++ b/engine/apps/api/throttlers/test_call_throttler.py @@ -0,0 +1,6 @@ +from rest_framework.throttling import UserRateThrottle + + +class TestCallThrottler(UserRateThrottle): + scope = "make_test_call" + rate = "5/m" diff --git a/engine/apps/api/views/user.py b/engine/apps/api/views/user.py index d5fc8af7..6fcd2123 100644 --- a/engine/apps/api/views/user.py +++ b/engine/apps/api/views/user.py @@ -24,6 +24,13 @@ from apps.api.permissions import ( ) from apps.api.serializers.team import TeamSerializer from apps.api.serializers.user import FilterUserSerializer, UserHiddenFieldsSerializer, UserSerializer +from apps.api.throttlers import ( + GetPhoneVerificationCodeThrottlerPerOrg, + GetPhoneVerificationCodeThrottlerPerUser, + TestCallThrottler, + VerifyPhoneNumberThrottlerPerOrg, + VerifyPhoneNumberThrottlerPerUser, +) from apps.auth_token.auth import PluginAuthentication from apps.auth_token.constants import SCHEDULE_EXPORT_TOKEN_NAME from apps.auth_token.models import UserScheduleExportAuthToken @@ -279,7 +286,11 @@ class UserView( def timezone_options(self, request): return Response(pytz.common_timezones) - @action(detail=True, methods=["get"]) + @action( + detail=True, + methods=["get"], + throttle_classes=[GetPhoneVerificationCodeThrottlerPerUser, GetPhoneVerificationCodeThrottlerPerOrg], + ) def get_verification_code(self, request, pk): user = self.get_object() phone_manager = PhoneManager(user) @@ -289,7 +300,11 @@ class UserView( return Response(status=status.HTTP_400_BAD_REQUEST) return Response(status=status.HTTP_200_OK) - @action(detail=True, methods=["put"]) + @action( + detail=True, + methods=["put"], + throttle_classes=[VerifyPhoneNumberThrottlerPerUser, VerifyPhoneNumberThrottlerPerOrg], + ) def verify_number(self, request, pk): target_user = self.get_object() code = request.query_params.get("token", None) @@ -327,7 +342,7 @@ class UserView( ) return Response(status=status.HTTP_200_OK) - @action(detail=True, methods=["post"]) + @action(detail=True, methods=["post"], throttle_classes=[TestCallThrottler]) def make_test_call(self, request, pk): user = self.get_object() phone_number = user.verified_phone_number diff --git a/engine/apps/public_api/tests/test_ratelimit.py b/engine/apps/public_api/tests/test_ratelimit.py index eb71827c..d6a74587 100644 --- a/engine/apps/public_api/tests/test_ratelimit.py +++ b/engine/apps/public_api/tests/test_ratelimit.py @@ -1,4 +1,4 @@ -from unittest.mock import patch +from unittest.mock import PropertyMock, patch import pytest from django.core.cache import cache @@ -7,27 +7,25 @@ from rest_framework import status from rest_framework.test import APIClient -@patch("apps.public_api.throttlers.user_throttle.UserThrottle.get_throttle_limits") @pytest.mark.django_db -def test_throttling(mocked_throttle_limits, make_organization_and_user_with_token): - MAX_REQUESTS = 1 - PERIOD = 360 +def test_throttling(make_organization_and_user_with_token): + with patch("apps.public_api.throttlers.user_throttle.UserThrottle.rate", new_callable=PropertyMock) as mocked_rate: + mocked_rate.return_value = "1/m" - _, _, token = make_organization_and_user_with_token() - cache.clear() + _, _, token = make_organization_and_user_with_token() + cache.clear() - client = APIClient() + client = APIClient() - mocked_throttle_limits.return_value = MAX_REQUESTS, PERIOD - url = reverse("api-public:alert_groups-list") + url = reverse("api-public:alert_groups-list") - response = client.get(url, format="json", HTTP_AUTHORIZATION=f"{token}") + response = client.get(url, format="json", HTTP_AUTHORIZATION=f"{token}") - assert response.status_code == status.HTTP_200_OK + assert response.status_code == status.HTTP_200_OK - response = client.get(url, format="json", HTTP_AUTHORIZATION=f"{token}") + response = client.get(url, format="json", HTTP_AUTHORIZATION=f"{token}") - assert response.status_code == status.HTTP_429_TOO_MANY_REQUESTS + assert response.status_code == status.HTTP_429_TOO_MANY_REQUESTS - # make sure RateLimitHeadersMixin used - assert response.has_header("RateLimit-Reset") + # make sure RateLimitHeadersMixin used + assert response.has_header("RateLimit-Reset") diff --git a/engine/apps/public_api/throttlers/user_throttle.py b/engine/apps/public_api/throttlers/user_throttle.py index 4c46e259..7c176f2e 100644 --- a/engine/apps/public_api/throttlers/user_throttle.py +++ b/engine/apps/public_api/throttlers/user_throttle.py @@ -2,42 +2,5 @@ from rest_framework.throttling import UserRateThrottle class UserThrottle(UserRateThrottle): - """ - __init__ and allow_request are overridden because we want rate 300/5m, - but default rate parser implementation doesn't allow to specify length of period (only m, d, etc.) - (See SimpleRateThrottle.parse_rate) - - """ - - def __init__(self): - self.num_requests, self.duration = self.get_throttle_limits() - - def get_throttle_limits(self): - """ - This method exits for speed up tests. - :return tuple requests/seconds - """ - return 300, 60 - - def allow_request(self, request, view): - """ - Implement the check to see if the request should be throttled. - - On success calls `throttle_success`. - On failure calls `throttle_failure`. - """ - - self.key = self.get_cache_key(request, view) - if self.key is None: - return True - - self.history = self.cache.get(self.key, []) - self.now = self.timer() - - # Drop any requests from the history which have now passed the - # throttle duration - while self.history and self.history[-1] <= self.now - self.duration: - self.history.pop() - if len(self.history) >= self.num_requests: - return self.throttle_failure() - return self.throttle_success() + scope = "public_api" + rate = "300/m" diff --git a/engine/common/api_helpers/custom_rate_scoped_throttler.py b/engine/common/api_helpers/custom_rate_scoped_throttler.py new file mode 100644 index 00000000..c8965d20 --- /dev/null +++ b/engine/common/api_helpers/custom_rate_scoped_throttler.py @@ -0,0 +1,56 @@ +from rest_framework.throttling import SimpleRateThrottle + + +class CustomRateScopedThrottler(SimpleRateThrottle): + """ + Abstract class to create throttlers with custom amount of seconds and custom scope. + The unique cache key will be generated by concatenating the + user id of the request, and the scope from get_scope() method. + + Should not be used directly. + """ + + def __init__(self): + self.scope = self.get_scope() + self.num_requests, self.duration = self.get_throttle_limits() + + def get_throttle_limits(self): + """ + :return tuple requests/seconds + """ + raise NotImplementedError + + def get_scope(self): + """ + :return ratelimit scope + """ + raise NotImplementedError + + def allow_request(self, request, view): + """ + Overriden allow_request method. + The difference is that overriden method doesn't check rate property. + """ + + self.key = self.get_cache_key(request, view) + if self.key is None: + return True + + self.history = self.cache.get(self.key, []) + self.now = self.timer() + + # Drop any requests from the history which have now passed the + # throttle duration + while self.history and self.history[-1] <= self.now - self.duration: + self.history.pop() + if len(self.history) >= self.num_requests: + return self.throttle_failure() + return self.throttle_success() + + def get_cache_key(self, request, view): + if request.user.is_authenticated: + ident = request.user.pk + else: + ident = self.get_ident(request) + + return self.cache_format % {"scope": self.scope, "ident": ident}