Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 13 additions & 6 deletions rest_framework/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1454,20 +1454,26 @@ def get_extra_kwargs(self):

def get_unique_together_constraints(self, model):
"""
Returns iterator of (fields, queryset, condition_fields, condition),
Returns iterator of (fields, queryset, condition_fields, condition, nulls_distinct),
each entry describes an unique together constraint on `fields` in `queryset`
with respect of constraint's `condition`.
with respect of constraint's `condition` and `nulls_distinct` option.
"""
for parent_class in [model] + list(model._meta.parents):
for unique_together in parent_class._meta.unique_together:
yield unique_together, model._default_manager, [], None
yield unique_together, model._default_manager, [], None, None
for constraint in parent_class._meta.constraints:
if isinstance(constraint, models.UniqueConstraint) and len(constraint.fields) > 1:
if constraint.condition is None:
condition_fields = []
else:
condition_fields = list(get_referenced_base_fields_from_q(constraint.condition))
yield (constraint.fields, model._default_manager, condition_fields, constraint.condition)
yield (
constraint.fields,
model._default_manager,
condition_fields,
constraint.condition,
getattr(constraint, 'nulls_distinct', None),
)

def get_uniqueness_extra_kwargs(self, field_names, declared_fields, extra_kwargs):
"""
Expand Down Expand Up @@ -1500,7 +1506,7 @@ def get_uniqueness_extra_kwargs(self, field_names, declared_fields, extra_kwargs

# Include each of the `unique_together` and `UniqueConstraint` field names,
# so long as all the field names are included on the serializer.
for unique_together_list, queryset, condition_fields, condition in self.get_unique_together_constraints(model):
for unique_together_list, queryset, condition_fields, condition, nulls_distinct in self.get_unique_together_constraints(model):
unique_together_list_and_condition_fields = set(unique_together_list) | set(condition_fields)
if model_fields_names.issuperset(unique_together_list_and_condition_fields):
unique_constraint_names |= unique_together_list_and_condition_fields
Expand Down Expand Up @@ -1643,7 +1649,7 @@ def get_unique_together_validators(self):
# Note that we make sure to check `unique_together` both on the
# base model class, but also on any parent classes.
validators = []
for unique_together, queryset, condition_fields, condition in self.get_unique_together_constraints(self.Meta.model):
for unique_together, queryset, condition_fields, condition, nulls_distinct in self.get_unique_together_constraints(self.Meta.model):
# Skip if serializer does not map to all unique together sources
unique_together_and_condition_fields = set(unique_together) | set(condition_fields)
if not set(source_map).issuperset(unique_together_and_condition_fields):
Expand Down Expand Up @@ -1677,6 +1683,7 @@ def get_unique_together_validators(self):
condition=condition,
message=violation_error_message,
code=getattr(constraint, 'violation_error_code', None),
nulls_distinct=nulls_distinct,
)
validators.append(validator)
return validators
Expand Down
18 changes: 12 additions & 6 deletions rest_framework/validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,13 +113,14 @@ class UniqueTogetherValidator:
requires_context = True
code = 'unique'

def __init__(self, queryset, fields, message=None, condition_fields=None, condition=None, code=None):
def __init__(self, queryset, fields, message=None, condition_fields=None, condition=None, code=None, nulls_distinct=None):
self.queryset = queryset
self.fields = fields
self.message = message or self.message
self.condition_fields = [] if condition_fields is None else condition_fields
self.condition = condition
self.code = code or self.code
self.nulls_distinct = nulls_distinct

def enforce_required_fields(self, attrs, serializer):
"""
Expand Down Expand Up @@ -197,17 +198,21 @@ def __call__(self, attrs, serializer):
else getattr(serializer.instance, source)
for source in condition_sources
}
if checked_values and None not in checked_values and qs_exists_with_condition(queryset, self.condition, condition_kwargs):
field_names = ', '.join(self.fields)
message = self.message.format(field_names=field_names)
raise ValidationError(message, code=self.code)
if checked_values:
# Skip validation for None values unless nulls_distinct is False
if self.nulls_distinct is not False and None in checked_values:
return
if qs_exists_with_condition(queryset, self.condition, condition_kwargs):
field_names = ', '.join(self.fields)
message = self.message.format(field_names=field_names)
raise ValidationError(message, code=self.code)

def __repr__(self):
return '<{}({})>'.format(
self.__class__.__name__,
', '.join(
f'{attr}={smart_repr(getattr(self, attr))}'
for attr in ('queryset', 'fields', 'condition')
for attr in ('queryset', 'fields', 'condition', 'nulls_distinct')
if getattr(self, attr) is not None)
)

