Add shift swaps endpoints to public API (#2775)

Related to https://github.com/grafana/oncall/issues/2678
This commit is contained in:
Matias Bordese 2023-08-14 09:26:21 -03:00 committed by GitHub
parent a9bf4f5521
commit 5e5eecc480
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 579 additions and 53 deletions

View file

@ -14,6 +14,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fix issue with updating "Require resolution note" setting by @Ferril ([#2782](https://github.com/grafana/oncall/pull/2782))
- Don't send notifications about past SSRs when turning on info notifications by @vadimkerr ([#2783](https://github.com/grafana/oncall/pull/2783))
### Added
- Shift swap requests public API ([#2775](https://github.com/grafana/oncall/pull/2775))
## v1.3.23 (2023-08-10)
### Added

View file

@ -14,6 +14,7 @@ from apps.mobile_app.auth import MobileAppAuthTokenAuthentication
from apps.schedules import exceptions
from apps.schedules.models import ShiftSwapRequest
from apps.schedules.tasks.shift_swaps import create_shift_swap_request_message, update_shift_swap_request_message
from apps.user_management.models import User
from common.api_helpers.exceptions import BadRequest
from common.api_helpers.mixins import PublicPrimaryKeyMixin
from common.api_helpers.paginators import FiftyPageSizePaginator
@ -22,7 +23,66 @@ from common.insight_log import EntityEvent, write_resource_insight_log
logger = logging.getLogger(__name__)
class ShiftSwapViewSet(PublicPrimaryKeyMixin[ShiftSwapRequest], ModelViewSet):
class BaseShiftSwapViewSet(ModelViewSet):
model = ShiftSwapRequest
serializer_class = ShiftSwapRequestSerializer
pagination_class = FiftyPageSizePaginator
def _do_create(self, beneficiary: User, serializer: BaseSerializer[ShiftSwapRequest]) -> None:
shift_swap_request = serializer.save(beneficiary=beneficiary)
write_resource_insight_log(instance=shift_swap_request, author=self.request.user, event=EntityEvent.CREATED)
create_shift_swap_request_message.apply_async((shift_swap_request.pk,))
def _do_take(self, benefactor: User) -> dict:
shift_swap = self.get_object()
try:
shift_swap.take(benefactor)
except exceptions.ShiftSwapRequestNotOpenForTaking:
raise BadRequest(detail="The shift swap request is not in a state which allows it to be taken")
except exceptions.BeneficiaryCannotTakeOwnShiftSwapRequest:
raise BadRequest(detail="A shift swap request cannot be created and taken by the same user")
return ShiftSwapRequestSerializer(shift_swap).data
def get_serializer_class(self):
return ShiftSwapRequestListSerializer if self.action == "list" else super().get_serializer_class()
def get_queryset(self):
queryset = ShiftSwapRequest.objects.filter(schedule__organization=self.request.auth.organization)
return self.serializer_class.setup_eager_loading(queryset)
def perform_destroy(self, instance: ShiftSwapRequest) -> None:
# TODO: should we allow deleting a taken request?
super().perform_destroy(instance)
write_resource_insight_log(instance=instance, author=self.request.user, event=EntityEvent.DELETED)
update_shift_swap_request_message.apply_async((instance.pk,))
def perform_create(self, serializer: BaseSerializer[ShiftSwapRequest]) -> None:
# default to create swap request with logged in user as beneficiary
self._do_create(self.request.user, serializer=serializer)
def perform_update(self, serializer: BaseSerializer[ShiftSwapRequest]) -> None:
prev_state = serializer.instance.insight_logs_serialized
serializer.save()
shift_swap_request = serializer.instance
write_resource_insight_log(
instance=shift_swap_request,
author=self.request.user,
event=EntityEvent.UPDATED,
prev_state=prev_state,
new_state=shift_swap_request.insight_logs_serialized,
)
update_shift_swap_request_message.apply_async((shift_swap_request.pk,))
class ShiftSwapViewSet(PublicPrimaryKeyMixin[ShiftSwapRequest], BaseShiftSwapViewSet):
authentication_classes = (MobileAppAuthTokenAuthentication, PluginAuthentication)
permission_classes = (IsAuthenticated, RBACPermission)
@ -49,57 +109,7 @@ class ShiftSwapViewSet(PublicPrimaryKeyMixin[ShiftSwapRequest], ModelViewSet):
],
}
model = ShiftSwapRequest
serializer_class = ShiftSwapRequestSerializer
pagination_class = FiftyPageSizePaginator
def get_serializer_class(self):
return ShiftSwapRequestListSerializer if self.action == "list" else super().get_serializer_class()
def get_queryset(self):
queryset = ShiftSwapRequest.objects.filter(schedule__organization=self.request.auth.organization)
return self.serializer_class.setup_eager_loading(queryset)
def perform_destroy(self, instance: ShiftSwapRequest) -> None:
# TODO: should we allow deleting a taken request?
super().perform_destroy(instance)
write_resource_insight_log(instance=instance, author=self.request.user, event=EntityEvent.DELETED)
update_shift_swap_request_message.apply_async((instance.pk,))
def perform_create(self, serializer: BaseSerializer[ShiftSwapRequest]) -> None:
beneficiary = self.request.user
shift_swap_request = serializer.save(beneficiary=beneficiary)
write_resource_insight_log(instance=shift_swap_request, author=beneficiary, event=EntityEvent.CREATED)
create_shift_swap_request_message.apply_async((shift_swap_request.pk,))
def perform_update(self, serializer: BaseSerializer[ShiftSwapRequest]) -> None:
prev_state = serializer.instance.insight_logs_serialized
serializer.save()
shift_swap_request = serializer.instance
write_resource_insight_log(
instance=shift_swap_request,
author=self.request.user,
event=EntityEvent.UPDATED,
prev_state=prev_state,
new_state=shift_swap_request.insight_logs_serialized,
)
update_shift_swap_request_message.apply_async((shift_swap_request.pk,))
@action(methods=["post"], detail=True)
def take(self, request: AuthenticatedRequest, pk: str) -> Response:
shift_swap = self.get_object()
try:
shift_swap.take(request.user)
except exceptions.ShiftSwapRequestNotOpenForTaking:
raise BadRequest(detail="The shift swap request is not in a state which allows it to be taken")
except exceptions.BeneficiaryCannotTakeOwnShiftSwapRequest:
raise BadRequest(detail="A shift swap request cannot be created and taken by the same user")
return Response(ShiftSwapRequestSerializer(shift_swap).data, status=status.HTTP_200_OK)
serialized_shift_swap = self._do_take(benefactor=request.user)
return Response(serialized_shift_swap, status=status.HTTP_200_OK)

View file

@ -0,0 +1,413 @@
from unittest.mock import patch
import pytest
from django.urls import reverse
from django.utils import timezone
from rest_framework import status
from rest_framework.test import APIClient
from apps.schedules.models import CustomOnCallShift, OnCallScheduleWeb, ShiftSwapRequest
from common.api_helpers.utils import serialize_datetime_as_utc_timestamp
from common.insight_log import EntityEvent
@pytest.fixture
def setup_swap(make_user_for_organization, make_schedule, make_shift_swap_request):
def _setup_swap(organization, **kwargs):
user = make_user_for_organization(organization)
schedule = make_schedule(organization, schedule_class=OnCallScheduleWeb)
today = timezone.now().replace(hour=0, minute=0, second=0, microsecond=0)
tomorrow = today + timezone.timedelta(days=1)
two_days_from_now = tomorrow + timezone.timedelta(days=1)
swap = make_shift_swap_request(schedule, user, swap_start=tomorrow, swap_end=two_days_from_now)
return swap
return _setup_swap
def assert_swap_response(response, request_data):
response_data = response.json()
swap = ShiftSwapRequest.objects.get(public_primary_key=response_data["id"])
# check description
assert swap.description == response_data["description"]
if "description" in request_data:
assert response_data["description"] == request_data["description"]
# check datetime fields
for field in ("swap_start", "swap_end"):
db_value = serialize_datetime_as_utc_timestamp(getattr(swap, field))
assert db_value == response_data[field]
if field in request_data:
assert db_value == request_data[field]
# check FK fields
for field in ("schedule", "beneficiary", "benefactor"):
value = response_data[field]
if value:
assert getattr(swap, field).public_primary_key == response_data[field]
else:
assert getattr(swap, field) is None
if field in request_data:
assert value == request_data[field]
@pytest.mark.django_db
def test_list_filters(
make_organization_and_user_with_token,
make_user_for_organization,
make_schedule,
make_shift_swap_request,
):
organization, user, token = make_organization_and_user_with_token()
user2 = make_user_for_organization(organization)
schedule1 = make_schedule(organization, schedule_class=OnCallScheduleWeb)
schedule2 = make_schedule(organization, schedule_class=OnCallScheduleWeb)
today = timezone.now().replace(hour=0, minute=0, second=0, microsecond=0)
yesterday = today - timezone.timedelta(days=1)
tomorrow = today + timezone.timedelta(days=1)
two_days_from_now = tomorrow + timezone.timedelta(days=1)
# open
swap1 = make_shift_swap_request(schedule1, user, swap_start=tomorrow, swap_end=two_days_from_now)
# past due
swap2 = make_shift_swap_request(schedule1, user2, swap_start=yesterday, swap_end=today)
# past due / in-progress
swap3 = make_shift_swap_request(schedule2, user2, swap_start=today, swap_end=tomorrow)
# taken
swap4 = make_shift_swap_request(schedule2, user2, swap_start=tomorrow, swap_end=two_days_from_now, benefactor=user)
def assert_expected(response, expected):
assert response.status_code == status.HTTP_200_OK
returned = [s["id"] for s in response.json().get("results", [])]
assert returned == [s.public_primary_key for s in expected]
client = APIClient()
base_url = reverse("api-public:shift_swap-list")
url = base_url
response = client.get(url, format="json", HTTP_AUTHORIZATION=f"{token}")
assert response.status_code == status.HTTP_200_OK
assert_expected(response, (swap1, swap4))
url = base_url + f"?schedule_id={schedule1.public_primary_key}"
response = client.get(url, format="json", HTTP_AUTHORIZATION=f"{token}")
assert response.status_code == status.HTTP_200_OK
assert_expected(response, (swap1,))
url = base_url + "?open_only=true"
response = client.get(url, format="json", HTTP_AUTHORIZATION=f"{token}")
assert response.status_code == status.HTTP_200_OK
assert_expected(response, (swap1,))
starting_after = serialize_datetime_as_utc_timestamp(yesterday)
url = base_url + f"?beneficiary={user2.public_primary_key}&starting_after={starting_after}"
response = client.get(url, format="json", HTTP_AUTHORIZATION=f"{token}")
assert response.status_code == status.HTTP_200_OK
assert_expected(response, (swap2, swap3, swap4))
url = base_url + f"?benefactor={user.public_primary_key}"
response = client.get(url, format="json", HTTP_AUTHORIZATION=f"{token}")
assert response.status_code == status.HTTP_200_OK
assert_expected(response, (swap4,))
@patch("apps.api.views.shift_swap.write_resource_insight_log")
@patch("apps.api.views.shift_swap.create_shift_swap_request_message")
@pytest.mark.django_db
def test_create(
mock_create_shift_swap_request_message,
mock_write_resource_insight_log,
make_organization_and_user_with_token,
make_user_for_organization,
make_schedule,
):
organization, user, token = make_organization_and_user_with_token()
another_user = make_user_for_organization(organization)
schedule = make_schedule(organization, schedule_class=OnCallScheduleWeb)
today = timezone.now().replace(hour=0, minute=0, second=0, microsecond=0)
tomorrow = today + timezone.timedelta(days=1)
two_days_from_now = tomorrow + timezone.timedelta(days=1)
data = {
"schedule": schedule.public_primary_key,
"description": "Taking a few days off",
"swap_start": serialize_datetime_as_utc_timestamp(tomorrow),
"swap_end": serialize_datetime_as_utc_timestamp(two_days_from_now),
"beneficiary": another_user.public_primary_key,
}
client = APIClient()
url = reverse("api-public:shift_swap-list")
response = client.post(url, data, format="json", HTTP_AUTHORIZATION=f"{token}")
assert response.status_code == status.HTTP_201_CREATED
assert_swap_response(response, data)
ssr = ShiftSwapRequest.objects.get(public_primary_key=response.json()["id"])
mock_write_resource_insight_log.assert_called_once_with(instance=ssr, author=user, event=EntityEvent.CREATED)
mock_create_shift_swap_request_message.apply_async.assert_called_once_with((ssr.pk,))
@pytest.mark.django_db
def test_create_requires_beneficiary(
make_organization_and_user_with_token,
make_schedule,
):
organization, user, token = make_organization_and_user_with_token()
schedule = make_schedule(organization, schedule_class=OnCallScheduleWeb)
today = timezone.now().replace(hour=0, minute=0, second=0, microsecond=0)
tomorrow = today + timezone.timedelta(days=1)
two_days_from_now = tomorrow + timezone.timedelta(days=1)
data = {
"schedule": schedule.public_primary_key,
"description": "Taking a few days off",
"swap_start": serialize_datetime_as_utc_timestamp(tomorrow),
"swap_end": serialize_datetime_as_utc_timestamp(two_days_from_now),
}
client = APIClient()
url = reverse("api-public:shift_swap-list")
response = client.post(url, data, format="json", HTTP_AUTHORIZATION=f"{token}")
assert response.status_code == status.HTTP_400_BAD_REQUEST
assert ShiftSwapRequest.objects.count() == 0
@patch("apps.api.views.shift_swap.write_resource_insight_log")
@patch("apps.api.views.shift_swap.update_shift_swap_request_message")
@pytest.mark.django_db
def test_update(
mock_update_shift_swap_request_message,
mock_write_resource_insight_log,
make_organization_and_user_with_token,
setup_swap,
):
organization, user, token = make_organization_and_user_with_token()
swap = setup_swap(organization)
assert swap.description is None
insights_log_prev_state = swap.insight_logs_serialized
data = {
"description": "Taking a few days off",
"schedule": swap.schedule.public_primary_key,
"swap_start": serialize_datetime_as_utc_timestamp(swap.swap_start),
"swap_end": serialize_datetime_as_utc_timestamp(swap.swap_end),
}
client = APIClient()
url = reverse("api-public:shift_swap-detail", kwargs={"pk": swap.public_primary_key})
response = client.put(url, data, format="json", HTTP_AUTHORIZATION=f"{token}")
assert response.status_code == status.HTTP_200_OK
assert_swap_response(response, data)
swap.refresh_from_db()
mock_write_resource_insight_log.assert_called_once_with(
instance=swap,
author=user,
event=EntityEvent.UPDATED,
prev_state=insights_log_prev_state,
new_state=swap.insight_logs_serialized,
)
mock_update_shift_swap_request_message.apply_async.assert_called_once_with((swap.pk,))
@patch("apps.api.views.shift_swap.write_resource_insight_log")
@patch("apps.api.views.shift_swap.update_shift_swap_request_message")
@pytest.mark.django_db
def test_partial_update(
mock_update_shift_swap_request_message,
mock_write_resource_insight_log,
make_organization_and_user_with_token,
setup_swap,
):
organization, user, token = make_organization_and_user_with_token()
swap = setup_swap(organization)
assert swap.description is None
insights_log_prev_state = swap.insight_logs_serialized
data = {"description": "Taking a few days off"}
client = APIClient()
url = reverse("api-public:shift_swap-detail", kwargs={"pk": swap.public_primary_key})
response = client.patch(url, data, format="json", HTTP_AUTHORIZATION=f"{token}")
assert response.status_code == status.HTTP_200_OK
assert_swap_response(response, data)
swap.refresh_from_db()
mock_write_resource_insight_log.assert_called_once_with(
instance=swap,
author=user,
event=EntityEvent.UPDATED,
prev_state=insights_log_prev_state,
new_state=swap.insight_logs_serialized,
)
mock_update_shift_swap_request_message.apply_async.assert_called_once_with((swap.pk,))
@pytest.mark.django_db
def test_details(
make_organization_and_user_with_token,
make_on_call_shift,
setup_swap,
):
organization, _, token = make_organization_and_user_with_token()
swap = setup_swap(organization)
schedule = swap.schedule
user = swap.beneficiary
today = timezone.now().replace(hour=0, minute=0, second=0, microsecond=0)
start = today + timezone.timedelta(days=1)
duration = timezone.timedelta(hours=8)
data = {
"start": start,
"rotation_start": start,
"duration": duration,
"priority_level": 1,
"frequency": CustomOnCallShift.FREQUENCY_DAILY,
"schedule": schedule,
}
on_call_shift = make_on_call_shift(
organization=organization, shift_type=CustomOnCallShift.TYPE_ROLLING_USERS_EVENT, **data
)
on_call_shift.add_rolling_users([[user]])
client = APIClient()
url = reverse("api-public:shift_swap-detail", kwargs={"pk": swap.public_primary_key})
response = client.get(url, HTTP_AUTHORIZATION=f"{token}")
assert response.status_code == status.HTTP_200_OK
assert_swap_response(response, {})
# include involved shifts information
shifts_data = response.json()["shifts"]
assert len(shifts_data) == 1
expected = [
# start, end, user, swap request ID
(
start.strftime("%Y-%m-%dT%H:%M:%SZ"),
(start + duration).strftime("%Y-%m-%dT%H:%M:%SZ"),
user.public_primary_key,
swap.public_primary_key,
),
]
returned_events = [
(e["start"], e["end"], e["users"][0]["pk"], e["users"][0]["swap_request"]["pk"]) for e in shifts_data
]
assert returned_events == expected
@patch("apps.api.views.shift_swap.write_resource_insight_log")
@patch("apps.api.views.shift_swap.update_shift_swap_request_message")
@pytest.mark.django_db
def test_delete(
mock_update_shift_swap_request_message,
mock_write_resource_insight_log,
make_organization_and_user_with_token,
setup_swap,
):
organization, user, token = make_organization_and_user_with_token()
swap = setup_swap(organization)
client = APIClient()
url = reverse("api-public:shift_swap-detail", kwargs={"pk": swap.public_primary_key})
response = client.delete(url, HTTP_AUTHORIZATION=f"{token}")
assert response.status_code == status.HTTP_204_NO_CONTENT
response = client.get(url, HTTP_AUTHORIZATION=f"{token}")
assert response.status_code == status.HTTP_404_NOT_FOUND
mock_write_resource_insight_log.assert_called_once_with(
instance=swap,
author=user,
event=EntityEvent.DELETED,
)
mock_update_shift_swap_request_message.apply_async.assert_called_once_with((swap.pk,))
@pytest.mark.django_db
def test_take(
make_organization_and_user_with_token,
make_user_for_organization,
setup_swap,
):
organization, user, token = make_organization_and_user_with_token()
another_user = make_user_for_organization(organization)
swap = setup_swap(organization)
client = APIClient()
url = reverse("api-public:shift_swap-take", kwargs={"pk": swap.public_primary_key})
data = {"benefactor": another_user.public_primary_key}
response = client.post(url, data, format="json", HTTP_AUTHORIZATION=f"{token}")
assert response.status_code == status.HTTP_200_OK
assert_swap_response(response, data)
swap.refresh_from_db()
assert swap.status == ShiftSwapRequest.Statuses.TAKEN
assert swap.benefactor == another_user
@pytest.mark.django_db
def test_take_requires_benefactor(
make_organization_and_user_with_token,
setup_swap,
):
organization, user, token = make_organization_and_user_with_token()
swap = setup_swap(organization)
client = APIClient()
url = reverse("api-public:shift_swap-take", kwargs={"pk": swap.public_primary_key})
data = {}
response = client.post(url, data, format="json", HTTP_AUTHORIZATION=f"{token}")
assert response.status_code == status.HTTP_400_BAD_REQUEST
swap.refresh_from_db()
assert swap.status == ShiftSwapRequest.Statuses.OPEN
assert swap.benefactor is None
@pytest.mark.django_db
def test_take_errors(
make_organization_and_user_with_token,
make_user_for_organization,
setup_swap,
):
organization, user, token = make_organization_and_user_with_token()
another_user = make_user_for_organization(organization)
swap = setup_swap(organization)
client = APIClient()
url = reverse("api-public:shift_swap-take", kwargs={"pk": swap.public_primary_key})
# same user taking the swap
data = {"benefactor": swap.beneficiary.public_primary_key}
response = client.post(url, data, format="json", HTTP_AUTHORIZATION=f"{token}")
assert response.status_code == status.HTTP_400_BAD_REQUEST
# already taken
swap.take(another_user)
data = {"benefactor": another_user.public_primary_key}
response = client.post(url, data, format="json", HTTP_AUTHORIZATION=f"{token}")
assert response.status_code == status.HTTP_400_BAD_REQUEST
# deleted
swap = setup_swap(organization)
swap.delete()
data = {"benefactor": another_user.public_primary_key}
response = client.post(url, data, format="json", HTTP_AUTHORIZATION=f"{token}")
assert response.status_code == status.HTTP_400_BAD_REQUEST
# past due
swap = setup_swap(organization)
swap.swap_start = timezone.now() - timezone.timedelta(days=2)
swap.save()
data = {"benefactor": another_user.public_primary_key}
response = client.post(url, data, format="json", HTTP_AUTHORIZATION=f"{token}")
assert response.status_code == status.HTTP_400_BAD_REQUEST

View file

@ -25,6 +25,7 @@ router.register(r"actions", views.ActionView, basename="actions")
router.register(r"user_groups", views.UserGroupView, basename="user_groups")
router.register(r"on_call_shifts", views.CustomOnCallShiftView, basename="on_call_shifts")
router.register(r"teams", views.TeamView, basename="teams")
router.register(r"shift_swaps", views.ShiftSwapViewSet, basename="shift_swap")
urlpatterns = [

View file

@ -12,6 +12,7 @@ from .phone_notifications import MakeCallView, SendSMSView # noqa: F401
from .resolution_notes import ResolutionNoteView # noqa: F401
from .routes import ChannelFilterView # noqa: F401
from .schedules import OnCallScheduleChannelView # noqa: F401
from .shift_swap import ShiftSwapViewSet # noqa: F401
from .slack_channels import SlackChannelView # noqa: F401
from .teams import TeamView # noqa: F401
from .user_groups import UserGroupView # noqa: F401

View file

@ -0,0 +1,97 @@
import logging
from django.utils import timezone
from rest_framework import status
from rest_framework.decorators import action
from rest_framework.exceptions import NotFound
from rest_framework.permissions import IsAuthenticated
from rest_framework.response import Response
from rest_framework.serializers import BaseSerializer
from apps.api.permissions import AuthenticatedRequest
from apps.api.views.shift_swap import BaseShiftSwapViewSet
from apps.auth_token.auth import ApiTokenAuthentication
from apps.public_api.throttlers.user_throttle import UserThrottle
from apps.schedules.models import ShiftSwapRequest
from apps.user_management.models import User
from common.api_helpers.custom_fields import TimeZoneAwareDatetimeField
from common.api_helpers.exceptions import BadRequest
from common.api_helpers.mixins import RateLimitHeadersMixin
logger = logging.getLogger(__name__)
class ShiftSwapViewSet(RateLimitHeadersMixin, BaseShiftSwapViewSet):
# set authentication and permission classes
authentication_classes = (ApiTokenAuthentication,)
permission_classes = (IsAuthenticated,)
# public API customizations
throttle_classes = [UserThrottle]
def get_queryset(self):
schedule_id = self.request.query_params.get("schedule_id", None)
beneficiary = self.request.query_params.get("beneficiary", None)
benefactor = self.request.query_params.get("benefactor", None)
starting_after = self.request.query_params.get("starting_after", None)
open_only = self.request.query_params.get("open_only", "false") == "true"
now = timezone.now()
if starting_after:
f = TimeZoneAwareDatetimeField()
# trigger datetime format validation
# will raise ValidationError if invalid timestamp is provided
starting_after = f.to_internal_value(starting_after)
else:
starting_after = now
# base queryset filters by organization
queryset = super().get_queryset()
queryset = queryset.filter(swap_start__gte=starting_after)
if schedule_id:
queryset = queryset.filter(schedule__public_primary_key=schedule_id)
if beneficiary:
queryset = queryset.filter(beneficiary__public_primary_key=beneficiary)
if benefactor:
queryset = queryset.filter(benefactor__public_primary_key=benefactor)
if benefactor:
queryset = queryset.filter(benefactor__public_primary_key=benefactor)
if open_only:
queryset = queryset.filter(benefactor__isnull=True, deleted_at__isnull=True, swap_start__gt=now)
return queryset.order_by("swap_start")
def get_object(self):
public_primary_key = self.kwargs["pk"]
try:
return self.get_queryset().get(public_primary_key=public_primary_key)
except ShiftSwapRequest.DoesNotExist:
raise NotFound
def _get_user(self, field_name: str):
"""Require and return user from ID given by field_name."""
user_pk = self.request.data.pop(field_name, None)
if not user_pk:
raise BadRequest(detail=f"{field_name} user ID is required")
try:
user = User.objects.get(organization=self.request.auth.organization, public_primary_key=user_pk)
except User.DoesNotExist:
raise BadRequest(detail=f"Invalid {field_name} user ID")
return user
def perform_create(self, serializer: BaseSerializer[ShiftSwapRequest]) -> None:
beneficiary = self._get_user("beneficiary")
self._do_create(beneficiary=beneficiary, serializer=serializer)
@action(methods=["post"], detail=True)
def take(self, request: AuthenticatedRequest, pk: str) -> Response:
# check the swap request exists and it's accessible
self.get_object()
benefactor = self._get_user("benefactor")
serialized_shift_swap = self._do_take(benefactor=benefactor)
return Response(serialized_shift_swap, status=status.HTTP_200_OK)