From 5e5eecc48068d466b21ed63fb10ebd0b8dee9a49 Mon Sep 17 00:00:00 2001 From: Matias Bordese Date: Mon, 14 Aug 2023 09:26:21 -0300 Subject: [PATCH] Add shift swaps endpoints to public API (#2775) Related to https://github.com/grafana/oncall/issues/2678 --- CHANGELOG.md | 4 + engine/apps/api/views/shift_swap.py | 116 ++--- .../apps/public_api/tests/test_shift_swap.py | 413 ++++++++++++++++++ engine/apps/public_api/urls.py | 1 + engine/apps/public_api/views/__init__.py | 1 + engine/apps/public_api/views/shift_swap.py | 97 ++++ 6 files changed, 579 insertions(+), 53 deletions(-) create mode 100644 engine/apps/public_api/tests/test_shift_swap.py create mode 100644 engine/apps/public_api/views/shift_swap.py diff --git a/CHANGELOG.md b/CHANGELOG.md index e24a957f..a871ecbc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fix issue with updating "Require resolution note" setting by @Ferril ([#2782](https://github.com/grafana/oncall/pull/2782)) - Don't send notifications about past SSRs when turning on info notifications by @vadimkerr ([#2783](https://github.com/grafana/oncall/pull/2783)) +### Added + +- Shift swap requests public API ([#2775](https://github.com/grafana/oncall/pull/2775)) + ## v1.3.23 (2023-08-10) ### Added diff --git a/engine/apps/api/views/shift_swap.py b/engine/apps/api/views/shift_swap.py index 87889a3a..8d47589f 100644 --- a/engine/apps/api/views/shift_swap.py +++ b/engine/apps/api/views/shift_swap.py @@ -14,6 +14,7 @@ from apps.mobile_app.auth import MobileAppAuthTokenAuthentication from apps.schedules import exceptions from apps.schedules.models import ShiftSwapRequest from apps.schedules.tasks.shift_swaps import create_shift_swap_request_message, update_shift_swap_request_message +from apps.user_management.models import User from common.api_helpers.exceptions import BadRequest from common.api_helpers.mixins import PublicPrimaryKeyMixin from common.api_helpers.paginators import FiftyPageSizePaginator @@ -22,7 +23,66 @@ from common.insight_log import EntityEvent, write_resource_insight_log logger = logging.getLogger(__name__) -class ShiftSwapViewSet(PublicPrimaryKeyMixin[ShiftSwapRequest], ModelViewSet): +class BaseShiftSwapViewSet(ModelViewSet): + model = ShiftSwapRequest + serializer_class = ShiftSwapRequestSerializer + pagination_class = FiftyPageSizePaginator + + def _do_create(self, beneficiary: User, serializer: BaseSerializer[ShiftSwapRequest]) -> None: + shift_swap_request = serializer.save(beneficiary=beneficiary) + + write_resource_insight_log(instance=shift_swap_request, author=self.request.user, event=EntityEvent.CREATED) + + create_shift_swap_request_message.apply_async((shift_swap_request.pk,)) + + def _do_take(self, benefactor: User) -> dict: + shift_swap = self.get_object() + + try: + shift_swap.take(benefactor) + except exceptions.ShiftSwapRequestNotOpenForTaking: + raise BadRequest(detail="The shift swap request is not in a state which allows it to be taken") + except exceptions.BeneficiaryCannotTakeOwnShiftSwapRequest: + raise BadRequest(detail="A shift swap request cannot be created and taken by the same user") + + return ShiftSwapRequestSerializer(shift_swap).data + + def get_serializer_class(self): + return ShiftSwapRequestListSerializer if self.action == "list" else super().get_serializer_class() + + def get_queryset(self): + queryset = ShiftSwapRequest.objects.filter(schedule__organization=self.request.auth.organization) + return self.serializer_class.setup_eager_loading(queryset) + + def perform_destroy(self, instance: ShiftSwapRequest) -> None: + # TODO: should we allow deleting a taken request? + + super().perform_destroy(instance) + write_resource_insight_log(instance=instance, author=self.request.user, event=EntityEvent.DELETED) + + update_shift_swap_request_message.apply_async((instance.pk,)) + + def perform_create(self, serializer: BaseSerializer[ShiftSwapRequest]) -> None: + # default to create swap request with logged in user as beneficiary + self._do_create(self.request.user, serializer=serializer) + + def perform_update(self, serializer: BaseSerializer[ShiftSwapRequest]) -> None: + prev_state = serializer.instance.insight_logs_serialized + serializer.save() + shift_swap_request = serializer.instance + + write_resource_insight_log( + instance=shift_swap_request, + author=self.request.user, + event=EntityEvent.UPDATED, + prev_state=prev_state, + new_state=shift_swap_request.insight_logs_serialized, + ) + + update_shift_swap_request_message.apply_async((shift_swap_request.pk,)) + + +class ShiftSwapViewSet(PublicPrimaryKeyMixin[ShiftSwapRequest], BaseShiftSwapViewSet): authentication_classes = (MobileAppAuthTokenAuthentication, PluginAuthentication) permission_classes = (IsAuthenticated, RBACPermission) @@ -49,57 +109,7 @@ class ShiftSwapViewSet(PublicPrimaryKeyMixin[ShiftSwapRequest], ModelViewSet): ], } - model = ShiftSwapRequest - serializer_class = ShiftSwapRequestSerializer - pagination_class = FiftyPageSizePaginator - - def get_serializer_class(self): - return ShiftSwapRequestListSerializer if self.action == "list" else super().get_serializer_class() - - def get_queryset(self): - queryset = ShiftSwapRequest.objects.filter(schedule__organization=self.request.auth.organization) - return self.serializer_class.setup_eager_loading(queryset) - - def perform_destroy(self, instance: ShiftSwapRequest) -> None: - # TODO: should we allow deleting a taken request? - - super().perform_destroy(instance) - write_resource_insight_log(instance=instance, author=self.request.user, event=EntityEvent.DELETED) - - update_shift_swap_request_message.apply_async((instance.pk,)) - - def perform_create(self, serializer: BaseSerializer[ShiftSwapRequest]) -> None: - beneficiary = self.request.user - shift_swap_request = serializer.save(beneficiary=beneficiary) - - write_resource_insight_log(instance=shift_swap_request, author=beneficiary, event=EntityEvent.CREATED) - - create_shift_swap_request_message.apply_async((shift_swap_request.pk,)) - - def perform_update(self, serializer: BaseSerializer[ShiftSwapRequest]) -> None: - prev_state = serializer.instance.insight_logs_serialized - serializer.save() - shift_swap_request = serializer.instance - - write_resource_insight_log( - instance=shift_swap_request, - author=self.request.user, - event=EntityEvent.UPDATED, - prev_state=prev_state, - new_state=shift_swap_request.insight_logs_serialized, - ) - - update_shift_swap_request_message.apply_async((shift_swap_request.pk,)) - @action(methods=["post"], detail=True) def take(self, request: AuthenticatedRequest, pk: str) -> Response: - shift_swap = self.get_object() - - try: - shift_swap.take(request.user) - except exceptions.ShiftSwapRequestNotOpenForTaking: - raise BadRequest(detail="The shift swap request is not in a state which allows it to be taken") - except exceptions.BeneficiaryCannotTakeOwnShiftSwapRequest: - raise BadRequest(detail="A shift swap request cannot be created and taken by the same user") - - return Response(ShiftSwapRequestSerializer(shift_swap).data, status=status.HTTP_200_OK) + serialized_shift_swap = self._do_take(benefactor=request.user) + return Response(serialized_shift_swap, status=status.HTTP_200_OK) diff --git a/engine/apps/public_api/tests/test_shift_swap.py b/engine/apps/public_api/tests/test_shift_swap.py new file mode 100644 index 00000000..913144f7 --- /dev/null +++ b/engine/apps/public_api/tests/test_shift_swap.py @@ -0,0 +1,413 @@ +from unittest.mock import patch + +import pytest +from django.urls import reverse +from django.utils import timezone +from rest_framework import status +from rest_framework.test import APIClient + +from apps.schedules.models import CustomOnCallShift, OnCallScheduleWeb, ShiftSwapRequest +from common.api_helpers.utils import serialize_datetime_as_utc_timestamp +from common.insight_log import EntityEvent + + +@pytest.fixture +def setup_swap(make_user_for_organization, make_schedule, make_shift_swap_request): + def _setup_swap(organization, **kwargs): + user = make_user_for_organization(organization) + schedule = make_schedule(organization, schedule_class=OnCallScheduleWeb) + today = timezone.now().replace(hour=0, minute=0, second=0, microsecond=0) + tomorrow = today + timezone.timedelta(days=1) + two_days_from_now = tomorrow + timezone.timedelta(days=1) + + swap = make_shift_swap_request(schedule, user, swap_start=tomorrow, swap_end=two_days_from_now) + return swap + + return _setup_swap + + +def assert_swap_response(response, request_data): + response_data = response.json() + swap = ShiftSwapRequest.objects.get(public_primary_key=response_data["id"]) + # check description + assert swap.description == response_data["description"] + if "description" in request_data: + assert response_data["description"] == request_data["description"] + # check datetime fields + for field in ("swap_start", "swap_end"): + db_value = serialize_datetime_as_utc_timestamp(getattr(swap, field)) + assert db_value == response_data[field] + if field in request_data: + assert db_value == request_data[field] + # check FK fields + for field in ("schedule", "beneficiary", "benefactor"): + value = response_data[field] + if value: + assert getattr(swap, field).public_primary_key == response_data[field] + else: + assert getattr(swap, field) is None + if field in request_data: + assert value == request_data[field] + + +@pytest.mark.django_db +def test_list_filters( + make_organization_and_user_with_token, + make_user_for_organization, + make_schedule, + make_shift_swap_request, +): + organization, user, token = make_organization_and_user_with_token() + user2 = make_user_for_organization(organization) + + schedule1 = make_schedule(organization, schedule_class=OnCallScheduleWeb) + schedule2 = make_schedule(organization, schedule_class=OnCallScheduleWeb) + + today = timezone.now().replace(hour=0, minute=0, second=0, microsecond=0) + yesterday = today - timezone.timedelta(days=1) + tomorrow = today + timezone.timedelta(days=1) + two_days_from_now = tomorrow + timezone.timedelta(days=1) + + # open + swap1 = make_shift_swap_request(schedule1, user, swap_start=tomorrow, swap_end=two_days_from_now) + # past due + swap2 = make_shift_swap_request(schedule1, user2, swap_start=yesterday, swap_end=today) + # past due / in-progress + swap3 = make_shift_swap_request(schedule2, user2, swap_start=today, swap_end=tomorrow) + # taken + swap4 = make_shift_swap_request(schedule2, user2, swap_start=tomorrow, swap_end=two_days_from_now, benefactor=user) + + def assert_expected(response, expected): + assert response.status_code == status.HTTP_200_OK + returned = [s["id"] for s in response.json().get("results", [])] + assert returned == [s.public_primary_key for s in expected] + + client = APIClient() + base_url = reverse("api-public:shift_swap-list") + + url = base_url + response = client.get(url, format="json", HTTP_AUTHORIZATION=f"{token}") + assert response.status_code == status.HTTP_200_OK + assert_expected(response, (swap1, swap4)) + + url = base_url + f"?schedule_id={schedule1.public_primary_key}" + response = client.get(url, format="json", HTTP_AUTHORIZATION=f"{token}") + assert response.status_code == status.HTTP_200_OK + assert_expected(response, (swap1,)) + + url = base_url + "?open_only=true" + response = client.get(url, format="json", HTTP_AUTHORIZATION=f"{token}") + assert response.status_code == status.HTTP_200_OK + assert_expected(response, (swap1,)) + + starting_after = serialize_datetime_as_utc_timestamp(yesterday) + url = base_url + f"?beneficiary={user2.public_primary_key}&starting_after={starting_after}" + response = client.get(url, format="json", HTTP_AUTHORIZATION=f"{token}") + assert response.status_code == status.HTTP_200_OK + assert_expected(response, (swap2, swap3, swap4)) + + url = base_url + f"?benefactor={user.public_primary_key}" + response = client.get(url, format="json", HTTP_AUTHORIZATION=f"{token}") + assert response.status_code == status.HTTP_200_OK + assert_expected(response, (swap4,)) + + +@patch("apps.api.views.shift_swap.write_resource_insight_log") +@patch("apps.api.views.shift_swap.create_shift_swap_request_message") +@pytest.mark.django_db +def test_create( + mock_create_shift_swap_request_message, + mock_write_resource_insight_log, + make_organization_and_user_with_token, + make_user_for_organization, + make_schedule, +): + organization, user, token = make_organization_and_user_with_token() + another_user = make_user_for_organization(organization) + schedule = make_schedule(organization, schedule_class=OnCallScheduleWeb) + today = timezone.now().replace(hour=0, minute=0, second=0, microsecond=0) + tomorrow = today + timezone.timedelta(days=1) + two_days_from_now = tomorrow + timezone.timedelta(days=1) + + data = { + "schedule": schedule.public_primary_key, + "description": "Taking a few days off", + "swap_start": serialize_datetime_as_utc_timestamp(tomorrow), + "swap_end": serialize_datetime_as_utc_timestamp(two_days_from_now), + "beneficiary": another_user.public_primary_key, + } + + client = APIClient() + url = reverse("api-public:shift_swap-list") + response = client.post(url, data, format="json", HTTP_AUTHORIZATION=f"{token}") + + assert response.status_code == status.HTTP_201_CREATED + assert_swap_response(response, data) + + ssr = ShiftSwapRequest.objects.get(public_primary_key=response.json()["id"]) + mock_write_resource_insight_log.assert_called_once_with(instance=ssr, author=user, event=EntityEvent.CREATED) + mock_create_shift_swap_request_message.apply_async.assert_called_once_with((ssr.pk,)) + + +@pytest.mark.django_db +def test_create_requires_beneficiary( + make_organization_and_user_with_token, + make_schedule, +): + organization, user, token = make_organization_and_user_with_token() + + schedule = make_schedule(organization, schedule_class=OnCallScheduleWeb) + today = timezone.now().replace(hour=0, minute=0, second=0, microsecond=0) + tomorrow = today + timezone.timedelta(days=1) + two_days_from_now = tomorrow + timezone.timedelta(days=1) + + data = { + "schedule": schedule.public_primary_key, + "description": "Taking a few days off", + "swap_start": serialize_datetime_as_utc_timestamp(tomorrow), + "swap_end": serialize_datetime_as_utc_timestamp(two_days_from_now), + } + + client = APIClient() + url = reverse("api-public:shift_swap-list") + response = client.post(url, data, format="json", HTTP_AUTHORIZATION=f"{token}") + + assert response.status_code == status.HTTP_400_BAD_REQUEST + assert ShiftSwapRequest.objects.count() == 0 + + +@patch("apps.api.views.shift_swap.write_resource_insight_log") +@patch("apps.api.views.shift_swap.update_shift_swap_request_message") +@pytest.mark.django_db +def test_update( + mock_update_shift_swap_request_message, + mock_write_resource_insight_log, + make_organization_and_user_with_token, + setup_swap, +): + organization, user, token = make_organization_and_user_with_token() + swap = setup_swap(organization) + assert swap.description is None + insights_log_prev_state = swap.insight_logs_serialized + + data = { + "description": "Taking a few days off", + "schedule": swap.schedule.public_primary_key, + "swap_start": serialize_datetime_as_utc_timestamp(swap.swap_start), + "swap_end": serialize_datetime_as_utc_timestamp(swap.swap_end), + } + + client = APIClient() + url = reverse("api-public:shift_swap-detail", kwargs={"pk": swap.public_primary_key}) + response = client.put(url, data, format="json", HTTP_AUTHORIZATION=f"{token}") + + assert response.status_code == status.HTTP_200_OK + assert_swap_response(response, data) + + swap.refresh_from_db() + mock_write_resource_insight_log.assert_called_once_with( + instance=swap, + author=user, + event=EntityEvent.UPDATED, + prev_state=insights_log_prev_state, + new_state=swap.insight_logs_serialized, + ) + mock_update_shift_swap_request_message.apply_async.assert_called_once_with((swap.pk,)) + + +@patch("apps.api.views.shift_swap.write_resource_insight_log") +@patch("apps.api.views.shift_swap.update_shift_swap_request_message") +@pytest.mark.django_db +def test_partial_update( + mock_update_shift_swap_request_message, + mock_write_resource_insight_log, + make_organization_and_user_with_token, + setup_swap, +): + organization, user, token = make_organization_and_user_with_token() + swap = setup_swap(organization) + assert swap.description is None + insights_log_prev_state = swap.insight_logs_serialized + + data = {"description": "Taking a few days off"} + + client = APIClient() + url = reverse("api-public:shift_swap-detail", kwargs={"pk": swap.public_primary_key}) + response = client.patch(url, data, format="json", HTTP_AUTHORIZATION=f"{token}") + + assert response.status_code == status.HTTP_200_OK + assert_swap_response(response, data) + + swap.refresh_from_db() + mock_write_resource_insight_log.assert_called_once_with( + instance=swap, + author=user, + event=EntityEvent.UPDATED, + prev_state=insights_log_prev_state, + new_state=swap.insight_logs_serialized, + ) + mock_update_shift_swap_request_message.apply_async.assert_called_once_with((swap.pk,)) + + +@pytest.mark.django_db +def test_details( + make_organization_and_user_with_token, + make_on_call_shift, + setup_swap, +): + organization, _, token = make_organization_and_user_with_token() + swap = setup_swap(organization) + schedule = swap.schedule + user = swap.beneficiary + + today = timezone.now().replace(hour=0, minute=0, second=0, microsecond=0) + start = today + timezone.timedelta(days=1) + duration = timezone.timedelta(hours=8) + data = { + "start": start, + "rotation_start": start, + "duration": duration, + "priority_level": 1, + "frequency": CustomOnCallShift.FREQUENCY_DAILY, + "schedule": schedule, + } + on_call_shift = make_on_call_shift( + organization=organization, shift_type=CustomOnCallShift.TYPE_ROLLING_USERS_EVENT, **data + ) + on_call_shift.add_rolling_users([[user]]) + + client = APIClient() + url = reverse("api-public:shift_swap-detail", kwargs={"pk": swap.public_primary_key}) + response = client.get(url, HTTP_AUTHORIZATION=f"{token}") + + assert response.status_code == status.HTTP_200_OK + assert_swap_response(response, {}) + + # include involved shifts information + shifts_data = response.json()["shifts"] + assert len(shifts_data) == 1 + expected = [ + # start, end, user, swap request ID + ( + start.strftime("%Y-%m-%dT%H:%M:%SZ"), + (start + duration).strftime("%Y-%m-%dT%H:%M:%SZ"), + user.public_primary_key, + swap.public_primary_key, + ), + ] + returned_events = [ + (e["start"], e["end"], e["users"][0]["pk"], e["users"][0]["swap_request"]["pk"]) for e in shifts_data + ] + assert returned_events == expected + + +@patch("apps.api.views.shift_swap.write_resource_insight_log") +@patch("apps.api.views.shift_swap.update_shift_swap_request_message") +@pytest.mark.django_db +def test_delete( + mock_update_shift_swap_request_message, + mock_write_resource_insight_log, + make_organization_and_user_with_token, + setup_swap, +): + organization, user, token = make_organization_and_user_with_token() + swap = setup_swap(organization) + + client = APIClient() + url = reverse("api-public:shift_swap-detail", kwargs={"pk": swap.public_primary_key}) + + response = client.delete(url, HTTP_AUTHORIZATION=f"{token}") + assert response.status_code == status.HTTP_204_NO_CONTENT + + response = client.get(url, HTTP_AUTHORIZATION=f"{token}") + assert response.status_code == status.HTTP_404_NOT_FOUND + + mock_write_resource_insight_log.assert_called_once_with( + instance=swap, + author=user, + event=EntityEvent.DELETED, + ) + mock_update_shift_swap_request_message.apply_async.assert_called_once_with((swap.pk,)) + + +@pytest.mark.django_db +def test_take( + make_organization_and_user_with_token, + make_user_for_organization, + setup_swap, +): + organization, user, token = make_organization_and_user_with_token() + another_user = make_user_for_organization(organization) + swap = setup_swap(organization) + + client = APIClient() + url = reverse("api-public:shift_swap-take", kwargs={"pk": swap.public_primary_key}) + + data = {"benefactor": another_user.public_primary_key} + response = client.post(url, data, format="json", HTTP_AUTHORIZATION=f"{token}") + assert response.status_code == status.HTTP_200_OK + + assert_swap_response(response, data) + swap.refresh_from_db() + assert swap.status == ShiftSwapRequest.Statuses.TAKEN + assert swap.benefactor == another_user + + +@pytest.mark.django_db +def test_take_requires_benefactor( + make_organization_and_user_with_token, + setup_swap, +): + organization, user, token = make_organization_and_user_with_token() + swap = setup_swap(organization) + + client = APIClient() + url = reverse("api-public:shift_swap-take", kwargs={"pk": swap.public_primary_key}) + + data = {} + response = client.post(url, data, format="json", HTTP_AUTHORIZATION=f"{token}") + assert response.status_code == status.HTTP_400_BAD_REQUEST + + swap.refresh_from_db() + assert swap.status == ShiftSwapRequest.Statuses.OPEN + assert swap.benefactor is None + + +@pytest.mark.django_db +def test_take_errors( + make_organization_and_user_with_token, + make_user_for_organization, + setup_swap, +): + organization, user, token = make_organization_and_user_with_token() + another_user = make_user_for_organization(organization) + swap = setup_swap(organization) + + client = APIClient() + url = reverse("api-public:shift_swap-take", kwargs={"pk": swap.public_primary_key}) + + # same user taking the swap + data = {"benefactor": swap.beneficiary.public_primary_key} + response = client.post(url, data, format="json", HTTP_AUTHORIZATION=f"{token}") + assert response.status_code == status.HTTP_400_BAD_REQUEST + + # already taken + swap.take(another_user) + data = {"benefactor": another_user.public_primary_key} + response = client.post(url, data, format="json", HTTP_AUTHORIZATION=f"{token}") + assert response.status_code == status.HTTP_400_BAD_REQUEST + + # deleted + swap = setup_swap(organization) + swap.delete() + data = {"benefactor": another_user.public_primary_key} + response = client.post(url, data, format="json", HTTP_AUTHORIZATION=f"{token}") + assert response.status_code == status.HTTP_400_BAD_REQUEST + + # past due + swap = setup_swap(organization) + swap.swap_start = timezone.now() - timezone.timedelta(days=2) + swap.save() + data = {"benefactor": another_user.public_primary_key} + response = client.post(url, data, format="json", HTTP_AUTHORIZATION=f"{token}") + assert response.status_code == status.HTTP_400_BAD_REQUEST diff --git a/engine/apps/public_api/urls.py b/engine/apps/public_api/urls.py index a91898df..c12369b9 100644 --- a/engine/apps/public_api/urls.py +++ b/engine/apps/public_api/urls.py @@ -25,6 +25,7 @@ router.register(r"actions", views.ActionView, basename="actions") router.register(r"user_groups", views.UserGroupView, basename="user_groups") router.register(r"on_call_shifts", views.CustomOnCallShiftView, basename="on_call_shifts") router.register(r"teams", views.TeamView, basename="teams") +router.register(r"shift_swaps", views.ShiftSwapViewSet, basename="shift_swap") urlpatterns = [ diff --git a/engine/apps/public_api/views/__init__.py b/engine/apps/public_api/views/__init__.py index 4ffcec04..4a8af434 100644 --- a/engine/apps/public_api/views/__init__.py +++ b/engine/apps/public_api/views/__init__.py @@ -12,6 +12,7 @@ from .phone_notifications import MakeCallView, SendSMSView # noqa: F401 from .resolution_notes import ResolutionNoteView # noqa: F401 from .routes import ChannelFilterView # noqa: F401 from .schedules import OnCallScheduleChannelView # noqa: F401 +from .shift_swap import ShiftSwapViewSet # noqa: F401 from .slack_channels import SlackChannelView # noqa: F401 from .teams import TeamView # noqa: F401 from .user_groups import UserGroupView # noqa: F401 diff --git a/engine/apps/public_api/views/shift_swap.py b/engine/apps/public_api/views/shift_swap.py new file mode 100644 index 00000000..73c725b1 --- /dev/null +++ b/engine/apps/public_api/views/shift_swap.py @@ -0,0 +1,97 @@ +import logging + +from django.utils import timezone +from rest_framework import status +from rest_framework.decorators import action +from rest_framework.exceptions import NotFound +from rest_framework.permissions import IsAuthenticated +from rest_framework.response import Response +from rest_framework.serializers import BaseSerializer + +from apps.api.permissions import AuthenticatedRequest +from apps.api.views.shift_swap import BaseShiftSwapViewSet +from apps.auth_token.auth import ApiTokenAuthentication +from apps.public_api.throttlers.user_throttle import UserThrottle +from apps.schedules.models import ShiftSwapRequest +from apps.user_management.models import User +from common.api_helpers.custom_fields import TimeZoneAwareDatetimeField +from common.api_helpers.exceptions import BadRequest +from common.api_helpers.mixins import RateLimitHeadersMixin + +logger = logging.getLogger(__name__) + + +class ShiftSwapViewSet(RateLimitHeadersMixin, BaseShiftSwapViewSet): + # set authentication and permission classes + authentication_classes = (ApiTokenAuthentication,) + permission_classes = (IsAuthenticated,) + + # public API customizations + throttle_classes = [UserThrottle] + + def get_queryset(self): + schedule_id = self.request.query_params.get("schedule_id", None) + beneficiary = self.request.query_params.get("beneficiary", None) + benefactor = self.request.query_params.get("benefactor", None) + starting_after = self.request.query_params.get("starting_after", None) + open_only = self.request.query_params.get("open_only", "false") == "true" + + now = timezone.now() + if starting_after: + f = TimeZoneAwareDatetimeField() + # trigger datetime format validation + # will raise ValidationError if invalid timestamp is provided + starting_after = f.to_internal_value(starting_after) + else: + starting_after = now + + # base queryset filters by organization + queryset = super().get_queryset() + queryset = queryset.filter(swap_start__gte=starting_after) + + if schedule_id: + queryset = queryset.filter(schedule__public_primary_key=schedule_id) + + if beneficiary: + queryset = queryset.filter(beneficiary__public_primary_key=beneficiary) + + if benefactor: + queryset = queryset.filter(benefactor__public_primary_key=benefactor) + + if benefactor: + queryset = queryset.filter(benefactor__public_primary_key=benefactor) + + if open_only: + queryset = queryset.filter(benefactor__isnull=True, deleted_at__isnull=True, swap_start__gt=now) + + return queryset.order_by("swap_start") + + def get_object(self): + public_primary_key = self.kwargs["pk"] + try: + return self.get_queryset().get(public_primary_key=public_primary_key) + except ShiftSwapRequest.DoesNotExist: + raise NotFound + + def _get_user(self, field_name: str): + """Require and return user from ID given by field_name.""" + user_pk = self.request.data.pop(field_name, None) + if not user_pk: + raise BadRequest(detail=f"{field_name} user ID is required") + try: + user = User.objects.get(organization=self.request.auth.organization, public_primary_key=user_pk) + except User.DoesNotExist: + raise BadRequest(detail=f"Invalid {field_name} user ID") + return user + + def perform_create(self, serializer: BaseSerializer[ShiftSwapRequest]) -> None: + beneficiary = self._get_user("beneficiary") + self._do_create(beneficiary=beneficiary, serializer=serializer) + + @action(methods=["post"], detail=True) + def take(self, request: AuthenticatedRequest, pk: str) -> Response: + # check the swap request exists and it's accessible + self.get_object() + benefactor = self._get_user("benefactor") + serialized_shift_swap = self._do_take(benefactor=benefactor) + return Response(serialized_shift_swap, status=status.HTTP_200_OK)