Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions src/rules_validation_api/validators/iteration_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,46 @@ def transform_actions_mapper(cls, action_mapper: ActionsMapper) -> ActionsMapper
action_mapper.root = new_root
return action_mapper

@model_validator(mode="after")
def validate_rule_cohort_labels_against_iteration_cohorts(self) -> typing.Self:
allowed_labels = {c.cohort_label for c in self.iteration_cohorts}
line_errors: list[InitErrorDetails] = []
Comment on lines +100 to +103
Copy link

Copilot AI Mar 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because this is a separate @model_validator(mode="after") that raises immediately, any subsequent mode="after" validations (e.g. action_mapper_validation) will not run when cohort-label errors exist. This is a regression in error aggregation/UX compared to the existing pattern that collects multiple validation errors before raising. Consider folding this check into the existing action_mapper_validation aggregator (or otherwise combining model-level checks) so callers can receive all validation issues in one response.

Copilot uses AI. Check for mistakes.

# Pre compute allowed label string once
allowed_str = ", ".join(sorted(allowed_labels)) if allowed_labels else None

for idx, rule in enumerate(self.iteration_rules):
if not rule.cohort_label:
continue

for label in rule.parsed_cohort_labels:
if label in allowed_labels:
continue

# Build error message
error_message = (
f"Invalid cohort_label value '{label}'. Allowed values: {allowed_str}."
if allowed_str
else (
f"Invalid cohort_label value '{label}'. "
"No iteration cohorts are defined, so no labels are allowed."
)
)

line_errors.append(
InitErrorDetails(
type="value_error",
loc=("iteration_rules", idx, "cohort_label"),
input=rule.cohort_label,
ctx={"error": error_message},
)
)

if line_errors:
raise ValidationError.from_exception_data(title="IterationValidation", line_errors=line_errors)

Comment on lines +104 to +137
Copy link

Copilot AI Mar 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

validate_rule_cohort_labels_against_iteration_cohorts raises a plain ValueError, which will produce a model-level error location and stops at the first invalid rule. This makes it hard to identify which IterationRules[idx].CohortLabel is wrong and prevents surfacing multiple invalid rules in one response. Consider collecting all invalid rules into InitErrorDetails (with a loc that includes the rule index and CohortLabel) and raising a single ValidationError.from_exception_data, consistent with the other validators in this file.

Suggested change
for idx, rule in enumerate(self.iteration_rules):
if rule.cohort_label is None:
continue
if not all(label in allowed_labels for label in rule.parsed_cohort_labels):
allowed_str = ", ".join(sorted(allowed_labels))
msg = (
f"Invalid cohort_label value: {rule.cohort_label}. Allowed values: {allowed_str}. Rule index: {idx}"
)
raise ValueError(msg)
line_errors: list[InitErrorDetails] = []
for idx, rule in enumerate(self.iteration_rules):
if rule.cohort_label is None:
continue
for label in rule.parsed_cohort_labels:
if label not in allowed_labels:
allowed_str = ", ".join(sorted(allowed_labels))
error = InitErrorDetails(
type="value_error",
loc=("iteration_rules", idx, "cohort_label"),
input=rule.cohort_label,
ctx={
"error": f"Invalid cohort_label value '{label}'. Allowed values: {allowed_str}."
},
)
line_errors.append(error)
if line_errors:
raise ValidationError.from_exception_data(title="IterationValidation", line_errors=line_errors)

Copilot uses AI. Check for mistakes.
return self

@model_validator(mode="after")
def action_mapper_validation(self) -> typing.Self:
all_errors = []
Expand Down
5 changes: 1 addition & 4 deletions tests/unit/validation/test_campaign_config_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,7 @@ class TestMandatoryFieldsSchemaValidations:
def test_campaign_config_with_only_mandatory_fields_configuration(
self, valid_campaign_config_with_only_mandatory_fields
):
try:
CampaignConfigValidation(**valid_campaign_config_with_only_mandatory_fields)
except ValidationError as e:
pytest.fail(f"Unexpected error during model instantiation: {e}")
CampaignConfigValidation(**valid_campaign_config_with_only_mandatory_fields)

@pytest.mark.parametrize(
"mandatory_field",
Expand Down
5 changes: 1 addition & 4 deletions tests/unit/validation/test_iteration_rules_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,7 @@ class TestMandatoryFieldsSchemaValidations:
def test_campaign_config_with_only_mandatory_fields_configuration(
self, valid_iteration_rule_with_only_mandatory_fields
):
try:
IterationRuleValidation(**valid_iteration_rule_with_only_mandatory_fields)
except ValidationError as e:
pytest.fail(f"Unexpected error during model instantiation: {e}")
IterationRuleValidation(**valid_iteration_rule_with_only_mandatory_fields)

@pytest.mark.parametrize(
"mandatory_field",
Expand Down
25 changes: 20 additions & 5 deletions tests/unit/validation/test_iteration_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,7 @@ class TestMandatoryFieldsSchemaValidations:
def test_campaign_config_with_only_mandatory_fields_configuration(
self, valid_campaign_config_with_only_mandatory_fields
):
try:
IterationValidation(**(valid_campaign_config_with_only_mandatory_fields["Iterations"][0]))
except ValidationError as e:
pytest.fail(f"Unexpected error during model instantiation: {e}")
IterationValidation(**(valid_campaign_config_with_only_mandatory_fields["Iterations"][0]))

@pytest.mark.parametrize(
"mandatory_field",
Expand Down Expand Up @@ -556,7 +553,7 @@ def test_iteration_full_datetime_validation( # noqa : PLR0913
data = valid_campaign_config_with_only_mandatory_fields.copy()

if default_time_iteration_input:
data["iteration_time"] = default_time_iteration_input
data["IterationTime"] = default_time_iteration_input

data["Iterations"] = [iteration_data]

Expand All @@ -570,3 +567,21 @@ def test_iteration_full_datetime_validation( # noqa : PLR0913
f"Failed! Input: {iteration_time_input}, Default: {default_time_iteration_input}. "
f"Expected {expected_date_time} but got {result}"
)

def test_iteration_rules_having_invalid_cohort_labels_throws_error(
self,
valid_iteration_with_only_mandatory_fields,
valid_iteration_rule_with_only_mandatory_fields,
valid_iteration_cohorts,
):
data = valid_iteration_with_only_mandatory_fields.copy()
data["IterationRules"] = [valid_iteration_rule_with_only_mandatory_fields]
data["IterationCohorts"] = [valid_iteration_cohorts()]
data["IterationRules"][0]["CohortLabel"] = "label_2"

with pytest.raises(ValidationError) as exc_info:
IterationValidation(**data)

errors = exc_info.value.errors()
# Ensure at least one error is specifically about the invalid CohortLabel in IterationRules[0]
assert any(err.get("loc", [])[:3] == ("iteration_rules", 0, "cohort_label") for err in errors)
Loading