diff --git a/.gitignore b/.gitignore
index 7428f4f025..f87727e33e 100644
--- a/.gitignore
+++ b/.gitignore
@@ -40,3 +40,4 @@ env/
sagemaker_train/src/**/container_drivers/sm_train.sh
sagemaker_train/src/**/container_drivers/sourcecode.json
sagemaker_train/src/**/container_drivers/distributed.json
+.kiro
diff --git a/sagemaker-core/src/sagemaker/core/resources.py b/sagemaker-core/src/sagemaker/core/resources.py
index 66b13e112a..61e0f9c677 100644
--- a/sagemaker-core/src/sagemaker/core/resources.py
+++ b/sagemaker-core/src/sagemaker/core/resources.py
@@ -35788,7 +35788,7 @@ def stop(self) -> None:
ResourceNotFound: Resource being access is not found.
"""
- client = SageMakerClient().client
+ client = SageMakerClient().sagemaker_client
operation_input_args = {
"TrainingJobName": self.training_job_name,
@@ -35833,15 +35833,17 @@ def wait(
progress.add_task("Waiting for TrainingJob...")
status = Status("Current status:")
- instance_count = (
- sum(
- instance_group.instance_count
- for instance_group in self.resource_config.instance_groups
- )
- if self.resource_config.instance_groups
- and not isinstance(self.resource_config.instance_groups, Unassigned)
- else self.resource_config.instance_count
- )
+ instance_count = 1 # Default
+ if not isinstance(self.resource_config, Unassigned):
+ if (hasattr(self.resource_config, 'instance_groups') and
+ self.resource_config.instance_groups and
+ not isinstance(self.resource_config.instance_groups, Unassigned)):
+ instance_count = sum(
+ instance_group.instance_count
+ for instance_group in self.resource_config.instance_groups
+ )
+ elif hasattr(self.resource_config, 'instance_count'):
+ instance_count = self.resource_config.instance_count
if logs:
multi_stream_logger = MultiLogStreamHandler(
diff --git a/sagemaker-train/example_notebooks/evaluate/benchmark_demo.ipynb b/sagemaker-train/example_notebooks/evaluate/benchmark_demo.ipynb
index e0133a9272..0719cbbab2 100644
--- a/sagemaker-train/example_notebooks/evaluate/benchmark_demo.ipynb
+++ b/sagemaker-train/example_notebooks/evaluate/benchmark_demo.ipynb
@@ -442,7 +442,7 @@
],
"metadata": {
"kernelspec": {
- "display_name": "Python 3",
+ "display_name": "py3.10.14",
"language": "python",
"name": "python3"
},
@@ -456,7 +456,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.12.12"
+ "version": "3.10.14"
}
},
"nbformat": 4,
diff --git a/sagemaker-train/pyproject.toml b/sagemaker-train/pyproject.toml
index f0acc6077c..648994b5eb 100644
--- a/sagemaker-train/pyproject.toml
+++ b/sagemaker-train/pyproject.toml
@@ -61,6 +61,11 @@ test = [
"graphene",
"IPython"
]
+notebook = [
+ "ipywidgets>=8.0.0",
+ "rich>=13.0.0",
+ "matplotlib>=3.5.0",
+]
[tool.setuptools.packages.find]
where = ["src/"]
diff --git a/sagemaker-train/src/sagemaker/train/__init__.py b/sagemaker-train/src/sagemaker/train/__init__.py
index 74518dc65a..38a6fda76d 100644
--- a/sagemaker-train/src/sagemaker/train/__init__.py
+++ b/sagemaker-train/src/sagemaker/train/__init__.py
@@ -56,4 +56,16 @@ def __getattr__(name):
elif name == "get_builtin_metrics":
from sagemaker.train.evaluate import get_builtin_metrics
return get_builtin_metrics
+ elif name == "plot_training_metrics":
+ from sagemaker.train.common_utils.metrics_visualizer import plot_training_metrics
+ return plot_training_metrics
+ elif name == "get_available_metrics":
+ from sagemaker.train.common_utils.metrics_visualizer import get_available_metrics
+ return get_available_metrics
+ elif name == "get_studio_url":
+ from sagemaker.train.common_utils.metrics_visualizer import get_studio_url
+ return get_studio_url
+ elif name == "get_mlflow_url":
+ from sagemaker.train.common_utils.trainer_wait import get_mlflow_url
+ return get_mlflow_url
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
diff --git a/sagemaker-train/src/sagemaker/train/common_utils/constants.py b/sagemaker-train/src/sagemaker/train/common_utils/constants.py
index 8de3ab4638..b96c58134c 100644
--- a/sagemaker-train/src/sagemaker/train/common_utils/constants.py
+++ b/sagemaker-train/src/sagemaker/train/common_utils/constants.py
@@ -20,6 +20,7 @@ class _MLflowConstants:
# Metric names
TOTAL_LOSS_METRIC = 'total_loss'
+ LOSS_METRIC_KEYWORDS = ('loss',)
EPOCH_KEYWORD = 'epoch'
# MLflow run tags
diff --git a/sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py b/sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py
index 3fd17c3ac0..c6e89e19c8 100644
--- a/sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py
+++ b/sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py
@@ -376,6 +376,7 @@ def _get_fine_tuning_options_and_model_arn(model_name: str, customization_techni
except Exception as e:
logger.error("Exception getting fine-tuning options: %s", e)
+ raise
def _create_input_channels(dataset: str, content_type: Optional[str] = None,
diff --git a/sagemaker-train/src/sagemaker/train/common_utils/metrics_visualizer.py b/sagemaker-train/src/sagemaker/train/common_utils/metrics_visualizer.py
new file mode 100644
index 0000000000..fe837a91fc
--- /dev/null
+++ b/sagemaker-train/src/sagemaker/train/common_utils/metrics_visualizer.py
@@ -0,0 +1,288 @@
+"""MLflow metrics visualization utilities for SageMaker training jobs."""
+
+import logging
+from typing import Optional, List, Dict, Any
+from sagemaker.core.resources import TrainingJob
+
+logger = logging.getLogger(__name__)
+
+
+def _is_in_studio() -> bool:
+ """Check if running inside SageMaker Studio."""
+ from sagemaker.train.common_utils.finetune_utils import _read_domain_id_from_metadata
+ return _read_domain_id_from_metadata() is not None
+
+
+def _get_studio_base_url(region: str) -> str:
+ """Get Studio base URL, or empty string if domain not resolvable."""
+ from sagemaker.train.common_utils.finetune_utils import _read_domain_id_from_metadata
+ domain_id = _read_domain_id_from_metadata()
+ if not domain_id or not region:
+ return ""
+ return f"https://studio-{domain_id}.studio.{region}.sagemaker.aws"
+
+
+def _parse_job_arn(job_arn: str):
+ """Parse a SageMaker job ARN into (region, resource) or None."""
+ import re
+ m = re.match(r'arn:aws(?:-[a-z]+)?:sagemaker:([a-z0-9-]+):\d+:(\S+)', job_arn)
+ return (m.group(1), m.group(2)) if m else None
+
+
+def get_console_job_url(job_arn: str) -> str:
+ """Get AWS Console URL for a SageMaker job ARN.
+
+ Args:
+ job_arn: Full ARN like arn:aws:sagemaker:us-east-1:123:training-job/my-job
+
+ Returns:
+ Console URL or empty string.
+ """
+ parsed = _parse_job_arn(job_arn)
+ if not parsed:
+ return ""
+ region, resource = parsed
+ job_type_map = {
+ "training-job/": "#/jobs/",
+ "processing-job/": "#/processing-jobs/",
+ "transform-job/": "#/transform-jobs/",
+ }
+ for prefix, fragment in job_type_map.items():
+ if resource.startswith(prefix):
+ job_name = resource.split("/", 1)[1]
+ return f"https://{region}.console.aws.amazon.com/sagemaker/home?region={region}{fragment}{job_name}"
+ return ""
+
+
+def get_cloudwatch_logs_url(job_arn: str) -> str:
+ """Get CloudWatch Logs console URL for a SageMaker job ARN.
+
+ Returns:
+ CloudWatch console URL or empty string.
+ """
+ parsed = _parse_job_arn(job_arn)
+ if not parsed:
+ return ""
+ region, resource = parsed
+ log_group_map = {
+ "training-job/": "/aws/sagemaker/TrainingJobs",
+ "processing-job/": "/aws/sagemaker/ProcessingJobs",
+ "transform-job/": "/aws/sagemaker/TransformJobs",
+ }
+ for prefix, log_group in log_group_map.items():
+ if resource.startswith(prefix):
+ job_name = resource.split("/", 1)[1]
+ encoded_group = log_group.replace("/", "$252F")
+ return (
+ f"https://{region}.console.aws.amazon.com/cloudwatch/home?region={region}"
+ f"#logsV2:log-groups/log-group/{encoded_group}"
+ f"$3FlogStreamNameFilter$3D{job_name}"
+ )
+ return ""
+
+
+def get_studio_url(training_job, domain_id: str = None) -> str:
+ """Get SageMaker Studio URL for training job logs.
+
+ Args:
+ training_job: SageMaker TrainingJob object, job name string, or job ARN string
+ domain_id: Studio domain ID (e.g., 'd-xxxxxxxxxxxx'). If not provided, attempts to auto-detect
+
+ Returns:
+ Studio URL pointing to the training job details, or empty string if not resolvable
+
+ Example:
+ >>> from sagemaker.train import get_studio_url
+ >>> url = get_studio_url('my-training-job')
+ >>> url = get_studio_url('arn:aws:sagemaker:us-west-2:123456789:training-job/my-job')
+ """
+ import re
+
+ if isinstance(training_job, str):
+ arn_match = re.match(
+ r'arn:aws(?:-[a-z]+)?:sagemaker:([a-z0-9-]+):\d+:training-job/(.+)',
+ training_job,
+ )
+ if arn_match:
+ region = arn_match.group(1)
+ job_name = arn_match.group(2)
+ else:
+ # Plain job name — use session region
+ training_job = TrainingJob.get(training_job_name=training_job)
+ from sagemaker.core.utils.utils import SageMakerClient
+ region = SageMakerClient().region_name
+ job_name = training_job.training_job_name
+ else:
+ from sagemaker.core.utils.utils import SageMakerClient
+ region = SageMakerClient().region_name
+ job_name = training_job.training_job_name
+
+ base = _get_studio_base_url(region)
+ if not base:
+ return ""
+ return f"{base}/jobs/train/{job_name}"
+
+
+def display_job_links_html(rows: list, as_html: bool = False):
+ """Render job/resource links with copy-to-clipboard buttons as a Jupyter HTML table.
+
+ Args:
+ rows: List of dicts, each with keys:
+ - label (str): Row label (e.g. step name, "Training Job", "MLflow Experiment")
+ - arn (str): The ARN or URI to display and copy
+ - url (Optional[str]): Clickable link URL. If None, resolved via get_studio_url for job ARNs.
+ - url_text (Optional[str]): Link display text. Defaults to "🔗 link"
+ - url_hint (Optional[str]): Hint text after link. Defaults to "(please sign in to Studio first)"
+ as_html: If True, return HTML object instead of displaying it.
+
+ Returns:
+ HTML object if as_html=True, otherwise None.
+ """
+ from IPython.display import display, HTML
+ import html as html_mod
+
+ html_rows = ""
+ for row in rows:
+ escaped_arn = html_mod.escape(row['arn'])
+ escaped_label = html_mod.escape(row['label'])
+
+ url = row.get('url')
+ if url is None:
+ url = get_studio_url(row['arn'])
+ url_text = row.get('url_text', '🔗 link')
+ url_hint = row.get('url_hint', '(please sign in to Studio first)')
+
+ link_html = ""
+ if url:
+ link_html = (
+ f'{html_mod.escape(url_text)}'
+ f' {html_mod.escape(url_hint)}'
+ )
+
+ copy_btn = (
+ f''
+ )
+
+ html_rows += (
+ f'
'
+ f'| {escaped_label} | '
+ f'{link_html} | '
+ f''
+ f'{escaped_arn}'
+ f' {copy_btn} | '
+ f'
'
+ )
+
+ result = HTML(
+ f''
+ f''
+ f'| Step | '
+ f'Job Link | '
+ f'Job ARN | '
+ f'
{html_rows}
'
+ )
+
+ if as_html:
+ return result
+ display(result)
+
+
+def plot_training_metrics(
+ training_job: TrainingJob,
+ metrics: Optional[List[str]] = None,
+ figsize: tuple = (12, 6)
+) -> None:
+ """Plot training metrics from MLflow for a completed training job.
+
+ Args:
+ training_job: SageMaker TrainingJob object or job name string
+ metrics: List of metric names to plot. If None, plots all available metrics.
+ figsize: Figure size as (width, height)
+ """
+ import matplotlib.pyplot as plt
+ import mlflow
+ from mlflow.tracking import MlflowClient
+ from IPython.display import display
+ import logging
+
+ logging.getLogger('botocore.credentials').setLevel(logging.WARNING)
+
+ if isinstance(training_job, str):
+ training_job = TrainingJob.get(training_job_name=training_job)
+
+ run_id = training_job.mlflow_details.mlflow_run_id
+
+ mlflow.set_tracking_uri(training_job.mlflow_config.mlflow_resource_arn)
+ client = MlflowClient()
+
+ run = mlflow.get_run(run_id)
+ available_metrics = list(run.data.metrics.keys())
+ metrics_to_plot = metrics if metrics else available_metrics
+
+ # Fetch metric histories
+ metric_data = {}
+ for metric_name in metrics_to_plot:
+ history = client.get_metric_history(run_id, metric_name)
+ if history:
+ metric_data[metric_name] = history
+
+ # Plot
+ num_metrics = len(metric_data)
+ rows = (num_metrics + 1) // 2
+ fig, axes = plt.subplots(rows, 2, figsize=(figsize[0], figsize[1] * rows))
+ axes = axes.flatten() if num_metrics > 1 else [axes]
+
+ for idx, (metric_name, history) in enumerate(metric_data.items()):
+ steps = [h.step for h in history]
+ values = [h.value for h in history]
+ axes[idx].plot(steps, values, linewidth=2, marker='o', markersize=4)
+ axes[idx].set_xlabel('Step')
+ axes[idx].set_ylabel('Value')
+ axes[idx].set_title(metric_name, fontweight='bold')
+ axes[idx].grid(True, alpha=0.3)
+
+ for idx in range(len(metric_data), len(axes)):
+ axes[idx].set_visible(False)
+
+ plt.suptitle(f'Training Metrics: {training_job.training_job_name}', fontweight='bold', fontsize=14)
+ plt.tight_layout(rect=[0, 0, 1, 0.98]) # Leave small space for suptitle
+ display(fig)
+ plt.close()
+
+
+def get_available_metrics(training_job: TrainingJob) -> List[str]:
+ """Get list of available metrics for a training job.
+
+ Args:
+ training_job: SageMaker TrainingJob object or job name string
+
+ Returns:
+ List of metric names
+ """
+ try:
+ import mlflow
+ except ImportError:
+ logger.error("mlflow package not installed")
+ return []
+
+ # Handle string input
+ if isinstance(training_job, str):
+ training_job = TrainingJob.get(training_job_name=training_job)
+
+ if not hasattr(training_job, 'mlflow_config') or not training_job.mlflow_config:
+ return []
+
+ mlflow_details = training_job.mlflow_details
+ if not mlflow_details or not mlflow_details.mlflow_run_id:
+ return []
+
+ mlflow.set_tracking_uri(training_job.mlflow_config.mlflow_resource_arn)
+ run = mlflow.get_run(mlflow_details.mlflow_run_id)
+
+ return list(run.data.metrics.keys())
diff --git a/sagemaker-train/src/sagemaker/train/common_utils/mlflow_metrics_util.py b/sagemaker-train/src/sagemaker/train/common_utils/mlflow_metrics_util.py
index 43cc5e3847..8d15731977 100644
--- a/sagemaker-train/src/sagemaker/train/common_utils/mlflow_metrics_util.py
+++ b/sagemaker-train/src/sagemaker/train/common_utils/mlflow_metrics_util.py
@@ -154,7 +154,7 @@ def _get_loss_metrics(
loss_data = []
for metric_key in run.data.metrics:
- if _MLflowConstants.TOTAL_LOSS_METRIC == metric_key.lower():
+ if any(kw in metric_key.lower() for kw in _MLflowConstants.LOSS_METRIC_KEYWORDS):
metric_history = client.get_metric_history(rid, metric_key)
loss_data.append({
'metric_name': metric_key,
@@ -335,7 +335,7 @@ def _get_most_recent_total_loss(
for rid, metrics in loss_metrics.items():
for metric in metrics:
- if metric['metric_name'].lower() == _MLflowConstants.TOTAL_LOSS_METRIC:
+ if any(kw in metric['metric_name'].lower() for kw in _MLflowConstants.LOSS_METRIC_KEYWORDS):
if metric['history']:
# Get the most recent entry (last in history)
return metric['history'][-1]['value']
diff --git a/sagemaker-train/src/sagemaker/train/common_utils/trainer_wait.py b/sagemaker-train/src/sagemaker/train/common_utils/trainer_wait.py
index 900f13d4c6..59adcdfbfc 100644
--- a/sagemaker-train/src/sagemaker/train/common_utils/trainer_wait.py
+++ b/sagemaker-train/src/sagemaker/train/common_utils/trainer_wait.py
@@ -40,6 +40,10 @@ def _setup_mlflow_integration(training_job: TrainingJob) -> Tuple[
try:
import boto3
+ # Check if mlflow_config exists and is assigned
+ if not hasattr(training_job, 'mlflow_config') or _is_unassigned_attribute(training_job.mlflow_config):
+ return None, None, None
+
sm_client = boto3.client('sagemaker')
mlflow_arn = training_job.mlflow_config.mlflow_resource_arn
@@ -56,7 +60,11 @@ def _setup_mlflow_integration(training_job: TrainingJob) -> Tuple[
return mlflow_url, metrics_util, mlflow_run_name
- except Exception:
+ except Exception as e:
+ # Log the exception for debugging
+ import logging
+ logger = logging.getLogger(__name__)
+ logger.debug(f"MLflow integration setup failed: {e}")
return None, None, None
@@ -154,6 +162,59 @@ def _calculate_transition_duration(trans) -> Tuple[str, str]:
return duration, check
+def get_mlflow_url(training_job) -> str:
+ """Get presigned MLflow URL for training job experiment.
+
+ Args:
+ training_job: SageMaker TrainingJob object or job name string
+
+ Returns:
+ Presigned MLflow URL to experiment (valid for 5 minutes)
+
+ Example:
+ >>> from sagemaker.train import get_mlflow_url
+ >>> url = get_mlflow_url('my-training-job')
+ >>> print(url)
+ """
+ if isinstance(training_job, str):
+ training_job = TrainingJob.get(training_job_name=training_job)
+
+ if not hasattr(training_job, 'mlflow_config') or _is_unassigned_attribute(training_job.mlflow_config):
+ raise ValueError("Training job does not have MLflow configured")
+
+ import os
+ from mlflow.tracking import MlflowClient
+ import mlflow
+ from sagemaker.core.utils.utils import SageMakerClient
+
+ mlflow_arn = training_job.mlflow_config.mlflow_resource_arn
+ exp_name = training_job.mlflow_config.mlflow_experiment_name
+
+ # Get presigned base URL
+ sm_client = SageMakerClient().sagemaker_client
+ response = sm_client.create_presigned_mlflow_app_url(Arn=mlflow_arn)
+ base_url = response.get('AuthorizedUrl')
+
+ # Try to get experiment ID and append to URL
+ try:
+ os.environ['MLFLOW_TRACKING_URI'] = mlflow_arn
+ mlflow.set_tracking_uri(mlflow_arn)
+
+ mlflow_client = MlflowClient(tracking_uri=mlflow_arn)
+ experiment = mlflow_client.get_experiment_by_name(exp_name)
+
+ if experiment:
+ # Format: base_url#/experiments/{id}
+ # The base_url already has /auth?authToken=...
+ return f"{base_url}#/experiments/{experiment.experiment_id}"
+ except Exception:
+ pass
+
+ return base_url
+
+
+
+
def wait(
training_job: TrainingJob,
poll: int = 5,
@@ -188,32 +249,101 @@ def wait(
from rich.console import Group
with _suppress_info_logging():
console = Console(force_jupyter=True)
+
+ # MLflow link caching
+ mlflow_link_cache = {'url': None, 'timestamp': 0, 'error': None}
+ has_mlflow_config = (hasattr(training_job, 'mlflow_config') and
+ not _is_unassigned_attribute(training_job.mlflow_config))
+
+ def get_cached_mlflow_url():
+ """Get cached MLflow URL or generate new one if expired."""
+ current_time = time.time()
+ # Regenerate every 4 minutes (before 5-minute expiration)
+ if mlflow_link_cache['url'] is None or (current_time - mlflow_link_cache['timestamp']) > 240:
+ try:
+ mlflow_link_cache['url'] = get_mlflow_url(training_job)
+ mlflow_link_cache['error'] = None
+ except Exception as e:
+ mlflow_link_cache['error'] = str(e)
+ mlflow_link_cache['timestamp'] = current_time
+ return mlflow_link_cache['url']
+
+ # Track last rendered state to avoid unnecessary refreshes
+ last_status = None
+ last_secondary_status = None
iteration = 0
while True:
iteration += 1
- time.sleep(1)
- if iteration == poll:
+ time.sleep(0.5)
+ if iteration >= poll * 2:
training_job.refresh()
iteration = 0
- clear_output(wait=True)
-
+
status = training_job.training_job_status
secondary_status = training_job.secondary_status
elapsed = time.time() - start_time
+
+ # Only re-render if status changed or every 2 seconds (for elapsed time)
+ should_render = (
+ status != last_status or
+ secondary_status != last_secondary_status or
+ iteration % 4 == 0 # Every 2 seconds (4 * 0.5s)
+ )
+
+ if not should_render:
+ continue
+
+ last_status = status
+ last_secondary_status = secondary_status
+
+ clear_output(wait=True)
- # Header section with training job name and MLFlow URL
+ # Header section with training job info
header_table = Table(show_header=False, box=None, padding=(0, 1))
header_table.add_column("Property", style="cyan bold", width=20)
- header_table.add_column("Value", style="white")
+ header_table.add_column("Value", style="dim", overflow="fold")
+
header_table.add_row("TrainingJob Name", f"[bold green]{training_job.training_job_name}[/bold green]")
- if mlflow_url:
- header_table.add_row("MLFlow URL",
- f"[link={mlflow_url}][bold bright_blue underline]{mlflow_run_name}(link valid for 5 mins)[/bright_blue bold underline][/link]")
+ header_table.add_row("TrainingJob ARN", f"[dim]{training_job.training_job_arn}[/dim]")
+
+ # Build links rows
+ links_row1 = []
+ links_row2 = []
+ try:
+ from sagemaker.train.common_utils.metrics_visualizer import (
+ _is_in_studio, get_console_job_url, get_cloudwatch_logs_url, get_studio_url
+ )
+ console_url = get_console_job_url(training_job.training_job_arn)
+ if console_url:
+ links_row1.append(f"[bright_blue underline][link={console_url}]🔗 Training Job (Console)[/link][/bright_blue underline]")
+ if _is_in_studio():
+ studio_url = get_studio_url(training_job)
+ if studio_url:
+ links_row1.append(f"[bright_blue underline][link={studio_url}]🔗 Training Job (Studio)[/link][/bright_blue underline]")
+ cw_url = get_cloudwatch_logs_url(training_job.training_job_arn)
+ if cw_url:
+ links_row2.append(f"[bright_blue underline][link={cw_url}]🔗 CloudWatch Logs[/link][/bright_blue underline]")
+ except Exception:
+ pass
+ if has_mlflow_config:
+ cached_url = get_cached_mlflow_url()
+ if cached_url:
+ links_row2.append(f"[bright_blue underline][link={cached_url}]🔗 MLflow Experiment[/link][/bright_blue underline]")
+ elif mlflow_link_cache['error']:
+ header_table.add_row("MLflow Experiment", f"[red]{mlflow_link_cache['error']}[/red]")
+ if has_mlflow_config:
+ exp_name = training_job.mlflow_config.mlflow_experiment_name if hasattr(training_job, 'mlflow_config') else None
+ if exp_name and not _is_unassigned_attribute(exp_name):
+ header_table.add_row("MLflow Experiment", f"{exp_name}")
+ if links_row1:
+ header_table.add_row("Links", " | ".join(links_row1))
+ if links_row2:
+ header_table.add_row("" if links_row1 else "Links", " | ".join(links_row2))
status_table = Table(show_header=False, box=None, padding=(0, 1))
status_table.add_column("Property", style="cyan bold", width=20)
- status_table.add_column("Value", style="white")
+ status_table.add_column("Value", style="dim")
status_table.add_row("Job Status", f"[bold][orange3]{status}[/][/]")
status_table.add_row("Secondary Status", f"[bold yellow]{secondary_status}[/bold yellow]")
diff --git a/sagemaker-train/src/sagemaker/train/evaluate/execution.py b/sagemaker-train/src/sagemaker/train/evaluate/execution.py
index e2388ef313..d5e50f86b5 100644
--- a/sagemaker-train/src/sagemaker/train/evaluate/execution.py
+++ b/sagemaker-train/src/sagemaker/train/evaluate/execution.py
@@ -447,6 +447,7 @@ class StepDetail(BaseModel):
end_time: Optional[str] = Field(None, description="Step end time")
display_name: Optional[str] = Field(None, description="Display name for the step")
failure_reason: Optional[str] = Field(None, description="Reason for failure if step failed")
+ job_arn: Optional[str] = Field(None, description="ARN of the underlying job (training, processing, transform, etc.)")
class PipelineExecutionStatus(BaseModel):
@@ -914,6 +915,7 @@ def wait(
from rich.panel import Panel
from rich.text import Text
from rich.layout import Layout
+ from rich.console import Group
# Create console with Jupyter support
console = Console(force_jupyter=True)
@@ -924,21 +926,57 @@ def wait(
current_status = self.status.overall_status
elapsed = time.time() - start_time
+ # Create header table with pipeline name link
+ header_table = Table(show_header=False, box=None, padding=(0, 1))
+ header_table.add_column("Property", style="cyan bold", width=20)
+ header_table.add_column("Value", style="dim", overflow="fold")
+
+ # Extract pipeline name and exec_id from execution ARN
+ pipeline_name = None
+ exec_id = ''
+ if self.arn:
+ arn_parts = self.arn.split('/')
+ if len(arn_parts) >= 4:
+ pipeline_name = arn_parts[-3]
+ exec_id = arn_parts[-1]
+ # Use execution display name if available, fall back to self.name
+ display_name = self.name
+ if self._pipeline_execution:
+ dn = getattr(self._pipeline_execution, 'pipeline_execution_display_name', None)
+ if dn and not (hasattr(dn, '__class__') and 'Unassigned' in dn.__class__.__name__):
+ display_name = dn
+ header_table.add_row("Evaluation Job", str(display_name))
+
+ # Build links row
+ links = []
+ try:
+ from sagemaker.core.utils.utils import SageMakerClient
+ from sagemaker.train.common_utils.metrics_visualizer import _is_in_studio, _get_studio_base_url
+ if pipeline_name and _is_in_studio():
+ region = SageMakerClient().region_name
+ base = _get_studio_base_url(region)
+ if base:
+ pipeline_url = f"{base}/jobs/evaluation/detail?pipeline_name={pipeline_name}&execution_id={exec_id}"
+ links.append(f"[bright_blue underline][link={pipeline_url}]🔗 Pipeline Execution (Studio)[/link][/bright_blue underline]")
+ except Exception:
+ pass
+ if links:
+ header_table.add_row("Links", " | ".join(links))
+
# Create main status table
status_table = Table(show_header=False, box=None, padding=(0, 1))
status_table.add_column("Property", style="cyan bold", width=20)
- status_table.add_column("Value", style="white")
+ status_table.add_column("Value", style="dim")
- status_table.add_row("Overall Status", f"[bold]{current_status}[/bold]")
- status_table.add_row("Target Status", f"[bold]{target_status}[/bold]")
- status_table.add_row("Elapsed Time", f"{elapsed:.1f}s")
+ status_table.add_row("Overall Status", f"[bold][orange3]{current_status}[/][/]")
+ status_table.add_row("Target Status", f"[bold yellow]{target_status}[/bold yellow]")
+ status_table.add_row("Elapsed Time", f"[bold bright_red]{elapsed:.1f}s[/bold bright_red]")
if self.status.failure_reason:
status_table.add_row("Failure Reason", f"[red]{self.status.failure_reason}[/red]")
# Create steps table if steps exist
if self.status.step_details:
- # Check if any step has a failure
has_failures = any(step.failure_reason for step in self.status.step_details)
steps_table = Table(show_header=True, header_style="bold magenta", box=None, padding=(0, 1))
@@ -946,10 +984,10 @@ def wait(
steps_table.add_column("Status", style="yellow", width=15)
steps_table.add_column("Duration", style="green", width=12)
- failed_steps = [] # Track steps with failures for detailed display
+ failed_steps = []
+ job_arn_entries = []
for step in self.status.step_details:
- # Calculate duration if both times are available
duration = ""
if step.start_time and step.end_time:
try:
@@ -963,7 +1001,6 @@ def wait(
elif step.start_time:
duration = "Running..."
- # Color code status
status_display = step.status
if "succeeded" in step.status.lower() or "completed" in step.status.lower():
status_display = f"[green]{step.status}[/green]"
@@ -972,14 +1009,18 @@ def wait(
elif "executing" in step.status.lower() or "running" in step.status.lower():
status_display = f"[yellow]{step.status}[/yellow]"
- # Build row data
+ if step.job_arn:
+ job_arn_entries.append({
+ 'step_name': step.display_name or step.name,
+ 'job_arn': step.job_arn,
+ })
+
row_data = [
step.display_name or step.name,
status_display,
duration
]
- # Add error indicator if failures exist
if has_failures:
if step.failure_reason:
row_data.append("❌")
@@ -989,39 +1030,87 @@ def wait(
steps_table.add_row(*row_data)
- # Build combined content
from rich.console import Group
content_parts = [
status_table,
- Text(""), # Empty line for spacing
+ Text(""),
Text("Pipeline Steps", style="bold magenta"),
steps_table
]
- # Add failure details section if there are any failures
if failed_steps:
- content_parts.append(Text("")) # Empty line
+ content_parts.append(Text(""))
content_parts.append(Text("Step Failure Details", style="bold red"))
for step in failed_steps:
- content_parts.append(Text("")) # Empty line before each failure
+ content_parts.append(Text(""))
content_parts.append(Text(f"• {step.display_name or step.name}:", style="bold red"))
content_parts.append(Text(f" {step.failure_reason}", style="red"))
- combined_content = Group(*content_parts)
+ # Add job links table if any steps have ARNs
+ if job_arn_entries:
+ links_table = Table(show_header=True, header_style="bold magenta", box=None, padding=(0, 1))
+ links_table.add_column("Step", style="cyan", width=20)
+ links_table.add_column("Console", style="dim")
+ from sagemaker.core.utils.utils import SageMakerClient
+ from sagemaker.train.common_utils.metrics_visualizer import (
+ _is_in_studio, _parse_job_arn, _get_studio_base_url,
+ get_console_job_url, get_cloudwatch_logs_url,
+ )
+ in_studio = _is_in_studio()
+ studio_base = _get_studio_base_url(SageMakerClient().region_name) if in_studio else ""
+ if in_studio:
+ links_table.add_column("Studio", style="dim")
+ links_table.add_column("Logs", style="dim")
+ links_table.add_column("Job ARN", style="dim", overflow="fold")
+ studio_path_map = {
+ "training-job/": "jobs/train/",
+ "processing-job/": "jobs/processing/",
+ "transform-job/": "jobs/transform/",
+ }
+ for entry in job_arn_entries:
+ console_link = ""
+ logs_link = ""
+ studio_link = ""
+ try:
+ arn = entry['job_arn']
+ url = get_console_job_url(arn)
+ if url:
+ console_link = f"[bright_blue underline][link={url}]🔗 link[/link][/bright_blue underline]"
+ cw_url = get_cloudwatch_logs_url(arn)
+ if cw_url:
+ logs_link = f"[bright_blue underline][link={cw_url}]🔗 link[/link][/bright_blue underline]"
+ if in_studio and studio_base:
+ parsed = _parse_job_arn(arn)
+ if parsed:
+ _, resource = parsed
+ for prefix, path in studio_path_map.items():
+ if resource.startswith(prefix):
+ job_name = resource.split("/", 1)[1]
+ s_url = f"{studio_base}/{path}{job_name}"
+ studio_link = f"[bright_blue underline][link={s_url}]🔗 link[/link][/bright_blue underline]"
+ break
+ except Exception:
+ pass
+ row = [entry['step_name'], console_link]
+ if in_studio:
+ row.append(studio_link)
+ row.extend([logs_link, entry['job_arn']])
+ links_table.add_row(*row)
+ content_parts.append(Text(""))
+ content_parts.append(Text("Job ARNs", style="bold magenta"))
+ content_parts.append(links_table)
- # Display combined content in a single panel
console.print(Panel(
- combined_content,
- title="[bold blue]Pipeline Execution Status[/bold blue]",
- border_style="blue"
+ Group(header_table, *content_parts),
+ title="[bold bright_blue]Pipeline Execution Status[/bold bright_blue]",
+ border_style="orange3"
))
else:
- # Display only status table if no steps
console.print(Panel(
- status_table,
- title="[bold blue]Pipeline Execution Status[/bold blue]",
- border_style="blue"
+ Group(header_table, status_table),
+ title="[bold bright_blue]Pipeline Execution Status[/bold bright_blue]",
+ border_style="orange3"
))
if target_status == current_status:
@@ -1204,7 +1293,23 @@ def _convert_to_subclass(self, eval_type: EvalType) -> 'EvaluationPipelineExecut
execution._pipeline_execution = pipeline_execution_ref
return execution
-
+
+ @staticmethod
+ def _extract_job_arn_from_metadata(step) -> Optional[str]:
+ """Extract the underlying job ARN from a pipeline step's metadata."""
+ from sagemaker.train.common_utils.trainer_wait import _is_unassigned_attribute
+ metadata = getattr(step, 'metadata', None)
+ if metadata is None or _is_unassigned_attribute(metadata):
+ return None
+ for attr in ('training_job', 'processing_job', 'transform_job', 'tuning_job',
+ 'auto_ml_job', 'compilation_job'):
+ job_meta = getattr(metadata, attr, None)
+ if job_meta is not None and not _is_unassigned_attribute(job_meta):
+ arn = getattr(job_meta, 'arn', None)
+ if arn and not _is_unassigned_attribute(arn):
+ return str(arn)
+ return None
+
def _update_step_details_from_raw_steps(self, raw_steps: List[Any]) -> None:
"""Internal method to update step_details from raw pipeline execution steps
@@ -1246,7 +1351,8 @@ def _update_step_details_from_raw_steps(self, raw_steps: List[Any]) -> None:
start_time=start_time,
end_time=end_time,
display_name=step_display_name,
- failure_reason=failure_reason
+ failure_reason=failure_reason,
+ job_arn=self._extract_job_arn_from_metadata(step)
)
step_details.append(step_detail)
@@ -1256,8 +1362,8 @@ def _update_step_details_from_raw_steps(self, raw_steps: List[Any]) -> None:
logger.warning(f"Failed to process pipeline step: {str(e)}")
continue
- # Update the job's step details
- self.status.step_details = step_details
+ # Update the job's step details (reverse so earliest step appears first)
+ self.status.step_details = list(reversed(step_details))
# ============================================================================
diff --git a/sagemaker-train/src/sagemaker/train/evaluate/pipeline_templates.py b/sagemaker-train/src/sagemaker/train/evaluate/pipeline_templates.py
index 3848ef0d5c..6657e97ef3 100644
--- a/sagemaker-train/src/sagemaker/train/evaluate/pipeline_templates.py
+++ b/sagemaker-train/src/sagemaker/train/evaluate/pipeline_templates.py
@@ -11,8 +11,7 @@
"Metadata": {},
"MlflowConfig": {
"MlflowResourceArn": "{{ mlflow_resource_arn }}"{% if mlflow_experiment_name %},
- "MlflowExperimentName": "{{ mlflow_experiment_name }}"{% endif %}{% if mlflow_run_name %},
- "MlflowRunName": "{{ mlflow_run_name }}"{% endif %}
+ "MlflowExperimentName": "{{ mlflow_experiment_name }}"{% endif %}
},
"Parameters": [],
"Steps": [
@@ -531,8 +530,7 @@
"Metadata": {},
"MlflowConfig": {
"MlflowResourceArn": "{{ mlflow_resource_arn }}"{% if mlflow_experiment_name %},
- "MlflowExperimentName": "{{ mlflow_experiment_name }}"{% endif %}{% if mlflow_run_name %},
- "MlflowRunName": "{{ mlflow_run_name }}"{% endif %}
+ "MlflowExperimentName": "{{ mlflow_experiment_name }}"{% endif %}
},
"Parameters": [],
"Steps": [
@@ -925,8 +923,7 @@
"Metadata": {},
"MlflowConfig": {
"MlflowResourceArn": "{{ mlflow_resource_arn }}"{% if mlflow_experiment_name %},
- "MlflowExperimentName": "{{ mlflow_experiment_name }}"{% endif %}{% if mlflow_run_name %},
- "MlflowRunName": "{{ mlflow_run_name }}"{% endif %}
+ "MlflowExperimentName": "{{ mlflow_experiment_name }}"{% endif %}
},
"Parameters": [],
"Steps": [
diff --git a/sagemaker-train/tests/unit/train/common_utils/test_metrics_visualizer.py b/sagemaker-train/tests/unit/train/common_utils/test_metrics_visualizer.py
new file mode 100644
index 0000000000..9b7b804055
--- /dev/null
+++ b/sagemaker-train/tests/unit/train/common_utils/test_metrics_visualizer.py
@@ -0,0 +1,121 @@
+"""Unit tests for metrics_visualizer module."""
+import pytest
+from unittest.mock import Mock, patch, MagicMock
+
+
+class TestParseJobArn:
+ def test_training_job_arn(self):
+ from sagemaker.train.common_utils.metrics_visualizer import _parse_job_arn
+ result = _parse_job_arn("arn:aws:sagemaker:us-west-2:123456789012:training-job/my-job")
+ assert result == ("us-west-2", "training-job/my-job")
+
+ def test_processing_job_arn(self):
+ from sagemaker.train.common_utils.metrics_visualizer import _parse_job_arn
+ result = _parse_job_arn("arn:aws:sagemaker:us-east-1:123456789012:processing-job/my-job")
+ assert result == ("us-east-1", "processing-job/my-job")
+
+ def test_invalid_arn_returns_none(self):
+ from sagemaker.train.common_utils.metrics_visualizer import _parse_job_arn
+ assert _parse_job_arn("not-an-arn") is None
+
+
+class TestGetConsoleJobUrl:
+ def test_training_job(self):
+ from sagemaker.train.common_utils.metrics_visualizer import get_console_job_url
+ url = get_console_job_url("arn:aws:sagemaker:us-west-2:123456789012:training-job/my-job")
+ assert url == "https://us-west-2.console.aws.amazon.com/sagemaker/home?region=us-west-2#/jobs/my-job"
+
+ def test_invalid_arn_returns_empty(self):
+ from sagemaker.train.common_utils.metrics_visualizer import get_console_job_url
+ assert get_console_job_url("not-an-arn") == ""
+
+ def test_unknown_job_type_returns_empty(self):
+ from sagemaker.train.common_utils.metrics_visualizer import get_console_job_url
+ assert get_console_job_url("arn:aws:sagemaker:us-west-2:123456789012:unknown-job/my-job") == ""
+
+
+class TestGetCloudwatchLogsUrl:
+ def test_training_job(self):
+ from sagemaker.train.common_utils.metrics_visualizer import get_cloudwatch_logs_url
+ url = get_cloudwatch_logs_url("arn:aws:sagemaker:us-west-2:123456789012:training-job/my-job")
+ assert "us-west-2" in url
+ assert "TrainingJobs" in url
+ assert "my-job" in url
+
+ def test_invalid_arn_returns_empty(self):
+ from sagemaker.train.common_utils.metrics_visualizer import get_cloudwatch_logs_url
+ assert get_cloudwatch_logs_url("not-an-arn") == ""
+
+
+class TestGetStudioUrl:
+ @patch("sagemaker.train.common_utils.metrics_visualizer._get_studio_base_url")
+ @patch("sagemaker.core.utils.utils.SageMakerClient")
+ def test_with_training_job_object(self, mock_client_cls, mock_base_url):
+ from sagemaker.train.common_utils.metrics_visualizer import get_studio_url
+ mock_client_cls.return_value.region_name = "us-west-2"
+ mock_base_url.return_value = "https://studio-d-abc.studio.us-west-2.sagemaker.aws"
+
+ mock_job = Mock()
+ mock_job.training_job_name = "my-job"
+
+ url = get_studio_url(mock_job)
+ assert url == "https://studio-d-abc.studio.us-west-2.sagemaker.aws/jobs/train/my-job"
+ mock_base_url.assert_called_once_with("us-west-2")
+
+ @patch("sagemaker.train.common_utils.metrics_visualizer._get_studio_base_url")
+ def test_with_arn_string(self, mock_base_url):
+ from sagemaker.train.common_utils.metrics_visualizer import get_studio_url
+ mock_base_url.return_value = "https://studio-d-abc.studio.us-west-2.sagemaker.aws"
+
+ url = get_studio_url("arn:aws:sagemaker:us-west-2:123456789012:training-job/my-job")
+ assert url == "https://studio-d-abc.studio.us-west-2.sagemaker.aws/jobs/train/my-job"
+ mock_base_url.assert_called_once_with("us-west-2")
+
+ @patch("sagemaker.train.common_utils.metrics_visualizer._get_studio_base_url")
+ @patch("sagemaker.core.utils.utils.SageMakerClient")
+ @patch("sagemaker.train.common_utils.metrics_visualizer.TrainingJob")
+ def test_with_job_name_string(self, mock_tj_cls, mock_client_cls, mock_base_url):
+ from sagemaker.train.common_utils.metrics_visualizer import get_studio_url
+ mock_client_cls.return_value.region_name = "us-west-2"
+ mock_base_url.return_value = "https://studio-d-abc.studio.us-west-2.sagemaker.aws"
+ mock_tj_cls.get.return_value.training_job_name = "my-job"
+
+ url = get_studio_url("my-job")
+ assert url == "https://studio-d-abc.studio.us-west-2.sagemaker.aws/jobs/train/my-job"
+
+ @patch("sagemaker.train.common_utils.metrics_visualizer._get_studio_base_url")
+ @patch("sagemaker.core.utils.utils.SageMakerClient")
+ def test_returns_empty_when_no_domain(self, mock_client_cls, mock_base_url):
+ from sagemaker.train.common_utils.metrics_visualizer import get_studio_url
+ mock_client_cls.return_value.region_name = "us-west-2"
+ mock_base_url.return_value = ""
+
+ url = get_studio_url(Mock(training_job_name="my-job"))
+ assert url == ""
+
+
+class TestGetAvailableMetrics:
+ @patch("sagemaker.train.common_utils.metrics_visualizer.TrainingJob")
+ def test_returns_empty_when_no_mlflow_config(self, _):
+ from sagemaker.train.common_utils.metrics_visualizer import get_available_metrics
+ mock_job = Mock(spec=[]) # no mlflow_config attribute
+ assert get_available_metrics(mock_job) == []
+
+ @patch("sagemaker.train.common_utils.metrics_visualizer.TrainingJob")
+ def test_returns_empty_when_mlflow_config_falsy(self, _):
+ from sagemaker.train.common_utils.metrics_visualizer import get_available_metrics
+ mock_job = Mock()
+ mock_job.mlflow_config = None
+ assert get_available_metrics(mock_job) == []
+
+ @patch("mlflow.get_run")
+ @patch("mlflow.set_tracking_uri")
+ def test_returns_metric_names(self, mock_set_uri, mock_get_run):
+ from sagemaker.train.common_utils.metrics_visualizer import get_available_metrics
+ mock_job = Mock()
+ mock_job.mlflow_config.mlflow_resource_arn = "arn:aws:sagemaker:us-west-2:123:mlflow-tracking/abc"
+ mock_job.mlflow_details.mlflow_run_id = "run-123"
+ mock_get_run.return_value.data.metrics = {"loss": 0.5, "accuracy": 0.9}
+
+ result = get_available_metrics(mock_job)
+ assert set(result) == {"loss", "accuracy"}
diff --git a/sagemaker-train/tests/unit/train/evaluate/test_pipeline_templates.py b/sagemaker-train/tests/unit/train/evaluate/test_pipeline_templates.py
index 136ac022cd..36092fbeb3 100644
--- a/sagemaker-train/tests/unit/train/evaluate/test_pipeline_templates.py
+++ b/sagemaker-train/tests/unit/train/evaluate/test_pipeline_templates.py
@@ -121,7 +121,6 @@ def test_deterministic_template_with_optional_mlflow_params(self):
pipeline_def = json.loads(rendered)
assert pipeline_def["MlflowConfig"]["MlflowExperimentName"] == "test-experiment"
- assert pipeline_def["MlflowConfig"]["MlflowRunName"] == "test-run"
def test_deterministic_template_with_all_hyperparameters(self):
"""Test DETERMINISTIC_TEMPLATE with all optional hyperparameters."""
diff --git a/v3-examples/model-customization-examples/sm-studio-nova-training-job-sample-notebook.ipynb b/v3-examples/model-customization-examples/sm-studio-nova-training-job-sample-notebook.ipynb
index 4e49266323..6645f2b7d2 100644
--- a/v3-examples/model-customization-examples/sm-studio-nova-training-job-sample-notebook.ipynb
+++ b/v3-examples/model-customization-examples/sm-studio-nova-training-job-sample-notebook.ipynb
@@ -1055,7 +1055,7 @@
],
"metadata": {
"kernelspec": {
- "display_name": "Python 3 (ipykernel)",
+ "display_name": "py3.10.14",
"language": "python",
"name": "python3"
},
@@ -1069,7 +1069,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.11.11"
+ "version": "3.10.14"
}
},
"nbformat": 4,