Expand All @@ -220,6 +225,7 @@ def __eq__(self, other):
and self.queryset == other.queryset
and self.fields == other.fields
and self.code == other.code
and self.nulls_distinct == other.nulls_distinct
)


Expand Down
151 changes: 151 additions & 0 deletions tests/test_validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -616,6 +616,23 @@ class Meta:
]


# Only define nulls_distinct model for Django 5.0+
if django_version >= (5, 0):
class UniqueConstraintNullsDistinctModel(models.Model):
name = models.CharField(max_length=100)
code = models.CharField(max_length=100, null=True)
category = models.CharField(max_length=100, null=True)

class Meta:
constraints = [
models.UniqueConstraint(
name='unique_code_category_nulls_not_distinct',
fields=('code', 'category'),
nulls_distinct=False,
),
]


class UniqueConstraintCustomMessageCodeModel(models.Model):
username = models.CharField(max_length=32)
company_id = models.IntegerField()
Expand Down Expand Up @@ -1063,3 +1080,137 @@ def test_equality_operator(self):
assert validator == validator2
validator2.date_field = "bar2"
assert validator != validator2


# Tests for `nulls_distinct` option (Django 5.0+)
# -----------------------------------------------

@pytest.mark.skipif(
django_version < (5, 0),
reason="nulls_distinct requires Django 5.0+"
)
class TestUniqueConstraintNullsDistinct(TestCase):
"""
Tests for UniqueConstraint with nulls_distinct=False option.
When nulls_distinct=False, NULL values should be treated as equal
for uniqueness validation.
"""

def setUp(self):
from tests.test_validators import UniqueConstraintNullsDistinctModel

self.model = UniqueConstraintNullsDistinctModel

class UniqueConstraintNullsDistinctSerializer(serializers.ModelSerializer):
class Meta:
model = UniqueConstraintNullsDistinctModel
fields = ('name', 'code', 'category')

self.serializer_class = UniqueConstraintNullsDistinctSerializer

def test_nulls_distinct_false_validates_null_as_duplicate(self):
"""
When nulls_distinct=False, creating a second record with NULL values
in the constrained fields should fail validation.
"""
self.model.objects.create(name='First', code=None, category=None)

serializer = self.serializer_class(data={
'name': 'Second',
'code': None,
'category': None
})
assert not serializer.is_valid()

def test_nulls_distinct_false_allows_different_non_null_values(self):
"""
Non-NULL values should still work normally with uniqueness validation.
"""
self.model.objects.create(name='First', code='A', category='X')

serializer = self.serializer_class(data={
'name': 'Second',
'code': 'B',
'category': 'Y'
})
assert serializer.is_valid(), serializer.errors

def test_nulls_distinct_false_rejects_duplicate_non_null_values(self):
"""
Duplicate non-NULL values should still fail validation.
"""
self.model.objects.create(name='First', code='A', category='X')

serializer = self.serializer_class(data={
'name': 'Second',
'code': 'A',
'category': 'X'
})
assert not serializer.is_valid()

def test_nulls_distinct_false_update_with_null_values(self):
"""
Updating an existing instance with NULL values should not
raise a uniqueness error against itself.
"""
instance = self.model.objects.create(name='First', code=None, category=None)

serializer = self.serializer_class(instance=instance, data={
'name': 'Updated',
'code': None,
'category': None
})
assert serializer.is_valid(), serializer.errors

def test_nulls_distinct_false_update_to_existing_null(self):
"""
Updating an instance to NULL values that already exist in
another record should fail validation.
"""
self.model.objects.create(name='First', code=None, category=None)
instance = self.model.objects.create(name='Second', code='A', category='X')

serializer = self.serializer_class(instance=instance, data={
'name': 'Second',
'code': None,
'category': None
})
assert not serializer.is_valid()

def test_nulls_distinct_false_partial_null(self):
"""
When only one constrained field is NULL and the other is non-NULL,
validation should still treat NULL as equal for the NULL field.
"""
self.model.objects.create(name='First', code=None, category='X')

serializer = self.serializer_class(data={
'name': 'Second',
'code': None,
'category': 'X'
})
assert not serializer.is_valid()

def test_unique_together_validator_nulls_distinct_equality(self):
"""
Test that UniqueTogetherValidator equality considers nulls_distinct.
"""
mock_queryset = MagicMock()
validator1 = UniqueTogetherValidator(
queryset=mock_queryset,
fields=('a', 'b'),
nulls_distinct=False
)
validator2 = UniqueTogetherValidator(
queryset=mock_queryset,
fields=('a', 'b'),
nulls_distinct=False
)
validator3 = UniqueTogetherValidator(
queryset=mock_queryset,
fields=('a', 'b'),
nulls_distinct=True
)

assert validator1 == validator2
assert validator1 != validator3