diff --git a/src/rules_validation_api/validators/iteration_validator.py b/src/rules_validation_api/validators/iteration_validator.py index d4934fd25..1c0f713fc 100644 --- a/src/rules_validation_api/validators/iteration_validator.py +++ b/src/rules_validation_api/validators/iteration_validator.py @@ -97,6 +97,46 @@ def transform_actions_mapper(cls, action_mapper: ActionsMapper) -> ActionsMapper action_mapper.root = new_root return action_mapper + @model_validator(mode="after") + def validate_rule_cohort_labels_against_iteration_cohorts(self) -> typing.Self: + allowed_labels = {c.cohort_label for c in self.iteration_cohorts} + line_errors: list[InitErrorDetails] = [] + + # Pre compute allowed label string once + allowed_str = ", ".join(sorted(allowed_labels)) if allowed_labels else None + + for idx, rule in enumerate(self.iteration_rules): + if not rule.cohort_label: + continue + + for label in rule.parsed_cohort_labels: + if label in allowed_labels: + continue + + # Build error message + error_message = ( + f"Invalid cohort_label value '{label}'. Allowed values: {allowed_str}." + if allowed_str + else ( + f"Invalid cohort_label value '{label}'. " + "No iteration cohorts are defined, so no labels are allowed." + ) + ) + + line_errors.append( + InitErrorDetails( + type="value_error", + loc=("iteration_rules", idx, "cohort_label"), + input=rule.cohort_label, + ctx={"error": error_message}, + ) + ) + + if line_errors: + raise ValidationError.from_exception_data(title="IterationValidation", line_errors=line_errors) + + return self + @model_validator(mode="after") def action_mapper_validation(self) -> typing.Self: all_errors = [] diff --git a/tests/unit/validation/test_campaign_config_validator.py b/tests/unit/validation/test_campaign_config_validator.py index c0e47e981..7782621b1 100644 --- a/tests/unit/validation/test_campaign_config_validator.py +++ b/tests/unit/validation/test_campaign_config_validator.py @@ -13,10 +13,7 @@ class TestMandatoryFieldsSchemaValidations: def test_campaign_config_with_only_mandatory_fields_configuration( self, valid_campaign_config_with_only_mandatory_fields ): - try: - CampaignConfigValidation(**valid_campaign_config_with_only_mandatory_fields) - except ValidationError as e: - pytest.fail(f"Unexpected error during model instantiation: {e}") + CampaignConfigValidation(**valid_campaign_config_with_only_mandatory_fields) @pytest.mark.parametrize( "mandatory_field", diff --git a/tests/unit/validation/test_iteration_rules_validator.py b/tests/unit/validation/test_iteration_rules_validator.py index 1c33f4623..da2e92f43 100644 --- a/tests/unit/validation/test_iteration_rules_validator.py +++ b/tests/unit/validation/test_iteration_rules_validator.py @@ -9,10 +9,7 @@ class TestMandatoryFieldsSchemaValidations: def test_campaign_config_with_only_mandatory_fields_configuration( self, valid_iteration_rule_with_only_mandatory_fields ): - try: - IterationRuleValidation(**valid_iteration_rule_with_only_mandatory_fields) - except ValidationError as e: - pytest.fail(f"Unexpected error during model instantiation: {e}") + IterationRuleValidation(**valid_iteration_rule_with_only_mandatory_fields) @pytest.mark.parametrize( "mandatory_field", diff --git a/tests/unit/validation/test_iteration_validator.py b/tests/unit/validation/test_iteration_validator.py index df63cfc11..a9ad48e91 100644 --- a/tests/unit/validation/test_iteration_validator.py +++ b/tests/unit/validation/test_iteration_validator.py @@ -14,10 +14,7 @@ class TestMandatoryFieldsSchemaValidations: def test_campaign_config_with_only_mandatory_fields_configuration( self, valid_campaign_config_with_only_mandatory_fields ): - try: - IterationValidation(**(valid_campaign_config_with_only_mandatory_fields["Iterations"][0])) - except ValidationError as e: - pytest.fail(f"Unexpected error during model instantiation: {e}") + IterationValidation(**(valid_campaign_config_with_only_mandatory_fields["Iterations"][0])) @pytest.mark.parametrize( "mandatory_field", @@ -556,7 +553,7 @@ def test_iteration_full_datetime_validation( # noqa : PLR0913 data = valid_campaign_config_with_only_mandatory_fields.copy() if default_time_iteration_input: - data["iteration_time"] = default_time_iteration_input + data["IterationTime"] = default_time_iteration_input data["Iterations"] = [iteration_data] @@ -570,3 +567,21 @@ def test_iteration_full_datetime_validation( # noqa : PLR0913 f"Failed! Input: {iteration_time_input}, Default: {default_time_iteration_input}. " f"Expected {expected_date_time} but got {result}" ) + + def test_iteration_rules_having_invalid_cohort_labels_throws_error( + self, + valid_iteration_with_only_mandatory_fields, + valid_iteration_rule_with_only_mandatory_fields, + valid_iteration_cohorts, + ): + data = valid_iteration_with_only_mandatory_fields.copy() + data["IterationRules"] = [valid_iteration_rule_with_only_mandatory_fields] + data["IterationCohorts"] = [valid_iteration_cohorts()] + data["IterationRules"][0]["CohortLabel"] = "label_2" + + with pytest.raises(ValidationError) as exc_info: + IterationValidation(**data) + + errors = exc_info.value.errors() + # Ensure at least one error is specifically about the invalid CohortLabel in IterationRules[0] + assert any(err.get("loc", [])[:3] == ("iteration_rules", 0, "cohort_label") for err in errors)