fix: Add rolling users validation for oncall shift API (#5050)

# What this PR does
Adds validation for rolling users param in the shift api

## Which issue(s) this PR closes
Closes [5041](https://github.com/grafana/oncall/issues/5041)

<!--
*Note*: If you want the issue to be auto-closed once the PR is merged,
change "Related to" to "Closes" in the line above.
If you have more than one GitHub issue that this PR closes, be sure to
preface
each issue link with a [closing
keyword](https://docs.github.com/en/get-started/writing-on-github/working-with-advanced-formatting/using-keywords-in-issues-and-pull-requests#linking-a-pull-request-to-an-issue).
This ensures that the issue(s) are auto-closed once the PR has been
merged.
-->

## Checklist

- [x] Unit, integration, and e2e (if applicable) tests updated
- [x] Documentation added (or `pr:no public docs` PR label added if not
required)
- [x] Added the relevant release notes label (see labels prefixed w/
`release:`). These labels dictate how your PR will
    show up in the autogenerated release notes.
This commit is contained in:
Ravishankar 2024-09-21 02:36:33 +05:30 committed by GitHub
parent 2cb8f4a24f
commit 1f209cd2bd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 155 additions and 3 deletions

View file

@ -29,7 +29,7 @@ class OnCallShiftSerializer(EagerLoadingMixin, serializers.ModelSerializer):
allow_null=True,
required=False,
child=UsersFilteredByOrganizationField(
queryset=User.objects, required=False, allow_null=True
queryset=User.objects, require_all_exist=True, required=False, allow_null=True
), # todo: filter by team?
)
updated_shift = serializers.CharField(read_only=True, allow_null=True, source="updated_shift.public_primary_key")

View file

@ -555,6 +555,48 @@ def test_update_future_on_call_shift_removing_users(
assert response.data["rolling_users"][0] == "User(s) are required"
@pytest.mark.django_db
def test_update_on_call_shift_invalid_rolling_users(
on_call_shift_internal_api_setup,
make_on_call_shift,
make_user_auth_headers,
):
token, user1, _, _, schedule = on_call_shift_internal_api_setup
client = APIClient()
start_date = (timezone.now() + timezone.timedelta(days=1)).replace(microsecond=0)
name = "Test Shift Rotation"
on_call_shift = make_on_call_shift(
schedule.organization,
shift_type=CustomOnCallShift.TYPE_ROLLING_USERS_EVENT,
schedule=schedule,
name=name,
start=start_date,
duration=timezone.timedelta(hours=1),
rotation_start=start_date,
rolling_users=[{user1.pk: user1.public_primary_key}],
)
data_to_update = {
"name": name,
"priority_level": 2,
"shift_start": start_date.strftime("%Y-%m-%dT%H:%M:%SZ"),
"shift_end": (start_date + timezone.timedelta(hours=1)).strftime("%Y-%m-%dT%H:%M:%SZ"),
"rotation_start": start_date.strftime("%Y-%m-%dT%H:%M:%SZ"),
"until": None,
"frequency": None,
"interval": None,
"by_day": None,
"rolling_users": [["fuzz"]],
}
url = reverse("api-internal:oncall_shifts-detail", kwargs={"pk": on_call_shift.public_primary_key})
response = client.put(url, data=data_to_update, format="json", **make_user_auth_headers(user1, token))
assert response.status_code == status.HTTP_400_BAD_REQUEST
assert response.json() == {"rolling_users": {"0": ["User does not exist {'fuzz'}"]}}
@pytest.mark.django_db
def test_update_started_on_call_shift(
on_call_shift_internal_api_setup,
@ -1202,6 +1244,41 @@ def test_create_on_call_shift_invalid_data_rolling_users(
assert response.data["rolling_users"][0] == "Cannot set multiple user groups for non-recurrent shifts"
@pytest.mark.django_db
def test_create_on_call_shift_invalid_rolling_users(on_call_shift_internal_api_setup, make_user_auth_headers):
token, user1, user2, _, schedule = on_call_shift_internal_api_setup
client = APIClient()
url = reverse("api-internal:oncall_shifts-list")
start_date = timezone.now().replace(microsecond=0, tzinfo=None)
data = {
"name": "Test Shift",
"type": CustomOnCallShift.TYPE_ROLLING_USERS_EVENT,
"schedule": schedule.public_primary_key,
"priority_level": 1,
"shift_start": start_date.strftime("%Y-%m-%dT%H:%M:%SZ"),
"shift_end": (start_date + timezone.timedelta(hours=2)).strftime("%Y-%m-%dT%H:%M:%SZ"),
"rotation_start": start_date.strftime("%Y-%m-%dT%H:%M:%SZ"),
"until": None,
"frequency": 1,
"interval": 1,
"by_day": [
CustomOnCallShift.ICAL_WEEKDAY_MAP[CustomOnCallShift.MONDAY],
CustomOnCallShift.ICAL_WEEKDAY_MAP[CustomOnCallShift.FRIDAY],
],
"week_start": CustomOnCallShift.ICAL_WEEKDAY_MAP[CustomOnCallShift.MONDAY],
"rolling_users": [[user1.public_primary_key], [user2.public_primary_key, "fuzz"]],
}
with patch("apps.schedules.models.CustomOnCallShift.refresh_schedule") as mock_refresh_schedule:
response = client.post(url, data, format="json", **make_user_auth_headers(user1, token))
expected_payload = {"rolling_users": {"1": ["User does not exist {'fuzz'}"]}}
assert response.status_code == status.HTTP_400_BAD_REQUEST
assert response.json() == expected_payload
mock_refresh_schedule.assert_not_called()
@pytest.mark.django_db
def test_create_on_call_shift_override_invalid_data(on_call_shift_internal_api_setup, make_user_auth_headers):
token, user1, _, _, schedule = on_call_shift_internal_api_setup

View file

@ -85,7 +85,9 @@ class CustomOnCallShiftSerializer(EagerLoadingMixin, serializers.ModelSerializer
rolling_users = RollingUsersField(
allow_null=True,
required=False,
child=UsersFilteredByOrganizationField(queryset=User.objects, required=False, allow_null=True),
child=UsersFilteredByOrganizationField(
queryset=User.objects, require_all_exist=True, required=False, allow_null=True
),
)
rotation_start = serializers.DateTimeField(required=False)

View file

@ -463,6 +463,38 @@ def test_create_on_call_shift_invalid_time_zone(make_organization_and_user_with_
assert response.json() == {"time_zone": ["Invalid timezone"]}
@pytest.mark.django_db
def test_create_on_call_shift_invalid_rolling_users(make_organization_and_user_with_token):
_, user, token = make_organization_and_user_with_token()
client = APIClient()
url = reverse("api-public:on_call_shifts-list")
start = timezone.now()
until = start + timezone.timedelta(days=30)
data = {
"team_id": None,
"name": "test name",
"type": "rolling_users",
"level": 1,
"start": start.strftime("%Y-%m-%dT%H:%M:%S"),
"rotation_start": start.strftime("%Y-%m-%dT%H:%M:%S"),
"duration": 10800,
"week_start": "MO",
"frequency": "weekly",
"interval": 2,
"until": until.strftime("%Y-%m-%dT%H:%M:%S"),
"by_day": ["MO", "WE", "FR"],
"time_zone": None,
"rolling_users": [[user.public_primary_key], ["fuzz"]],
}
response = client.post(url, data=data, format="json", HTTP_AUTHORIZATION=f"{token}")
assert response.status_code == status.HTTP_400_BAD_REQUEST
assert response.json() == {"rolling_users": {"1": ["User does not exist {'fuzz'}"]}}
@pytest.mark.django_db
def test_update_on_call_shift(make_organization_and_user_with_token, make_on_call_shift, make_schedule):
organization, user, token = make_organization_and_user_with_token()
@ -633,6 +665,35 @@ def test_update_on_call_shift_invalid_field(make_organization_and_user_with_toke
assert response.status_code == status.HTTP_400_BAD_REQUEST
@pytest.mark.django_db
def test_update_on_call_shift_invalid_rolling_users(make_organization_and_user_with_token, make_on_call_shift):
organization, user, token = make_organization_and_user_with_token()
client = APIClient()
start_date = timezone.now().replace(microsecond=0)
data = {
"start": start_date,
"rotation_start": start_date,
"duration": timezone.timedelta(seconds=7200),
"frequency": CustomOnCallShift.FREQUENCY_WEEKLY,
"interval": 2,
"by_day": ["MO", "FR"],
"rolling_users": [[user.public_primary_key]],
}
data_to_update = {"rolling_users": [[user.public_primary_key], ["fuzz"]]}
on_call_shift = make_on_call_shift(
organization=organization, shift_type=CustomOnCallShift.TYPE_ROLLING_USERS_EVENT, **data
)
url = reverse("api-public:on_call_shifts-detail", kwargs={"pk": on_call_shift.public_primary_key})
response = client.put(url, data=data_to_update, format="json", HTTP_AUTHORIZATION=f"{token}")
assert response.status_code == status.HTTP_400_BAD_REQUEST
@pytest.mark.django_db
def test_delete_on_call_shift(make_organization_and_user_with_token, make_on_call_shift):
organization, _, token = make_organization_and_user_with_token()

View file

@ -90,6 +90,7 @@ class UsersFilteredByOrganizationField(serializers.Field):
def __init__(self, **kwargs):
self.queryset = kwargs.pop("queryset", None)
self.require_all_exist = kwargs.pop("require_all_exist", False)
super().__init__(**kwargs)
def to_representation(self, value):
@ -102,7 +103,18 @@ class UsersFilteredByOrganizationField(serializers.Field):
if not request or not queryset:
return None
return queryset.filter(organization=request.user.organization, public_primary_key__in=data).distinct()
users = queryset.filter(organization=request.user.organization, public_primary_key__in=data).distinct()
users_ppk = set(u.public_primary_key for u in users)
data_set = set(data)
if not self.require_all_exist:
return users
if len(data_set) != len(users_ppk):
missing_users = data_set - users_ppk
raise ValidationError(f"User does not exist {missing_users}")
return users
class IntegrationFilteredByOrganizationField(serializers.RelatedField):