Fix duplicate orders for user notification policies (#2278)

# What this PR does

Fixes an issue when multiple user notification policies have duplicated
order values, leading to the following unexpected behaviours:
1. Not possible to rearrange notification policies that have duplicated
orders.
2. The notification system only executes the first policy from each
order group. For example, if there are policies with orders `[0, 0, 0,
0]`, only the first policy will be executed, and all others will be
skipped. So the user will see four policies in the UI, while only one of
them will be actually executed.

This PR fixes the issue by adding a unique index on `(user_id,
important, order)` for `UserNotificationPolicy` model. However, it's not
possible to add that unique index using the ordering library that we use
due to it's implementation details.
I added a new abstract Django model `OrderedModel` that's able to work
with such unique indices + under concurrent load.

Important info on this new `OrderedModel` abstract model:
- Orders are unique on the DB level
- Orders are allowed to be non-consecutive, for example order sequence
`[100, 150, 400]` is valid
- When deleting an instance, orders of other instances don't change.
This is a notable difference from the library we use. I think it's
better to only delete the instance without changing any other orders,
because it reduces the number of dependencies between instances (e.g.
Terraform drift will be much smaller this way if a policy is deleted via
the web UI).

## Which issue(s) this PR fixes

Related to https://github.com/grafana/oncall-private/issues/1680

## Checklist

- [x] Unit, integration, and e2e (if applicable) tests updated
- [x] Documentation added (or `pr:no public docs` PR label added if not
required)
- [x] `CHANGELOG.md` updated (or `pr:no changelog` PR label added if not
required)
This commit is contained in:
Vadim Stepanov 2023-06-21 12:13:56 +01:00 committed by GitHub
parent 0c46b41498
commit eada4a4355
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 816 additions and 75 deletions

View file

@ -22,6 +22,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Change mobile shift notifications title and subtitle by @imtoori ([#2288](https://github.com/grafana/oncall/pull/2288))
## Fixed
- Fix duplicate orders for user notification policies by @vadimkerr ([#2278](https://github.com/grafana/oncall/pull/2278))
## v1.2.45 (2023-06-19)
### Changed

View file

@ -208,7 +208,7 @@ services:
container_name: mysql
labels: *oncall-labels
image: mysql:8.0.32
command: --default-authentication-plugin=mysql_native_password --character-set-server=utf8mb4 --collation-server=utf8mb4_unicode_ci
command: --default-authentication-plugin=mysql_native_password --character-set-server=utf8mb4 --collation-server=utf8mb4_unicode_ci --max_connections=1024
restart: always
environment:
MYSQL_ROOT_PASSWORD: empty

View file

@ -7,7 +7,7 @@ from apps.base.models import UserNotificationPolicy
from apps.base.models.user_notification_policy import NotificationChannelAPIOptions
from apps.user_management.models import User
from common.api_helpers.custom_fields import OrganizationFilteredPrimaryKeyRelatedField
from common.api_helpers.exceptions import BadRequest, Forbidden
from common.api_helpers.exceptions import Forbidden
from common.api_helpers.mixins import EagerLoadingMixin
@ -34,6 +34,7 @@ class UserNotificationPolicyBaseSerializer(EagerLoadingMixin, serializers.ModelS
class Meta:
model = UserNotificationPolicy
fields = ["id", "step", "order", "notify_by", "wait_delay", "important", "user"]
read_only_fields = ["order"]
def to_internal_value(self, data):
if data.get("wait_delay", None):
@ -67,7 +68,6 @@ class UserNotificationPolicyBaseSerializer(EagerLoadingMixin, serializers.ModelS
class UserNotificationPolicySerializer(UserNotificationPolicyBaseSerializer):
prev_step = serializers.CharField(required=False, write_only=True, allow_null=True)
user = OrganizationFilteredPrimaryKeyRelatedField(
queryset=User.objects,
required=False,
@ -80,36 +80,16 @@ class UserNotificationPolicySerializer(UserNotificationPolicyBaseSerializer):
default=NotificationChannelAPIOptions.DEFAULT_NOTIFICATION_CHANNEL,
)
class Meta(UserNotificationPolicyBaseSerializer.Meta):
fields = [*UserNotificationPolicyBaseSerializer.Meta.fields, "prev_step"]
read_only_fields = ("order",)
def create(self, validated_data):
prev_step = validated_data.pop("prev_step", None)
user = validated_data.get("user")
user = validated_data.get("user") or self.context["request"].user
organization = self.context["request"].auth.organization
if not user:
user = self.context["request"].user
self_or_admin = user.self_or_admin(user_to_check=self.context["request"].user, organization=organization)
if not self_or_admin:
raise Forbidden()
if prev_step is not None:
try:
prev_step = UserNotificationPolicy.objects.get(public_primary_key=prev_step)
except UserNotificationPolicy.DoesNotExist:
raise BadRequest(detail="Prev step does not exist")
if prev_step.user != user or prev_step.important != validated_data.get("important", False):
raise BadRequest(detail="UserNotificationPolicy can be created only with the same user and importance")
instance = UserNotificationPolicy.objects.create(**validated_data)
instance.to(prev_step.order + 1)
return instance
else:
instance = UserNotificationPolicy.objects.create(**validated_data)
return instance
instance = UserNotificationPolicy.objects.create(**validated_data)
return instance
class UserNotificationPolicyUpdateSerializer(UserNotificationPolicyBaseSerializer):

View file

@ -110,7 +110,7 @@ def test_user_cant_create_notification_policy_for_user(
@pytest.mark.django_db
def test_create_notification_policy_from_step(
def test_create_notification_policy_order_is_ignored(
user_notification_policy_internal_api_setup,
make_user_auth_headers,
):
@ -121,7 +121,7 @@ def test_create_notification_policy_from_step(
url = reverse("api-internal:notification_policy-list")
data = {
"prev_step": wait_notification_step.public_primary_key,
"position": 2023,
"step": UserNotificationPolicy.Step.NOTIFY,
"notify_by": UserNotificationPolicy.NotificationChannel.SLACK,
"wait_delay": None,
@ -130,26 +130,19 @@ def test_create_notification_policy_from_step(
}
response = client.post(url, data, format="json", **make_user_auth_headers(admin, token))
assert response.status_code == status.HTTP_201_CREATED
assert response.data["order"] == 1
assert response.data["order"] == 2
@pytest.mark.django_db
def test_create_invalid_notification_policy(user_notification_policy_internal_api_setup, make_user_auth_headers):
def test_move_to_position_position_error(user_notification_policy_internal_api_setup, make_user_auth_headers):
token, steps, users = user_notification_policy_internal_api_setup
wait_notification_step, _, _, _ = steps
admin, _ = users
step = steps[0]
client = APIClient()
url = reverse("api-internal:notification_policy-list")
url = reverse("api-internal:notification_policy-move-to-position", kwargs={"pk": step.public_primary_key})
data = {
"prev_step": wait_notification_step.public_primary_key,
"step": UserNotificationPolicy.Step.NOTIFY,
"notify_by": UserNotificationPolicy.NotificationChannel.SLACK,
"wait_delay": None,
"important": True,
"user": admin.public_primary_key,
}
response = client.post(url, data, format="json", **make_user_auth_headers(admin, token))
# position value only can be 0 or 1 for this test setup, because there are only 2 steps
response = client.put(f"{url}?position=2", content_type="application/json", **make_user_auth_headers(admin, token))
assert response.status_code == status.HTTP_400_BAD_REQUEST
@ -221,7 +214,7 @@ def test_admin_can_move_user_step(user_notification_policy_internal_api_setup, m
"api-internal:notification_policy-move-to-position", kwargs={"pk": second_user_step.public_primary_key}
)
response = client.put(f"{url}?position=1", content_type="application/json", **make_user_auth_headers(admin, token))
response = client.put(f"{url}?position=0", content_type="application/json", **make_user_auth_headers(admin, token))
assert response.status_code == status.HTTP_200_OK

View file

@ -142,7 +142,12 @@ class UserNotificationPolicyView(UpdateSerializerMixin, ModelViewSet):
def move_to_position(self, request, pk):
instance = self.get_object()
position = get_move_to_position_param(request)
instance.to(position)
try:
instance.to_index(position)
except IndexError:
raise BadRequest(detail="Invalid position")
return Response(status=status.HTTP_200_OK)
@action(detail=False, methods=["get"])

View file

@ -0,0 +1,50 @@
# Generated by Django 3.2.19 on 2023-06-16 15:10
from django.db import migrations, models
from django.db.models import Count
from common.database import get_random_readonly_database_key_if_present_otherwise_default
import django_migration_linter as linter
def fix_duplicate_order_user_notification_policy(apps, schema_editor):
UserNotificationPolicy = apps.get_model('base', 'UserNotificationPolicy')
# it should be safe to use a readonly database because duplicates are pretty infrequent
db = get_random_readonly_database_key_if_present_otherwise_default()
# find all (user_id, important, order) tuples that have more than one entry (meaning duplicates)
items_with_duplicate_orders = UserNotificationPolicy.objects.using(db).values(
"user_id", "important", "order"
).annotate(count=Count("order")).order_by().filter(count__gt=1) # use order_by() to reset any existing ordering
# make sure we don't fix the same (user_id, important) pair more than once
values_to_fix = set((item["user_id"], item["important"]) for item in items_with_duplicate_orders)
for user_id, important in values_to_fix:
policies = UserNotificationPolicy.objects.filter(user_id=user_id, important=important).order_by("order", "id")
# assign correct sequential order for each policy starting from 0
for idx, policy in enumerate(policies):
policy.order = idx
UserNotificationPolicy.objects.bulk_update(policies, fields=["order"])
class Migration(migrations.Migration):
dependencies = [
('base', '0003_delete_organizationlogrecord'),
]
operations = [
linter.IgnoreMigration(), # adding a unique constraint after fixing duplicates should be fine
migrations.AlterField(
model_name='usernotificationpolicy',
name='order',
field=models.PositiveIntegerField(db_index=True, editable=False, null=True),
),
migrations.RunPython(fix_duplicate_order_user_notification_policy, migrations.RunPython.noop),
migrations.AddConstraint(
model_name='usernotificationpolicy',
constraint=models.UniqueConstraint(fields=('user_id', 'important', 'order'), name='unique_user_notification_policy_order'),
),
]

View file

@ -0,0 +1,272 @@
import random
import time
import typing
from functools import wraps
from django.db import IntegrityError, OperationalError, models, transaction
def _retry(exc: typing.Type[Exception] | tuple[typing.Type[Exception], ...], max_attempts: int = 5) -> typing.Callable:
"""
A utility decorator for retrying a function on a given exception(s) up to max_attempts times.
"""
def _retry_with_params(f):
@wraps(f)
def wrapper(*args, **kwargs):
attempts = 0
while attempts < max_attempts:
try:
return f(*args, **kwargs)
except exc:
if attempts == max_attempts - 1:
raise
attempts += 1
time.sleep(random.random())
return wrapper
return _retry_with_params
Self = typing.TypeVar("Self", bound="OrderedModel")
class OrderedModel(models.Model):
"""
This class is intended to be used as a mixin for models that need to be ordered.
It's similar to django-ordered-model: https://github.com/django-ordered-model/django-ordered-model.
The key difference of this implementation is that it allows orders to be unique at the database level and
is designed to work correctly under concurrent load.
Notable differences compared to django-ordered-model:
- order can be unique at the database level;
- order can temporarily be set to NULL while performing moving operations;
- instance.delete() only deletes the instance, and doesn't shift other instances' orders;
- some methods are not implemented because they're not used in the codebase;
Example usage:
class Step(OrderedModel):
user = models.ForeignKey(User, on_delete=models.CASCADE)
order_with_respect_to = ["user_id"] # steps are ordered per user
class Meta:
ordering = ["order"] # to make queryset ordering correct and consistent
unique_together = ["user_id", "order"] # orders are unique per user at the database level
It's possible for orders to be non-sequential, e.g. order sequence [100, 150, 400] is totally possible and valid.
"""
order = models.PositiveIntegerField(editable=False, db_index=True, null=True)
order_with_respect_to: list[str] = []
class Meta:
abstract = True
ordering = ["order"]
constraints = [
models.UniqueConstraint(fields=["order"], name="unique_order"),
]
def save(self, *args, **kwargs) -> None:
if self.order is None:
self._save_no_order_provided()
else:
super().save()
@_retry(OperationalError) # retry on deadlock
def delete(self, *args, **kwargs) -> None:
with transaction.atomic():
# lock ordering queryset to prevent deleting instances that are used by other transactions
self._lock_ordering_queryset()
super().delete(*args, **kwargs)
@_retry((IntegrityError, OperationalError)) # retry on duplicate order or deadlock
def _save_no_order_provided(self) -> None:
"""
Save self to DB without an order provided (e.g on creation).
Order is set to the next available order, or 0 if there are no other instances.
Example:
a = OrderedModel.objects.create()
b = OrderedModel.objects.create()
c = OrderedModel.objects.create(order=10)
d = OrderedModel.objects.create()
assert (a.order, b.order, c.order, d.order) == (0, 1, 10, 11)
"""
with transaction.atomic():
instances = self._lock_ordering_queryset() # lock ordering queryset to prevent reading inconsistent data
max_order = max(instance.order for instance in instances) if instances else -1
self.order = max_order + 1
super().save()
@_retry(OperationalError) # retry on deadlock
def to(self, order: int) -> None:
"""
Move self to a given order, adjusting other instances' orders if necessary.
Example:
a = OrderedModel(order=1)
b = OrderedModel(order=2)
c = OrderedModel(order=3)
a.to(3) # move the first element to the last order
assert (a.order, b.order, c.order) == (3, 1, 2) # [a, b, c] -> [b, c, a]
"""
self._validate_positive_integer(order)
with transaction.atomic():
instances = self._lock_ordering_queryset()
self._move_instances_to_order(instances, order)
@_retry(OperationalError) # retry on deadlock
def to_index(self, index: int) -> None:
"""
Move self to a given index, adjusting other instances' orders if necessary.
Similar with to(), but accepts an index instead of an order.
This might be handy as orders might be non-sequential, but most clients assume that they are sequential.
Example:
a = OrderedModel(order=1)
b = OrderedModel(order=5)
c = OrderedModel(order=10)
a.to_index(2) # move the first element to the second index (where c is)
assert (a.order, b.order, c.order) == (10, 4, 9) # [a, b, c] -> [b, c, a]
"""
self._validate_positive_integer(index)
with transaction.atomic():
instances = self._lock_ordering_queryset()
order = instances[index].order # get order of the instance at the given index
self._move_instances_to_order(instances, order)
def _move_instances_to_order(self, instances: list[Self], order: int) -> None:
"""
Helper method for moving self to a given order, adjusting other instances' orders if necessary.
Must be called within a transaction that locks the ordering queryset.
"""
# Get the up-to-date instance from the database, because it might have been updated by another transaction.
try:
_self = next(instance for instance in instances if instance.pk == self.pk)
self.order = _self.order
assert self.order is not None
except StopIteration:
raise self.DoesNotExist()
# If the order is already correct, do nothing.
if self.order == order:
return
# Figure out instances that need to be moved and their new orders.
instances_to_move = []
if self.order < order:
for instance in instances:
if instance.order is not None and self.order < instance.order <= order:
instance.order -= 1
instances_to_move.append(instance)
else:
for instance in instances:
if instance.order is not None and order <= instance.order < self.order:
instance.order += 1
instances_to_move.append(instance)
# If there's nothing to move, just update self.order and return.
if not instances_to_move:
self.order = order
self.save(update_fields=["order"])
return
# Temporarily set order values to NULL to avoid unique constraint violations.
pks = [self.pk] + [instance.pk for instance in instances_to_move]
self._manager.filter(pk__in=pks).update(order=None)
# Update orders to appropriate unique values.
self.order = order
self._manager.filter(pk__in=pks).bulk_update([self] + instances_to_move, fields=["order"])
@_retry(OperationalError) # retry on deadlock
def swap(self, order: int) -> None:
"""
Swap self with an instance at a given order.
Example:
a = OrderedModel(order=1)
b = OrderedModel(order=2)
c = OrderedModel(order=3)
d = OrderedModel(order=4)
a.swap(4) # swap the first element with the last element
assert (a.order, b.order, c.order, d.order) == (4, 2, 3, 1) # [a, b, c, d] -> [d, b, c, a]
"""
self._validate_positive_integer(order)
with transaction.atomic():
instances = self._lock_ordering_queryset()
# Get the up-to-date instance from the database, because it might have been updated by another transaction.
try:
_self = next(instance for instance in instances if instance.pk == self.pk)
self.order = _self.order
assert self.order is not None
except StopIteration:
raise self.DoesNotExist()
# If the order is already correct, do nothing.
if self.order == order:
return
# Get the instance to swap with.
try:
other = next(instance for instance in instances if instance.order == order)
except StopIteration:
other = None
# If there's no instance to swap with, just update self.order and return.
if not other:
self.order = order
self.save(update_fields=["order"])
return
# Temporarily set order values to NULL to avoid unique constraint violations.
self._manager.filter(pk__in=[self.pk, other.pk]).update(order=None)
# Swap order values.
self.order, other.order = other.order, self.order
self._manager.filter(pk__in=[self.pk, other.pk]).bulk_update([self, other], fields=["order"])
def next(self) -> Self | None:
"""
Return the next instance in the ordering queryset, or None if there's no next instance.
Example:
a = OrderedModel(order=1)
b = OrderedModel(order=2)
assert a.next() == b
assert b.next() is None
"""
return self._get_ordering_queryset().filter(order__gt=self.order).first()
def max_order(self) -> int | None:
"""
Return the maximum order value in the ordering queryset or None if there are no instances.
"""
return self._get_ordering_queryset().aggregate(models.Max("order"))["order__max"]
@staticmethod
def _validate_positive_integer(value: int | None) -> None:
if value is None or not isinstance(value, int) or value < 0:
raise ValueError("Value must be a positive integer.")
def _get_ordering_queryset(self) -> models.QuerySet[Self]:
return self._manager.filter(**self._ordering_params)
def _lock_ordering_queryset(self) -> list[Self]:
"""
Locks the ordering queryset with SELECT FOR UPDATE and returns the queryset as a list.
This allows to prevent concurrent updates from different transactions.
"""
return list(self._get_ordering_queryset().select_for_update().only("pk", "order"))
@property
def _manager(self):
return self._meta.default_manager
@property
def _ordering_params(self) -> dict[str, typing.Any]:
return {field: getattr(self, field) for field in self.order_with_respect_to}

View file

@ -5,11 +5,11 @@ from typing import Tuple
from django.conf import settings
from django.core.exceptions import ValidationError
from django.core.validators import MinLengthValidator
from django.db import models
from django.db.models import Q, QuerySet
from ordered_model.models import OrderedModel
from django.db import IntegrityError, models
from django.db.models import Q
from apps.base.messaging import get_messaging_backends
from apps.base.models.ordered_model import OrderedModel
from apps.user_management.models import User
from common.exceptions import UserNotificationPolicyCouldNotBeDeleted
from common.public_primary_keys import generate_public_primary_key, increase_public_primary_key_length
@ -67,9 +67,11 @@ def validate_channel_choice(value):
class UserNotificationPolicyQuerySet(models.QuerySet):
def create_default_policies_for_user(self, user: User) -> "QuerySet[UserNotificationPolicy]":
model = self.model
def create_default_policies_for_user(self, user: User) -> None:
if user.notification_policies.filter(important=False).exists():
return
model = self.model
policies_to_create = (
model(
user=user,
@ -81,12 +83,16 @@ class UserNotificationPolicyQuerySet(models.QuerySet):
model(user=user, step=model.Step.NOTIFY, notify_by=model.NotificationChannel.PHONE_CALL, order=2),
)
super().bulk_create(policies_to_create)
return user.notification_policies.filter(important=False)
try:
super().bulk_create(policies_to_create)
except IntegrityError:
pass
def create_important_policies_for_user(self, user: User) -> None:
if user.notification_policies.filter(important=True).exists():
return
def create_important_policies_for_user(self, user: User) -> "QuerySet[UserNotificationPolicy]":
model = self.model
policies_to_create = (
model(
user=user,
@ -97,13 +103,15 @@ class UserNotificationPolicyQuerySet(models.QuerySet):
),
)
super().bulk_create(policies_to_create)
return user.notification_policies.filter(important=True)
try:
super().bulk_create(policies_to_create)
except IntegrityError:
pass
class UserNotificationPolicy(OrderedModel):
objects = UserNotificationPolicyQuerySet.as_manager()
order_with_respect_to = ("user", "important")
order_with_respect_to = ("user_id", "important")
public_primary_key = models.CharField(
max_length=20,
@ -145,6 +153,11 @@ class UserNotificationPolicy(OrderedModel):
class Meta:
ordering = ("order",)
constraints = [
models.UniqueConstraint(
fields=["user_id", "important", "order"], name="unique_user_notification_policy_order"
)
]
def __str__(self):
return f"{self.pk}: {self.short_verbal}"

View file

@ -0,0 +1,404 @@
import random
import threading
import pytest
from django.db import models
from apps.base.models.ordered_model import OrderedModel
class TestOrderedModel(OrderedModel):
test_field = models.CharField(max_length=255)
extra_field = models.IntegerField(null=True, default=None)
order_with_respect_to = ["test_field"]
class Meta:
app_label = "base"
ordering = ["order"]
constraints = [
models.UniqueConstraint(fields=["test_field", "order"], name="unique_test_field_order"),
]
def _get_ids():
return list(TestOrderedModel.objects.filter(test_field="test").values_list("id", flat=True))
def _get_orders():
return list(TestOrderedModel.objects.filter(test_field="test").values_list("order", flat=True))
def _orders_are_sequential():
orders = _get_orders()
return orders == list(range(len(orders)))
@pytest.mark.django_db
def test_ordered_model_create():
first = TestOrderedModel.objects.create(test_field="test")
second = TestOrderedModel.objects.create(test_field="test")
assert first.order == 0
assert second.order == 1
@pytest.mark.django_db
def test_ordered_model_delete():
instances = [TestOrderedModel.objects.create(test_field="test") for _ in range(3)]
instances[1].delete()
assert instances[1].pk is None
assert _get_ids() == [instances[0].id, instances[2].id]
assert _get_orders() == [0, 2]
@pytest.mark.django_db
def test_ordered_model_to():
instances = [TestOrderedModel.objects.create(test_field="test") for _ in range(5)]
def _ids(indices):
return [instances[i].id for i in indices]
# move to the end
instances[0].to(4)
assert instances[0].order == 4
assert _get_ids() == _ids([1, 2, 3, 4, 0])
assert _orders_are_sequential()
# move to the beginning
instances[0].to(0)
assert instances[0].order == 0
assert _get_ids() == _ids([0, 1, 2, 3, 4])
assert _orders_are_sequential()
# move to the middle
instances[0].to(2)
assert instances[0].order == 2
assert _get_ids() == _ids([1, 2, 0, 3, 4])
assert _orders_are_sequential()
# move from the middle to the end
instances[0].to(4)
assert instances[0].order == 4
assert _get_ids() == _ids([1, 2, 3, 4, 0])
assert _orders_are_sequential()
# move from the end to the second position
instances[0].to(1)
assert instances[0].order == 1
assert _get_ids() == _ids([1, 0, 2, 3, 4])
assert _orders_are_sequential()
# move from the second position to the beginning
instances[0].to(0)
assert instances[0].order == 0
assert _get_ids() == _ids([0, 1, 2, 3, 4])
assert _orders_are_sequential()
# don't move if the order is the same
for instance in instances:
instance.to(instance.order)
assert instance.order == instance.order
assert _get_ids() == _ids([0, 1, 2, 3, 4])
assert _orders_are_sequential()
@pytest.mark.django_db
def test_ordered_model_to_index():
instances = [TestOrderedModel.objects.create(test_field="test") for _ in range(5)]
def _ids(indices):
return [instances[i].id for i in indices]
# move to the end
instances[0].to_index(4)
assert instances[0].order == 4
assert _get_ids() == _ids([1, 2, 3, 4, 0])
assert _orders_are_sequential()
# move to the beginning
instances[0].to_index(0)
assert instances[0].order == 0
assert _get_ids() == _ids([0, 1, 2, 3, 4])
assert _orders_are_sequential()
# move to the middle
instances[0].to_index(2)
assert instances[0].order == 2
assert _get_ids() == _ids([1, 2, 0, 3, 4])
assert _orders_are_sequential()
# move from the middle to the end
instances[0].to_index(4)
assert instances[0].order == 4
assert _get_ids() == _ids([1, 2, 3, 4, 0])
assert _orders_are_sequential()
# move from the end to the second position
instances[0].to_index(1)
assert instances[0].order == 1
assert _get_ids() == _ids([1, 0, 2, 3, 4])
assert _orders_are_sequential()
# move from the second position to the beginning
instances[0].to_index(0)
assert instances[0].order == 0
assert _get_ids() == _ids([0, 1, 2, 3, 4])
assert _orders_are_sequential()
# don't move if the order is the same
for instance in instances:
instance.to_index(instance.order)
assert instance.order == instance.order
assert _get_ids() == _ids([0, 1, 2, 3, 4])
assert _orders_are_sequential()
@pytest.mark.django_db
def test_ordered_model_swap():
instances = [TestOrderedModel.objects.create(test_field="test") for _ in range(5)]
def _ids(indices):
return [instances[i].id for i in indices]
# swap with last
instances[0].swap(4)
assert instances[0].order == 4
assert _get_ids() == _ids([4, 1, 2, 3, 0])
assert _orders_are_sequential()
# swap with first
instances[0].swap(0)
assert instances[0].order == 0
assert _get_ids() == _ids([0, 1, 2, 3, 4])
assert _orders_are_sequential()
# swap with middle
instances[0].swap(2)
assert instances[0].order == 2
assert _get_ids() == _ids([2, 1, 0, 3, 4])
assert _orders_are_sequential()
# swap from the middle to the end
instances[0].swap(4)
assert instances[0].order == 4
assert _get_ids() == _ids([2, 1, 4, 3, 0])
assert _orders_are_sequential()
# swap from the end to the second position
instances[0].swap(1)
assert instances[0].order == 1
assert _get_ids() == _ids([2, 0, 4, 3, 1])
assert _orders_are_sequential()
# swap from the second position to the beginning
instances[0].swap(0)
assert instances[0].order == 0
assert _get_ids() == _ids([0, 2, 4, 3, 1])
assert _orders_are_sequential()
# swap with itself
for instance in instances:
instance.refresh_from_db(fields=["order"])
instance.swap(instance.order)
assert instance.order == instance.order
assert _get_ids() == _ids([0, 2, 4, 3, 1])
assert _orders_are_sequential()
@pytest.mark.django_db
def test_order_with_respect_to_isolation():
instances = [TestOrderedModel.objects.create(test_field="test") for _ in range(5)]
other_instances = [TestOrderedModel.objects.create(test_field="test1") for _ in range(5)]
assert [i.order for i in instances] == [0, 1, 2, 3, 4]
assert [i.order for i in other_instances] == [0, 1, 2, 3, 4]
assert instances[-1].next() is None
assert instances[-1].max_order() == 4
instances[0].to(8)
instances[1].swap(7)
for idx, instance in enumerate(other_instances):
instance.refresh_from_db()
assert instance.order == idx
with pytest.raises(IndexError):
instances[0].to_index(6)
# Tests below are for checking that concurrent operations are performed correctly.
# They are skipped by default because they might take a lot of time to run.
# It could be useful to run them manually when making changes to the code, making sure
# that the changes don't break concurrent operations. To run the tests, set SKIP_CONCURRENT to False.
SKIP_CONCURRENT = True
@pytest.mark.skipif(SKIP_CONCURRENT, reason="OrderedModel concurrent tests are skipped to speed up tests")
@pytest.mark.django_db(transaction=True)
def test_ordered_model_create_concurrent():
LOOPS = 30
THREADS = 10
exceptions = []
def create():
for loop in range(LOOPS):
try:
TestOrderedModel.objects.create(test_field="test")
except Exception as e:
exceptions.append(e)
threads = [threading.Thread(target=create) for _ in range(THREADS)]
for thread in threads:
thread.start()
for thread in threads:
thread.join()
assert not exceptions
assert TestOrderedModel.objects.count() == LOOPS * THREADS
assert _orders_are_sequential()
@pytest.mark.skipif(SKIP_CONCURRENT, reason="OrderedModel concurrent tests are skipped to speed up tests")
@pytest.mark.django_db(transaction=True)
def test_ordered_model_to_concurrent():
THREADS = 300
exceptions = []
TestOrderedModel.objects.all().delete() # clear table
instances = [TestOrderedModel.objects.create(test_field="test") for _ in range(THREADS)]
random.seed(42)
positions = [random.randint(0, THREADS - 1) for _ in range(THREADS)]
def to(idx):
try:
instance = instances[idx]
instance.to(positions[idx]) # swap with next
except Exception as e:
exceptions.append(e)
threads = [threading.Thread(target=to, args=(idx,)) for idx in range(THREADS - 1)]
for thread in threads:
thread.start()
for thread in threads:
thread.join()
# can only check that orders are still sequential and that there are no exceptions
# can't check the exact order because it changes depending on the order of execution
assert not exceptions
assert _orders_are_sequential()
@pytest.mark.skipif(SKIP_CONCURRENT, reason="OrderedModel concurrent tests are skipped to speed up tests")
@pytest.mark.django_db(transaction=True)
def test_ordered_model_swap_concurrent():
THREADS = 300
exceptions = []
TestOrderedModel.objects.all().delete() # clear table
instances = [TestOrderedModel.objects.create(test_field="test") for _ in range(THREADS)]
# generate random unique orders
random.seed(42)
unique_orders = list(range(THREADS))
random.shuffle(unique_orders)
def swap(idx):
try:
instance = instances[idx]
instance.swap(unique_orders[idx])
except Exception as e:
exceptions.append(e)
threads = [threading.Thread(target=swap, args=(idx,)) for idx in range(THREADS)]
for thread in threads:
thread.start()
for thread in threads:
thread.join()
assert not exceptions
assert _orders_are_sequential()
# in case of unique orders, the final order is deterministic
assert list(TestOrderedModel.objects.order_by("id").values_list("order", flat=True)) == unique_orders
@pytest.mark.skipif(SKIP_CONCURRENT, reason="OrderedModel concurrent tests are skipped to speed up tests")
@pytest.mark.django_db(transaction=True)
def test_ordered_model_swap_non_unique_orders_concurrent():
THREADS = 300
exceptions = []
TestOrderedModel.objects.all().delete() # clear table
instances = [TestOrderedModel.objects.create(test_field="test") for _ in range(THREADS)]
# generate random non-unique orders
random.seed(42)
positions = [random.randint(0, THREADS - 1) for _ in range(THREADS)]
def swap(idx):
try:
instance = instances[idx]
instance.swap(positions[idx])
except Exception as e:
exceptions.append(e)
threads = [threading.Thread(target=swap, args=(idx,)) for idx in range(THREADS)]
for thread in threads:
thread.start()
for thread in threads:
thread.join()
# can only check that orders are still sequential and that there are no exceptions
# can't check the exact order because it changes depending on the order of execution
assert not exceptions
assert _orders_are_sequential()
@pytest.mark.skipif(SKIP_CONCURRENT, reason="OrderedModel concurrent tests are skipped to speed up tests")
@pytest.mark.django_db(transaction=True)
def test_ordered_model_create_swap_and_delete_concurrent():
"""Check that create+swap, swap and delete operations are performed correctly when run concurrently."""
THREADS = 100
exceptions = []
instances = [TestOrderedModel.objects.create(test_field="test", extra_field=idx) for idx in range(THREADS * 3)]
def create_swap(idx):
try:
instance = TestOrderedModel.objects.create(test_field="test", extra_field=idx + 1000)
instance.swap(idx)
except Exception as e:
exceptions.append(("create_swap", e))
def swap(idx):
try:
instances[idx].swap(idx + 1)
except Exception as e:
exceptions.append(("swap", e))
def delete(idx):
try:
instances[idx].delete()
except Exception as e:
exceptions.append(("delete", e))
threads = [threading.Thread(target=create_swap, args=(idx,)) for idx in list(range(THREADS))]
threads += [threading.Thread(target=delete, args=(idx,)) for idx in range(THREADS)]
threads += [threading.Thread(target=swap, args=(idx,)) for idx in range(THREADS, THREADS * 2 - 1)]
for thread in threads:
thread.start()
for thread in threads:
thread.join()
expected_extra_field_values = list(range(1000, 1000 + THREADS))
expected_extra_field_values += [THREADS * 2 - 1] + list(range(THREADS, THREADS * 2 - 1))
expected_extra_field_values += [instance.extra_field for instance in instances[THREADS * 2 : THREADS * 3]]
assert not exceptions
assert _orders_are_sequential()
assert list(TestOrderedModel.objects.values_list("extra_field", flat=True)) == expected_extra_field_values

View file

@ -43,14 +43,16 @@ class PersonalNotificationRuleSerializer(EagerLoadingMixin, serializers.ModelSer
# that is why step key is used instead of type below
if "wait_delay" in validated_data and validated_data["step"] != UserNotificationPolicy.Step.WAIT:
raise exceptions.ValidationError({"duration": "Duration can't be set"})
user = validated_data.pop("user")
# Remove "manual_order" and "order" fields from validated_data, so they are not passed to create method.
# Policies are always created at the end of the list, and then moved to the desired position by _adjust_order.
manual_order = validated_data.pop("manual_order")
if not manual_order:
order = validated_data.pop("order", None)
instance = UserNotificationPolicy.objects.create(**validated_data, user=user)
self._change_position(order, instance)
else:
instance = UserNotificationPolicy.objects.create(**validated_data, user=user)
order = validated_data.pop("order", None)
instance = UserNotificationPolicy.objects.create(**validated_data)
if order is not None:
self._adjust_order(instance, manual_order, order, created=True)
return instance
@ -117,14 +119,32 @@ class PersonalNotificationRuleSerializer(EagerLoadingMixin, serializers.ModelSer
raise exceptions.ValidationError({"type": "Invalid type"})
def _change_position(self, order, instance):
if order is not None:
if order >= 0:
instance.to(order)
elif order == -1:
instance.bottom()
else:
raise BadRequest(detail="Invalid value for position field")
@staticmethod
def _adjust_order(instance, manual_order, order, created):
# Passing order=-1 means that the policy should be moved to the end of the list.
if order == -1:
if created:
# The policy was just created, so it is already at the end of the list.
return
order = instance.max_order()
# max_order() can't be None here because at least one instance exists the one we are moving.
assert order is not None
# Negative order is not allowed.
if order < 0:
raise BadRequest(detail="Invalid value for position field")
# manual_order=True is intended for use by Terraform provider only, and is not a documented feature.
# Orders are swapped instead of moved when using Terraform, because Terraform may issue concurrent requests
# to create / update / delete multiple policies. "Move to" operation is not deterministic in this case, and
# final order of policies may be different depending on the order in which requests are processed. On the other
# hand, the result of concurrent "swap" operations is deterministic and does not depend on the order in
# which requests are processed.
if manual_order:
instance.swap(order)
else:
instance.to(order)
class PersonalNotificationRuleUpdateSerializer(PersonalNotificationRuleSerializer):
@ -145,10 +165,10 @@ class PersonalNotificationRuleUpdateSerializer(PersonalNotificationRuleSerialize
if "wait_delay" in validated_data and instance.step != UserNotificationPolicy.Step.WAIT:
raise exceptions.ValidationError({"duration": "Duration can't be set"})
# Remove "manual_order" and "order" fields from validated_data, so they are not passed to update method.
manual_order = validated_data.pop("manual_order")
if not manual_order:
order = validated_data.pop("order", None)
self._change_position(order, instance)
order = validated_data.pop("order", None)
if order is not None:
self._adjust_order(instance, manual_order, order, created=False)
return super().update(instance, validated_data)

View file

@ -9,6 +9,6 @@ banned-modules =
[pytest]
# https://pytest-django.readthedocs.io/en/latest/configuring_django.html#order-of-choosing-settings
# https://pytest-django.readthedocs.io/en/latest/database.html
addopts = --color=yes --showlocals
addopts = --no-migrations --color=yes --showlocals
# https://pytest-django.readthedocs.io/en/latest/faq.html#my-tests-are-not-being-found-why
python_files = tests.py test_*.py *_tests.py