diff --git a/engine/apps/api/tests/test_oncall_shift.py b/engine/apps/api/tests/test_oncall_shift.py index 7d5ef169..b31aa598 100644 --- a/engine/apps/api/tests/test_oncall_shift.py +++ b/engine/apps/api/tests/test_oncall_shift.py @@ -1406,3 +1406,112 @@ def test_on_call_shift_preview_merge_events( if not e["is_override"] and not e["is_gap"] ] assert returned_events == expected_events + + +@pytest.mark.django_db +def test_on_call_shift_preview_update( + make_organization_and_user_with_plugin_token, + make_user_for_organization, + make_user_auth_headers, + make_schedule, + make_on_call_shift, +): + organization, user, token = make_organization_and_user_with_plugin_token() + client = APIClient() + + schedule = make_schedule( + organization, + schedule_class=OnCallScheduleWeb, + name="test_web_schedule", + ) + + now = timezone.now().replace(hour=0, minute=0, second=0, microsecond=0) + start_date = now - timezone.timedelta(days=7) + request_date = start_date + + user = make_user_for_organization(organization) + other_user = make_user_for_organization(organization) + + data = { + "start": start_date + timezone.timedelta(hours=8), + "rotation_start": start_date + timezone.timedelta(hours=8), + "duration": timezone.timedelta(hours=1), + "priority_level": 1, + "interval": 4, + "frequency": CustomOnCallShift.FREQUENCY_HOURLY, + "schedule": schedule, + } + on_call_shift = make_on_call_shift( + organization=organization, shift_type=CustomOnCallShift.TYPE_ROLLING_USERS_EVENT, **data + ) + on_call_shift.add_rolling_users([[user]]) + + url = "{}?date={}&days={}".format( + reverse("api-internal:oncall_shifts-preview"), request_date.strftime("%Y-%m-%d"), 1 + ) + shift_start = (start_date + timezone.timedelta(hours=10)).strftime("%Y-%m-%dT%H:%M:%SZ") + shift_end = (start_date + timezone.timedelta(hours=18)).strftime("%Y-%m-%dT%H:%M:%SZ") + shift_data = { + "schedule": schedule.public_primary_key, + "shift_pk": on_call_shift.public_primary_key, + "type": CustomOnCallShift.TYPE_ROLLING_USERS_EVENT, + "rotation_start": shift_start, + "shift_start": shift_start, + "shift_end": shift_end, + "rolling_users": [[other_user.public_primary_key]], + "priority_level": 1, + "frequency": CustomOnCallShift.FREQUENCY_DAILY, + } + response = client.post(url, shift_data, format="json", **make_user_auth_headers(user, token)) + assert response.status_code == status.HTTP_200_OK + + # check rotation events + rotation_events = response.json()["rotation"] + expected_rotation_events = [ + { + "calendar_type": OnCallSchedule.TYPE_ICAL_PRIMARY, + "start": shift_start, + "end": shift_end, + "all_day": False, + "is_override": False, + "is_empty": False, + "is_gap": False, + "priority_level": 1, + "missing_users": [], + "users": [{"display_name": other_user.username, "pk": other_user.public_primary_key}], + "source": "web", + } + ] + # there isn't a saved shift, we don't care/know the temp pk + _ = [r.pop("shift") for r in rotation_events] + assert rotation_events == expected_rotation_events + + # check final schedule events + final_events = response.json()["final"] + expected = ( + # start (h), duration (H), user, priority + (8, 1, user.username, 1), # 8-9 user + (10, 8, other_user.username, 1), # 10-18 other_user + ) + expected_events = [ + { + "end": (start_date + timezone.timedelta(hours=start + duration)).strftime("%Y-%m-%dT%H:%M:%SZ"), + "priority_level": priority, + "start": (start_date + timezone.timedelta(hours=start, milliseconds=1 if start == 0 else 0)).strftime( + "%Y-%m-%dT%H:%M:%SZ" + ), + "user": user, + } + for start, duration, user, priority in expected + ] + returned_events = [ + { + "end": e["end"], + "priority_level": e["priority_level"], + "start": e["start"], + "user": e["users"][0]["display_name"] if e["users"] else None, + } + for e in final_events + if not e["is_override"] and not e["is_gap"] + ] + assert returned_events == expected_events diff --git a/engine/apps/api/views/on_call_shifts.py b/engine/apps/api/views/on_call_shifts.py index 5ccce3b4..20853f7c 100644 --- a/engine/apps/api/views/on_call_shifts.py +++ b/engine/apps/api/views/on_call_shifts.py @@ -89,9 +89,12 @@ class OnCallShiftView(PublicPrimaryKeyMixin, UpdateSerializerMixin, ModelViewSet validated_data = serializer._correct_validated_data( serializer.validated_data["type"], serializer.validated_data ) + updated_shift_pk = self.request.data.get("shift_pk") shift = CustomOnCallShift(**validated_data) schedule = shift.schedule - shift_events, final_events = schedule.preview_shift(shift, user_tz, starting_date, days) + shift_events, final_events = schedule.preview_shift( + shift, user_tz, starting_date, days, updated_shift_pk=updated_shift_pk + ) data = { "rotation": shift_events, "final": final_events, diff --git a/engine/apps/schedules/models/on_call_schedule.py b/engine/apps/schedules/models/on_call_schedule.py index 1b9a4b7c..16d8a9e6 100644 --- a/engine/apps/schedules/models/on_call_schedule.py +++ b/engine/apps/schedules/models/on_call_schedule.py @@ -623,7 +623,7 @@ class OnCallScheduleWeb(OnCallSchedule): self.cached_ical_file_overrides = self._generate_ical_file_overrides() self.save(update_fields=["cached_ical_file_overrides", "prev_ical_file_overrides"]) - def preview_shift(self, custom_shift, user_tz, starting_date, days): + def preview_shift(self, custom_shift, user_tz, starting_date, days, updated_shift_pk=None): """Return unsaved rotation and final schedule preview events.""" if custom_shift.type == CustomOnCallShift.TYPE_OVERRIDE: qs = self.custom_shifts.filter(type=CustomOnCallShift.TYPE_OVERRIDE) @@ -643,7 +643,18 @@ class OnCallScheduleWeb(OnCallSchedule): except AttributeError: pass - ical_file = self._generate_ical_file_from_shifts(qs, extra_shifts=[custom_shift]) + extra_shifts = [custom_shift] + if updated_shift_pk is not None: + try: + update_shift = qs.get(public_primary_key=updated_shift_pk) + except CustomOnCallShift.DoesNotExist: + pass + else: + update_shift.until = custom_shift.rotation_start + qs = qs.exclude(public_primary_key=updated_shift_pk) + extra_shifts.append(update_shift) + + ical_file = self._generate_ical_file_from_shifts(qs, extra_shifts=extra_shifts) original_value = getattr(self, ical_attr) _invalidate_cache(self, ical_property)