diff --git a/CHANGELOG.md b/CHANGELOG.md index 61b10cff..c5d90b53 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,6 +22,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Change mobile shift notifications title and subtitle by @imtoori ([#2288](https://github.com/grafana/oncall/pull/2288)) +## Fixed + +- Fix duplicate orders for user notification policies by @vadimkerr ([#2278](https://github.com/grafana/oncall/pull/2278)) + ## v1.2.45 (2023-06-19) ### Changed diff --git a/docker-compose-developer.yml b/docker-compose-developer.yml index 69ab2e3b..c44633e4 100644 --- a/docker-compose-developer.yml +++ b/docker-compose-developer.yml @@ -208,7 +208,7 @@ services: container_name: mysql labels: *oncall-labels image: mysql:8.0.32 - command: --default-authentication-plugin=mysql_native_password --character-set-server=utf8mb4 --collation-server=utf8mb4_unicode_ci + command: --default-authentication-plugin=mysql_native_password --character-set-server=utf8mb4 --collation-server=utf8mb4_unicode_ci --max_connections=1024 restart: always environment: MYSQL_ROOT_PASSWORD: empty diff --git a/engine/apps/api/serializers/user_notification_policy.py b/engine/apps/api/serializers/user_notification_policy.py index 79eb845f..ba3e7172 100644 --- a/engine/apps/api/serializers/user_notification_policy.py +++ b/engine/apps/api/serializers/user_notification_policy.py @@ -7,7 +7,7 @@ from apps.base.models import UserNotificationPolicy from apps.base.models.user_notification_policy import NotificationChannelAPIOptions from apps.user_management.models import User from common.api_helpers.custom_fields import OrganizationFilteredPrimaryKeyRelatedField -from common.api_helpers.exceptions import BadRequest, Forbidden +from common.api_helpers.exceptions import Forbidden from common.api_helpers.mixins import EagerLoadingMixin @@ -34,6 +34,7 @@ class UserNotificationPolicyBaseSerializer(EagerLoadingMixin, serializers.ModelS class Meta: model = UserNotificationPolicy fields = ["id", "step", "order", "notify_by", "wait_delay", "important", "user"] + read_only_fields = ["order"] def to_internal_value(self, data): if data.get("wait_delay", None): @@ -67,7 +68,6 @@ class UserNotificationPolicyBaseSerializer(EagerLoadingMixin, serializers.ModelS class UserNotificationPolicySerializer(UserNotificationPolicyBaseSerializer): - prev_step = serializers.CharField(required=False, write_only=True, allow_null=True) user = OrganizationFilteredPrimaryKeyRelatedField( queryset=User.objects, required=False, @@ -80,36 +80,16 @@ class UserNotificationPolicySerializer(UserNotificationPolicyBaseSerializer): default=NotificationChannelAPIOptions.DEFAULT_NOTIFICATION_CHANNEL, ) - class Meta(UserNotificationPolicyBaseSerializer.Meta): - fields = [*UserNotificationPolicyBaseSerializer.Meta.fields, "prev_step"] - read_only_fields = ("order",) - def create(self, validated_data): - prev_step = validated_data.pop("prev_step", None) - - user = validated_data.get("user") + user = validated_data.get("user") or self.context["request"].user organization = self.context["request"].auth.organization - if not user: - user = self.context["request"].user - self_or_admin = user.self_or_admin(user_to_check=self.context["request"].user, organization=organization) if not self_or_admin: raise Forbidden() - if prev_step is not None: - try: - prev_step = UserNotificationPolicy.objects.get(public_primary_key=prev_step) - except UserNotificationPolicy.DoesNotExist: - raise BadRequest(detail="Prev step does not exist") - if prev_step.user != user or prev_step.important != validated_data.get("important", False): - raise BadRequest(detail="UserNotificationPolicy can be created only with the same user and importance") - instance = UserNotificationPolicy.objects.create(**validated_data) - instance.to(prev_step.order + 1) - return instance - else: - instance = UserNotificationPolicy.objects.create(**validated_data) - return instance + instance = UserNotificationPolicy.objects.create(**validated_data) + return instance class UserNotificationPolicyUpdateSerializer(UserNotificationPolicyBaseSerializer): diff --git a/engine/apps/api/tests/test_user_notification_policy.py b/engine/apps/api/tests/test_user_notification_policy.py index 3cda1f0d..996775cc 100644 --- a/engine/apps/api/tests/test_user_notification_policy.py +++ b/engine/apps/api/tests/test_user_notification_policy.py @@ -110,7 +110,7 @@ def test_user_cant_create_notification_policy_for_user( @pytest.mark.django_db -def test_create_notification_policy_from_step( +def test_create_notification_policy_order_is_ignored( user_notification_policy_internal_api_setup, make_user_auth_headers, ): @@ -121,7 +121,7 @@ def test_create_notification_policy_from_step( url = reverse("api-internal:notification_policy-list") data = { - "prev_step": wait_notification_step.public_primary_key, + "position": 2023, "step": UserNotificationPolicy.Step.NOTIFY, "notify_by": UserNotificationPolicy.NotificationChannel.SLACK, "wait_delay": None, @@ -130,26 +130,19 @@ def test_create_notification_policy_from_step( } response = client.post(url, data, format="json", **make_user_auth_headers(admin, token)) assert response.status_code == status.HTTP_201_CREATED - assert response.data["order"] == 1 + assert response.data["order"] == 2 @pytest.mark.django_db -def test_create_invalid_notification_policy(user_notification_policy_internal_api_setup, make_user_auth_headers): +def test_move_to_position_position_error(user_notification_policy_internal_api_setup, make_user_auth_headers): token, steps, users = user_notification_policy_internal_api_setup - wait_notification_step, _, _, _ = steps admin, _ = users + step = steps[0] client = APIClient() - url = reverse("api-internal:notification_policy-list") + url = reverse("api-internal:notification_policy-move-to-position", kwargs={"pk": step.public_primary_key}) - data = { - "prev_step": wait_notification_step.public_primary_key, - "step": UserNotificationPolicy.Step.NOTIFY, - "notify_by": UserNotificationPolicy.NotificationChannel.SLACK, - "wait_delay": None, - "important": True, - "user": admin.public_primary_key, - } - response = client.post(url, data, format="json", **make_user_auth_headers(admin, token)) + # position value only can be 0 or 1 for this test setup, because there are only 2 steps + response = client.put(f"{url}?position=2", content_type="application/json", **make_user_auth_headers(admin, token)) assert response.status_code == status.HTTP_400_BAD_REQUEST @@ -221,7 +214,7 @@ def test_admin_can_move_user_step(user_notification_policy_internal_api_setup, m "api-internal:notification_policy-move-to-position", kwargs={"pk": second_user_step.public_primary_key} ) - response = client.put(f"{url}?position=1", content_type="application/json", **make_user_auth_headers(admin, token)) + response = client.put(f"{url}?position=0", content_type="application/json", **make_user_auth_headers(admin, token)) assert response.status_code == status.HTTP_200_OK diff --git a/engine/apps/api/views/user_notification_policy.py b/engine/apps/api/views/user_notification_policy.py index e05fc121..a7efb87f 100644 --- a/engine/apps/api/views/user_notification_policy.py +++ b/engine/apps/api/views/user_notification_policy.py @@ -142,7 +142,12 @@ class UserNotificationPolicyView(UpdateSerializerMixin, ModelViewSet): def move_to_position(self, request, pk): instance = self.get_object() position = get_move_to_position_param(request) - instance.to(position) + + try: + instance.to_index(position) + except IndexError: + raise BadRequest(detail="Invalid position") + return Response(status=status.HTTP_200_OK) @action(detail=False, methods=["get"]) diff --git a/engine/apps/base/migrations/0004_auto_20230616_1510.py b/engine/apps/base/migrations/0004_auto_20230616_1510.py new file mode 100644 index 00000000..09878106 --- /dev/null +++ b/engine/apps/base/migrations/0004_auto_20230616_1510.py @@ -0,0 +1,50 @@ +# Generated by Django 3.2.19 on 2023-06-16 15:10 + +from django.db import migrations, models +from django.db.models import Count + +from common.database import get_random_readonly_database_key_if_present_otherwise_default +import django_migration_linter as linter + + +def fix_duplicate_order_user_notification_policy(apps, schema_editor): + UserNotificationPolicy = apps.get_model('base', 'UserNotificationPolicy') + + # it should be safe to use a readonly database because duplicates are pretty infrequent + db = get_random_readonly_database_key_if_present_otherwise_default() + + # find all (user_id, important, order) tuples that have more than one entry (meaning duplicates) + items_with_duplicate_orders = UserNotificationPolicy.objects.using(db).values( + "user_id", "important", "order" + ).annotate(count=Count("order")).order_by().filter(count__gt=1) # use order_by() to reset any existing ordering + + # make sure we don't fix the same (user_id, important) pair more than once + values_to_fix = set((item["user_id"], item["important"]) for item in items_with_duplicate_orders) + + for user_id, important in values_to_fix: + policies = UserNotificationPolicy.objects.filter(user_id=user_id, important=important).order_by("order", "id") + # assign correct sequential order for each policy starting from 0 + for idx, policy in enumerate(policies): + policy.order = idx + UserNotificationPolicy.objects.bulk_update(policies, fields=["order"]) + + +class Migration(migrations.Migration): + + dependencies = [ + ('base', '0003_delete_organizationlogrecord'), + ] + + operations = [ + linter.IgnoreMigration(), # adding a unique constraint after fixing duplicates should be fine + migrations.AlterField( + model_name='usernotificationpolicy', + name='order', + field=models.PositiveIntegerField(db_index=True, editable=False, null=True), + ), + migrations.RunPython(fix_duplicate_order_user_notification_policy, migrations.RunPython.noop), + migrations.AddConstraint( + model_name='usernotificationpolicy', + constraint=models.UniqueConstraint(fields=('user_id', 'important', 'order'), name='unique_user_notification_policy_order'), + ), + ] diff --git a/engine/apps/base/models/ordered_model.py b/engine/apps/base/models/ordered_model.py new file mode 100644 index 00000000..b286520d --- /dev/null +++ b/engine/apps/base/models/ordered_model.py @@ -0,0 +1,272 @@ +import random +import time +import typing +from functools import wraps + +from django.db import IntegrityError, OperationalError, models, transaction + + +def _retry(exc: typing.Type[Exception] | tuple[typing.Type[Exception], ...], max_attempts: int = 5) -> typing.Callable: + """ + A utility decorator for retrying a function on a given exception(s) up to max_attempts times. + """ + + def _retry_with_params(f): + @wraps(f) + def wrapper(*args, **kwargs): + attempts = 0 + while attempts < max_attempts: + try: + return f(*args, **kwargs) + except exc: + if attempts == max_attempts - 1: + raise + attempts += 1 + time.sleep(random.random()) + + return wrapper + + return _retry_with_params + + +Self = typing.TypeVar("Self", bound="OrderedModel") + + +class OrderedModel(models.Model): + """ + This class is intended to be used as a mixin for models that need to be ordered. + It's similar to django-ordered-model: https://github.com/django-ordered-model/django-ordered-model. + The key difference of this implementation is that it allows orders to be unique at the database level and + is designed to work correctly under concurrent load. + + Notable differences compared to django-ordered-model: + - order can be unique at the database level; + - order can temporarily be set to NULL while performing moving operations; + - instance.delete() only deletes the instance, and doesn't shift other instances' orders; + - some methods are not implemented because they're not used in the codebase; + + Example usage: + class Step(OrderedModel): + user = models.ForeignKey(User, on_delete=models.CASCADE) + order_with_respect_to = ["user_id"] # steps are ordered per user + + class Meta: + ordering = ["order"] # to make queryset ordering correct and consistent + unique_together = ["user_id", "order"] # orders are unique per user at the database level + + It's possible for orders to be non-sequential, e.g. order sequence [100, 150, 400] is totally possible and valid. + """ + + order = models.PositiveIntegerField(editable=False, db_index=True, null=True) + order_with_respect_to: list[str] = [] + + class Meta: + abstract = True + ordering = ["order"] + constraints = [ + models.UniqueConstraint(fields=["order"], name="unique_order"), + ] + + def save(self, *args, **kwargs) -> None: + if self.order is None: + self._save_no_order_provided() + else: + super().save() + + @_retry(OperationalError) # retry on deadlock + def delete(self, *args, **kwargs) -> None: + with transaction.atomic(): + # lock ordering queryset to prevent deleting instances that are used by other transactions + self._lock_ordering_queryset() + super().delete(*args, **kwargs) + + @_retry((IntegrityError, OperationalError)) # retry on duplicate order or deadlock + def _save_no_order_provided(self) -> None: + """ + Save self to DB without an order provided (e.g on creation). + Order is set to the next available order, or 0 if there are no other instances. + Example: + a = OrderedModel.objects.create() + b = OrderedModel.objects.create() + c = OrderedModel.objects.create(order=10) + d = OrderedModel.objects.create() + + assert (a.order, b.order, c.order, d.order) == (0, 1, 10, 11) + """ + with transaction.atomic(): + instances = self._lock_ordering_queryset() # lock ordering queryset to prevent reading inconsistent data + max_order = max(instance.order for instance in instances) if instances else -1 + self.order = max_order + 1 + super().save() + + @_retry(OperationalError) # retry on deadlock + def to(self, order: int) -> None: + """ + Move self to a given order, adjusting other instances' orders if necessary. + Example: + a = OrderedModel(order=1) + b = OrderedModel(order=2) + c = OrderedModel(order=3) + + a.to(3) # move the first element to the last order + assert (a.order, b.order, c.order) == (3, 1, 2) # [a, b, c] -> [b, c, a] + """ + self._validate_positive_integer(order) + with transaction.atomic(): + instances = self._lock_ordering_queryset() + self._move_instances_to_order(instances, order) + + @_retry(OperationalError) # retry on deadlock + def to_index(self, index: int) -> None: + """ + Move self to a given index, adjusting other instances' orders if necessary. + Similar with to(), but accepts an index instead of an order. + This might be handy as orders might be non-sequential, but most clients assume that they are sequential. + + Example: + a = OrderedModel(order=1) + b = OrderedModel(order=5) + c = OrderedModel(order=10) + + a.to_index(2) # move the first element to the second index (where c is) + assert (a.order, b.order, c.order) == (10, 4, 9) # [a, b, c] -> [b, c, a] + """ + self._validate_positive_integer(index) + with transaction.atomic(): + instances = self._lock_ordering_queryset() + order = instances[index].order # get order of the instance at the given index + self._move_instances_to_order(instances, order) + + def _move_instances_to_order(self, instances: list[Self], order: int) -> None: + """ + Helper method for moving self to a given order, adjusting other instances' orders if necessary. + Must be called within a transaction that locks the ordering queryset. + """ + + # Get the up-to-date instance from the database, because it might have been updated by another transaction. + try: + _self = next(instance for instance in instances if instance.pk == self.pk) + self.order = _self.order + assert self.order is not None + except StopIteration: + raise self.DoesNotExist() + + # If the order is already correct, do nothing. + if self.order == order: + return + + # Figure out instances that need to be moved and their new orders. + instances_to_move = [] + if self.order < order: + for instance in instances: + if instance.order is not None and self.order < instance.order <= order: + instance.order -= 1 + instances_to_move.append(instance) + else: + for instance in instances: + if instance.order is not None and order <= instance.order < self.order: + instance.order += 1 + instances_to_move.append(instance) + + # If there's nothing to move, just update self.order and return. + if not instances_to_move: + self.order = order + self.save(update_fields=["order"]) + return + + # Temporarily set order values to NULL to avoid unique constraint violations. + pks = [self.pk] + [instance.pk for instance in instances_to_move] + self._manager.filter(pk__in=pks).update(order=None) + + # Update orders to appropriate unique values. + self.order = order + self._manager.filter(pk__in=pks).bulk_update([self] + instances_to_move, fields=["order"]) + + @_retry(OperationalError) # retry on deadlock + def swap(self, order: int) -> None: + """ + Swap self with an instance at a given order. + Example: + a = OrderedModel(order=1) + b = OrderedModel(order=2) + c = OrderedModel(order=3) + d = OrderedModel(order=4) + + a.swap(4) # swap the first element with the last element + assert (a.order, b.order, c.order, d.order) == (4, 2, 3, 1) # [a, b, c, d] -> [d, b, c, a] + """ + self._validate_positive_integer(order) + with transaction.atomic(): + instances = self._lock_ordering_queryset() + + # Get the up-to-date instance from the database, because it might have been updated by another transaction. + try: + _self = next(instance for instance in instances if instance.pk == self.pk) + self.order = _self.order + assert self.order is not None + except StopIteration: + raise self.DoesNotExist() + + # If the order is already correct, do nothing. + if self.order == order: + return + + # Get the instance to swap with. + try: + other = next(instance for instance in instances if instance.order == order) + except StopIteration: + other = None + + # If there's no instance to swap with, just update self.order and return. + if not other: + self.order = order + self.save(update_fields=["order"]) + return + + # Temporarily set order values to NULL to avoid unique constraint violations. + self._manager.filter(pk__in=[self.pk, other.pk]).update(order=None) + + # Swap order values. + self.order, other.order = other.order, self.order + self._manager.filter(pk__in=[self.pk, other.pk]).bulk_update([self, other], fields=["order"]) + + def next(self) -> Self | None: + """ + Return the next instance in the ordering queryset, or None if there's no next instance. + Example: + a = OrderedModel(order=1) + b = OrderedModel(order=2) + + assert a.next() == b + assert b.next() is None + """ + return self._get_ordering_queryset().filter(order__gt=self.order).first() + + def max_order(self) -> int | None: + """ + Return the maximum order value in the ordering queryset or None if there are no instances. + """ + return self._get_ordering_queryset().aggregate(models.Max("order"))["order__max"] + + @staticmethod + def _validate_positive_integer(value: int | None) -> None: + if value is None or not isinstance(value, int) or value < 0: + raise ValueError("Value must be a positive integer.") + + def _get_ordering_queryset(self) -> models.QuerySet[Self]: + return self._manager.filter(**self._ordering_params) + + def _lock_ordering_queryset(self) -> list[Self]: + """ + Locks the ordering queryset with SELECT FOR UPDATE and returns the queryset as a list. + This allows to prevent concurrent updates from different transactions. + """ + return list(self._get_ordering_queryset().select_for_update().only("pk", "order")) + + @property + def _manager(self): + return self._meta.default_manager + + @property + def _ordering_params(self) -> dict[str, typing.Any]: + return {field: getattr(self, field) for field in self.order_with_respect_to} diff --git a/engine/apps/base/models/user_notification_policy.py b/engine/apps/base/models/user_notification_policy.py index b4393916..7e93af63 100644 --- a/engine/apps/base/models/user_notification_policy.py +++ b/engine/apps/base/models/user_notification_policy.py @@ -5,11 +5,11 @@ from typing import Tuple from django.conf import settings from django.core.exceptions import ValidationError from django.core.validators import MinLengthValidator -from django.db import models -from django.db.models import Q, QuerySet -from ordered_model.models import OrderedModel +from django.db import IntegrityError, models +from django.db.models import Q from apps.base.messaging import get_messaging_backends +from apps.base.models.ordered_model import OrderedModel from apps.user_management.models import User from common.exceptions import UserNotificationPolicyCouldNotBeDeleted from common.public_primary_keys import generate_public_primary_key, increase_public_primary_key_length @@ -67,9 +67,11 @@ def validate_channel_choice(value): class UserNotificationPolicyQuerySet(models.QuerySet): - def create_default_policies_for_user(self, user: User) -> "QuerySet[UserNotificationPolicy]": - model = self.model + def create_default_policies_for_user(self, user: User) -> None: + if user.notification_policies.filter(important=False).exists(): + return + model = self.model policies_to_create = ( model( user=user, @@ -81,12 +83,16 @@ class UserNotificationPolicyQuerySet(models.QuerySet): model(user=user, step=model.Step.NOTIFY, notify_by=model.NotificationChannel.PHONE_CALL, order=2), ) - super().bulk_create(policies_to_create) - return user.notification_policies.filter(important=False) + try: + super().bulk_create(policies_to_create) + except IntegrityError: + pass + + def create_important_policies_for_user(self, user: User) -> None: + if user.notification_policies.filter(important=True).exists(): + return - def create_important_policies_for_user(self, user: User) -> "QuerySet[UserNotificationPolicy]": model = self.model - policies_to_create = ( model( user=user, @@ -97,13 +103,15 @@ class UserNotificationPolicyQuerySet(models.QuerySet): ), ) - super().bulk_create(policies_to_create) - return user.notification_policies.filter(important=True) + try: + super().bulk_create(policies_to_create) + except IntegrityError: + pass class UserNotificationPolicy(OrderedModel): objects = UserNotificationPolicyQuerySet.as_manager() - order_with_respect_to = ("user", "important") + order_with_respect_to = ("user_id", "important") public_primary_key = models.CharField( max_length=20, @@ -145,6 +153,11 @@ class UserNotificationPolicy(OrderedModel): class Meta: ordering = ("order",) + constraints = [ + models.UniqueConstraint( + fields=["user_id", "important", "order"], name="unique_user_notification_policy_order" + ) + ] def __str__(self): return f"{self.pk}: {self.short_verbal}" diff --git a/engine/apps/base/tests/test_ordered_model.py b/engine/apps/base/tests/test_ordered_model.py new file mode 100644 index 00000000..b312d238 --- /dev/null +++ b/engine/apps/base/tests/test_ordered_model.py @@ -0,0 +1,404 @@ +import random +import threading + +import pytest +from django.db import models + +from apps.base.models.ordered_model import OrderedModel + + +class TestOrderedModel(OrderedModel): + test_field = models.CharField(max_length=255) + extra_field = models.IntegerField(null=True, default=None) + order_with_respect_to = ["test_field"] + + class Meta: + app_label = "base" + ordering = ["order"] + constraints = [ + models.UniqueConstraint(fields=["test_field", "order"], name="unique_test_field_order"), + ] + + +def _get_ids(): + return list(TestOrderedModel.objects.filter(test_field="test").values_list("id", flat=True)) + + +def _get_orders(): + return list(TestOrderedModel.objects.filter(test_field="test").values_list("order", flat=True)) + + +def _orders_are_sequential(): + orders = _get_orders() + return orders == list(range(len(orders))) + + +@pytest.mark.django_db +def test_ordered_model_create(): + first = TestOrderedModel.objects.create(test_field="test") + second = TestOrderedModel.objects.create(test_field="test") + + assert first.order == 0 + assert second.order == 1 + + +@pytest.mark.django_db +def test_ordered_model_delete(): + instances = [TestOrderedModel.objects.create(test_field="test") for _ in range(3)] + + instances[1].delete() + assert instances[1].pk is None + assert _get_ids() == [instances[0].id, instances[2].id] + assert _get_orders() == [0, 2] + + +@pytest.mark.django_db +def test_ordered_model_to(): + instances = [TestOrderedModel.objects.create(test_field="test") for _ in range(5)] + + def _ids(indices): + return [instances[i].id for i in indices] + + # move to the end + instances[0].to(4) + assert instances[0].order == 4 + assert _get_ids() == _ids([1, 2, 3, 4, 0]) + assert _orders_are_sequential() + + # move to the beginning + instances[0].to(0) + assert instances[0].order == 0 + assert _get_ids() == _ids([0, 1, 2, 3, 4]) + assert _orders_are_sequential() + + # move to the middle + instances[0].to(2) + assert instances[0].order == 2 + assert _get_ids() == _ids([1, 2, 0, 3, 4]) + assert _orders_are_sequential() + + # move from the middle to the end + instances[0].to(4) + assert instances[0].order == 4 + assert _get_ids() == _ids([1, 2, 3, 4, 0]) + assert _orders_are_sequential() + + # move from the end to the second position + instances[0].to(1) + assert instances[0].order == 1 + assert _get_ids() == _ids([1, 0, 2, 3, 4]) + assert _orders_are_sequential() + + # move from the second position to the beginning + instances[0].to(0) + assert instances[0].order == 0 + assert _get_ids() == _ids([0, 1, 2, 3, 4]) + assert _orders_are_sequential() + + # don't move if the order is the same + for instance in instances: + instance.to(instance.order) + assert instance.order == instance.order + assert _get_ids() == _ids([0, 1, 2, 3, 4]) + assert _orders_are_sequential() + + +@pytest.mark.django_db +def test_ordered_model_to_index(): + instances = [TestOrderedModel.objects.create(test_field="test") for _ in range(5)] + + def _ids(indices): + return [instances[i].id for i in indices] + + # move to the end + instances[0].to_index(4) + assert instances[0].order == 4 + assert _get_ids() == _ids([1, 2, 3, 4, 0]) + assert _orders_are_sequential() + + # move to the beginning + instances[0].to_index(0) + assert instances[0].order == 0 + assert _get_ids() == _ids([0, 1, 2, 3, 4]) + assert _orders_are_sequential() + + # move to the middle + instances[0].to_index(2) + assert instances[0].order == 2 + assert _get_ids() == _ids([1, 2, 0, 3, 4]) + assert _orders_are_sequential() + + # move from the middle to the end + instances[0].to_index(4) + assert instances[0].order == 4 + assert _get_ids() == _ids([1, 2, 3, 4, 0]) + assert _orders_are_sequential() + + # move from the end to the second position + instances[0].to_index(1) + assert instances[0].order == 1 + assert _get_ids() == _ids([1, 0, 2, 3, 4]) + assert _orders_are_sequential() + + # move from the second position to the beginning + instances[0].to_index(0) + assert instances[0].order == 0 + assert _get_ids() == _ids([0, 1, 2, 3, 4]) + assert _orders_are_sequential() + + # don't move if the order is the same + for instance in instances: + instance.to_index(instance.order) + assert instance.order == instance.order + assert _get_ids() == _ids([0, 1, 2, 3, 4]) + assert _orders_are_sequential() + + +@pytest.mark.django_db +def test_ordered_model_swap(): + instances = [TestOrderedModel.objects.create(test_field="test") for _ in range(5)] + + def _ids(indices): + return [instances[i].id for i in indices] + + # swap with last + instances[0].swap(4) + assert instances[0].order == 4 + assert _get_ids() == _ids([4, 1, 2, 3, 0]) + assert _orders_are_sequential() + + # swap with first + instances[0].swap(0) + assert instances[0].order == 0 + assert _get_ids() == _ids([0, 1, 2, 3, 4]) + assert _orders_are_sequential() + + # swap with middle + instances[0].swap(2) + assert instances[0].order == 2 + assert _get_ids() == _ids([2, 1, 0, 3, 4]) + assert _orders_are_sequential() + + # swap from the middle to the end + instances[0].swap(4) + assert instances[0].order == 4 + assert _get_ids() == _ids([2, 1, 4, 3, 0]) + assert _orders_are_sequential() + + # swap from the end to the second position + instances[0].swap(1) + assert instances[0].order == 1 + assert _get_ids() == _ids([2, 0, 4, 3, 1]) + assert _orders_are_sequential() + + # swap from the second position to the beginning + instances[0].swap(0) + assert instances[0].order == 0 + assert _get_ids() == _ids([0, 2, 4, 3, 1]) + assert _orders_are_sequential() + + # swap with itself + for instance in instances: + instance.refresh_from_db(fields=["order"]) + instance.swap(instance.order) + assert instance.order == instance.order + assert _get_ids() == _ids([0, 2, 4, 3, 1]) + assert _orders_are_sequential() + + +@pytest.mark.django_db +def test_order_with_respect_to_isolation(): + instances = [TestOrderedModel.objects.create(test_field="test") for _ in range(5)] + other_instances = [TestOrderedModel.objects.create(test_field="test1") for _ in range(5)] + + assert [i.order for i in instances] == [0, 1, 2, 3, 4] + assert [i.order for i in other_instances] == [0, 1, 2, 3, 4] + + assert instances[-1].next() is None + assert instances[-1].max_order() == 4 + + instances[0].to(8) + instances[1].swap(7) + + for idx, instance in enumerate(other_instances): + instance.refresh_from_db() + assert instance.order == idx + + with pytest.raises(IndexError): + instances[0].to_index(6) + + +# Tests below are for checking that concurrent operations are performed correctly. +# They are skipped by default because they might take a lot of time to run. +# It could be useful to run them manually when making changes to the code, making sure +# that the changes don't break concurrent operations. To run the tests, set SKIP_CONCURRENT to False. +SKIP_CONCURRENT = True + + +@pytest.mark.skipif(SKIP_CONCURRENT, reason="OrderedModel concurrent tests are skipped to speed up tests") +@pytest.mark.django_db(transaction=True) +def test_ordered_model_create_concurrent(): + LOOPS = 30 + THREADS = 10 + exceptions = [] + + def create(): + for loop in range(LOOPS): + try: + TestOrderedModel.objects.create(test_field="test") + except Exception as e: + exceptions.append(e) + + threads = [threading.Thread(target=create) for _ in range(THREADS)] + for thread in threads: + thread.start() + for thread in threads: + thread.join() + + assert not exceptions + assert TestOrderedModel.objects.count() == LOOPS * THREADS + assert _orders_are_sequential() + + +@pytest.mark.skipif(SKIP_CONCURRENT, reason="OrderedModel concurrent tests are skipped to speed up tests") +@pytest.mark.django_db(transaction=True) +def test_ordered_model_to_concurrent(): + THREADS = 300 + exceptions = [] + + TestOrderedModel.objects.all().delete() # clear table + instances = [TestOrderedModel.objects.create(test_field="test") for _ in range(THREADS)] + + random.seed(42) + positions = [random.randint(0, THREADS - 1) for _ in range(THREADS)] + + def to(idx): + try: + instance = instances[idx] + instance.to(positions[idx]) # swap with next + except Exception as e: + exceptions.append(e) + + threads = [threading.Thread(target=to, args=(idx,)) for idx in range(THREADS - 1)] + for thread in threads: + thread.start() + for thread in threads: + thread.join() + + # can only check that orders are still sequential and that there are no exceptions + # can't check the exact order because it changes depending on the order of execution + assert not exceptions + assert _orders_are_sequential() + + +@pytest.mark.skipif(SKIP_CONCURRENT, reason="OrderedModel concurrent tests are skipped to speed up tests") +@pytest.mark.django_db(transaction=True) +def test_ordered_model_swap_concurrent(): + THREADS = 300 + exceptions = [] + + TestOrderedModel.objects.all().delete() # clear table + instances = [TestOrderedModel.objects.create(test_field="test") for _ in range(THREADS)] + + # generate random unique orders + random.seed(42) + unique_orders = list(range(THREADS)) + random.shuffle(unique_orders) + + def swap(idx): + try: + instance = instances[idx] + instance.swap(unique_orders[idx]) + except Exception as e: + exceptions.append(e) + + threads = [threading.Thread(target=swap, args=(idx,)) for idx in range(THREADS)] + for thread in threads: + thread.start() + for thread in threads: + thread.join() + + assert not exceptions + assert _orders_are_sequential() + + # in case of unique orders, the final order is deterministic + assert list(TestOrderedModel.objects.order_by("id").values_list("order", flat=True)) == unique_orders + + +@pytest.mark.skipif(SKIP_CONCURRENT, reason="OrderedModel concurrent tests are skipped to speed up tests") +@pytest.mark.django_db(transaction=True) +def test_ordered_model_swap_non_unique_orders_concurrent(): + THREADS = 300 + exceptions = [] + + TestOrderedModel.objects.all().delete() # clear table + instances = [TestOrderedModel.objects.create(test_field="test") for _ in range(THREADS)] + + # generate random non-unique orders + random.seed(42) + positions = [random.randint(0, THREADS - 1) for _ in range(THREADS)] + + def swap(idx): + try: + instance = instances[idx] + instance.swap(positions[idx]) + except Exception as e: + exceptions.append(e) + + threads = [threading.Thread(target=swap, args=(idx,)) for idx in range(THREADS)] + for thread in threads: + thread.start() + for thread in threads: + thread.join() + + # can only check that orders are still sequential and that there are no exceptions + # can't check the exact order because it changes depending on the order of execution + assert not exceptions + assert _orders_are_sequential() + + +@pytest.mark.skipif(SKIP_CONCURRENT, reason="OrderedModel concurrent tests are skipped to speed up tests") +@pytest.mark.django_db(transaction=True) +def test_ordered_model_create_swap_and_delete_concurrent(): + """Check that create+swap, swap and delete operations are performed correctly when run concurrently.""" + + THREADS = 100 + exceptions = [] + + instances = [TestOrderedModel.objects.create(test_field="test", extra_field=idx) for idx in range(THREADS * 3)] + + def create_swap(idx): + try: + instance = TestOrderedModel.objects.create(test_field="test", extra_field=idx + 1000) + instance.swap(idx) + except Exception as e: + exceptions.append(("create_swap", e)) + + def swap(idx): + try: + instances[idx].swap(idx + 1) + except Exception as e: + exceptions.append(("swap", e)) + + def delete(idx): + try: + instances[idx].delete() + except Exception as e: + exceptions.append(("delete", e)) + + threads = [threading.Thread(target=create_swap, args=(idx,)) for idx in list(range(THREADS))] + threads += [threading.Thread(target=delete, args=(idx,)) for idx in range(THREADS)] + threads += [threading.Thread(target=swap, args=(idx,)) for idx in range(THREADS, THREADS * 2 - 1)] + + for thread in threads: + thread.start() + for thread in threads: + thread.join() + + expected_extra_field_values = list(range(1000, 1000 + THREADS)) + expected_extra_field_values += [THREADS * 2 - 1] + list(range(THREADS, THREADS * 2 - 1)) + expected_extra_field_values += [instance.extra_field for instance in instances[THREADS * 2 : THREADS * 3]] + + assert not exceptions + assert _orders_are_sequential() + assert list(TestOrderedModel.objects.values_list("extra_field", flat=True)) == expected_extra_field_values diff --git a/engine/apps/public_api/serializers/personal_notification_rules.py b/engine/apps/public_api/serializers/personal_notification_rules.py index 8d915da7..3d0e7df3 100644 --- a/engine/apps/public_api/serializers/personal_notification_rules.py +++ b/engine/apps/public_api/serializers/personal_notification_rules.py @@ -43,14 +43,16 @@ class PersonalNotificationRuleSerializer(EagerLoadingMixin, serializers.ModelSer # that is why step key is used instead of type below if "wait_delay" in validated_data and validated_data["step"] != UserNotificationPolicy.Step.WAIT: raise exceptions.ValidationError({"duration": "Duration can't be set"}) - user = validated_data.pop("user") + + # Remove "manual_order" and "order" fields from validated_data, so they are not passed to create method. + # Policies are always created at the end of the list, and then moved to the desired position by _adjust_order. manual_order = validated_data.pop("manual_order") - if not manual_order: - order = validated_data.pop("order", None) - instance = UserNotificationPolicy.objects.create(**validated_data, user=user) - self._change_position(order, instance) - else: - instance = UserNotificationPolicy.objects.create(**validated_data, user=user) + order = validated_data.pop("order", None) + + instance = UserNotificationPolicy.objects.create(**validated_data) + + if order is not None: + self._adjust_order(instance, manual_order, order, created=True) return instance @@ -117,14 +119,32 @@ class PersonalNotificationRuleSerializer(EagerLoadingMixin, serializers.ModelSer raise exceptions.ValidationError({"type": "Invalid type"}) - def _change_position(self, order, instance): - if order is not None: - if order >= 0: - instance.to(order) - elif order == -1: - instance.bottom() - else: - raise BadRequest(detail="Invalid value for position field") + @staticmethod + def _adjust_order(instance, manual_order, order, created): + # Passing order=-1 means that the policy should be moved to the end of the list. + if order == -1: + if created: + # The policy was just created, so it is already at the end of the list. + return + + order = instance.max_order() + # max_order() can't be None here because at least one instance exists – the one we are moving. + assert order is not None + + # Negative order is not allowed. + if order < 0: + raise BadRequest(detail="Invalid value for position field") + + # manual_order=True is intended for use by Terraform provider only, and is not a documented feature. + # Orders are swapped instead of moved when using Terraform, because Terraform may issue concurrent requests + # to create / update / delete multiple policies. "Move to" operation is not deterministic in this case, and + # final order of policies may be different depending on the order in which requests are processed. On the other + # hand, the result of concurrent "swap" operations is deterministic and does not depend on the order in + # which requests are processed. + if manual_order: + instance.swap(order) + else: + instance.to(order) class PersonalNotificationRuleUpdateSerializer(PersonalNotificationRuleSerializer): @@ -145,10 +165,10 @@ class PersonalNotificationRuleUpdateSerializer(PersonalNotificationRuleSerialize if "wait_delay" in validated_data and instance.step != UserNotificationPolicy.Step.WAIT: raise exceptions.ValidationError({"duration": "Duration can't be set"}) + # Remove "manual_order" and "order" fields from validated_data, so they are not passed to update method. manual_order = validated_data.pop("manual_order") - - if not manual_order: - order = validated_data.pop("order", None) - self._change_position(order, instance) + order = validated_data.pop("order", None) + if order is not None: + self._adjust_order(instance, manual_order, order, created=False) return super().update(instance, validated_data) diff --git a/engine/tox.ini b/engine/tox.ini index 7cabc843..330b6eb9 100644 --- a/engine/tox.ini +++ b/engine/tox.ini @@ -9,6 +9,6 @@ banned-modules = [pytest] # https://pytest-django.readthedocs.io/en/latest/configuring_django.html#order-of-choosing-settings # https://pytest-django.readthedocs.io/en/latest/database.html -addopts = --color=yes --showlocals +addopts = --no-migrations --color=yes --showlocals # https://pytest-django.readthedocs.io/en/latest/faq.html#my-tests-are-not-being-found-why python_files = tests.py test_*.py *_tests.py