diff --git a/sagemaker-train/src/sagemaker/train/aws_batch/batch_api_helper.py b/sagemaker-train/src/sagemaker/train/aws_batch/batch_api_helper.py index 0b84c10af5..ea20ab362b 100644 --- a/sagemaker-train/src/sagemaker/train/aws_batch/batch_api_helper.py +++ b/sagemaker-train/src/sagemaker/train/aws_batch/batch_api_helper.py @@ -32,6 +32,8 @@ def _submit_service_job( timeout: Optional[Dict] = None, share_identifier: Optional[str] = None, tags: Optional[Dict] = None, + quota_share_name: Optional[str] = None, + preemption_config: Optional[Dict] = None, ) -> Dict: """Batch submit_service_job API helper function. @@ -44,6 +46,8 @@ def _submit_service_job( timeout: Set with value of timeout if specified, else default to 1 day. share_identifier: value of shareIdentifier if specified. tags: A dict of string to string representing Batch tags. + quota_share_name: Quota Share name for the Batch job. + preemption_config: Preemption configuration. Returns: A dict containing jobArn, jobName and jobId. @@ -68,6 +72,10 @@ def _submit_service_job( payload["shareIdentifier"] = share_identifier if tags or training_payload_tags: payload["tags"] = __merge_tags(tags, training_payload_tags) + if quota_share_name: + payload["quotaShareName"] = quota_share_name + if preemption_config: + payload["preemptionConfiguration"] = preemption_config return client.submit_service_job(**payload) @@ -96,21 +104,45 @@ def _describe_service_job(job_id: str) -> Dict: 'jobId': 'string', 'jobName': 'string', 'jobQueue': 'string', + 'latestAttempt': { + 'serviceResourceId': { + 'name': 'string', + 'value': 'string' + } + }, + 'preemptionSummary': { + 'preemptedAttemptCount': 123, + 'recentPreemptedAttempts': [ + { + 'serviceResourceId': { + 'name': 'string', + 'value': 'string' + }, + 'startedAt': 123, + 'stoppedAt': 123, + 'statusReason': 'string' + }, + ] + }, 'retryStrategy': { 'attempts': 123 }, 'schedulingPriority': 123, 'serviceRequestPayload': 'string', - 'serviceJobType': 'EKS'|'ECS'|'ECS_FARGATE'|'SAGEMAKER_TRAINING', + 'serviceJobType': 'SAGEMAKER_TRAINING', 'shareIdentifier': 'string', + 'quotaShareName': 'string', + 'preemptionConfiguration': { + 'preemptionRetriesBeforeTermination': 123 + }, 'startedAt': 123, - 'status': 'SUBMITTED'|'PENDING'|'RUNNABLE'|'STARTING'|'RUNNING'|'SUCCEEDED'|'FAILED', + 'status': 'SUBMITTED'|'PENDING'|'RUNNABLE'|'SCHEDULED'|'STARTING'|'RUNNING'|'SUCCEEDED'|'FAILED', 'statusReason': 'string', 'stoppedAt': 123, 'tags': { 'string': 'string' }, - 'timeout': { + 'timeoutConfig': { 'attemptDurationSeconds': 123 } } @@ -132,6 +164,19 @@ def _terminate_service_job(job_id: str, reason: Optional[str] = "default termina return client.terminate_service_job(jobId=job_id, reason=reason) +def _update_service_job(job_id: str, scheduling_priority: int) -> Dict: + """Batch update_service_job API helper function. + + Args: + job_id: Job ID or Job Arn + scheduling_priority: An integer representing scheduling priority. + + Returns: a dict containing jobArn, jobId and jobName. + """ + client = get_batch_boto_client() + return client.update_service_job(jobId=job_id, schedulingPriority=scheduling_priority) + + def _list_service_job( job_queue: str, job_status: Optional[str] = None, diff --git a/sagemaker-train/src/sagemaker/train/aws_batch/training_queue.py b/sagemaker-train/src/sagemaker/train/aws_batch/training_queue.py index bfbc1fb0da..d3b5f78940 100644 --- a/sagemaker-train/src/sagemaker/train/aws_batch/training_queue.py +++ b/sagemaker-train/src/sagemaker/train/aws_batch/training_queue.py @@ -41,6 +41,8 @@ def submit( share_identifier: Optional[str] = None, timeout: Optional[Dict] = None, tags: Optional[Dict] = None, + quota_share_name: Optional[str] = None, + preemption_config: Optional[Dict] = None, ) -> TrainingQueuedJob: """Submit a queued job and return a QueuedJob object. @@ -53,6 +55,8 @@ def submit( share_identifier: Share identifier for Batch job. timeout: Timeout configuration for Batch job. tags: Tags apply to Batch job. These tags are for Batch job only. + quota_share_name: Quota Share name for the Batch job. + preemption_config: Preemption configuration. Returns: a TrainingQueuedJob object with Batch job ARN and job name. @@ -85,6 +89,8 @@ def submit( timeout, share_identifier, tags, + quota_share_name, + preemption_config, ) if "jobArn" not in resp or "jobName" not in resp: raise MissingRequiredArgument( diff --git a/sagemaker-train/src/sagemaker/train/aws_batch/training_queued_job.py b/sagemaker-train/src/sagemaker/train/aws_batch/training_queued_job.py index 98f57069e1..0b8d73eebc 100644 --- a/sagemaker-train/src/sagemaker/train/aws_batch/training_queued_job.py +++ b/sagemaker-train/src/sagemaker/train/aws_batch/training_queued_job.py @@ -29,7 +29,7 @@ SourceCode, TrainingImageConfig, ) -from .batch_api_helper import _terminate_service_job, _describe_service_job +from .batch_api_helper import _terminate_service_job, _describe_service_job, _update_service_job from .exception import NoTrainingJob, MissingRequiredArgument from ..utils import _get_training_job_name_from_training_job_arn from .constants import JOB_STATUS_COMPLETED, JOB_STATUS_FAILED, POLL_IN_SECONDS @@ -85,6 +85,17 @@ def terminate(self, reason: Optional[str] = "Default terminate reason") -> None: """ _terminate_service_job(self.job_arn, reason) + def update(self, scheduling_priority: int) -> Dict: + """Update Batch job. + + Args: + scheduling_priority: An integer representing scheduling priority. + + Returns: A dict which includes jobArn, jobName and jobId. + + """ + return _update_service_job(self.job_arn, scheduling_priority) + def describe(self) -> Dict: """Describe Batch job. diff --git a/sagemaker-train/tests/integ/train/aws_batch/manager.py b/sagemaker-train/tests/integ/train/aws_batch/manager.py index b417f86b53..5fb1c4719a 100644 --- a/sagemaker-train/tests/integ/train/aws_batch/manager.py +++ b/sagemaker-train/tests/integ/train/aws_batch/manager.py @@ -16,16 +16,21 @@ class BatchTestResourceManager: + CAPACITY_UNIT = "ml.m5.2xlarge" def __init__( self, batch_client, - queue_name="pysdk-test-queue", - service_env_name="pysdk-test-queue-service-environment", + queue_name="pysdk-test-qm-queue", + service_env_name="pysdk-test-qm-queue-service-environment", + scheduling_policy_name="pysdk-test-qm-scheduling-policy", + quota_share_name="pysdk-test-quota-share", ): self.batch_client = batch_client self.queue_name = queue_name self.service_environment_name = service_env_name + self.scheduling_policy_name = scheduling_policy_name + self.quota_share_name = quota_share_name def _create_or_get_service_environment(self, service_environment_name): print(f"Creating service environment: {service_environment_name}") @@ -33,7 +38,7 @@ def _create_or_get_service_environment(self, service_environment_name): response = self.batch_client.create_service_environment( serviceEnvironmentName=service_environment_name, serviceEnvironmentType="SAGEMAKER_TRAINING", - capacityLimits=[{"maxCapacity": 10, "capacityUnit": "NUM_INSTANCES"}], + capacityLimits=[{"maxCapacity": 10, "capacityUnit": BatchTestResourceManager.CAPACITY_UNIT}], ) print(f"Service environment {service_environment_name} created successfully.") return response @@ -48,22 +53,24 @@ def _create_or_get_service_environment(self, service_environment_name): print(f"Error creating service environment: {e}") raise - def _create_or_get_queue(self, queue_name, service_environment_arn): - + def _create_or_get_queue(self, queue_name, service_environment_arn, scheduling_policy_arn=None): print(f"Creating job queue: {queue_name}") try: - response = self.batch_client.create_job_queue( - jobQueueName=queue_name, - priority=1, - computeEnvironmentOrder=[], - serviceEnvironmentOrder=[ + create_params = { + "jobQueueName": queue_name, + "priority": 1, + "computeEnvironmentOrder": [], + "serviceEnvironmentOrder": [ { "order": 1, "serviceEnvironment": service_environment_arn, }, ], - jobQueueType="SAGEMAKER_TRAINING", - ) + "jobQueueType": "SAGEMAKER_TRAINING", + } + if scheduling_policy_arn: + create_params["schedulingPolicyArn"] = scheduling_policy_arn + response = self.batch_client.create_job_queue(**create_params) print(f"Job queue {queue_name} created successfully.") return response except Exception as e: @@ -75,6 +82,88 @@ def _create_or_get_queue(self, queue_name, service_environment_arn): print(f"Error creating job queue: {e}") raise + def _find_scheduling_policy(self, scheduling_policy_name): + paginator = self.batch_client.get_paginator("list_scheduling_policies") + for page in paginator.paginate(): + for sp in page.get("schedulingPolicies", []): + if scheduling_policy_name in sp["arn"]: + return sp + return None + + def _create_or_get_scheduling_policy(self, scheduling_policy_name): + print(f"Creating scheduling policy: {scheduling_policy_name}") + try: + response = self.batch_client.create_scheduling_policy( + name=scheduling_policy_name, + quotaSharePolicy={"idleResourceAssignmentStrategy": "FIFO"}, + ) + print(f"Scheduling policy {scheduling_policy_name} created successfully.") + return response + except Exception as e: + if "Object already exists" in str(e): + print("Resource already exists. Fetching existing resource.") + sp = self._find_scheduling_policy(scheduling_policy_name) + if not sp: + raise + return sp + else: + print(f"Error creating scheduling policy: {e}") + raise + + def _create_or_get_quota_share(self, quota_share_name, queue_name): + print(f"Creating quota share: {quota_share_name}") + try: + response = self.batch_client.create_quota_share( + quotaShareName=quota_share_name, + jobQueue=queue_name, + capacityLimits=[{"maxCapacity": 10, "capacityUnit": BatchTestResourceManager.CAPACITY_UNIT}], + resourceSharingConfiguration={"strategy": "RESERVE"}, + preemptionConfiguration={"inSharePreemption": "DISABLED"}, + state="ENABLED", + ) + print(f"Quota share {quota_share_name} created successfully.") + return response + except Exception as e: + if "already exists" in str(e): + print("Resource already exists. Fetching existing resource.") + desc_jq = self.batch_client.describe_job_queues(jobQueues=[queue_name]) + jq_arn = desc_jq["jobQueues"][0]["jobQueueArn"] + return self.batch_client.describe_quota_share(quotaShareArn=f"{jq_arn}/quota-share/{quota_share_name}") + else: + print(f"Error creating quota share: {e}") + raise + + def _update_quota_share_state(self, quota_share_arn, state): + print(f"Updating quota share {quota_share_arn} to state {state}") + try: + response = self.batch_client.update_quota_share(quotaShareArn=quota_share_arn, state=state) + return response + except Exception as e: + print(f"Error updating quota share: {e}") + + def _wait_for_quota_share_state(self, quota_share_arn, expected_status, expected_state, timeout=300): + print(f"Waiting for quota share to be {expected_status}...") + start = time.time() + while time.time() - start < timeout: + try: + response = self.batch_client.describe_quota_share(quotaShareArn=quota_share_arn) + except Exception as e: + if expected_status == "DELETED" and "does not exist" in str(e): + return + raise e + + state = response.get("state") + status = response.get("status") + + if status == expected_status and state == expected_state: + print(f"Quota share is now {expected_state}.") + return + if status == "INVALID": + raise ValueError(f"Something went wrong!") + + time.sleep(5) + raise TimeoutError(f"Quota share did not reach {expected_state} within {timeout}s") + def _update_queue_state(self, queue_name, state): try: print(f"Updating queue {queue_name} to state {state}") @@ -93,41 +182,160 @@ def _update_service_environment_state(self, service_environment_name, state): except Exception as e: print(f"Error updating service environment: {e}") - def _wait_for_queue_state(self, queue_name, state): - print(f"Waiting for queue {queue_name} to be {state}...") - while True: - response = self.batch_client.describe_job_queues(jobQueues=[queue_name]) - print(f"Current state: {response}") - if response["jobQueues"][0]["state"] == state: - break + def _wait_for_queue_state(self, job_queue_name, expected_status, expected_state, timeout=300): + print(f"Waiting for queue {job_queue_name} to be {expected_status}...") + start = time.time() + while time.time() - start < timeout: + describe_jq_response = self.batch_client.describe_job_queues( + jobQueues=[job_queue_name] + ) + if describe_jq_response["jobQueues"]: + jq = describe_jq_response["jobQueues"][0] + + state = jq["state"] + status = jq["status"] + + if status == expected_status and state == expected_state: + print(f"Queue {job_queue_name} is now {state}.") + return + if status == "INVALID": + raise ValueError(f"Something went wrong!") + elif expected_status == "DELETED": + print(f"JobQueue {job_queue_name} has been deleted") + return + time.sleep(5) - print(f"Queue {queue_name} is now {state}.") + raise TimeoutError(f"Queue {job_queue_name} did not reach {expected_state} within {timeout}s") - def _wait_for_service_environment_state(self, service_environment_name, state): - print(f"Waiting for service environment {service_environment_name} to be {state}...") - while True: - response = self.batch_client.describe_service_environments( + def _wait_for_service_environment_state(self, service_environment_name, expected_status, expected_state, timeout=300): + print(f"Waiting for service environment {service_environment_name} to be {expected_status}...") + start = time.time() + while time.time() - start < timeout: + describe_response = self.batch_client.describe_service_environments( serviceEnvironments=[service_environment_name] ) - print(f"Current state: {response}") - if response["serviceEnvironments"][0]["state"] == state: - break + if describe_response["serviceEnvironments"]: + se = describe_response["serviceEnvironments"][0] + + state = se["state"] + status = se["status"] + + if status == expected_status and state == expected_state: + print(f"Service environment {service_environment_name} is now {expected_state}.") + return + if status == "INVALID": + raise ValueError(f"Something went wrong!") + elif expected_status == "DELETED": + print(f"ServiceEnvironment {service_environment_name} has been deleted") + return + time.sleep(5) - print(f"Service environment {service_environment_name} is now {state}.") + raise TimeoutError(f"Service environment {service_environment_name} did not reach {expected_state} within {timeout}s") + + def _delete_service_environment(self, service_environment_name: str): + print(f"Setting ServiceEnvironment {service_environment_name} to DISABLED") + self.batch_client.update_service_environment( + serviceEnvironment=service_environment_name, state="DISABLED" + ) + + print("Waiting for ServiceEnvironment update to finish...") + self._wait_for_service_environment_state(service_environment_name, "VALID", "DISABLED") + + print(f"Deleting ServiceEnvironment {service_environment_name}") + self.batch_client.delete_service_environment(serviceEnvironment=service_environment_name) + + print("Waiting for ServiceEnvironment update to finish...") + self._wait_for_service_environment_state(service_environment_name, "DELETED", "DISABLED") + + def _delete_job_queue(self, job_queue_name: str): + print(f"Disabling JobQueue {job_queue_name}") + self.batch_client.update_job_queue(jobQueue=job_queue_name, state="DISABLED") + + print("Waiting for JobQueue update to finish...") + self._wait_for_queue_state(job_queue_name, "VALID", "DISABLED") + + print(f"Deleting JobQueue {job_queue_name}") + self.batch_client.delete_job_queue(jobQueue=job_queue_name) + + print("Waiting for JobQueue update to finish...") + self._wait_for_queue_state(job_queue_name, "DELETED", "DISABLED") + + def _delete_scheduling_policy(self, scheduling_policy_arn: str): + print(f"Deleting SchedulingPolicy {scheduling_policy_arn}") + self.batch_client.delete_scheduling_policy(arn=scheduling_policy_arn) + + def _delete_quota_share(self, quota_share_arn: str): + print(f"Disabling QuotaShare {quota_share_arn}") + self.batch_client.update_quota_share(quotaShareArn=quota_share_arn, state="DISABLED") + + print("Waiting for QuotaShare update to finish...") + self._wait_for_quota_share_state(quota_share_arn, "VALID", "DISABLED") + + print(f"Deleting QuotaShare {quota_share_arn}") + self.batch_client.delete_quota_share(quotaShareArn=quota_share_arn) - def get_or_create_resources(self, queue_name=None, service_environment_name=None): + print("Waiting for QuotaShare deletion to finish...") + self._wait_for_quota_share_state(quota_share_arn, "DELETED", "DISABLED") + + def get_or_create_resources( + self, + queue_name=None, + service_environment_name=None, + scheduling_policy_name=None, + quota_share_name=None + ): queue_name = queue_name or self.queue_name service_environment_name = service_environment_name or self.service_environment_name + scheduling_policy_name = scheduling_policy_name or self.scheduling_policy_name + quota_share_name = quota_share_name or self.quota_share_name service_environment = self._create_or_get_service_environment(service_environment_name) if service_environment.get("state") != "ENABLED": self._update_service_environment_state(service_environment_name, "ENABLED") - self._wait_for_service_environment_state(service_environment_name, "ENABLED") + self._wait_for_service_environment_state(service_environment_name, "VALID", "ENABLED") time.sleep(10) - queue = self._create_or_get_queue(queue_name, service_environment["serviceEnvironmentArn"]) + scheduling_policy = self._create_or_get_scheduling_policy(scheduling_policy_name) + scheduling_policy_arn = scheduling_policy.get("arn") + + queue = self._create_or_get_queue(queue_name, service_environment["serviceEnvironmentArn"], + scheduling_policy_arn) if queue.get("state") != "ENABLED": self._update_queue_state(queue_name, "ENABLED") - self._wait_for_queue_state(queue_name, "ENABLED") + self._wait_for_queue_state(queue_name, "VALID", "ENABLED") time.sleep(10) - return queue, service_environment + + quota_share = self._create_or_get_quota_share(quota_share_name, queue_name) + if quota_share.get("state") != "ENABLED": + self._update_quota_share_state(quota_share["quotaShareArn"], "ENABLED") + self._wait_for_quota_share_state(quota_share["quotaShareArn"], "VALID", "ENABLED") + time.sleep(10) + + return queue, service_environment, scheduling_policy, quota_share + + def delete_resources( + self, + queue_name=None, + service_environment_name=None, + scheduling_policy_name=None, + quota_share_name=None + ): + queue_name = queue_name or self.queue_name + service_environment_name = service_environment_name or self.service_environment_name + scheduling_policy_name = scheduling_policy_name or self.scheduling_policy_name + quota_share_name = quota_share_name or self.quota_share_name + + # Get ARNs needed for deletion + desc_jq = self.batch_client.describe_job_queues(jobQueues=[queue_name]) + if desc_jq["jobQueues"]: + jq_arn = desc_jq["jobQueues"][0]["jobQueueArn"] + quota_share_arn = f"{jq_arn}/quota-share/{quota_share_name}" + self._delete_quota_share(quota_share_arn) + + self._delete_job_queue(queue_name) + + sp = self._find_scheduling_policy(scheduling_policy_name) + if sp: + self._delete_scheduling_policy(sp["arn"]) + + self._delete_service_environment(service_environment_name) diff --git a/sagemaker-train/tests/integ/train/aws_batch/test_queue.py b/sagemaker-train/tests/integ/train/aws_batch/test_queue.py index 7333acddca..9706fec977 100644 --- a/sagemaker-train/tests/integ/train/aws_batch/test_queue.py +++ b/sagemaker-train/tests/integ/train/aws_batch/test_queue.py @@ -38,56 +38,69 @@ def batch_client(): def batch_test_resource_manager(batch_client): resource_manager = BatchTestResourceManager(batch_client=batch_client) resource_manager.get_or_create_resources() - return resource_manager + yield resource_manager + resource_manager.delete_resources() def test_model_trainer_submit(batch_test_resource_manager, sagemaker_session): # noqa: F811 queue_name = batch_test_resource_manager.queue_name - source_code = SourceCode( - source_dir=f"{DATA_DIR}/train/script_mode/", - requirements="requirements.txt", - entry_script="custom_script.py", - ) - hyperparameters = { - "batch-size": 32, - "epochs": 1, - "learning-rate": 0.01, - } + source_code = SourceCode(command="echo 'Hello World'") compute = Compute(instance_type="ml.m5.2xlarge") model_trainer = ModelTrainer( sagemaker_session=sagemaker_session, training_image=DEFAULT_CPU_IMAGE, source_code=source_code, compute=compute, - hyperparameters=hyperparameters, base_job_name="test-batch-model-trainer", ) - train_data = InputData( - channel_name="train", - data_source=f"{DATA_DIR}/train/script_mode/data/train/", - ) - test_data = InputData( - channel_name="test", - data_source=f"{DATA_DIR}/train/script_mode/data/test/", - ) training_queue = TrainingQueue(queue_name=queue_name) try: queued_job = training_queue.submit( training_job=model_trainer, - inputs=[train_data, test_data], + inputs=None, + job_name="pysdk_integ_test_job", + retry_config={ + "attempts": 1, + "evaluateOnExit": [ + { + "action": "Retry", + "onStatusReason": "Received status from SageMaker: AlgorithmError: *" + }, + { + "action": "EXIT", + "onStatusReason": "*" + } + ] + }, + priority=1, + tags={"pysdk-integ-test-tag-key": "pysdk-integ-test-tag-value"}, + quota_share_name=batch_test_resource_manager.quota_share_name, + preemption_config={"preemptionRetriesBeforeTermination": 0} ) except botocore.exceptions.ClientError as e: print(e.response["ResponseMetadata"]) print(e.response["Error"]["Message"]) raise e + res = queued_job.describe() assert res is not None - assert res["status"] == "SUBMITTED" + assert res["status"] in {"SUBMITTED", "RUNNABLE", "SCHEDULED"} - queued_job.wait(timeout=1800) - res = queued_job.describe() + res = queued_job.update(2) assert res is not None - assert res["status"] == "SUCCEEDED" + assert res["jobArn"] == queued_job.job_arn + + # Job termination results in FAILED + queued_job.terminate() + + res = queued_job.wait(timeout=900) + assert res is not None + assert res["status"] == "FAILED" + + list_by_job_name = training_queue.list_jobs(queued_job.job_name) + list_by_job_status = training_queue.list_jobs(status="FAILED") + assert queued_job.job_arn in [job.job_arn for job in list_by_job_name] + assert queued_job.job_name in [job.job_name for job in list_by_job_status] diff --git a/sagemaker-train/tests/unit/train/aws_batch/conftest.py b/sagemaker-train/tests/unit/train/aws_batch/conftest.py index e02ba33c33..02d883f6e4 100644 --- a/sagemaker-train/tests/unit/train/aws_batch/conftest.py +++ b/sagemaker-train/tests/unit/train/aws_batch/conftest.py @@ -182,3 +182,7 @@ "TrainingInputMode": "File", }, } + +# Quota Management +QUOTA_SHARE_NAME = "test-quota-share" +PREEMPTION_CONFIG = {"preemptionRetriesBeforeTermination": 10} diff --git a/sagemaker-train/tests/unit/train/aws_batch/test_batch_api_helper.py b/sagemaker-train/tests/unit/train/aws_batch/test_batch_api_helper.py index b769918341..9e90edfb39 100644 --- a/sagemaker-train/tests/unit/train/aws_batch/test_batch_api_helper.py +++ b/sagemaker-train/tests/unit/train/aws_batch/test_batch_api_helper.py @@ -21,11 +21,13 @@ _describe_service_job, _terminate_service_job, _list_service_job, + _update_service_job, ) from .conftest import ( JOB_NAME, JOB_QUEUE, JOB_ID, + JOB_ARN, REASON, BATCH_TAGS, TRAINING_TAGS, @@ -43,6 +45,8 @@ TRAINING_JOB_PAYLOAD, NEXT_TOKEN, JOB_STATUS_RUNNING, + QUOTA_SHARE_NAME, + PREEMPTION_CONFIG, ) @@ -258,3 +262,56 @@ def test_list_service_job_with_status(self, mock_get_client): call_kwargs = mock_client.list_service_jobs.call_args[1] assert call_kwargs["jobStatus"] == JOB_STATUS_RUNNING + + +class TestSubmitServiceJobWithQuotaManagement: + """Tests for submit_service_job with quota management parameters""" + + @patch("sagemaker.train.aws_batch.batch_api_helper.get_batch_boto_client") + def test_submit_service_job_with_quota_share_name_and_preemption_config(self, mock_get_client): + """Test submit_service_job with quota_share_name and preemption_config""" + mock_client = Mock() + mock_client.submit_service_job.return_value = SUBMIT_SERVICE_JOB_RESP + mock_get_client.return_value = mock_client + + result = _submit_service_job( + TRAINING_JOB_PAYLOAD.copy(), + JOB_NAME, + JOB_QUEUE, + DEFAULT_SAGEMAKER_TRAINING_RETRY_CONFIG, + SCHEDULING_PRIORITY, + TIMEOUT_CONFIG, + None, + BATCH_TAGS, + QUOTA_SHARE_NAME, + PREEMPTION_CONFIG, + ) + + assert result["jobName"] == JOB_NAME + call_kwargs = mock_client.submit_service_job.call_args[1] + assert call_kwargs["quotaShareName"] == QUOTA_SHARE_NAME + assert call_kwargs["preemptionConfiguration"] == PREEMPTION_CONFIG + + +class TestUpdateServiceJob: + """Tests for update_service_job function""" + + @patch("sagemaker.train.aws_batch.batch_api_helper.get_batch_boto_client") + def test_update_service_job(self, mock_get_client): + """Test update_service_job calls update API""" + mock_client = Mock() + mock_client.update_service_job.return_value = { + "jobArn": JOB_ARN, + "jobName": JOB_NAME, + "jobId": JOB_ID, + } + mock_get_client.return_value = mock_client + + result = _update_service_job(JOB_ID, SCHEDULING_PRIORITY) + + assert result["jobArn"] == JOB_ARN + assert result["jobName"] == JOB_NAME + assert result["jobId"] == JOB_ID + mock_client.update_service_job.assert_called_once_with( + jobId=JOB_ID, schedulingPriority=SCHEDULING_PRIORITY + ) diff --git a/sagemaker-train/tests/unit/train/aws_batch/test_training_queue.py b/sagemaker-train/tests/unit/train/aws_batch/test_training_queue.py index c917b71f1f..3c1084ff58 100644 --- a/sagemaker-train/tests/unit/train/aws_batch/test_training_queue.py +++ b/sagemaker-train/tests/unit/train/aws_batch/test_training_queue.py @@ -32,6 +32,8 @@ LIST_SERVICE_JOB_BY_SHARE_RESP_WITH_JOBS, LIST_SERVICE_JOB_RESP_EMPTY, TRAINING_JOB_PAYLOAD, + QUOTA_SHARE_NAME, + PREEMPTION_CONFIG, ) @@ -377,3 +379,45 @@ def test_get_job_not_found(self, mock_list_service_job): with pytest.raises(ValueError, match="Cannot find job"): queue.get_job(JOB_NAME) + + +class TestTrainingQueueSubmitWithQuotaManagement: + """Tests for TrainingQueue.submit with quota management parameters""" + + @patch("sagemaker.train.aws_batch.training_queue._submit_service_job") + def test_submit_with_quota_share_name_and_preemption_config(self, mock_submit_service_job): + """Test submit with quota_share_name and preemption_config""" + mock_submit_service_job.return_value = SUBMIT_SERVICE_JOB_RESP + + trainer = Mock(spec=ModelTrainer) + trainer.training_mode = Mode.SAGEMAKER_TRAINING_JOB + trainer._create_training_job_args.return_value = TRAINING_JOB_PAYLOAD + + queue = TrainingQueue(JOB_QUEUE) + queued_job = queue.submit( + trainer, + [], + JOB_NAME, + DEFAULT_SAGEMAKER_TRAINING_RETRY_CONFIG, + SCHEDULING_PRIORITY, + SHARE_IDENTIFIER, + TIMEOUT_CONFIG, + BATCH_TAGS, + quota_share_name=QUOTA_SHARE_NAME, + preemption_config=PREEMPTION_CONFIG, + ) + + assert queued_job.job_name == JOB_NAME + assert queued_job.job_arn == JOB_ARN + mock_submit_service_job.assert_called_once_with( + TRAINING_JOB_PAYLOAD, + JOB_NAME, + JOB_QUEUE, + DEFAULT_SAGEMAKER_TRAINING_RETRY_CONFIG, + SCHEDULING_PRIORITY, + TIMEOUT_CONFIG, + SHARE_IDENTIFIER, + BATCH_TAGS, + QUOTA_SHARE_NAME, + PREEMPTION_CONFIG, + ) diff --git a/sagemaker-train/tests/unit/train/aws_batch/test_training_queued_job.py b/sagemaker-train/tests/unit/train/aws_batch/test_training_queued_job.py index babdfcd585..2ba61f4471 100644 --- a/sagemaker-train/tests/unit/train/aws_batch/test_training_queued_job.py +++ b/sagemaker-train/tests/unit/train/aws_batch/test_training_queued_job.py @@ -22,6 +22,7 @@ from .conftest import ( JOB_NAME, JOB_ARN, + JOB_ID, REASON, TRAINING_JOB_NAME, TRAINING_JOB_ARN, @@ -33,6 +34,7 @@ DESCRIBE_SERVICE_JOB_RESP_SUCCEEDED, DESCRIBE_SERVICE_JOB_RESP_FAILED, DESCRIBE_SERVICE_JOB_RESP_PENDING, + SCHEDULING_PRIORITY, ) @@ -86,6 +88,27 @@ def test_terminate_default_reason(self, mock_terminate_service_job): assert call_kwargs[0] == JOB_ARN +class TestTrainingQueuedJobUpdate: + """Tests for TrainingQueuedJob.update method""" + + @patch("sagemaker.train.aws_batch.training_queued_job._update_service_job") + def test_update(self, mock_update_service_job): + """Test update calls update API""" + mock_update_service_job.return_value = { + "jobArn": JOB_ARN, + "jobName": JOB_NAME, + "jobId": JOB_ID, + } + + queued_job = TrainingQueuedJob(JOB_ARN, JOB_NAME) + result = queued_job.update(SCHEDULING_PRIORITY) + + mock_update_service_job.assert_called_once_with(JOB_ARN, SCHEDULING_PRIORITY) + assert result["jobArn"] == JOB_ARN + assert result["jobName"] == JOB_NAME + assert result["jobId"] == JOB_ID + + class TestTrainingQueuedJobWait: """Tests for TrainingQueuedJob.wait method""" diff --git a/v3-examples/training-examples/aws_batch/sm-training-queues_quota-management.ipynb b/v3-examples/training-examples/aws_batch/sm-training-queues_quota-management.ipynb new file mode 100644 index 0000000000..272592853f --- /dev/null +++ b/v3-examples/training-examples/aws_batch/sm-training-queues_quota-management.ipynb @@ -0,0 +1,978 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "f6d21c4c-e5fe-4992-9fd9-a33f36e4db2d", + "metadata": {}, + "source": [ + "# Getting Started with Quota Management\n", + "Quota management enables administrators to efficiently allocate shared compute resources between teams and projects by defining compute quotas and strategies for sharing capacity between quota shares. Each quota share operates as a virtual queue. When scheduling jobs for a job queue, AWS Batch will iterate through all attached quota shares to dispatch jobs that fit within their configured capacity and borrowing limits.\n", + "\n", + "This notebook shows how to create quota management resources in [AWS Batch for SageMaker Training jobs](https://aws.amazon.com/blogs/machine-learning/introducing-aws-batch-support-for-amazon-sagemaker-training-jobs/), and illustrates how the [AWS Batch](https://aws.amazon.com/batch/) scheduler enables resource sharing between quota shares, leveraging preemption to restore borrowed idle capacity when jobs arrive.\n", + "\n", + "---\n", + "\n", + "This notebook's CI test result for us-west-2 is as follows. CI test results in other regions can be found at the end of the notebook.\n", + "\n", + "![This us-west-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/us-west-2/build_and_train_models|sm-training-queues|sm-training-queues_quota-management.ipynb)\n", + "\n", + "---" + ] + }, + { + "cell_type": "markdown", + "id": "10e12b35-3dc2-4376-b90e-54c00c70a607", + "metadata": { + "tags": [] + }, + "source": [ + "## Setup and Configure Training Job Variables\n", + "We will need a few instances for a short duration for the sample jobs. Change any of the constant variables below to adjust the example to your liking." + ] + }, + { + "cell_type": "code", + "id": "6316085c-262d-4437-8987-9ca7eca94965", + "metadata": { + "tags": [] + }, + "source": [ + "INSTANCE_TYPE = \"ml.g5.xlarge\"\n", + "INSTANCE_COUNT = 1\n", + "MAX_RUN_TIME = 300\n", + "TRAINING_JOB_NAME = \"hello-world-simple-job\"" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "code", + "id": "b4edef56-49f3-4729-afdd-5345c5710363", + "metadata": { + "tags": [] + }, + "source": [ + "import logging\n", + "\n", + "logging.basicConfig(\n", + " level=logging.INFO, format=\"%(asctime)s - %(name)s - %(levelname)s - %(message)s\"\n", + ")\n", + "logging.getLogger(\"botocore.client\").setLevel(level=logging.WARN)\n", + "logger = logging.getLogger(__name__)\n", + "\n", + "from sagemaker.core.helper.session_helper import Session\n", + "from sagemaker.core import image_uris\n", + "\n", + "session = Session()\n", + "\n", + "image_uri = image_uris.retrieve(\n", + " framework=\"pytorch\",\n", + " region=session.boto_session.region_name,\n", + " version=\"2.5\",\n", + " instance_type=INSTANCE_TYPE,\n", + " image_scope=\"training\",\n", + ")" + ], + "outputs": [], + "execution_count": null + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": [ + "# Create Sample Resources\n", + "Here we create the AWS Batch service environment, job queue and quota shares that we will use to enqueue our training jobs. Each quota share is configured with its own dedicated capacity limits, and may be configured to lend idle capacity to or borrow idle capacity from other quota shares. Cross-share preemption is always on, and allows a given quota share to take back any capacity it has lended to other quota shares when needed. In-share preemption can be enabled to allow high priority jobs to preempt low priority jobs within a given quota share.\n", + "- QS1: configured with the `LEND_AND_BORROW` resource sharing strategy, and a borrow limit of 200%. This allows QS1 to both lend its own idle capacity and borrow idle capacity from any other quota share that is configured with a `LEND` or `LEND_AND_BORROW` resource sharing strategy. QS1 is also configured with in-share preemption, which allows jobs within QS1 to preempt each other based on priority.\n", + "- QS2: configured with the `LEND` resource sharing strategy. This configuration allows QS2 to lend its own idle capacity but not borrow any other quota share's idle capacity.\n", + "- QS3: configured with the `RESERVE` resource sharing strategy. This configuration prevents QS3 from borrowing idle capacity from, and lending idle capacity to other quota shares.\n", + "\n", + "You can use [Batch Console](https://console.aws.amazon.com/batch) to create these resources, or you can run the cell below. The ```create_quota_management_resources``` function below will skip creating any resources that already exist." + ], + "id": "75fbf8b8827c7eef" + }, + { + "cell_type": "code", + "id": "e325ddb0-aa86-4f3b-9820-753f4bdadb19", + "metadata": { + "tags": [] + }, + "source": [ + "from sagemaker.train.aws_batch.boto_client import get_batch_boto_client\n", + "from utils.aws_batch_resource_management import AwsBatchResourceManager, QuotaShareConfig, create_quota_management_resources\n", + "\n", + "SCHEDULING_POLICY_NAME = \"my-qm-scheduling-policy\"\n", + "JOB_QUEUE_NAME = \"my-sm-training-qm-jq\"\n", + "SERVICE_ENVIRONMENT_NAME = \"my-sm-training-qm-se\"\n", + "\n", + "# Create SchedulingPolicy, ServiceEnvironment, JobQueue, and QuotaShares\n", + "resource_manager = AwsBatchResourceManager(get_batch_boto_client())\n", + "qs1 = QuotaShareConfig(\n", + " name=\"QS1\",\n", + " capacity_unit=INSTANCE_TYPE,\n", + " max_capacity=1,\n", + " in_share_preemption=True,\n", + " sharing_strategy=\"LEND_AND_BORROW\",\n", + " borrow_limit=200\n", + ")\n", + "qs2 = QuotaShareConfig(\n", + " name=\"QS2\",\n", + " capacity_unit=INSTANCE_TYPE,\n", + " max_capacity=1,\n", + " in_share_preemption=False,\n", + " sharing_strategy=\"LEND\"\n", + ")\n", + "qs3 = QuotaShareConfig(\n", + " name=\"QS3\",\n", + " capacity_unit=INSTANCE_TYPE,\n", + " max_capacity=1,\n", + " in_share_preemption=False,\n", + " sharing_strategy=\"RESERVE\"\n", + ")\n", + "resources = create_quota_management_resources(\n", + " resource_manager=resource_manager,\n", + " scheduling_policy_name=SCHEDULING_POLICY_NAME,\n", + " job_queue_name=JOB_QUEUE_NAME,\n", + " service_environment_name=SERVICE_ENVIRONMENT_NAME,\n", + " capacity_unit=INSTANCE_TYPE,\n", + " max_capacity=3,\n", + " quota_share_configs=[qs1, qs2, qs3]\n", + ")" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "markdown", + "id": "2a5d5e2e-b266-41c4-9d17-ce6c93a42db3", + "metadata": {}, + "source": [ + "## Create Hello World Model Trainer\n", + "Now that our resources are created, we'll construct a simple [ModelTrainer](https://sagemaker.readthedocs.io/en/stable/api/training/model_trainer.html). Any model trainer may be used, you may import your own instead of constructing a new one here if you wish!" + ] + }, + { + "cell_type": "code", + "id": "d71f7b99-63a3-4fd0-b735-6140fe1489f6", + "metadata": {}, + "source": [ + "from sagemaker.train.model_trainer import ModelTrainer\n", + "from sagemaker.train.configs import SourceCode, Compute, StoppingCondition\n", + "\n", + "source_code = SourceCode(command=\"echo 'Hello World'\")\n", + "\n", + "model_trainer = ModelTrainer(\n", + " training_image=image_uri,\n", + " source_code=source_code,\n", + " base_job_name=TRAINING_JOB_NAME,\n", + " compute=Compute(instance_type=INSTANCE_TYPE, instance_count=INSTANCE_COUNT),\n", + " stopping_condition=StoppingCondition(max_runtime_in_seconds=MAX_RUN_TIME),\n", + ")" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "markdown", + "id": "48f030c1-3e32-4265-a1fe-cdd6927e82ac", + "metadata": {}, + "source": [ + "## Create TrainingQueue object\n", + "Using our queue is as easy as referring to it by name in the TrainingQueue contructor. The TrainingQueue class within the SageMaker Python SDK provides built in support for working with Batch queues." + ] + }, + { + "cell_type": "code", + "id": "5d90c9a4-ff38-492e-b446-61674701d9ca", + "metadata": {}, + "source": [ + "from sagemaker.train.aws_batch.training_queue import TrainingQueue, TrainingQueuedJob\n", + "\n", + "# Construct the queue object using the SageMaker Python SDK\n", + "queue = TrainingQueue(JOB_QUEUE_NAME)\n", + "logger.info(f\"Using queue: {queue.queue_name}\")" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "markdown", + "id": "734b7fe6-0bf7-460e-aa95-7608421e900c", + "metadata": {}, + "source": [ + "## Submit Some Training Jobs to Quota Share QS1\n", + "Submitting your job to a quota share on a queue is done by calling queue.submit and passing in the quota share name. These jobs don't require any data, but in general, data should be provided by specifying inputs." + ] + }, + { + "cell_type": "code", + "id": "60f2ca52-b1ea-4af2-a143-8202ce34d5e6", + "metadata": {}, + "source": [ + "from utils.aws_batch_resource_management import JobStatus, await_jobs, list_jobs_by_quota_share\n", + "\n", + "# Submit three jobs to QS1, all with different priorities to guarantee dispatch ordering.\n", + "qs1_job_high: TrainingQueuedJob = queue.submit(job_name=\"qs1_job_high\", training_job=model_trainer, quota_share_name=qs1.name, priority=3, inputs=None)\n", + "qs1_job_med: TrainingQueuedJob = queue.submit(job_name=\"qs1_job_med\", training_job=model_trainer, quota_share_name=qs1.name, priority=2, inputs=None)\n", + "qs1_job_low: TrainingQueuedJob = queue.submit(job_name=\"qs1_job_low\", training_job=model_trainer, quota_share_name=qs1.name, priority=1, inputs=None)\n", + "\n", + "logger.info(f\"Submitted jobs: {qs1_job_high.job_name}, {qs1_job_med.job_name}, {qs1_job_low.job_name}\")\n", + "logger.info(f\"Waiting for jobs to transition...\")\n", + "\n", + "# qs1_job_high is dispatched first using capacity from QS1, and qs1_job_med is dispatched second borrowing idle capacity from QS2\n", + "await_jobs([([qs1_job_high, qs1_job_med], JobStatus.dispatched() | JobStatus.terminal())])\n", + "list_jobs_by_quota_share(queue, [qs1.name, qs2.name, qs3.name], JobStatus.active())\n" + ], + "outputs": [], + "execution_count": null + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": [ + "## Submit a Job to Quota Share QS2\n", + "This will trigger cross-share job preemption. QS2 will preempt its own capacity back from QS1, which QS1 was borrowing to run job `qs1_job_med`." + ], + "id": "e0509ae56500b3f9" + }, + { + "metadata": {}, + "cell_type": "code", + "source": [ + "# When a job is submitted to QS2, it will preempt its capacity back from QS1\n", + "qs2_job_1: TrainingQueuedJob = queue.submit(job_name=\"qs2_job_1\", training_job=model_trainer, quota_share_name=qs2.name, priority=1, inputs=None)\n", + "\n", + "logger.info(f\"Submitted job {qs2_job_1.job_name}\")\n", + "logger.info(f\"Waiting for jobs to transition...\")\n", + "\n", + "await_jobs([\n", + " ([qs2_job_1], JobStatus.dispatched() | JobStatus.terminal()),\n", + " ([qs1_job_med], {JobStatus.RUNNABLE}) # Lowest priority QS1 job will be preempted\n", + "])\n", + "list_jobs_by_quota_share(queue, [qs1.name, qs2.name, qs3.name], JobStatus.active())" + ], + "id": "dd27ee554ccac87c", + "outputs": [], + "execution_count": null + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": [ + "## Update job priority\n", + "This will trigger in-share job preemption. QS1 has in-share job preemption enabled, which allows a high priority job to preempt a low priority job." + ], + "id": "3ba56c738982ef8e" + }, + { + "metadata": {}, + "cell_type": "code", + "source": [ + "# Updating qs1_job_low to be the highest priority job in QS1\n", + "qs1_job_low.update(scheduling_priority=4)\n", + "\n", + "logger.info(f\"Updated job {qs1_job_low.job_name} to increase its priority\")\n", + "logger.info(f\"Waiting for jobs to transition...\")\n", + "\n", + "await_jobs([\n", + " ([qs1_job_low], JobStatus.dispatched() | JobStatus.terminal()),\n", + " ([qs1_job_high], {JobStatus.RUNNABLE}) # High priority job will preempt lower priority jobs within the same quota share\n", + "])\n", + "list_jobs_by_quota_share(queue, [qs1.name, qs2.name, qs3.name], JobStatus.active())" + ], + "id": "c6dd613ec1e373e", + "outputs": [], + "execution_count": null + }, + { + "cell_type": "markdown", + "id": "3ca39fca-6bb6-4e0b-841e-7f66d97b6074", + "metadata": {}, + "source": [ + "# Optional: Delete AWS Batch Resources\n", + "This shows how to delete the AWS Batch ServiceEnvironment and JobQueue. This step is completely optional, uncomment the code below to delete the resources created a few steps above." + ] + }, + { + "cell_type": "code", + "id": "8d745e2d-40a8-45ef-b231-e8acd9b5e8eb", + "metadata": { + "tags": [] + }, + "source": [ + "from utils.aws_batch_resource_management import delete_resources\n", + "\n", + "# delete_resources(resource_manager, resources)" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "markdown", + "id": "7070dcc9", + "metadata": {}, + "source": [ + "## Notebook CI Test Results\n", + "\n", + "This notebook was tested in multiple regions. The test results are as follows, except for us-west-2 which is shown at the top of the notebook.\n", + "\n", + "\n", + "![This us-east-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/us-east-1/build_and_train_models|sm-training-queues|sm-training-queues_quota-management.ipynb)\n", + "\n", + "![This us-east-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/us-east-2/build_and_train_models|sm-training-queues|sm-training-queues_quota-management.ipynb)\n", + "\n", + "![This us-west-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/us-west-1/build_and_train_models|sm-training-queues|sm-training-queues_quota-management.ipynb)\n", + "\n", + "![This ca-central-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ca-central-1/build_and_train_models|sm-training-queues|sm-training-queues_quota-management.ipynb)\n", + "\n", + "![This sa-east-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/sa-east-1/build_and_train_models|sm-training-queues|sm-training-queues_quota-management.ipynb)\n", + "\n", + "![This eu-west-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-west-1/build_and_train_models|sm-training-queues|sm-training-queues_quota-management.ipynb)\n", + "\n", + "![This eu-west-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-west-2/build_and_train_models|sm-training-queues|sm-training-queues_quota-management.ipynb)\n", + "\n", + "![This eu-west-3 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-west-3/build_and_train_models|sm-training-queues|sm-training-queues_quota-management.ipynb)\n", + "\n", + "![This eu-central-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-central-1/build_and_train_models|sm-training-queues|sm-training-queues_quota-management.ipynb)\n", + "\n", + "![This eu-north-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-north-1/build_and_train_models|sm-training-queues|sm-training-queues_quota-management.ipynb)\n", + "\n", + "![This ap-southeast-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-southeast-1/build_and_train_models|sm-training-queues|sm-training-queues_quota-management.ipynb)\n", + "\n", + "![This ap-southeast-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-southeast-2/build_and_train_models|sm-training-queues|sm-training-queues_quota-management.ipynb)\n", + "\n", + "![This ap-northeast-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-northeast-1/build_and_train_models|sm-training-queues|sm-training-queues_quota-management.ipynb)\n", + "\n", + "![This ap-northeast-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-northeast-2/build_and_train_models|sm-training-queues|sm-training-queues_quota-management.ipynb)\n", + "\n", + "![This ap-south-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-south-1/build_and_train_models|sm-training-queues|sm-training-queues_quota-management.ipynb)\n" + ] + } + ], + "metadata": { + "availableInstances": [ + { + "_defaultOrder": 0, + "_isFastLaunch": true, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 4, + "name": "ml.t3.medium", + "vcpuNum": 2 + }, + { + "_defaultOrder": 1, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 8, + "name": "ml.t3.large", + "vcpuNum": 2 + }, + { + "_defaultOrder": 2, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 16, + "name": "ml.t3.xlarge", + "vcpuNum": 4 + }, + { + "_defaultOrder": 3, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 32, + "name": "ml.t3.2xlarge", + "vcpuNum": 8 + }, + { + "_defaultOrder": 4, + "_isFastLaunch": true, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 8, + "name": "ml.m5.large", + "vcpuNum": 2 + }, + { + "_defaultOrder": 5, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 16, + "name": "ml.m5.xlarge", + "vcpuNum": 4 + }, + { + "_defaultOrder": 6, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 32, + "name": "ml.m5.2xlarge", + "vcpuNum": 8 + }, + { + "_defaultOrder": 7, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 64, + "name": "ml.m5.4xlarge", + "vcpuNum": 16 + }, + { + "_defaultOrder": 8, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 128, + "name": "ml.m5.8xlarge", + "vcpuNum": 32 + }, + { + "_defaultOrder": 9, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 192, + "name": "ml.m5.12xlarge", + "vcpuNum": 48 + }, + { + "_defaultOrder": 10, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 256, + "name": "ml.m5.16xlarge", + "vcpuNum": 64 + }, + { + "_defaultOrder": 11, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 384, + "name": "ml.m5.24xlarge", + "vcpuNum": 96 + }, + { + "_defaultOrder": 12, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 8, + "name": "ml.m5d.large", + "vcpuNum": 2 + }, + { + "_defaultOrder": 13, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 16, + "name": "ml.m5d.xlarge", + "vcpuNum": 4 + }, + { + "_defaultOrder": 14, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 32, + "name": "ml.m5d.2xlarge", + "vcpuNum": 8 + }, + { + "_defaultOrder": 15, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 64, + "name": "ml.m5d.4xlarge", + "vcpuNum": 16 + }, + { + "_defaultOrder": 16, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 128, + "name": "ml.m5d.8xlarge", + "vcpuNum": 32 + }, + { + "_defaultOrder": 17, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 192, + "name": "ml.m5d.12xlarge", + "vcpuNum": 48 + }, + { + "_defaultOrder": 18, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 256, + "name": "ml.m5d.16xlarge", + "vcpuNum": 64 + }, + { + "_defaultOrder": 19, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 384, + "name": "ml.m5d.24xlarge", + "vcpuNum": 96 + }, + { + "_defaultOrder": 20, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": true, + "memoryGiB": 0, + "name": "ml.geospatial.interactive", + "supportedImageNames": [ + "sagemaker-geospatial-v1-0" + ], + "vcpuNum": 0 + }, + { + "_defaultOrder": 21, + "_isFastLaunch": true, + "category": "Compute optimized", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 4, + "name": "ml.c5.large", + "vcpuNum": 2 + }, + { + "_defaultOrder": 22, + "_isFastLaunch": false, + "category": "Compute optimized", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 8, + "name": "ml.c5.xlarge", + "vcpuNum": 4 + }, + { + "_defaultOrder": 23, + "_isFastLaunch": false, + "category": "Compute optimized", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 16, + "name": "ml.c5.2xlarge", + "vcpuNum": 8 + }, + { + "_defaultOrder": 24, + "_isFastLaunch": false, + "category": "Compute optimized", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 32, + "name": "ml.c5.4xlarge", + "vcpuNum": 16 + }, + { + "_defaultOrder": 25, + "_isFastLaunch": false, + "category": "Compute optimized", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 72, + "name": "ml.c5.9xlarge", + "vcpuNum": 36 + }, + { + "_defaultOrder": 26, + "_isFastLaunch": false, + "category": "Compute optimized", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 96, + "name": "ml.c5.12xlarge", + "vcpuNum": 48 + }, + { + "_defaultOrder": 27, + "_isFastLaunch": false, + "category": "Compute optimized", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 144, + "name": "ml.c5.18xlarge", + "vcpuNum": 72 + }, + { + "_defaultOrder": 28, + "_isFastLaunch": false, + "category": "Compute optimized", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 192, + "name": "ml.c5.24xlarge", + "vcpuNum": 96 + }, + { + "_defaultOrder": 29, + "_isFastLaunch": true, + "category": "Accelerated computing", + "gpuNum": 1, + "hideHardwareSpecs": false, + "memoryGiB": 16, + "name": "ml.g4dn.xlarge", + "vcpuNum": 4 + }, + { + "_defaultOrder": 30, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 1, + "hideHardwareSpecs": false, + "memoryGiB": 32, + "name": "ml.g4dn.2xlarge", + "vcpuNum": 8 + }, + { + "_defaultOrder": 31, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 1, + "hideHardwareSpecs": false, + "memoryGiB": 64, + "name": "ml.g4dn.4xlarge", + "vcpuNum": 16 + }, + { + "_defaultOrder": 32, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 1, + "hideHardwareSpecs": false, + "memoryGiB": 128, + "name": "ml.g4dn.8xlarge", + "vcpuNum": 32 + }, + { + "_defaultOrder": 33, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 4, + "hideHardwareSpecs": false, + "memoryGiB": 192, + "name": "ml.g4dn.12xlarge", + "vcpuNum": 48 + }, + { + "_defaultOrder": 34, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 1, + "hideHardwareSpecs": false, + "memoryGiB": 256, + "name": "ml.g4dn.16xlarge", + "vcpuNum": 64 + }, + { + "_defaultOrder": 35, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 1, + "hideHardwareSpecs": false, + "memoryGiB": 61, + "name": "ml.p3.2xlarge", + "vcpuNum": 8 + }, + { + "_defaultOrder": 36, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 4, + "hideHardwareSpecs": false, + "memoryGiB": 244, + "name": "ml.p3.8xlarge", + "vcpuNum": 32 + }, + { + "_defaultOrder": 37, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 8, + "hideHardwareSpecs": false, + "memoryGiB": 488, + "name": "ml.p3.16xlarge", + "vcpuNum": 64 + }, + { + "_defaultOrder": 38, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 8, + "hideHardwareSpecs": false, + "memoryGiB": 768, + "name": "ml.p3dn.24xlarge", + "vcpuNum": 96 + }, + { + "_defaultOrder": 39, + "_isFastLaunch": false, + "category": "Memory Optimized", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 16, + "name": "ml.r5.large", + "vcpuNum": 2 + }, + { + "_defaultOrder": 40, + "_isFastLaunch": false, + "category": "Memory Optimized", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 32, + "name": "ml.r5.xlarge", + "vcpuNum": 4 + }, + { + "_defaultOrder": 41, + "_isFastLaunch": false, + "category": "Memory Optimized", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 64, + "name": "ml.r5.2xlarge", + "vcpuNum": 8 + }, + { + "_defaultOrder": 42, + "_isFastLaunch": false, + "category": "Memory Optimized", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 128, + "name": "ml.r5.4xlarge", + "vcpuNum": 16 + }, + { + "_defaultOrder": 43, + "_isFastLaunch": false, + "category": "Memory Optimized", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 256, + "name": "ml.r5.8xlarge", + "vcpuNum": 32 + }, + { + "_defaultOrder": 44, + "_isFastLaunch": false, + "category": "Memory Optimized", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 384, + "name": "ml.r5.12xlarge", + "vcpuNum": 48 + }, + { + "_defaultOrder": 45, + "_isFastLaunch": false, + "category": "Memory Optimized", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 512, + "name": "ml.r5.16xlarge", + "vcpuNum": 64 + }, + { + "_defaultOrder": 46, + "_isFastLaunch": false, + "category": "Memory Optimized", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 768, + "name": "ml.r5.24xlarge", + "vcpuNum": 96 + }, + { + "_defaultOrder": 47, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 1, + "hideHardwareSpecs": false, + "memoryGiB": 16, + "name": "ml.g5.xlarge", + "vcpuNum": 4 + }, + { + "_defaultOrder": 48, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 1, + "hideHardwareSpecs": false, + "memoryGiB": 32, + "name": "ml.g5.2xlarge", + "vcpuNum": 8 + }, + { + "_defaultOrder": 49, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 1, + "hideHardwareSpecs": false, + "memoryGiB": 64, + "name": "ml.g5.4xlarge", + "vcpuNum": 16 + }, + { + "_defaultOrder": 50, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 1, + "hideHardwareSpecs": false, + "memoryGiB": 128, + "name": "ml.g5.8xlarge", + "vcpuNum": 32 + }, + { + "_defaultOrder": 51, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 1, + "hideHardwareSpecs": false, + "memoryGiB": 256, + "name": "ml.g5.16xlarge", + "vcpuNum": 64 + }, + { + "_defaultOrder": 52, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 4, + "hideHardwareSpecs": false, + "memoryGiB": 192, + "name": "ml.g5.12xlarge", + "vcpuNum": 48 + }, + { + "_defaultOrder": 53, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 4, + "hideHardwareSpecs": false, + "memoryGiB": 384, + "name": "ml.g5.24xlarge", + "vcpuNum": 96 + }, + { + "_defaultOrder": 54, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 8, + "hideHardwareSpecs": false, + "memoryGiB": 768, + "name": "ml.g5.48xlarge", + "vcpuNum": 192 + }, + { + "_defaultOrder": 55, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 8, + "hideHardwareSpecs": false, + "memoryGiB": 1152, + "name": "ml.p4d.24xlarge", + "vcpuNum": 96 + }, + { + "_defaultOrder": 56, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 8, + "hideHardwareSpecs": false, + "memoryGiB": 1152, + "name": "ml.p4de.24xlarge", + "vcpuNum": 96 + }, + { + "_defaultOrder": 57, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 32, + "name": "ml.trn1.2xlarge", + "vcpuNum": 8 + }, + { + "_defaultOrder": 58, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 512, + "name": "ml.trn1.32xlarge", + "vcpuNum": 128 + }, + { + "_defaultOrder": 59, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 512, + "name": "ml.trn1n.32xlarge", + "vcpuNum": 128 + } + ], + "instance_type": "ml.t3.medium", + "kernelspec": { + "display_name": "venv-test", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.11" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/v3-examples/training-examples/aws_batch/utils/aws_batch_resource_management.py b/v3-examples/training-examples/aws_batch/utils/aws_batch_resource_management.py index 19f51e0bdc..c77b222023 100644 --- a/v3-examples/training-examples/aws_batch/utils/aws_batch_resource_management.py +++ b/v3-examples/training-examples/aws_batch/utils/aws_batch_resource_management.py @@ -1,10 +1,17 @@ import json import logging import time +from concurrent.futures import as_completed +from concurrent.futures.thread import ThreadPoolExecutor from dataclasses import dataclass +from enum import Enum import boto3 from botocore.exceptions import ClientError +from typing import List, Set, Tuple + +from sagemaker.train.aws_batch.training_queue import TrainingQueue +from sagemaker.train.aws_batch.training_queued_job import TrainingQueuedJob # Configure logging logger = logging.getLogger(__name__) @@ -36,8 +43,142 @@ class Resources: job_queue: Resource = None service_environment: Resource = None + scheduling_policy: Resource = None batch_role: Resource = None - sagemaker_exeuction_role: Resource = None + sagemaker_execution_role: Resource = None + quota_shares: list = None + + +@dataclass +class QuotaShareConfig: + """ + Configuration for an AWS Batch quota share. + + Attributes: + name (str): Name of the quota share. + capacity_unit (str): Unit of capacity (e.g., "ml.g5.xlarge"). + max_capacity (int): Maximum capacity for this quota share. + in_share_preemption (bool): Whether in-share preemption is enabled. + sharing_strategy (str): Resource sharing strategy (e.g., "RESERVE", "LEND", "LEND_AND_BORROW"). + borrow_limit (int, optional): Maximum capacity that can be borrowed, only applicable for LEND_AND_BORROW + sharing strategy. + """ + + name: str + capacity_unit: str + max_capacity: int + in_share_preemption: bool + sharing_strategy: str + borrow_limit: int = None + + def create_quota_share_request(self, job_queue_name: str): + """ + Build the request dictionary for creating a quota share. + + Returns: + dict: Request parameters for the create_quota_share API call. + """ + return { + "quotaShareName": self.name, + "jobQueue": job_queue_name, + "capacityLimits": [{ + "capacityUnit": self.capacity_unit, + "maxCapacity": self.max_capacity, + }], + "resourceSharingConfiguration": { + "strategy": self.sharing_strategy, + **({"borrowLimit": self.borrow_limit} if self.borrow_limit else {}) + }, + "preemptionConfiguration": { + "inSharePreemption": "ENABLED" if self.in_share_preemption else "DISABLED" + }, + "state": "ENABLED" + } + + +class JobStatus(str, Enum): + """ + Enumeration of AWS Batch job statuses. + + Provides helper methods to get sets of statuses for common filtering operations. + """ + + SUBMITTED = "SUBMITTED" + PENDING = "PENDING" + RUNNABLE = "RUNNABLE" + SCHEDULED = "SCHEDULED" + STARTING = "STARTING" + RUNNING = "RUNNING" + SUCCEEDED = "SUCCEEDED" + FAILED = "FAILED" + + @staticmethod + def active() -> set: + """ + Returns the set of job statuses indicating a job is still active (not yet completed). + + Returns: + set: Job statuses for jobs that are queued or running (SUBMITTED, PENDING, + RUNNABLE, SCHEDULED, STARTING, RUNNING). + """ + return { + JobStatus.SUBMITTED, + JobStatus.PENDING, + JobStatus.RUNNABLE, + JobStatus.SCHEDULED, + JobStatus.STARTING, + JobStatus.RUNNING + } + + @staticmethod + def dispatched() -> set: + """ + Returns the set of job statuses indicating a job has been dispatched to compute resources. + + Returns: + set: Job statuses for jobs that have been scheduled or completed (SCHEDULED, + STARTING, RUNNING, SUCCEEDED, FAILED). + """ + return { + JobStatus.SCHEDULED, + JobStatus.STARTING, + JobStatus.RUNNING, + JobStatus.SUCCEEDED, + JobStatus.FAILED + } + + @staticmethod + def terminal() -> set: + """ + Returns the set of job statuses indicating a job has reached a final state. + + Returns: + set: Job statuses for completed jobs (SUCCEEDED, FAILED). + """ + return { + JobStatus.SUCCEEDED, + JobStatus.FAILED + } + + @staticmethod + def all() -> set: + """ + Returns the set of all possible job statuses. + + Returns: + set: All job statuses (SUBMITTED, PENDING, RUNNABLE, SCHEDULED, STARTING, + RUNNING, SUCCEEDED, FAILED). + """ + return { + JobStatus.SUBMITTED, + JobStatus.PENDING, + JobStatus.RUNNABLE, + JobStatus.SCHEDULED, + JobStatus.STARTING, + JobStatus.RUNNING, + JobStatus.SUCCEEDED, + JobStatus.FAILED + } class AwsBatchResourceManager: @@ -126,9 +267,9 @@ def await_service_environment_update( or is deleted if that's the expected status. Args: - se_name (str): Name of the service environment to monitor. + service_environment_name (str): Name of the service environment to monitor. expected_status (str): The expected status to wait for (e.g., "VALID", "DELETED"). - expected_state (str, optional): The expected state to wait for (e.g., "ENABLED", "DISABLED"). + expected_state (str): The expected state to wait for (e.g., "ENABLED", "DISABLED"). Returns: dict: The describe service environments response when the expected state is reached. @@ -146,7 +287,7 @@ def await_service_environment_update( if status == expected_status and state == expected_state: break if status == "INVALID": - raise ValueError(f"Something went wrong! {json.dumps(jq, indent=4)}") + raise ValueError(f"Something went wrong! {json.dumps(describe_response, indent=4)}") elif expected_status == "DELETED": logger.info(f"ServiceEnvironment {service_environment_name} has been deleted") break @@ -164,7 +305,7 @@ def delete_service_environment(self, service_environment_name: str): 4. Wait for the deletion to complete Args: - se_name (str): Name of the service environment to delete. + service_environment_name (str): Name of the service environment to delete. """ logger.info(f"Setting ServiceEnvironment {service_environment_name} to DISABLED") self._batch_client.update_service_environment( @@ -223,7 +364,7 @@ def delete_job_queue(self, job_queue_name: str): 4. Wait for the deletion to complete Args: - jq_name (str): Name of the job queue to delete. + job_queue_name (str): Name of the job queue to delete. """ logger.info(f"Disabling JobQueue {job_queue_name}") self._batch_client.update_job_queue(jobQueue=job_queue_name, state="DISABLED") @@ -247,9 +388,9 @@ def await_job_queue_update( or is deleted if that's the expected status. Args: - jq_name (str): Name of the job queue to monitor. + job_queue_name (str): Name of the job queue to monitor. expected_status (str): The expected status to wait for (e.g., "VALID", "DELETED"). - expected_state (str, optional): The expected state to wait for (e.g., "ENABLED", "DISABLED"). + expected_state (str): The expected state to wait for (e.g., "ENABLED", "DISABLED"). Raises: ValueError: If the job queue enters an INVALID status. @@ -274,6 +415,135 @@ def await_job_queue_update( time.sleep(5) + def create_scheduling_policy(self, create_sp_request: dict): + """ + Create a new AWS Batch scheduling policy. + + If the scheduling policy already exists, returns the existing policy details. + + Args: + create_sp_request (dict): Request parameters for creating a scheduling policy. + Must contain 'name' key. Optional: 'quotaSharePolicy', 'tags'. + + Returns: + dict: Response containing the scheduling policy name and ARN. + + Raises: + ClientError: If there's an error creating the scheduling policy. + """ + try: + return self._batch_client.create_scheduling_policy(**create_sp_request) + except ClientError as error: + if error.response["message"] == "Object already exists": + logger.info("SchedulingPolicy already exists, skipping creation") + list_resp = self._batch_client.list_scheduling_policies() + sp = [p for p in list_resp["schedulingPolicies"] if create_sp_request["name"] in p["arn"]][0] + return {"name": create_sp_request["name"], "arn": sp["arn"]} + logger.error(f"Error: {json.dumps(error.response, indent=4)}") + raise error + + def delete_scheduling_policy(self, scheduling_policy_arn: str): + """ + Delete an AWS Batch scheduling policy. + + Args: + scheduling_policy_arn (str): ARN of the scheduling policy to delete. + """ + logger.info(f"Deleting SchedulingPolicy {scheduling_policy_arn}") + self._batch_client.delete_scheduling_policy(arn=scheduling_policy_arn) + + def create_quota_share(self, create_qs_request: dict): + """ + Create a new AWS Batch quota share. + + Args: + create_qs_request (dict): Request parameters for creating a quota share. + Required: quotaShareName, jobQueueArn, capacityLimits, + resourceSharingConfiguration, preemptionConfiguration. + + Returns: + dict: Response containing the quota share name and ARN. + """ + try: + return self._batch_client.create_quota_share(**create_qs_request) + except ClientError as error: + if "already exists" in error.response["message"]: + logger.info("QuotaShare already exists, skipping creation") + desc_jqs_resp = self._batch_client.describe_job_queues( + jobQueues=[create_qs_request["jobQueue"]] + ) + jq_arn = desc_jqs_resp["jobQueues"][0]["jobQueueArn"] + quota_share_arn = f"{jq_arn}/quota-share/{create_qs_request["quotaShareName"]}" + return { + "quotaShareName": create_qs_request["quotaShareName"], + "quotaShareArn": quota_share_arn, + } + logger.error(f"Error: {json.dumps(error.response, indent=4)}") + raise error + + def await_quota_share_update( + self, quota_share_arn: str, expected_status: str, expected_state: str + ): + """ + Wait for a quota share to reach the expected status. + + Args: + quota_share_arn (str): ARN of the quota share to monitor. + expected_status (str): The expected status (VALID, DELETED, etc.). + expected_state (str): The expected state to wait for (e.g., ENABLED, DISABLED). + """ + while True: + try: + describe_response = self._batch_client.describe_quota_share( + quotaShareArn=quota_share_arn + ) + + state = describe_response["state"] + status = describe_response["status"] + + if status == expected_status and state == expected_state: + break + if describe_response["status"] == "INVALID": + raise ValueError(f"Something went wrong! {json.dumps(describe_response, indent=4)}") + except ClientError as error: + if expected_status == "DELETED": + logger.info(f"QuotaShare {quota_share_arn} has been deleted") + break + raise error + + time.sleep(5) + + def list_quota_shares(self, job_queue: str): + """ + List quota shares for a job queue. + + Args: + job_queue (str): Name or ARN of the job queue. + + Returns: + list: List of quota share summaries. + """ + return self._batch_client.list_quota_shares(jobQueue=job_queue).get("quotaShares", []) + + def delete_quota_share(self, quota_share_arn: str): + """ + Delete an AWS Batch quota share. + + Args: + quota_share_arn (str): ARN of the quota share to delete. + """ + logger.info(f"Disabling QuotaShare {quota_share_arn}") + self._batch_client.update_quota_share(quotaShareArn=quota_share_arn, state="DISABLED") + + logger.info("Waiting for QuotaShare update to finish...") + self.await_quota_share_update(quota_share_arn, "VALID", "DISABLED") + + logger.info(f"Deleting QuotaShare {quota_share_arn}") + self._batch_client.delete_quota_share(quotaShareArn=quota_share_arn) + + logger.info("Waiting for QuotaShare deletion to finish...") + self.await_quota_share_update(quota_share_arn, "DELETED", "DISABLED") + class RoleManager: """ @@ -457,7 +727,7 @@ def create_roles( sagemaker_execution_role_name ) - resources = Resources(batch_role=batch_role, sagemaker_exeuction_role=sagemaker_execution_role) + resources = Resources(batch_role=batch_role, sagemaker_execution_role=sagemaker_execution_role) logger.info(f"Role creation complete: {resources}") return resources @@ -485,6 +755,84 @@ def assume_role_and_get_session(role: Resource, sts_client): ) +def await_jobs(job_groups: List[Tuple[List[TrainingQueuedJob], Set[JobStatus]]]): + """ + Wait for jobs to reach desired statuses, polling in parallel. + + Args: + job_groups: List of tuples, each containing a list of jobs and the set of + statuses to wait for those jobs to reach. + """ + + def poll(job: TrainingQueuedJob, desired_status_set: Set[JobStatus]): + while True: + job_status = job.describe().get("status", "") + + if job_status in desired_status_set: + logger.info(f"Job: {job.job_name} is {job_status}") + break + + time.sleep(5) + + all_tasks = [(job, statuses) for jobs, statuses in job_groups for job in jobs] + + with ThreadPoolExecutor(max_workers=len(all_tasks)) as executor: + futures = [executor.submit(poll, job, statuses) for job, statuses in all_tasks] + for future in as_completed(futures): + future.result() + + +def _status_message(job_detail) -> str: + """ + Extract a status message from job details for logging. + + Args: + job_detail: Job detail dictionary from describe_jobs response. + + Returns: + str: Formatted status reason string, or empty string if none available. + """ + status_reason: str = job_detail.get("statusReason", None) + recent_preempted_attempts: list = job_detail.get("preemptionSummary", {}).get("recentPreemptedAttempts", [{}]) + preemption_status_reason: str = recent_preempted_attempts[0].get("statusReason", None) + + # Return most recent preempted attempt statusReason if it exists + if preemption_status_reason and job_detail.get("status", "") not in {JobStatus.SCHEDULED, JobStatus.STARTING}: + return f" ({preemption_status_reason})" + + # Return top level statusReason if no preempted attempts + if status_reason: + return f" ({status_reason})" + + # Return empty string if no statusReasons are set + return "" + + +def list_jobs_by_quota_share(training_queue: TrainingQueue, quota_share_names: List[str], statuses: Set[JobStatus]): + """ + Lists all jobs in a TrainingQueue grouped by their quota share and prints a formatted log statement showing each + quota share and the details for each job. + + Args: + training_queue (TrainingQueue): The TrainingQueue to query for jobs. + """ + all_jobs = [job for status in statuses for job in training_queue.list_jobs(status=status.value)] + + jobs_by_qs = {qs_name: [] for qs_name in quota_share_names} + for job in all_jobs: + job_detail = job.describe() + jobs_by_qs.setdefault(job_detail["quotaShareName"], []).append(job_detail) + + log_lines = [] + for qs_name, job_details in jobs_by_qs.items(): + log_lines.append(f"QuotaShare: {qs_name}" + "".join( + f"\n -> {jd['jobName']} (priority: {jd['schedulingPriority']}): {jd['status']}{_status_message(jd)}" + for jd in job_details + )) + + logger.info("Listing jobs by QuotaShare:\n" + "\n".join(log_lines)) + + def create_resources( resource_manager: AwsBatchResourceManager, job_queue_name: str, @@ -499,6 +847,9 @@ def create_resources( Args: resource_manager (AwsBatchResourceManager): The resource manager to use for creating resources. + job_queue_name (str): Name for the job queue. + service_environment_name (str): Name for the service environment. + max_capacity (int): Maximum instance capacity for the service environment. Defaults to 1. Returns: Resources: A Resources object containing the created service environment and job queue. @@ -546,21 +897,117 @@ def create_resources( return resources +def create_quota_management_resources( + resource_manager: AwsBatchResourceManager, + job_queue_name: str, + service_environment_name: str, + capacity_unit: str, + max_capacity: int, + scheduling_policy_name: str, + quota_share_configs: List[QuotaShareConfig], +): + """ + Create AWS Batch resources including a service environment, job queue, scheduling policy, and quota shares. + + Args: + resource_manager (AwsBatchResourceManager): The resource manager to use for creating resources. + job_queue_name (str): Name for the job queue. + service_environment_name (str): Name for the service environment. + capacity_unit (str): The capacity unit for the service environment (e.g., "NUM_INSTANCES"). + max_capacity (int): Maximum capacity for the service environment. + scheduling_policy_name (str): Name for the scheduling policy. + quota_share_configs (List[QuotaShareConfig]): List of QuotaShareConfig objects defining quota shares. + + Returns: + Resources: A Resources object containing the created service environment, job queue, + scheduling policy, and quota shares. + """ + # Create SchedulingPolicy + logger.info(f"Creating SchedulingPolicy: {scheduling_policy_name}") + create_sp_resp = resource_manager.create_scheduling_policy({ + "name": scheduling_policy_name, + "quotaSharePolicy": {"idleResourceAssignmentStrategy": "FIFO"}, + }) + scheduling_policy = Resource(name=create_sp_resp["name"], arn=create_sp_resp["arn"]) + + # Create ServiceEnvironment + logger.info(f"Creating ServiceEnvironment: {service_environment_name}") + create_se_resp = resource_manager.create_service_environment( + { + "serviceEnvironmentName": service_environment_name, + "serviceEnvironmentType": "SAGEMAKER_TRAINING", + "state": "ENABLED", + "capacityLimits": [{"maxCapacity": max_capacity, "capacityUnit": capacity_unit}], + } + ) + logger.info("Waiting for ServiceEnvironment to transition to VALID...") + resource_manager.await_service_environment_update(service_environment_name, "VALID", "ENABLED") + + # Create JobQueue + logger.info(f"Creating JobQueue: {job_queue_name}") + create_jq_request = { + "jobQueueName": job_queue_name, + "jobQueueType": "SAGEMAKER_TRAINING", + "state": "ENABLED", + "priority": 1, + "serviceEnvironmentOrder": [ + {"order": 1, "serviceEnvironment": create_se_resp["serviceEnvironmentName"]}, + ], + "schedulingPolicyArn": scheduling_policy.arn + } + create_jq_response = resource_manager.create_job_queue(create_jq_request) + logger.info("Waiting for JobQueue to transition to VALID...") + resource_manager.await_job_queue_update(job_queue_name, "VALID", "ENABLED") + + # Create QuotaShares + quota_shares = [] + for qs_config in quota_share_configs: + logger.info(f"Creating QuotaShare: {qs_config.name}") + create_qs_resp = resource_manager.create_quota_share(qs_config.create_quota_share_request(job_queue_name)) + resource_manager.await_quota_share_update(create_qs_resp["quotaShareArn"], "VALID", "ENABLED") + quota_shares.append(Resource(name=create_qs_resp["quotaShareName"], arn=create_qs_resp["quotaShareArn"])) + + resources = Resources( + service_environment=Resource( + name=create_se_resp["serviceEnvironmentName"], + arn=create_se_resp["serviceEnvironmentArn"], + ), + job_queue=Resource( + name=create_jq_response["jobQueueName"], + arn=create_jq_response["jobQueueArn"] + ), + scheduling_policy=scheduling_policy, + quota_shares=quota_shares, + ) + + logger.info(f"Resource creation complete: {resources}") + return resources + + def delete_resources(resource_manager: AwsBatchResourceManager, resources: Resources): """ Delete AWS Batch resources. - This function deletes the job queue first and then the service environment, - following the proper order for resource cleanup. + This function deletes quota shares first, then the job queue, then the scheduling policy, + and finally the service environment, following the proper order for resource cleanup. Args: resource_manager (AwsBatchResourceManager): The resource manager to use for deleting resources. resources (Resources): The Resources object containing the resources to delete. """ + if resources.quota_shares: + for qs in resources.quota_shares: + logger.info(f"Deleting QuotaShare: {qs.name}") + resource_manager.delete_quota_share(qs.arn) + if resources.job_queue: logger.info(f"Deleting JobQueue: {resources.job_queue.name}") resource_manager.delete_job_queue(resources.job_queue.name) + if resources.scheduling_policy: + logger.info(f"Deleting SchedulingPolicy: {resources.scheduling_policy.name}") + resource_manager.delete_scheduling_policy(resources.scheduling_policy.arn) + if resources.service_environment: logger.info(f"Deleting ServiceEnvironment: {resources.service_environment.name}") resource_manager.delete_service_environment(resources.service_environment.name)