Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 48 additions & 3 deletions sagemaker-train/src/sagemaker/train/aws_batch/batch_api_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ def _submit_service_job(
timeout: Optional[Dict] = None,
share_identifier: Optional[str] = None,
tags: Optional[Dict] = None,
quota_share_name: Optional[str] = None,
preemption_config: Optional[Dict] = None,
) -> Dict:
"""Batch submit_service_job API helper function.

Expand All @@ -44,6 +46,8 @@ def _submit_service_job(
timeout: Set with value of timeout if specified, else default to 1 day.
share_identifier: value of shareIdentifier if specified.
tags: A dict of string to string representing Batch tags.
quota_share_name: Quota Share name for the Batch job.
preemption_config: Preemption configuration.

Returns:
A dict containing jobArn, jobName and jobId.
Expand All @@ -68,6 +72,10 @@ def _submit_service_job(
payload["shareIdentifier"] = share_identifier
if tags or training_payload_tags:
payload["tags"] = __merge_tags(tags, training_payload_tags)
if quota_share_name:
payload["quotaShareName"] = quota_share_name
if preemption_config:
payload["preemptionConfiguration"] = preemption_config
return client.submit_service_job(**payload)


Expand Down Expand Up @@ -96,21 +104,45 @@ def _describe_service_job(job_id: str) -> Dict:
'jobId': 'string',
'jobName': 'string',
'jobQueue': 'string',
'latestAttempt': {
'serviceResourceId': {
'name': 'string',
'value': 'string'
}
},
'preemptionSummary': {
'preemptedAttemptCount': 123,
'recentPreemptedAttempts': [
{
'serviceResourceId': {
'name': 'string',
'value': 'string'
},
'startedAt': 123,
'stoppedAt': 123,
'statusReason': 'string'
},
]
},
'retryStrategy': {
'attempts': 123
},
'schedulingPriority': 123,
'serviceRequestPayload': 'string',
'serviceJobType': 'EKS'|'ECS'|'ECS_FARGATE'|'SAGEMAKER_TRAINING',
'serviceJobType': 'SAGEMAKER_TRAINING',
'shareIdentifier': 'string',
'quotaShareName': 'string',
'preemptionConfiguration': {
'preemptionRetriesBeforeTermination': 123
},
'startedAt': 123,
'status': 'SUBMITTED'|'PENDING'|'RUNNABLE'|'STARTING'|'RUNNING'|'SUCCEEDED'|'FAILED',
'status': 'SUBMITTED'|'PENDING'|'RUNNABLE'|'SCHEDULED'|'STARTING'|'RUNNING'|'SUCCEEDED'|'FAILED',
'statusReason': 'string',
'stoppedAt': 123,
'tags': {
'string': 'string'
},
'timeout': {
'timeoutConfig': {
'attemptDurationSeconds': 123
}
}
Expand All @@ -132,6 +164,19 @@ def _terminate_service_job(job_id: str, reason: Optional[str] = "default termina
return client.terminate_service_job(jobId=job_id, reason=reason)


def _update_service_job(job_id: str, scheduling_priority: int) -> Dict:
"""Batch update_service_job API helper function.

Args:
job_id: Job ID or Job Arn
scheduling_priority: An integer representing scheduling priority.

Returns: a dict containing jobArn, jobId and jobName.
"""
client = get_batch_boto_client()
return client.update_service_job(jobId=job_id, schedulingPriority=scheduling_priority)


def _list_service_job(
job_queue: str,
job_status: Optional[str] = None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ def submit(
share_identifier: Optional[str] = None,
timeout: Optional[Dict] = None,
tags: Optional[Dict] = None,
quota_share_name: Optional[str] = None,
preemption_config: Optional[Dict] = None,
) -> TrainingQueuedJob:
"""Submit a queued job and return a QueuedJob object.

Expand All @@ -53,6 +55,8 @@ def submit(
share_identifier: Share identifier for Batch job.
timeout: Timeout configuration for Batch job.
tags: Tags apply to Batch job. These tags are for Batch job only.
quota_share_name: Quota Share name for the Batch job.
preemption_config: Preemption configuration.

Returns: a TrainingQueuedJob object with Batch job ARN and job name.

Expand Down Expand Up @@ -85,6 +89,8 @@ def submit(
timeout,
share_identifier,
tags,
quota_share_name,
preemption_config,
)
if "jobArn" not in resp or "jobName" not in resp:
raise MissingRequiredArgument(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
SourceCode,
TrainingImageConfig,
)
from .batch_api_helper import _terminate_service_job, _describe_service_job
from .batch_api_helper import _terminate_service_job, _describe_service_job, _update_service_job
from .exception import NoTrainingJob, MissingRequiredArgument
from ..utils import _get_training_job_name_from_training_job_arn
from .constants import JOB_STATUS_COMPLETED, JOB_STATUS_FAILED, POLL_IN_SECONDS
Expand Down Expand Up @@ -85,6 +85,17 @@ def terminate(self, reason: Optional[str] = "Default terminate reason") -> None:
"""
_terminate_service_job(self.job_arn, reason)

def update(self, scheduling_priority: int) -> Dict:
"""Update Batch job.

Args:
scheduling_priority: An integer representing scheduling priority.

Returns: A dict which includes jobArn, jobName and jobId.

"""
return _update_service_job(self.job_arn, scheduling_priority)

def describe(self) -> Dict:
"""Describe Batch job.

Expand Down
Loading
Loading