diff --git a/.github/workflows/coverage.yml b/.github/workflows/coverage.yml index 1d74013..39df7e9 100644 --- a/.github/workflows/coverage.yml +++ b/.github/workflows/coverage.yml @@ -21,7 +21,7 @@ jobs: pip install django~=5.2 pytest pytest-django pytest-cov drf-spectacular django-filter pytest --cov --cov-report=xml - name: Upload coverage to Codecov - uses: codecov/codecov-action@v4 + uses: codecov/codecov-action@v6 with: token: ${{ secrets.CODECOV_TOKEN }} fail_ci_if_error: true diff --git a/drf_standardized_errors/openapi_utils.py b/drf_standardized_errors/openapi_utils.py index cc2b7f1..a032f0d 100644 --- a/drf_standardized_errors/openapi_utils.py +++ b/drf_standardized_errors/openapi_utils.py @@ -1,6 +1,8 @@ from dataclasses import dataclass, field as dataclass_field from typing import Any, Dict, List, Optional, Set, Type, Union +import django +import rest_framework from django import forms from django.core.validators import ( DecimalValidator, @@ -230,12 +232,18 @@ def add_unique_together_error_codes( sfields_with_error_codes: "List[InputDataField]", ) -> None: for sfield in sfields_with_unique_together_validators: - sfield.error_codes.add("unique") unique_together_validators = [ validator for validator in sfield.field.validators if isinstance(validator, UniqueTogetherValidator) ] + if _drf_version() >= (3, 17) and django.VERSION >= (5, 0): + # drf 3.17 passes the `custom_violation_error` added in django 5.0 + # to `drf.UniqueTogetherValidator`. Before that, the error code was + # hardcoded as `"unique"` + sfield.error_codes.update(v.code for v in unique_together_validators) + else: # pragma: no cover + sfield.error_codes.add("unique") # fields involved in a unique together constraint have an implied # "required" state, so we're adding the "required" error code to them implicitly_required_fields = set() @@ -501,3 +509,9 @@ def get_example_from_exception(exc: exceptions.APIException) -> OpenApiExample: response_only=True, status_codes=[str(exc.status_code)], ) + + +def _drf_version(): + # we just care about major and minor drf versions + parts = rest_framework.VERSION.split(".") + return int(parts[0]), int(parts[1]) diff --git a/tests/test_openapi_utils.py b/tests/test_openapi_utils.py index d3f8c6d..dd0c523 100644 --- a/tests/test_openapi_utils.py +++ b/tests/test_openapi_utils.py @@ -16,6 +16,7 @@ from drf_standardized_errors.openapi_utils import ( InputDataField, + _drf_version, get_django_filter_backends, get_error_serializer, get_filter_forms, @@ -372,6 +373,43 @@ def test_unique_together_error_codes(unique_together): assert "required" in model.error_codes +@pytest.fixture +def unique_together_with_violation_code(): + from django.db import models + + class SomeModel(models.Model): + app_label = models.CharField(max_length=100) + model = models.CharField(max_length=100) + + class Meta: + constraints = [ + models.UniqueConstraint( + fields=["app_label", "model"], + name="unique_model", + violation_error_code="custom_violation_code", + ) + ] + + class SomeSerializer(serializers.ModelSerializer): + class Meta: + model = SomeModel + fields = ["app_label", "model"] + + return get_flat_serializer_fields(SomeSerializer()) + + +@pytest.mark.skipif( + _drf_version() < (3, 17) or django.VERSION < (5, 0), + reason="django added violation_error_code in v5 and drf supported it in v3.17", +) +def test_unique_together_new(unique_together_with_violation_code): + non_field_errors, _, __ = get_serializer_fields_with_error_codes( + unique_together_with_violation_code + ) + + assert "custom_violation_code" in non_field_errors.error_codes + + class PostSerializer(serializers.ModelSerializer): """ Intentional required=False to test that the 'required' error code is added