From eada4a4355084058fd53d97f3b286cd82cad650b Mon Sep 17 00:00:00 2001 From: Vadim Stepanov Date: Wed, 21 Jun 2023 12:13:56 +0100 Subject: [PATCH] Fix duplicate orders for user notification policies (#2278) # What this PR does Fixes an issue when multiple user notification policies have duplicated order values, leading to the following unexpected behaviours: 1. Not possible to rearrange notification policies that have duplicated orders. 2. The notification system only executes the first policy from each order group. For example, if there are policies with orders `[0, 0, 0, 0]`, only the first policy will be executed, and all others will be skipped. So the user will see four policies in the UI, while only one of them will be actually executed. This PR fixes the issue by adding a unique index on `(user_id, important, order)` for `UserNotificationPolicy` model. However, it's not possible to add that unique index using the ordering library that we use due to it's implementation details. I added a new abstract Django model `OrderedModel` that's able to work with such unique indices + under concurrent load. Important info on this new `OrderedModel` abstract model: - Orders are unique on the DB level - Orders are allowed to be non-consecutive, for example order sequence `[100, 150, 400]` is valid - When deleting an instance, orders of other instances don't change. This is a notable difference from the library we use. I think it's better to only delete the instance without changing any other orders, because it reduces the number of dependencies between instances (e.g. Terraform drift will be much smaller this way if a policy is deleted via the web UI). ## Which issue(s) this PR fixes Related to https://github.com/grafana/oncall-private/issues/1680 ## Checklist - [x] Unit, integration, and e2e (if applicable) tests updated - [x] Documentation added (or `pr:no public docs` PR label added if not required) - [x] `CHANGELOG.md` updated (or `pr:no changelog` PR label added if not required) --- CHANGELOG.md | 4 + docker-compose-developer.yml | 2 +- .../serializers/user_notification_policy.py | 30 +- .../tests/test_user_notification_policy.py | 25 +- .../api/views/user_notification_policy.py | 7 +- .../migrations/0004_auto_20230616_1510.py | 50 +++ engine/apps/base/models/ordered_model.py | 272 ++++++++++++ .../base/models/user_notification_policy.py | 37 +- engine/apps/base/tests/test_ordered_model.py | 404 ++++++++++++++++++ .../personal_notification_rules.py | 58 ++- engine/tox.ini | 2 +- 11 files changed, 816 insertions(+), 75 deletions(-) create mode 100644 engine/apps/base/migrations/0004_auto_20230616_1510.py create mode 100644 engine/apps/base/models/ordered_model.py create mode 100644 engine/apps/base/tests/test_ordered_model.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 61b10cff..c5d90b53 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,6 +22,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Change mobile shift notifications title and subtitle by @imtoori ([#2288](https://github.com/grafana/oncall/pull/2288)) +## Fixed + +- Fix duplicate orders for user notification policies by @vadimkerr ([#2278](https://github.com/grafana/oncall/pull/2278)) + ## v1.2.45 (2023-06-19) ### Changed diff --git a/docker-compose-developer.yml b/docker-compose-developer.yml index 69ab2e3b..c44633e4 100644 --- a/docker-compose-developer.yml +++ b/docker-compose-developer.yml @@ -208,7 +208,7 @@ services: container_name: mysql labels: *oncall-labels image: mysql:8.0.32 - command: --default-authentication-plugin=mysql_native_password --character-set-server=utf8mb4 --collation-server=utf8mb4_unicode_ci + command: --default-authentication-plugin=mysql_native_password --character-set-server=utf8mb4 --collation-server=utf8mb4_unicode_ci --max_connections=1024 restart: always environment: MYSQL_ROOT_PASSWORD: empty diff --git a/engine/apps/api/serializers/user_notification_policy.py b/engine/apps/api/serializers/user_notification_policy.py index 79eb845f..ba3e7172 100644 --- a/engine/apps/api/serializers/user_notification_policy.py +++ b/engine/apps/api/serializers/user_notification_policy.py @@ -7,7 +7,7 @@ from apps.base.models import UserNotificationPolicy from apps.base.models.user_notification_policy import NotificationChannelAPIOptions from apps.user_management.models import User from common.api_helpers.custom_fields import OrganizationFilteredPrimaryKeyRelatedField -from common.api_helpers.exceptions import BadRequest, Forbidden +from common.api_helpers.exceptions import Forbidden from common.api_helpers.mixins import EagerLoadingMixin @@ -34,6 +34,7 @@ class UserNotificationPolicyBaseSerializer(EagerLoadingMixin, serializers.ModelS class Meta: model = UserNotificationPolicy fields = ["id", "step", "order", "notify_by", "wait_delay", "important", "user"] + read_only_fields = ["order"] def to_internal_value(self, data): if data.get("wait_delay", None): @@ -67,7 +68,6 @@ class UserNotificationPolicyBaseSerializer(EagerLoadingMixin, serializers.ModelS class UserNotificationPolicySerializer(UserNotificationPolicyBaseSerializer): - prev_step = serializers.CharField(required=False, write_only=True, allow_null=True) user = OrganizationFilteredPrimaryKeyRelatedField( queryset=User.objects, required=False, @@ -80,36 +80,16 @@ class UserNotificationPolicySerializer(UserNotificationPolicyBaseSerializer): default=NotificationChannelAPIOptions.DEFAULT_NOTIFICATION_CHANNEL, ) - class Meta(UserNotificationPolicyBaseSerializer.Meta): - fields = [*UserNotificationPolicyBaseSerializer.Meta.fields, "prev_step"] - read_only_fields = ("order",) - def create(self, validated_data): - prev_step = validated_data.pop("prev_step", None) - - user = validated_data.get("user") + user = validated_data.get("user") or self.context["request"].user organization = self.context["request"].auth.organization - if not user: - user = self.context["request"].user - self_or_admin = user.self_or_admin(user_to_check=self.context["request"].user, organization=organization) if not self_or_admin: raise Forbidden() - if prev_step is not None: - try: - prev_step = UserNotificationPolicy.objects.get(public_primary_key=prev_step) - except UserNotificationPolicy.DoesNotExist: - raise BadRequest(detail="Prev step does not exist") - if prev_step.user != user or prev_step.important != validated_data.get("important", False): - raise BadRequest(detail="UserNotificationPolicy can be created only with the same user and importance") - instance = UserNotificationPolicy.objects.create(**validated_data) - instance.to(prev_step.order + 1) - return instance - else: - instance = UserNotificationPolicy.objects.create(**validated_data) - return instance + instance = UserNotificationPolicy.objects.create(**validated_data) + return instance class UserNotificationPolicyUpdateSerializer(UserNotificationPolicyBaseSerializer): diff --git a/engine/apps/api/tests/test_user_notification_policy.py b/engine/apps/api/tests/test_user_notification_policy.py index 3cda1f0d..996775cc 100644 --- a/engine/apps/api/tests/test_user_notification_policy.py +++ b/engine/apps/api/tests/test_user_notification_policy.py @@ -110,7 +110,7 @@ def test_user_cant_create_notification_policy_for_user( @pytest.mark.django_db -def test_create_notification_policy_from_step( +def test_create_notification_policy_order_is_ignored( user_notification_policy_internal_api_setup, make_user_auth_headers, ): @@ -121,7 +121,7 @@ def test_create_notification_policy_from_step( url = reverse("api-internal:notification_policy-list") data = { - "prev_step": wait_notification_step.public_primary_key, + "position": 2023, "step": UserNotificationPolicy.Step.NOTIFY, "notify_by": UserNotificationPolicy.NotificationChannel.SLACK, "wait_delay": None, @@ -130,26 +130,19 @@ def test_create_notification_policy_from_step( } response = client.post(url, data, format="json", **make_user_auth_headers(admin, token)) assert response.status_code == status.HTTP_201_CREATED - assert response.data["order"] == 1 + assert response.data["order"] == 2 @pytest.mark.django_db -def test_create_invalid_notification_policy(user_notification_policy_internal_api_setup, make_user_auth_headers): +def test_move_to_position_position_error(user_notification_policy_internal_api_setup, make_user_auth_headers): token, steps, users = user_notification_policy_internal_api_setup - wait_notification_step, _, _, _ = steps admin, _ = users + step = steps[0] client = APIClient() - url = reverse("api-internal:notification_policy-list") + url = reverse("api-internal:notification_policy-move-to-position", kwargs={"pk": step.public_primary_key}) - data = { - "prev_step": wait_notification_step.public_primary_key, - "step": UserNotificationPolicy.Step.NOTIFY, - "notify_by": UserNotificationPolicy.NotificationChannel.SLACK, - "wait_delay": None, - "important": True, - "user": admin.public_primary_key, - } - response = client.post(url, data, format="json", **make_user_auth_headers(admin, token)) + # position value only can be 0 or 1 for this test setup, because there are only 2 steps + response = client.put(f"{url}?position=2", content_type="application/json", **make_user_auth_headers(admin, token)) assert response.status_code == status.HTTP_400_BAD_REQUEST @@ -221,7 +214,7 @@ def test_admin_can_move_user_step(user_notification_policy_internal_api_setup, m "api-internal:notification_policy-move-to-position", kwargs={"pk": second_user_step.public_primary_key} ) - response = client.put(f"{url}?position=1", content_type="application/json", **make_user_auth_headers(admin, token)) + response = client.put(f"{url}?position=0", content_type="application/json", **make_user_auth_headers(admin, token)) assert response.status_code == status.HTTP_200_OK diff --git a/engine/apps/api/views/user_notification_policy.py b/engine/apps/api/views/user_notification_policy.py index e05fc121..a7efb87f 100644 --- a/engine/apps/api/views/user_notification_policy.py +++ b/engine/apps/api/views/user_notification_policy.py @@ -142,7 +142,12 @@ class UserNotificationPolicyView(UpdateSerializerMixin, ModelViewSet): def move_to_position(self, request, pk): instance = self.get_object() position = get_move_to_position_param(request) - instance.to(position) + + try: + instance.to_index(position) + except IndexError: + raise BadRequest(detail="Invalid position") + return Response(status=status.HTTP_200_OK) @action(detail=False, methods=["get"]) diff --git a/engine/apps/base/migrations/0004_auto_20230616_1510.py b/engine/apps/base/migrations/0004_auto_20230616_1510.py new file mode 100644 index 00000000..09878106 --- /dev/null +++ b/engine/apps/base/migrations/0004_auto_20230616_1510.py @@ -0,0 +1,50 @@ +# Generated by Django 3.2.19 on 2023-06-16 15:10 + +from django.db import migrations, models +from django.db.models import Count + +from common.database import get_random_readonly_database_key_if_present_otherwise_default +import django_migration_linter as linter + + +def fix_duplicate_order_user_notification_policy(apps, schema_editor): + UserNotificationPolicy = apps.get_model('base', 'UserNotificationPolicy') + + # it should be safe to use a readonly database because duplicates are pretty infrequent + db = get_random_readonly_database_key_if_present_otherwise_default() + + # find all (user_id, important, order) tuples that have more than one entry (meaning duplicates) + items_with_duplicate_orders = UserNotificationPolicy.objects.using(db).values( + "user_id", "important", "order" + ).annotate(count=Count("order")).order_by().filter(count__gt=1) # use order_by() to reset any existing ordering + + # make sure we don't fix the same (user_id, important) pair more than once + values_to_fix = set((item["user_id"], item["important"]) for item in items_with_duplicate_orders) + + for user_id, important in values_to_fix: + policies = UserNotificationPolicy.objects.filter(user_id=user_id, important=important).order_by("order", "id") + # assign correct sequential order for each policy starting from 0 + for idx, policy in enumerate(policies): + policy.order = idx + UserNotificationPolicy.objects.bulk_update(policies, fields=["order"]) + + +class Migration(migrations.Migration): + + dependencies = [ + ('base', '0003_delete_organizationlogrecord'), + ] + + operations = [ + linter.IgnoreMigration(), # adding a unique constraint after fixing duplicates should be fine + migrations.AlterField( + model_name='usernotificationpolicy', + name='order', + field=models.PositiveIntegerField(db_index=True, editable=False, null=True), + ), + migrations.RunPython(fix_duplicate_order_user_notification_policy, migrations.RunPython.noop), + migrations.AddConstraint( + model_name='usernotificationpolicy', + constraint=models.UniqueConstraint(fields=('user_id', 'important', 'order'), name='unique_user_notification_policy_order'), + ), + ] diff --git a/engine/apps/base/models/ordered_model.py b/engine/apps/base/models/ordered_model.py new file mode 100644 index 00000000..b286520d --- /dev/null +++ b/engine/apps/base/models/ordered_model.py @@ -0,0 +1,272 @@ +import random +import time +import typing +from functools import wraps + +from django.db import IntegrityError, OperationalError, models, transaction + + +def _retry(exc: typing.Type[Exception] | tuple[typing.Type[Exception], ...], max_attempts: int = 5) -> typing.Callable: + """ + A utility decorator for retrying a function on a given exception(s) up to max_attempts times. + """ + + def _retry_with_params(f): + @wraps(f) + def wrapper(*args, **kwargs): + attempts = 0 + while attempts < max_attempts: + try: + return f(*args, **kwargs) + except exc: + if attempts == max_attempts - 1: + raise + attempts += 1 + time.sleep(random.random()) + + return wrapper + + return _retry_with_params + + +Self = typing.TypeVar("Self", bound="OrderedModel") + + +class OrderedModel(models.Model): + """ + This class is intended to be used as a mixin for models that need to be ordered. + It's similar to django-ordered-model: https://github.com/django-ordered-model/django-ordered-model. + The key difference of this implementation is that it allows orders to be unique at the database level and + is designed to work correctly under concurrent load. + + Notable differences compared to django-ordered-model: + - order can be unique at the database level; + - order can temporarily be set to NULL while performing moving operations; + - instance.delete() only deletes the instance, and doesn't shift other instances' orders; + - some methods are not implemented because they're not used in the codebase; + + Example usage: + class Step(OrderedModel): + user = models.ForeignKey(User, on_delete=models.CASCADE) + order_with_respect_to = ["user_id"] # steps are ordered per user + + class Meta: + ordering = ["order"] # to make queryset ordering correct and consistent + unique_together = ["user_id", "order"] # orders are unique per user at the database level + + It's possible for orders to be non-sequential, e.g. order sequence [100, 150, 400] is totally possible and valid. + """ + + order = models.PositiveIntegerField(editable=False, db_index=True, null=True) + order_with_respect_to: list[str] = [] + + class Meta: + abstract = True + ordering = ["order"] + constraints = [ + models.UniqueConstraint(fields=["order"], name="unique_order"), + ] + + def save(self, *args, **kwargs) -> None: + if self.order is None: + self._save_no_order_provided() + else: + super().save() + + @_retry(OperationalError) # retry on deadlock + def delete(self, *args, **kwargs) -> None: + with transaction.atomic(): + # lock ordering queryset to prevent deleting instances that are used by other transactions + self._lock_ordering_queryset() + super().delete(*args, **kwargs) + + @_retry((IntegrityError, OperationalError)) # retry on duplicate order or deadlock + def _save_no_order_provided(self) -> None: + """ + Save self to DB without an order provided (e.g on creation). + Order is set to the next available order, or 0 if there are no other instances. + Example: + a = OrderedModel.objects.create() + b = OrderedModel.objects.create() + c = OrderedModel.objects.create(order=10) + d = OrderedModel.objects.create() + + assert (a.order, b.order, c.order, d.order) == (0, 1, 10, 11) + """ + with transaction.atomic(): + instances = self._lock_ordering_queryset() # lock ordering queryset to prevent reading inconsistent data + max_order = max(instance.order for instance in instances) if instances else -1 + self.order = max_order + 1 + super().save() + + @_retry(OperationalError) # retry on deadlock + def to(self, order: int) -> None: + """ + Move self to a given order, adjusting other instances' orders if necessary. + Example: + a = OrderedModel(order=1) + b = OrderedModel(order=2) + c = OrderedModel(order=3) + + a.to(3) # move the first element to the last order + assert (a.order, b.order, c.order) == (3, 1, 2) # [a, b, c] -> [b, c, a] + """ + self._validate_positive_integer(order) + with transaction.atomic(): + instances = self._lock_ordering_queryset() + self._move_instances_to_order(instances, order) + + @_retry(OperationalError) # retry on deadlock + def to_index(self, index: int) -> None: + """ + Move self to a given index, adjusting other instances' orders if necessary. + Similar with to(), but accepts an index instead of an order. + This might be handy as orders might be non-sequential, but most clients assume that they are sequential. + + Example: + a = OrderedModel(order=1) + b = OrderedModel(order=5) + c = OrderedModel(order=10) + + a.to_index(2) # move the first element to the second index (where c is) + assert (a.order, b.order, c.order) == (10, 4, 9) # [a, b, c] -> [b, c, a] + """ + self._validate_positive_integer(index) + with transaction.atomic(): + instances = self._lock_ordering_queryset() + order = instances[index].order # get order of the instance at the given index + self._move_instances_to_order(instances, order) + + def _move_instances_to_order(self, instances: list[Self], order: int) -> None: + """ + Helper method for moving self to a given order, adjusting other instances' orders if necessary. + Must be called within a transaction that locks the ordering queryset. + """ + + # Get the up-to-date instance from the database, because it might have been updated by another transaction. + try: + _self = next(instance for instance in instances if instance.pk == self.pk) + self.order = _self.order + assert self.order is not None + except StopIteration: + raise self.DoesNotExist() + + # If the order is already correct, do nothing. + if self.order == order: + return + + # Figure out instances that need to be moved and their new orders. + instances_to_move = [] + if self.order < order: + for instance in instances: + if instance.order is not None and self.order < instance.order <= order: + instance.order -= 1 + instances_to_move.append(instance) + else: + for instance in instances: + if instance.order is not None and order <= instance.order < self.order: + instance.order += 1 + instances_to_move.append(instance) + + # If there's nothing to move, just update self.order and return. + if not instances_to_move: + self.order = order + self.save(update_fields=["order"]) + return + + # Temporarily set order values to NULL to avoid unique constraint violations. + pks = [self.pk] + [instance.pk for instance in instances_to_move] + self._manager.filter(pk__in=pks).update(order=None) + + # Update orders to appropriate unique values. + self.order = order + self._manager.filter(pk__in=pks).bulk_update([self] + instances_to_move, fields=["order"]) + + @_retry(OperationalError) # retry on deadlock + def swap(self, order: int) -> None: + """ + Swap self with an instance at a given order. + Example: + a = OrderedModel(order=1) + b = OrderedModel(order=2) + c = OrderedModel(order=3) + d = OrderedModel(order=4) + + a.swap(4) # swap the first element with the last element + assert (a.order, b.order, c.order, d.order) == (4, 2, 3, 1) # [a, b, c, d] -> [d, b, c, a] + """ + self._validate_positive_integer(order) + with transaction.atomic(): + instances = self._lock_ordering_queryset() + + # Get the up-to-date instance from the database, because it might have been updated by another transaction. + try: + _self = next(instance for instance in instances if instance.pk == self.pk) + self.order = _self.order + assert self.order is not None + except StopIteration: + raise self.DoesNotExist() + + # If the order is already correct, do nothing. + if self.order == order: + return + + # Get the instance to swap with. + try: + other = next(instance for instance in instances if instance.order == order) + except StopIteration: + other = None + + # If there's no instance to swap with, just update self.order and return. + if not other: + self.order = order + self.save(update_fields=["order"]) + return + + # Temporarily set order values to NULL to avoid unique constraint violations. + self._manager.filter(pk__in=[self.pk, other.pk]).update(order=None) + + # Swap order values. + self.order, other.order = other.order, self.order + self._manager.filter(pk__in=[self.pk, other.pk]).bulk_update([self, other], fields=["order"]) + + def next(self) -> Self | None: + """ + Return the next instance in the ordering queryset, or None if there's no next instance. + Example: + a = OrderedModel(order=1) + b = OrderedModel(order=2) + + assert a.next() == b + assert b.next() is None + """ + return self._get_ordering_queryset().filter(order__gt=self.order).first() + + def max_order(self) -> int | None: + """ + Return the maximum order value in the ordering queryset or None if there are no instances. + """ + return self._get_ordering_queryset().aggregate(models.Max("order"))["order__max"] + + @staticmethod + def _validate_positive_integer(value: int | None) -> None: + if value is None or not isinstance(value, int) or value < 0: + raise ValueError("Value must be a positive integer.") + + def _get_ordering_queryset(self) -> models.QuerySet[Self]: + return self._manager.filter(**self._ordering_params) + + def _lock_ordering_queryset(self) -> list[Self]: + """ + Locks the ordering queryset with SELECT FOR UPDATE and returns the queryset as a list. + This allows to prevent concurrent updates from different transactions. + """ + return list(self._get_ordering_queryset().select_for_update().only("pk", "order")) + + @property + def _manager(self): + return self._meta.default_manager + + @property + def _ordering_params(self) -> dict[str, typing.Any]: + return {field: getattr(self, field) for field in self.order_with_respect_to} diff --git a/engine/apps/base/models/user_notification_policy.py b/engine/apps/base/models/user_notification_policy.py index b4393916..7e93af63 100644 --- a/engine/apps/base/models/user_notification_policy.py +++ b/engine/apps/base/models/user_notification_policy.py @@ -5,11 +5,11 @@ from typing import Tuple from django.conf import settings from django.core.exceptions import ValidationError from django.core.validators import MinLengthValidator -from django.db import models -from django.db.models import Q, QuerySet -from ordered_model.models import OrderedModel +from django.db import IntegrityError, models +from django.db.models import Q from apps.base.messaging import get_messaging_backends +from apps.base.models.ordered_model import OrderedModel from apps.user_management.models import User from common.exceptions import UserNotificationPolicyCouldNotBeDeleted from common.public_primary_keys import generate_public_primary_key, increase_public_primary_key_length @@ -67,9 +67,11 @@ def validate_channel_choice(value): class UserNotificationPolicyQuerySet(models.QuerySet): - def create_default_policies_for_user(self, user: User) -> "QuerySet[UserNotificationPolicy]": - model = self.model + def create_default_policies_for_user(self, user: User) -> None: + if user.notification_policies.filter(important=False).exists(): + return + model = self.model policies_to_create = ( model( user=user, @@ -81,12 +83,16 @@ class UserNotificationPolicyQuerySet(models.QuerySet): model(user=user, step=model.Step.NOTIFY, notify_by=model.NotificationChannel.PHONE_CALL, order=2), ) - super().bulk_create(policies_to_create) - return user.notification_policies.filter(important=False) + try: + super().bulk_create(policies_to_create) + except IntegrityError: + pass + + def create_important_policies_for_user(self, user: User) -> None: + if user.notification_policies.filter(important=True).exists(): + return - def create_important_policies_for_user(self, user: User) -> "QuerySet[UserNotificationPolicy]": model = self.model - policies_to_create = ( model( user=user, @@ -97,13 +103,15 @@ class UserNotificationPolicyQuerySet(models.QuerySet): ), ) - super().bulk_create(policies_to_create) - return user.notification_policies.filter(important=True) + try: + super().bulk_create(policies_to_create) + except IntegrityError: + pass class UserNotificationPolicy(OrderedModel): objects = UserNotificationPolicyQuerySet.as_manager() - order_with_respect_to = ("user", "important") + order_with_respect_to = ("user_id", "important") public_primary_key = models.CharField( max_length=20, @@ -145,6 +153,11 @@ class UserNotificationPolicy(OrderedModel): class Meta: ordering = ("order",) + constraints = [ + models.UniqueConstraint( + fields=["user_id", "important", "order"], name="unique_user_notification_policy_order" + ) + ] def __str__(self): return f"{self.pk}: {self.short_verbal}" diff --git a/engine/apps/base/tests/test_ordered_model.py b/engine/apps/base/tests/test_ordered_model.py new file mode 100644 index 00000000..b312d238 --- /dev/null +++ b/engine/apps/base/tests/test_ordered_model.py @@ -0,0 +1,404 @@ +import random +import threading + +import pytest +from django.db import models + +from apps.base.models.ordered_model import OrderedModel + + +class TestOrderedModel(OrderedModel): + test_field = models.CharField(max_length=255) + extra_field = models.IntegerField(null=True, default=None) + order_with_respect_to = ["test_field"] + + class Meta: + app_label = "base" + ordering = ["order"] + constraints = [ + models.UniqueConstraint(fields=["test_field", "order"], name="unique_test_field_order"), + ] + + +def _get_ids(): + return list(TestOrderedModel.objects.filter(test_field="test").values_list("id", flat=True)) + + +def _get_orders(): + return list(TestOrderedModel.objects.filter(test_field="test").values_list("order", flat=True)) + + +def _orders_are_sequential(): + orders = _get_orders() + return orders == list(range(len(orders))) + + +@pytest.mark.django_db +def test_ordered_model_create(): + first = TestOrderedModel.objects.create(test_field="test") + second = TestOrderedModel.objects.create(test_field="test") + + assert first.order == 0 + assert second.order == 1 + + +@pytest.mark.django_db +def test_ordered_model_delete(): + instances = [TestOrderedModel.objects.create(test_field="test") for _ in range(3)] + + instances[1].delete() + assert instances[1].pk is None + assert _get_ids() == [instances[0].id, instances[2].id] + assert _get_orders() == [0, 2] + + +@pytest.mark.django_db +def test_ordered_model_to(): + instances = [TestOrderedModel.objects.create(test_field="test") for _ in range(5)] + + def _ids(indices): + return [instances[i].id for i in indices] + + # move to the end + instances[0].to(4) + assert instances[0].order == 4 + assert _get_ids() == _ids([1, 2, 3, 4, 0]) + assert _orders_are_sequential() + + # move to the beginning + instances[0].to(0) + assert instances[0].order == 0 + assert _get_ids() == _ids([0, 1, 2, 3, 4]) + assert _orders_are_sequential() + + # move to the middle + instances[0].to(2) + assert instances[0].order == 2 + assert _get_ids() == _ids([1, 2, 0, 3, 4]) + assert _orders_are_sequential() + + # move from the middle to the end + instances[0].to(4) + assert instances[0].order == 4 + assert _get_ids() == _ids([1, 2, 3, 4, 0]) + assert _orders_are_sequential() + + # move from the end to the second position + instances[0].to(1) + assert instances[0].order == 1 + assert _get_ids() == _ids([1, 0, 2, 3, 4]) + assert _orders_are_sequential() + + # move from the second position to the beginning + instances[0].to(0) + assert instances[0].order == 0 + assert _get_ids() == _ids([0, 1, 2, 3, 4]) + assert _orders_are_sequential() + + # don't move if the order is the same + for instance in instances: + instance.to(instance.order) + assert instance.order == instance.order + assert _get_ids() == _ids([0, 1, 2, 3, 4]) + assert _orders_are_sequential() + + +@pytest.mark.django_db +def test_ordered_model_to_index(): + instances = [TestOrderedModel.objects.create(test_field="test") for _ in range(5)] + + def _ids(indices): + return [instances[i].id for i in indices] + + # move to the end + instances[0].to_index(4) + assert instances[0].order == 4 + assert _get_ids() == _ids([1, 2, 3, 4, 0]) + assert _orders_are_sequential() + + # move to the beginning + instances[0].to_index(0) + assert instances[0].order == 0 + assert _get_ids() == _ids([0, 1, 2, 3, 4]) + assert _orders_are_sequential() + + # move to the middle + instances[0].to_index(2) + assert instances[0].order == 2 + assert _get_ids() == _ids([1, 2, 0, 3, 4]) + assert _orders_are_sequential() + + # move from the middle to the end + instances[0].to_index(4) + assert instances[0].order == 4 + assert _get_ids() == _ids([1, 2, 3, 4, 0]) + assert _orders_are_sequential() + + # move from the end to the second position + instances[0].to_index(1) + assert instances[0].order == 1 + assert _get_ids() == _ids([1, 0, 2, 3, 4]) + assert _orders_are_sequential() + + # move from the second position to the beginning + instances[0].to_index(0) + assert instances[0].order == 0 + assert _get_ids() == _ids([0, 1, 2, 3, 4]) + assert _orders_are_sequential() + + # don't move if the order is the same + for instance in instances: + instance.to_index(instance.order) + assert instance.order == instance.order + assert _get_ids() == _ids([0, 1, 2, 3, 4]) + assert _orders_are_sequential() + + +@pytest.mark.django_db +def test_ordered_model_swap(): + instances = [TestOrderedModel.objects.create(test_field="test") for _ in range(5)] + + def _ids(indices): + return [instances[i].id for i in indices] + + # swap with last + instances[0].swap(4) + assert instances[0].order == 4 + assert _get_ids() == _ids([4, 1, 2, 3, 0]) + assert _orders_are_sequential() + + # swap with first + instances[0].swap(0) + assert instances[0].order == 0 + assert _get_ids() == _ids([0, 1, 2, 3, 4]) + assert _orders_are_sequential() + + # swap with middle + instances[0].swap(2) + assert instances[0].order == 2 + assert _get_ids() == _ids([2, 1, 0, 3, 4]) + assert _orders_are_sequential() + + # swap from the middle to the end + instances[0].swap(4) + assert instances[0].order == 4 + assert _get_ids() == _ids([2, 1, 4, 3, 0]) + assert _orders_are_sequential() + + # swap from the end to the second position + instances[0].swap(1) + assert instances[0].order == 1 + assert _get_ids() == _ids([2, 0, 4, 3, 1]) + assert _orders_are_sequential() + + # swap from the second position to the beginning + instances[0].swap(0) + assert instances[0].order == 0 + assert _get_ids() == _ids([0, 2, 4, 3, 1]) + assert _orders_are_sequential() + + # swap with itself + for instance in instances: + instance.refresh_from_db(fields=["order"]) + instance.swap(instance.order) + assert instance.order == instance.order + assert _get_ids() == _ids([0, 2, 4, 3, 1]) + assert _orders_are_sequential() + + +@pytest.mark.django_db +def test_order_with_respect_to_isolation(): + instances = [TestOrderedModel.objects.create(test_field="test") for _ in range(5)] + other_instances = [TestOrderedModel.objects.create(test_field="test1") for _ in range(5)] + + assert [i.order for i in instances] == [0, 1, 2, 3, 4] + assert [i.order for i in other_instances] == [0, 1, 2, 3, 4] + + assert instances[-1].next() is None + assert instances[-1].max_order() == 4 + + instances[0].to(8) + instances[1].swap(7) + + for idx, instance in enumerate(other_instances): + instance.refresh_from_db() + assert instance.order == idx + + with pytest.raises(IndexError): + instances[0].to_index(6) + + +# Tests below are for checking that concurrent operations are performed correctly. +# They are skipped by default because they might take a lot of time to run. +# It could be useful to run them manually when making changes to the code, making sure +# that the changes don't break concurrent operations. To run the tests, set SKIP_CONCURRENT to False. +SKIP_CONCURRENT = True + + +@pytest.mark.skipif(SKIP_CONCURRENT, reason="OrderedModel concurrent tests are skipped to speed up tests") +@pytest.mark.django_db(transaction=True) +def test_ordered_model_create_concurrent(): + LOOPS = 30 + THREADS = 10 + exceptions = [] + + def create(): + for loop in range(LOOPS): + try: + TestOrderedModel.objects.create(test_field="test") + except Exception as e: + exceptions.append(e) + + threads = [threading.Thread(target=create) for _ in range(THREADS)] + for thread in threads: + thread.start() + for thread in threads: + thread.join() + + assert not exceptions + assert TestOrderedModel.objects.count() == LOOPS * THREADS + assert _orders_are_sequential() + + +@pytest.mark.skipif(SKIP_CONCURRENT, reason="OrderedModel concurrent tests are skipped to speed up tests") +@pytest.mark.django_db(transaction=True) +def test_ordered_model_to_concurrent(): + THREADS = 300 + exceptions = [] + + TestOrderedModel.objects.all().delete() # clear table + instances = [TestOrderedModel.objects.create(test_field="test") for _ in range(THREADS)] + + random.seed(42) + positions = [random.randint(0, THREADS - 1) for _ in range(THREADS)] + + def to(idx): + try: + instance = instances[idx] + instance.to(positions[idx]) # swap with next + except Exception as e: + exceptions.append(e) + + threads = [threading.Thread(target=to, args=(idx,)) for idx in range(THREADS - 1)] + for thread in threads: + thread.start() + for thread in threads: + thread.join() + + # can only check that orders are still sequential and that there are no exceptions + # can't check the exact order because it changes depending on the order of execution + assert not exceptions + assert _orders_are_sequential() + + +@pytest.mark.skipif(SKIP_CONCURRENT, reason="OrderedModel concurrent tests are skipped to speed up tests") +@pytest.mark.django_db(transaction=True) +def test_ordered_model_swap_concurrent(): + THREADS = 300 + exceptions = [] + + TestOrderedModel.objects.all().delete() # clear table + instances = [TestOrderedModel.objects.create(test_field="test") for _ in range(THREADS)] + + # generate random unique orders + random.seed(42) + unique_orders = list(range(THREADS)) + random.shuffle(unique_orders) + + def swap(idx): + try: + instance = instances[idx] + instance.swap(unique_orders[idx]) + except Exception as e: + exceptions.append(e) + + threads = [threading.Thread(target=swap, args=(idx,)) for idx in range(THREADS)] + for thread in threads: + thread.start() + for thread in threads: + thread.join() + + assert not exceptions + assert _orders_are_sequential() + + # in case of unique orders, the final order is deterministic + assert list(TestOrderedModel.objects.order_by("id").values_list("order", flat=True)) == unique_orders + + +@pytest.mark.skipif(SKIP_CONCURRENT, reason="OrderedModel concurrent tests are skipped to speed up tests") +@pytest.mark.django_db(transaction=True) +def test_ordered_model_swap_non_unique_orders_concurrent(): + THREADS = 300 + exceptions = [] + + TestOrderedModel.objects.all().delete() # clear table + instances = [TestOrderedModel.objects.create(test_field="test") for _ in range(THREADS)] + + # generate random non-unique orders + random.seed(42) + positions = [random.randint(0, THREADS - 1) for _ in range(THREADS)] + + def swap(idx): + try: + instance = instances[idx] + instance.swap(positions[idx]) + except Exception as e: + exceptions.append(e) + + threads = [threading.Thread(target=swap, args=(idx,)) for idx in range(THREADS)] + for thread in threads: + thread.start() + for thread in threads: + thread.join() + + # can only check that orders are still sequential and that there are no exceptions + # can't check the exact order because it changes depending on the order of execution + assert not exceptions + assert _orders_are_sequential() + + +@pytest.mark.skipif(SKIP_CONCURRENT, reason="OrderedModel concurrent tests are skipped to speed up tests") +@pytest.mark.django_db(transaction=True) +def test_ordered_model_create_swap_and_delete_concurrent(): + """Check that create+swap, swap and delete operations are performed correctly when run concurrently.""" + + THREADS = 100 + exceptions = [] + + instances = [TestOrderedModel.objects.create(test_field="test", extra_field=idx) for idx in range(THREADS * 3)] + + def create_swap(idx): + try: + instance = TestOrderedModel.objects.create(test_field="test", extra_field=idx + 1000) + instance.swap(idx) + except Exception as e: + exceptions.append(("create_swap", e)) + + def swap(idx): + try: + instances[idx].swap(idx + 1) + except Exception as e: + exceptions.append(("swap", e)) + + def delete(idx): + try: + instances[idx].delete() + except Exception as e: + exceptions.append(("delete", e)) + + threads = [threading.Thread(target=create_swap, args=(idx,)) for idx in list(range(THREADS))] + threads += [threading.Thread(target=delete, args=(idx,)) for idx in range(THREADS)] + threads += [threading.Thread(target=swap, args=(idx,)) for idx in range(THREADS, THREADS * 2 - 1)] + + for thread in threads: + thread.start() + for thread in threads: + thread.join() + + expected_extra_field_values = list(range(1000, 1000 + THREADS)) + expected_extra_field_values += [THREADS * 2 - 1] + list(range(THREADS, THREADS * 2 - 1)) + expected_extra_field_values += [instance.extra_field for instance in instances[THREADS * 2 : THREADS * 3]] + + assert not exceptions + assert _orders_are_sequential() + assert list(TestOrderedModel.objects.values_list("extra_field", flat=True)) == expected_extra_field_values diff --git a/engine/apps/public_api/serializers/personal_notification_rules.py b/engine/apps/public_api/serializers/personal_notification_rules.py index 8d915da7..3d0e7df3 100644 --- a/engine/apps/public_api/serializers/personal_notification_rules.py +++ b/engine/apps/public_api/serializers/personal_notification_rules.py @@ -43,14 +43,16 @@ class PersonalNotificationRuleSerializer(EagerLoadingMixin, serializers.ModelSer # that is why step key is used instead of type below if "wait_delay" in validated_data and validated_data["step"] != UserNotificationPolicy.Step.WAIT: raise exceptions.ValidationError({"duration": "Duration can't be set"}) - user = validated_data.pop("user") + + # Remove "manual_order" and "order" fields from validated_data, so they are not passed to create method. + # Policies are always created at the end of the list, and then moved to the desired position by _adjust_order. manual_order = validated_data.pop("manual_order") - if not manual_order: - order = validated_data.pop("order", None) - instance = UserNotificationPolicy.objects.create(**validated_data, user=user) - self._change_position(order, instance) - else: - instance = UserNotificationPolicy.objects.create(**validated_data, user=user) + order = validated_data.pop("order", None) + + instance = UserNotificationPolicy.objects.create(**validated_data) + + if order is not None: + self._adjust_order(instance, manual_order, order, created=True) return instance @@ -117,14 +119,32 @@ class PersonalNotificationRuleSerializer(EagerLoadingMixin, serializers.ModelSer raise exceptions.ValidationError({"type": "Invalid type"}) - def _change_position(self, order, instance): - if order is not None: - if order >= 0: - instance.to(order) - elif order == -1: - instance.bottom() - else: - raise BadRequest(detail="Invalid value for position field") + @staticmethod + def _adjust_order(instance, manual_order, order, created): + # Passing order=-1 means that the policy should be moved to the end of the list. + if order == -1: + if created: + # The policy was just created, so it is already at the end of the list. + return + + order = instance.max_order() + # max_order() can't be None here because at least one instance exists – the one we are moving. + assert order is not None + + # Negative order is not allowed. + if order < 0: + raise BadRequest(detail="Invalid value for position field") + + # manual_order=True is intended for use by Terraform provider only, and is not a documented feature. + # Orders are swapped instead of moved when using Terraform, because Terraform may issue concurrent requests + # to create / update / delete multiple policies. "Move to" operation is not deterministic in this case, and + # final order of policies may be different depending on the order in which requests are processed. On the other + # hand, the result of concurrent "swap" operations is deterministic and does not depend on the order in + # which requests are processed. + if manual_order: + instance.swap(order) + else: + instance.to(order) class PersonalNotificationRuleUpdateSerializer(PersonalNotificationRuleSerializer): @@ -145,10 +165,10 @@ class PersonalNotificationRuleUpdateSerializer(PersonalNotificationRuleSerialize if "wait_delay" in validated_data and instance.step != UserNotificationPolicy.Step.WAIT: raise exceptions.ValidationError({"duration": "Duration can't be set"}) + # Remove "manual_order" and "order" fields from validated_data, so they are not passed to update method. manual_order = validated_data.pop("manual_order") - - if not manual_order: - order = validated_data.pop("order", None) - self._change_position(order, instance) + order = validated_data.pop("order", None) + if order is not None: + self._adjust_order(instance, manual_order, order, created=False) return super().update(instance, validated_data) diff --git a/engine/tox.ini b/engine/tox.ini index 7cabc843..330b6eb9 100644 --- a/engine/tox.ini +++ b/engine/tox.ini @@ -9,6 +9,6 @@ banned-modules = [pytest] # https://pytest-django.readthedocs.io/en/latest/configuring_django.html#order-of-choosing-settings # https://pytest-django.readthedocs.io/en/latest/database.html -addopts = --color=yes --showlocals +addopts = --no-migrations --color=yes --showlocals # https://pytest-django.readthedocs.io/en/latest/faq.html#my-tests-are-not-being-found-why python_files = tests.py test_*.py *_tests.py