diff --git a/sdk/ml/azure-ai-ml/azure/ai/ml/_schema/schedule/schedule.py b/sdk/ml/azure-ai-ml/azure/ai/ml/_schema/schedule/schedule.py index fbde3e9b2cea..13c0592306ab 100644 --- a/sdk/ml/azure-ai-ml/azure/ai/ml/_schema/schedule/schedule.py +++ b/sdk/ml/azure-ai-ml/azure/ai/ml/_schema/schedule/schedule.py @@ -4,7 +4,7 @@ from marshmallow import fields -from azure.ai.ml._schema.core.fields import ArmStr, NestedField, UnionField +from azure.ai.ml._schema.core.fields import ArmStr, NestedField, TypeSensitiveUnionField, UnionField from azure.ai.ml._schema.core.resource import ResourceSchema from azure.ai.ml._schema.job import CreationContextSchema from azure.ai.ml._schema.schedule.create_job import ( @@ -32,13 +32,21 @@ class ScheduleSchema(ResourceSchema): properties = fields.Dict(keys=fields.Str(), values=fields.Str(allow_none=True)) +class ScheduleCreateJobField(TypeSensitiveUnionField): + # Keep legacy dump behavior so full scheduled jobs continue to serialize via CreateJobFileRefField. + def _serialize(self, value, attr, obj, **kwargs): + return super(TypeSensitiveUnionField, self)._serialize(value, attr, obj, **kwargs) + + class JobScheduleSchema(ScheduleSchema): - create_job = UnionField( - [ + create_job = ScheduleCreateJobField( + { + "pipeline": [NestedField(PipelineCreateJobSchema)], + "command": [NestedField(CommandCreateJobSchema)], + "spark": [NestedField(SparkCreateJobSchema)], + }, + plain_union_fields=[ ArmStr(azureml_type=AzureMLResourceType.JOB), CreateJobFileRefField, - NestedField(PipelineCreateJobSchema), - NestedField(CommandCreateJobSchema), - NestedField(SparkCreateJobSchema), - ] + ], ) diff --git a/sdk/ml/azure-ai-ml/tests/schedule/unittests/test_schedule_schema.py b/sdk/ml/azure-ai-ml/tests/schedule/unittests/test_schedule_schema.py index 2bf677c8503a..3c3fe56bc6d7 100644 --- a/sdk/ml/azure-ai-ml/tests/schedule/unittests/test_schedule_schema.py +++ b/sdk/ml/azure-ai-ml/tests/schedule/unittests/test_schedule_schema.py @@ -157,3 +157,28 @@ def test_load_invalid_schedule_missing_type(self): with pytest.raises(ValidationError) as e: load_schedule(test_path) assert "'type' must be specified when scheduling a remote job with updates." in e.value.messages[0] + + def test_load_invalid_schedule_pipeline_file_not_found_error_simplified(self, tmp_path): + test_path = tmp_path / "invalid_pipeline_schedule.yml" + test_path.write_text( + """ +$schema: https://azuremlschemas.azureedge.net/latest/schedule.schema.json +name: weekly_retrain_2022_cron_pipeline_file_not_found +trigger: + type: cron + expression: "15 10 * * 1" +create_job: + type: pipeline + job: ../pipeline.yml +""".strip(), + encoding="utf-8", + ) + + with pytest.raises(ValidationError) as e: + load_schedule(str(test_path)) + + error_message = str(e.value.messages) + assert "No such file or directory" in error_message + assert "Not supporting non file for create_job" not in error_message + assert "Value 'pipeline' passed is not in set ['command']" not in error_message + assert "Value 'pipeline' passed is not in set ['spark']" not in error_message