diff --git a/.gitignore b/.gitignore index 7428f4f025..f87727e33e 100644 --- a/.gitignore +++ b/.gitignore @@ -40,3 +40,4 @@ env/ sagemaker_train/src/**/container_drivers/sm_train.sh sagemaker_train/src/**/container_drivers/sourcecode.json sagemaker_train/src/**/container_drivers/distributed.json +.kiro diff --git a/sagemaker-core/src/sagemaker/core/resources.py b/sagemaker-core/src/sagemaker/core/resources.py index 66b13e112a..61e0f9c677 100644 --- a/sagemaker-core/src/sagemaker/core/resources.py +++ b/sagemaker-core/src/sagemaker/core/resources.py @@ -35788,7 +35788,7 @@ def stop(self) -> None: ResourceNotFound: Resource being access is not found. """ - client = SageMakerClient().client + client = SageMakerClient().sagemaker_client operation_input_args = { "TrainingJobName": self.training_job_name, @@ -35833,15 +35833,17 @@ def wait( progress.add_task("Waiting for TrainingJob...") status = Status("Current status:") - instance_count = ( - sum( - instance_group.instance_count - for instance_group in self.resource_config.instance_groups - ) - if self.resource_config.instance_groups - and not isinstance(self.resource_config.instance_groups, Unassigned) - else self.resource_config.instance_count - ) + instance_count = 1 # Default + if not isinstance(self.resource_config, Unassigned): + if (hasattr(self.resource_config, 'instance_groups') and + self.resource_config.instance_groups and + not isinstance(self.resource_config.instance_groups, Unassigned)): + instance_count = sum( + instance_group.instance_count + for instance_group in self.resource_config.instance_groups + ) + elif hasattr(self.resource_config, 'instance_count'): + instance_count = self.resource_config.instance_count if logs: multi_stream_logger = MultiLogStreamHandler( diff --git a/sagemaker-train/example_notebooks/evaluate/benchmark_demo.ipynb b/sagemaker-train/example_notebooks/evaluate/benchmark_demo.ipynb index e0133a9272..0719cbbab2 100644 --- a/sagemaker-train/example_notebooks/evaluate/benchmark_demo.ipynb +++ b/sagemaker-train/example_notebooks/evaluate/benchmark_demo.ipynb @@ -442,7 +442,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3", + "display_name": "py3.10.14", "language": "python", "name": "python3" }, @@ -456,7 +456,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.12.12" + "version": "3.10.14" } }, "nbformat": 4, diff --git a/sagemaker-train/pyproject.toml b/sagemaker-train/pyproject.toml index f0acc6077c..648994b5eb 100644 --- a/sagemaker-train/pyproject.toml +++ b/sagemaker-train/pyproject.toml @@ -61,6 +61,11 @@ test = [ "graphene", "IPython" ] +notebook = [ + "ipywidgets>=8.0.0", + "rich>=13.0.0", + "matplotlib>=3.5.0", +] [tool.setuptools.packages.find] where = ["src/"] diff --git a/sagemaker-train/src/sagemaker/train/__init__.py b/sagemaker-train/src/sagemaker/train/__init__.py index 74518dc65a..38a6fda76d 100644 --- a/sagemaker-train/src/sagemaker/train/__init__.py +++ b/sagemaker-train/src/sagemaker/train/__init__.py @@ -56,4 +56,16 @@ def __getattr__(name): elif name == "get_builtin_metrics": from sagemaker.train.evaluate import get_builtin_metrics return get_builtin_metrics + elif name == "plot_training_metrics": + from sagemaker.train.common_utils.metrics_visualizer import plot_training_metrics + return plot_training_metrics + elif name == "get_available_metrics": + from sagemaker.train.common_utils.metrics_visualizer import get_available_metrics + return get_available_metrics + elif name == "get_studio_url": + from sagemaker.train.common_utils.metrics_visualizer import get_studio_url + return get_studio_url + elif name == "get_mlflow_url": + from sagemaker.train.common_utils.trainer_wait import get_mlflow_url + return get_mlflow_url raise AttributeError(f"module '{__name__}' has no attribute '{name}'") diff --git a/sagemaker-train/src/sagemaker/train/common_utils/constants.py b/sagemaker-train/src/sagemaker/train/common_utils/constants.py index 8de3ab4638..b96c58134c 100644 --- a/sagemaker-train/src/sagemaker/train/common_utils/constants.py +++ b/sagemaker-train/src/sagemaker/train/common_utils/constants.py @@ -20,6 +20,7 @@ class _MLflowConstants: # Metric names TOTAL_LOSS_METRIC = 'total_loss' + LOSS_METRIC_KEYWORDS = ('loss',) EPOCH_KEYWORD = 'epoch' # MLflow run tags diff --git a/sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py b/sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py index 3fd17c3ac0..c6e89e19c8 100644 --- a/sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py +++ b/sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py @@ -376,6 +376,7 @@ def _get_fine_tuning_options_and_model_arn(model_name: str, customization_techni except Exception as e: logger.error("Exception getting fine-tuning options: %s", e) + raise def _create_input_channels(dataset: str, content_type: Optional[str] = None, diff --git a/sagemaker-train/src/sagemaker/train/common_utils/metrics_visualizer.py b/sagemaker-train/src/sagemaker/train/common_utils/metrics_visualizer.py new file mode 100644 index 0000000000..fe837a91fc --- /dev/null +++ b/sagemaker-train/src/sagemaker/train/common_utils/metrics_visualizer.py @@ -0,0 +1,288 @@ +"""MLflow metrics visualization utilities for SageMaker training jobs.""" + +import logging +from typing import Optional, List, Dict, Any +from sagemaker.core.resources import TrainingJob + +logger = logging.getLogger(__name__) + + +def _is_in_studio() -> bool: + """Check if running inside SageMaker Studio.""" + from sagemaker.train.common_utils.finetune_utils import _read_domain_id_from_metadata + return _read_domain_id_from_metadata() is not None + + +def _get_studio_base_url(region: str) -> str: + """Get Studio base URL, or empty string if domain not resolvable.""" + from sagemaker.train.common_utils.finetune_utils import _read_domain_id_from_metadata + domain_id = _read_domain_id_from_metadata() + if not domain_id or not region: + return "" + return f"https://studio-{domain_id}.studio.{region}.sagemaker.aws" + + +def _parse_job_arn(job_arn: str): + """Parse a SageMaker job ARN into (region, resource) or None.""" + import re + m = re.match(r'arn:aws(?:-[a-z]+)?:sagemaker:([a-z0-9-]+):\d+:(\S+)', job_arn) + return (m.group(1), m.group(2)) if m else None + + +def get_console_job_url(job_arn: str) -> str: + """Get AWS Console URL for a SageMaker job ARN. + + Args: + job_arn: Full ARN like arn:aws:sagemaker:us-east-1:123:training-job/my-job + + Returns: + Console URL or empty string. + """ + parsed = _parse_job_arn(job_arn) + if not parsed: + return "" + region, resource = parsed + job_type_map = { + "training-job/": "#/jobs/", + "processing-job/": "#/processing-jobs/", + "transform-job/": "#/transform-jobs/", + } + for prefix, fragment in job_type_map.items(): + if resource.startswith(prefix): + job_name = resource.split("/", 1)[1] + return f"https://{region}.console.aws.amazon.com/sagemaker/home?region={region}{fragment}{job_name}" + return "" + + +def get_cloudwatch_logs_url(job_arn: str) -> str: + """Get CloudWatch Logs console URL for a SageMaker job ARN. + + Returns: + CloudWatch console URL or empty string. + """ + parsed = _parse_job_arn(job_arn) + if not parsed: + return "" + region, resource = parsed + log_group_map = { + "training-job/": "/aws/sagemaker/TrainingJobs", + "processing-job/": "/aws/sagemaker/ProcessingJobs", + "transform-job/": "/aws/sagemaker/TransformJobs", + } + for prefix, log_group in log_group_map.items(): + if resource.startswith(prefix): + job_name = resource.split("/", 1)[1] + encoded_group = log_group.replace("/", "$252F") + return ( + f"https://{region}.console.aws.amazon.com/cloudwatch/home?region={region}" + f"#logsV2:log-groups/log-group/{encoded_group}" + f"$3FlogStreamNameFilter$3D{job_name}" + ) + return "" + + +def get_studio_url(training_job, domain_id: str = None) -> str: + """Get SageMaker Studio URL for training job logs. + + Args: + training_job: SageMaker TrainingJob object, job name string, or job ARN string + domain_id: Studio domain ID (e.g., 'd-xxxxxxxxxxxx'). If not provided, attempts to auto-detect + + Returns: + Studio URL pointing to the training job details, or empty string if not resolvable + + Example: + >>> from sagemaker.train import get_studio_url + >>> url = get_studio_url('my-training-job') + >>> url = get_studio_url('arn:aws:sagemaker:us-west-2:123456789:training-job/my-job') + """ + import re + + if isinstance(training_job, str): + arn_match = re.match( + r'arn:aws(?:-[a-z]+)?:sagemaker:([a-z0-9-]+):\d+:training-job/(.+)', + training_job, + ) + if arn_match: + region = arn_match.group(1) + job_name = arn_match.group(2) + else: + # Plain job name — use session region + training_job = TrainingJob.get(training_job_name=training_job) + from sagemaker.core.utils.utils import SageMakerClient + region = SageMakerClient().region_name + job_name = training_job.training_job_name + else: + from sagemaker.core.utils.utils import SageMakerClient + region = SageMakerClient().region_name + job_name = training_job.training_job_name + + base = _get_studio_base_url(region) + if not base: + return "" + return f"{base}/jobs/train/{job_name}" + + +def display_job_links_html(rows: list, as_html: bool = False): + """Render job/resource links with copy-to-clipboard buttons as a Jupyter HTML table. + + Args: + rows: List of dicts, each with keys: + - label (str): Row label (e.g. step name, "Training Job", "MLflow Experiment") + - arn (str): The ARN or URI to display and copy + - url (Optional[str]): Clickable link URL. If None, resolved via get_studio_url for job ARNs. + - url_text (Optional[str]): Link display text. Defaults to "🔗 link" + - url_hint (Optional[str]): Hint text after link. Defaults to "(please sign in to Studio first)" + as_html: If True, return HTML object instead of displaying it. + + Returns: + HTML object if as_html=True, otherwise None. + """ + from IPython.display import display, HTML + import html as html_mod + + html_rows = "" + for row in rows: + escaped_arn = html_mod.escape(row['arn']) + escaped_label = html_mod.escape(row['label']) + + url = row.get('url') + if url is None: + url = get_studio_url(row['arn']) + url_text = row.get('url_text', '🔗 link') + url_hint = row.get('url_hint', '(please sign in to Studio first)') + + link_html = "" + if url: + link_html = ( + f'{html_mod.escape(url_text)}' + f' {html_mod.escape(url_hint)}' + ) + + copy_btn = ( + f'' + ) + + html_rows += ( + f'' + f'{escaped_label}' + f'{link_html}' + f'' + f'{escaped_arn}' + f' {copy_btn}' + f'' + ) + + result = HTML( + f'' + f'' + f'' + f'' + f'' + f'{html_rows}
StepJob LinkJob ARN
' + ) + + if as_html: + return result + display(result) + + +def plot_training_metrics( + training_job: TrainingJob, + metrics: Optional[List[str]] = None, + figsize: tuple = (12, 6) +) -> None: + """Plot training metrics from MLflow for a completed training job. + + Args: + training_job: SageMaker TrainingJob object or job name string + metrics: List of metric names to plot. If None, plots all available metrics. + figsize: Figure size as (width, height) + """ + import matplotlib.pyplot as plt + import mlflow + from mlflow.tracking import MlflowClient + from IPython.display import display + import logging + + logging.getLogger('botocore.credentials').setLevel(logging.WARNING) + + if isinstance(training_job, str): + training_job = TrainingJob.get(training_job_name=training_job) + + run_id = training_job.mlflow_details.mlflow_run_id + + mlflow.set_tracking_uri(training_job.mlflow_config.mlflow_resource_arn) + client = MlflowClient() + + run = mlflow.get_run(run_id) + available_metrics = list(run.data.metrics.keys()) + metrics_to_plot = metrics if metrics else available_metrics + + # Fetch metric histories + metric_data = {} + for metric_name in metrics_to_plot: + history = client.get_metric_history(run_id, metric_name) + if history: + metric_data[metric_name] = history + + # Plot + num_metrics = len(metric_data) + rows = (num_metrics + 1) // 2 + fig, axes = plt.subplots(rows, 2, figsize=(figsize[0], figsize[1] * rows)) + axes = axes.flatten() if num_metrics > 1 else [axes] + + for idx, (metric_name, history) in enumerate(metric_data.items()): + steps = [h.step for h in history] + values = [h.value for h in history] + axes[idx].plot(steps, values, linewidth=2, marker='o', markersize=4) + axes[idx].set_xlabel('Step') + axes[idx].set_ylabel('Value') + axes[idx].set_title(metric_name, fontweight='bold') + axes[idx].grid(True, alpha=0.3) + + for idx in range(len(metric_data), len(axes)): + axes[idx].set_visible(False) + + plt.suptitle(f'Training Metrics: {training_job.training_job_name}', fontweight='bold', fontsize=14) + plt.tight_layout(rect=[0, 0, 1, 0.98]) # Leave small space for suptitle + display(fig) + plt.close() + + +def get_available_metrics(training_job: TrainingJob) -> List[str]: + """Get list of available metrics for a training job. + + Args: + training_job: SageMaker TrainingJob object or job name string + + Returns: + List of metric names + """ + try: + import mlflow + except ImportError: + logger.error("mlflow package not installed") + return [] + + # Handle string input + if isinstance(training_job, str): + training_job = TrainingJob.get(training_job_name=training_job) + + if not hasattr(training_job, 'mlflow_config') or not training_job.mlflow_config: + return [] + + mlflow_details = training_job.mlflow_details + if not mlflow_details or not mlflow_details.mlflow_run_id: + return [] + + mlflow.set_tracking_uri(training_job.mlflow_config.mlflow_resource_arn) + run = mlflow.get_run(mlflow_details.mlflow_run_id) + + return list(run.data.metrics.keys()) diff --git a/sagemaker-train/src/sagemaker/train/common_utils/mlflow_metrics_util.py b/sagemaker-train/src/sagemaker/train/common_utils/mlflow_metrics_util.py index 43cc5e3847..8d15731977 100644 --- a/sagemaker-train/src/sagemaker/train/common_utils/mlflow_metrics_util.py +++ b/sagemaker-train/src/sagemaker/train/common_utils/mlflow_metrics_util.py @@ -154,7 +154,7 @@ def _get_loss_metrics( loss_data = [] for metric_key in run.data.metrics: - if _MLflowConstants.TOTAL_LOSS_METRIC == metric_key.lower(): + if any(kw in metric_key.lower() for kw in _MLflowConstants.LOSS_METRIC_KEYWORDS): metric_history = client.get_metric_history(rid, metric_key) loss_data.append({ 'metric_name': metric_key, @@ -335,7 +335,7 @@ def _get_most_recent_total_loss( for rid, metrics in loss_metrics.items(): for metric in metrics: - if metric['metric_name'].lower() == _MLflowConstants.TOTAL_LOSS_METRIC: + if any(kw in metric['metric_name'].lower() for kw in _MLflowConstants.LOSS_METRIC_KEYWORDS): if metric['history']: # Get the most recent entry (last in history) return metric['history'][-1]['value'] diff --git a/sagemaker-train/src/sagemaker/train/common_utils/trainer_wait.py b/sagemaker-train/src/sagemaker/train/common_utils/trainer_wait.py index 900f13d4c6..59adcdfbfc 100644 --- a/sagemaker-train/src/sagemaker/train/common_utils/trainer_wait.py +++ b/sagemaker-train/src/sagemaker/train/common_utils/trainer_wait.py @@ -40,6 +40,10 @@ def _setup_mlflow_integration(training_job: TrainingJob) -> Tuple[ try: import boto3 + # Check if mlflow_config exists and is assigned + if not hasattr(training_job, 'mlflow_config') or _is_unassigned_attribute(training_job.mlflow_config): + return None, None, None + sm_client = boto3.client('sagemaker') mlflow_arn = training_job.mlflow_config.mlflow_resource_arn @@ -56,7 +60,11 @@ def _setup_mlflow_integration(training_job: TrainingJob) -> Tuple[ return mlflow_url, metrics_util, mlflow_run_name - except Exception: + except Exception as e: + # Log the exception for debugging + import logging + logger = logging.getLogger(__name__) + logger.debug(f"MLflow integration setup failed: {e}") return None, None, None @@ -154,6 +162,59 @@ def _calculate_transition_duration(trans) -> Tuple[str, str]: return duration, check +def get_mlflow_url(training_job) -> str: + """Get presigned MLflow URL for training job experiment. + + Args: + training_job: SageMaker TrainingJob object or job name string + + Returns: + Presigned MLflow URL to experiment (valid for 5 minutes) + + Example: + >>> from sagemaker.train import get_mlflow_url + >>> url = get_mlflow_url('my-training-job') + >>> print(url) + """ + if isinstance(training_job, str): + training_job = TrainingJob.get(training_job_name=training_job) + + if not hasattr(training_job, 'mlflow_config') or _is_unassigned_attribute(training_job.mlflow_config): + raise ValueError("Training job does not have MLflow configured") + + import os + from mlflow.tracking import MlflowClient + import mlflow + from sagemaker.core.utils.utils import SageMakerClient + + mlflow_arn = training_job.mlflow_config.mlflow_resource_arn + exp_name = training_job.mlflow_config.mlflow_experiment_name + + # Get presigned base URL + sm_client = SageMakerClient().sagemaker_client + response = sm_client.create_presigned_mlflow_app_url(Arn=mlflow_arn) + base_url = response.get('AuthorizedUrl') + + # Try to get experiment ID and append to URL + try: + os.environ['MLFLOW_TRACKING_URI'] = mlflow_arn + mlflow.set_tracking_uri(mlflow_arn) + + mlflow_client = MlflowClient(tracking_uri=mlflow_arn) + experiment = mlflow_client.get_experiment_by_name(exp_name) + + if experiment: + # Format: base_url#/experiments/{id} + # The base_url already has /auth?authToken=... + return f"{base_url}#/experiments/{experiment.experiment_id}" + except Exception: + pass + + return base_url + + + + def wait( training_job: TrainingJob, poll: int = 5, @@ -188,32 +249,101 @@ def wait( from rich.console import Group with _suppress_info_logging(): console = Console(force_jupyter=True) + + # MLflow link caching + mlflow_link_cache = {'url': None, 'timestamp': 0, 'error': None} + has_mlflow_config = (hasattr(training_job, 'mlflow_config') and + not _is_unassigned_attribute(training_job.mlflow_config)) + + def get_cached_mlflow_url(): + """Get cached MLflow URL or generate new one if expired.""" + current_time = time.time() + # Regenerate every 4 minutes (before 5-minute expiration) + if mlflow_link_cache['url'] is None or (current_time - mlflow_link_cache['timestamp']) > 240: + try: + mlflow_link_cache['url'] = get_mlflow_url(training_job) + mlflow_link_cache['error'] = None + except Exception as e: + mlflow_link_cache['error'] = str(e) + mlflow_link_cache['timestamp'] = current_time + return mlflow_link_cache['url'] + + # Track last rendered state to avoid unnecessary refreshes + last_status = None + last_secondary_status = None iteration = 0 while True: iteration += 1 - time.sleep(1) - if iteration == poll: + time.sleep(0.5) + if iteration >= poll * 2: training_job.refresh() iteration = 0 - clear_output(wait=True) - + status = training_job.training_job_status secondary_status = training_job.secondary_status elapsed = time.time() - start_time + + # Only re-render if status changed or every 2 seconds (for elapsed time) + should_render = ( + status != last_status or + secondary_status != last_secondary_status or + iteration % 4 == 0 # Every 2 seconds (4 * 0.5s) + ) + + if not should_render: + continue + + last_status = status + last_secondary_status = secondary_status + + clear_output(wait=True) - # Header section with training job name and MLFlow URL + # Header section with training job info header_table = Table(show_header=False, box=None, padding=(0, 1)) header_table.add_column("Property", style="cyan bold", width=20) - header_table.add_column("Value", style="white") + header_table.add_column("Value", style="dim", overflow="fold") + header_table.add_row("TrainingJob Name", f"[bold green]{training_job.training_job_name}[/bold green]") - if mlflow_url: - header_table.add_row("MLFlow URL", - f"[link={mlflow_url}][bold bright_blue underline]{mlflow_run_name}(link valid for 5 mins)[/bright_blue bold underline][/link]") + header_table.add_row("TrainingJob ARN", f"[dim]{training_job.training_job_arn}[/dim]") + + # Build links rows + links_row1 = [] + links_row2 = [] + try: + from sagemaker.train.common_utils.metrics_visualizer import ( + _is_in_studio, get_console_job_url, get_cloudwatch_logs_url, get_studio_url + ) + console_url = get_console_job_url(training_job.training_job_arn) + if console_url: + links_row1.append(f"[bright_blue underline][link={console_url}]🔗 Training Job (Console)[/link][/bright_blue underline]") + if _is_in_studio(): + studio_url = get_studio_url(training_job) + if studio_url: + links_row1.append(f"[bright_blue underline][link={studio_url}]🔗 Training Job (Studio)[/link][/bright_blue underline]") + cw_url = get_cloudwatch_logs_url(training_job.training_job_arn) + if cw_url: + links_row2.append(f"[bright_blue underline][link={cw_url}]🔗 CloudWatch Logs[/link][/bright_blue underline]") + except Exception: + pass + if has_mlflow_config: + cached_url = get_cached_mlflow_url() + if cached_url: + links_row2.append(f"[bright_blue underline][link={cached_url}]🔗 MLflow Experiment[/link][/bright_blue underline]") + elif mlflow_link_cache['error']: + header_table.add_row("MLflow Experiment", f"[red]{mlflow_link_cache['error']}[/red]") + if has_mlflow_config: + exp_name = training_job.mlflow_config.mlflow_experiment_name if hasattr(training_job, 'mlflow_config') else None + if exp_name and not _is_unassigned_attribute(exp_name): + header_table.add_row("MLflow Experiment", f"{exp_name}") + if links_row1: + header_table.add_row("Links", " | ".join(links_row1)) + if links_row2: + header_table.add_row("" if links_row1 else "Links", " | ".join(links_row2)) status_table = Table(show_header=False, box=None, padding=(0, 1)) status_table.add_column("Property", style="cyan bold", width=20) - status_table.add_column("Value", style="white") + status_table.add_column("Value", style="dim") status_table.add_row("Job Status", f"[bold][orange3]{status}[/][/]") status_table.add_row("Secondary Status", f"[bold yellow]{secondary_status}[/bold yellow]") diff --git a/sagemaker-train/src/sagemaker/train/evaluate/execution.py b/sagemaker-train/src/sagemaker/train/evaluate/execution.py index e2388ef313..d5e50f86b5 100644 --- a/sagemaker-train/src/sagemaker/train/evaluate/execution.py +++ b/sagemaker-train/src/sagemaker/train/evaluate/execution.py @@ -447,6 +447,7 @@ class StepDetail(BaseModel): end_time: Optional[str] = Field(None, description="Step end time") display_name: Optional[str] = Field(None, description="Display name for the step") failure_reason: Optional[str] = Field(None, description="Reason for failure if step failed") + job_arn: Optional[str] = Field(None, description="ARN of the underlying job (training, processing, transform, etc.)") class PipelineExecutionStatus(BaseModel): @@ -914,6 +915,7 @@ def wait( from rich.panel import Panel from rich.text import Text from rich.layout import Layout + from rich.console import Group # Create console with Jupyter support console = Console(force_jupyter=True) @@ -924,21 +926,57 @@ def wait( current_status = self.status.overall_status elapsed = time.time() - start_time + # Create header table with pipeline name link + header_table = Table(show_header=False, box=None, padding=(0, 1)) + header_table.add_column("Property", style="cyan bold", width=20) + header_table.add_column("Value", style="dim", overflow="fold") + + # Extract pipeline name and exec_id from execution ARN + pipeline_name = None + exec_id = '' + if self.arn: + arn_parts = self.arn.split('/') + if len(arn_parts) >= 4: + pipeline_name = arn_parts[-3] + exec_id = arn_parts[-1] + # Use execution display name if available, fall back to self.name + display_name = self.name + if self._pipeline_execution: + dn = getattr(self._pipeline_execution, 'pipeline_execution_display_name', None) + if dn and not (hasattr(dn, '__class__') and 'Unassigned' in dn.__class__.__name__): + display_name = dn + header_table.add_row("Evaluation Job", str(display_name)) + + # Build links row + links = [] + try: + from sagemaker.core.utils.utils import SageMakerClient + from sagemaker.train.common_utils.metrics_visualizer import _is_in_studio, _get_studio_base_url + if pipeline_name and _is_in_studio(): + region = SageMakerClient().region_name + base = _get_studio_base_url(region) + if base: + pipeline_url = f"{base}/jobs/evaluation/detail?pipeline_name={pipeline_name}&execution_id={exec_id}" + links.append(f"[bright_blue underline][link={pipeline_url}]🔗 Pipeline Execution (Studio)[/link][/bright_blue underline]") + except Exception: + pass + if links: + header_table.add_row("Links", " | ".join(links)) + # Create main status table status_table = Table(show_header=False, box=None, padding=(0, 1)) status_table.add_column("Property", style="cyan bold", width=20) - status_table.add_column("Value", style="white") + status_table.add_column("Value", style="dim") - status_table.add_row("Overall Status", f"[bold]{current_status}[/bold]") - status_table.add_row("Target Status", f"[bold]{target_status}[/bold]") - status_table.add_row("Elapsed Time", f"{elapsed:.1f}s") + status_table.add_row("Overall Status", f"[bold][orange3]{current_status}[/][/]") + status_table.add_row("Target Status", f"[bold yellow]{target_status}[/bold yellow]") + status_table.add_row("Elapsed Time", f"[bold bright_red]{elapsed:.1f}s[/bold bright_red]") if self.status.failure_reason: status_table.add_row("Failure Reason", f"[red]{self.status.failure_reason}[/red]") # Create steps table if steps exist if self.status.step_details: - # Check if any step has a failure has_failures = any(step.failure_reason for step in self.status.step_details) steps_table = Table(show_header=True, header_style="bold magenta", box=None, padding=(0, 1)) @@ -946,10 +984,10 @@ def wait( steps_table.add_column("Status", style="yellow", width=15) steps_table.add_column("Duration", style="green", width=12) - failed_steps = [] # Track steps with failures for detailed display + failed_steps = [] + job_arn_entries = [] for step in self.status.step_details: - # Calculate duration if both times are available duration = "" if step.start_time and step.end_time: try: @@ -963,7 +1001,6 @@ def wait( elif step.start_time: duration = "Running..." - # Color code status status_display = step.status if "succeeded" in step.status.lower() or "completed" in step.status.lower(): status_display = f"[green]{step.status}[/green]" @@ -972,14 +1009,18 @@ def wait( elif "executing" in step.status.lower() or "running" in step.status.lower(): status_display = f"[yellow]{step.status}[/yellow]" - # Build row data + if step.job_arn: + job_arn_entries.append({ + 'step_name': step.display_name or step.name, + 'job_arn': step.job_arn, + }) + row_data = [ step.display_name or step.name, status_display, duration ] - # Add error indicator if failures exist if has_failures: if step.failure_reason: row_data.append("❌") @@ -989,39 +1030,87 @@ def wait( steps_table.add_row(*row_data) - # Build combined content from rich.console import Group content_parts = [ status_table, - Text(""), # Empty line for spacing + Text(""), Text("Pipeline Steps", style="bold magenta"), steps_table ] - # Add failure details section if there are any failures if failed_steps: - content_parts.append(Text("")) # Empty line + content_parts.append(Text("")) content_parts.append(Text("Step Failure Details", style="bold red")) for step in failed_steps: - content_parts.append(Text("")) # Empty line before each failure + content_parts.append(Text("")) content_parts.append(Text(f"• {step.display_name or step.name}:", style="bold red")) content_parts.append(Text(f" {step.failure_reason}", style="red")) - combined_content = Group(*content_parts) + # Add job links table if any steps have ARNs + if job_arn_entries: + links_table = Table(show_header=True, header_style="bold magenta", box=None, padding=(0, 1)) + links_table.add_column("Step", style="cyan", width=20) + links_table.add_column("Console", style="dim") + from sagemaker.core.utils.utils import SageMakerClient + from sagemaker.train.common_utils.metrics_visualizer import ( + _is_in_studio, _parse_job_arn, _get_studio_base_url, + get_console_job_url, get_cloudwatch_logs_url, + ) + in_studio = _is_in_studio() + studio_base = _get_studio_base_url(SageMakerClient().region_name) if in_studio else "" + if in_studio: + links_table.add_column("Studio", style="dim") + links_table.add_column("Logs", style="dim") + links_table.add_column("Job ARN", style="dim", overflow="fold") + studio_path_map = { + "training-job/": "jobs/train/", + "processing-job/": "jobs/processing/", + "transform-job/": "jobs/transform/", + } + for entry in job_arn_entries: + console_link = "" + logs_link = "" + studio_link = "" + try: + arn = entry['job_arn'] + url = get_console_job_url(arn) + if url: + console_link = f"[bright_blue underline][link={url}]🔗 link[/link][/bright_blue underline]" + cw_url = get_cloudwatch_logs_url(arn) + if cw_url: + logs_link = f"[bright_blue underline][link={cw_url}]🔗 link[/link][/bright_blue underline]" + if in_studio and studio_base: + parsed = _parse_job_arn(arn) + if parsed: + _, resource = parsed + for prefix, path in studio_path_map.items(): + if resource.startswith(prefix): + job_name = resource.split("/", 1)[1] + s_url = f"{studio_base}/{path}{job_name}" + studio_link = f"[bright_blue underline][link={s_url}]🔗 link[/link][/bright_blue underline]" + break + except Exception: + pass + row = [entry['step_name'], console_link] + if in_studio: + row.append(studio_link) + row.extend([logs_link, entry['job_arn']]) + links_table.add_row(*row) + content_parts.append(Text("")) + content_parts.append(Text("Job ARNs", style="bold magenta")) + content_parts.append(links_table) - # Display combined content in a single panel console.print(Panel( - combined_content, - title="[bold blue]Pipeline Execution Status[/bold blue]", - border_style="blue" + Group(header_table, *content_parts), + title="[bold bright_blue]Pipeline Execution Status[/bold bright_blue]", + border_style="orange3" )) else: - # Display only status table if no steps console.print(Panel( - status_table, - title="[bold blue]Pipeline Execution Status[/bold blue]", - border_style="blue" + Group(header_table, status_table), + title="[bold bright_blue]Pipeline Execution Status[/bold bright_blue]", + border_style="orange3" )) if target_status == current_status: @@ -1204,7 +1293,23 @@ def _convert_to_subclass(self, eval_type: EvalType) -> 'EvaluationPipelineExecut execution._pipeline_execution = pipeline_execution_ref return execution - + + @staticmethod + def _extract_job_arn_from_metadata(step) -> Optional[str]: + """Extract the underlying job ARN from a pipeline step's metadata.""" + from sagemaker.train.common_utils.trainer_wait import _is_unassigned_attribute + metadata = getattr(step, 'metadata', None) + if metadata is None or _is_unassigned_attribute(metadata): + return None + for attr in ('training_job', 'processing_job', 'transform_job', 'tuning_job', + 'auto_ml_job', 'compilation_job'): + job_meta = getattr(metadata, attr, None) + if job_meta is not None and not _is_unassigned_attribute(job_meta): + arn = getattr(job_meta, 'arn', None) + if arn and not _is_unassigned_attribute(arn): + return str(arn) + return None + def _update_step_details_from_raw_steps(self, raw_steps: List[Any]) -> None: """Internal method to update step_details from raw pipeline execution steps @@ -1246,7 +1351,8 @@ def _update_step_details_from_raw_steps(self, raw_steps: List[Any]) -> None: start_time=start_time, end_time=end_time, display_name=step_display_name, - failure_reason=failure_reason + failure_reason=failure_reason, + job_arn=self._extract_job_arn_from_metadata(step) ) step_details.append(step_detail) @@ -1256,8 +1362,8 @@ def _update_step_details_from_raw_steps(self, raw_steps: List[Any]) -> None: logger.warning(f"Failed to process pipeline step: {str(e)}") continue - # Update the job's step details - self.status.step_details = step_details + # Update the job's step details (reverse so earliest step appears first) + self.status.step_details = list(reversed(step_details)) # ============================================================================ diff --git a/sagemaker-train/src/sagemaker/train/evaluate/pipeline_templates.py b/sagemaker-train/src/sagemaker/train/evaluate/pipeline_templates.py index 3848ef0d5c..6657e97ef3 100644 --- a/sagemaker-train/src/sagemaker/train/evaluate/pipeline_templates.py +++ b/sagemaker-train/src/sagemaker/train/evaluate/pipeline_templates.py @@ -11,8 +11,7 @@ "Metadata": {}, "MlflowConfig": { "MlflowResourceArn": "{{ mlflow_resource_arn }}"{% if mlflow_experiment_name %}, - "MlflowExperimentName": "{{ mlflow_experiment_name }}"{% endif %}{% if mlflow_run_name %}, - "MlflowRunName": "{{ mlflow_run_name }}"{% endif %} + "MlflowExperimentName": "{{ mlflow_experiment_name }}"{% endif %} }, "Parameters": [], "Steps": [ @@ -531,8 +530,7 @@ "Metadata": {}, "MlflowConfig": { "MlflowResourceArn": "{{ mlflow_resource_arn }}"{% if mlflow_experiment_name %}, - "MlflowExperimentName": "{{ mlflow_experiment_name }}"{% endif %}{% if mlflow_run_name %}, - "MlflowRunName": "{{ mlflow_run_name }}"{% endif %} + "MlflowExperimentName": "{{ mlflow_experiment_name }}"{% endif %} }, "Parameters": [], "Steps": [ @@ -925,8 +923,7 @@ "Metadata": {}, "MlflowConfig": { "MlflowResourceArn": "{{ mlflow_resource_arn }}"{% if mlflow_experiment_name %}, - "MlflowExperimentName": "{{ mlflow_experiment_name }}"{% endif %}{% if mlflow_run_name %}, - "MlflowRunName": "{{ mlflow_run_name }}"{% endif %} + "MlflowExperimentName": "{{ mlflow_experiment_name }}"{% endif %} }, "Parameters": [], "Steps": [ diff --git a/sagemaker-train/tests/unit/train/common_utils/test_metrics_visualizer.py b/sagemaker-train/tests/unit/train/common_utils/test_metrics_visualizer.py new file mode 100644 index 0000000000..9b7b804055 --- /dev/null +++ b/sagemaker-train/tests/unit/train/common_utils/test_metrics_visualizer.py @@ -0,0 +1,121 @@ +"""Unit tests for metrics_visualizer module.""" +import pytest +from unittest.mock import Mock, patch, MagicMock + + +class TestParseJobArn: + def test_training_job_arn(self): + from sagemaker.train.common_utils.metrics_visualizer import _parse_job_arn + result = _parse_job_arn("arn:aws:sagemaker:us-west-2:123456789012:training-job/my-job") + assert result == ("us-west-2", "training-job/my-job") + + def test_processing_job_arn(self): + from sagemaker.train.common_utils.metrics_visualizer import _parse_job_arn + result = _parse_job_arn("arn:aws:sagemaker:us-east-1:123456789012:processing-job/my-job") + assert result == ("us-east-1", "processing-job/my-job") + + def test_invalid_arn_returns_none(self): + from sagemaker.train.common_utils.metrics_visualizer import _parse_job_arn + assert _parse_job_arn("not-an-arn") is None + + +class TestGetConsoleJobUrl: + def test_training_job(self): + from sagemaker.train.common_utils.metrics_visualizer import get_console_job_url + url = get_console_job_url("arn:aws:sagemaker:us-west-2:123456789012:training-job/my-job") + assert url == "https://us-west-2.console.aws.amazon.com/sagemaker/home?region=us-west-2#/jobs/my-job" + + def test_invalid_arn_returns_empty(self): + from sagemaker.train.common_utils.metrics_visualizer import get_console_job_url + assert get_console_job_url("not-an-arn") == "" + + def test_unknown_job_type_returns_empty(self): + from sagemaker.train.common_utils.metrics_visualizer import get_console_job_url + assert get_console_job_url("arn:aws:sagemaker:us-west-2:123456789012:unknown-job/my-job") == "" + + +class TestGetCloudwatchLogsUrl: + def test_training_job(self): + from sagemaker.train.common_utils.metrics_visualizer import get_cloudwatch_logs_url + url = get_cloudwatch_logs_url("arn:aws:sagemaker:us-west-2:123456789012:training-job/my-job") + assert "us-west-2" in url + assert "TrainingJobs" in url + assert "my-job" in url + + def test_invalid_arn_returns_empty(self): + from sagemaker.train.common_utils.metrics_visualizer import get_cloudwatch_logs_url + assert get_cloudwatch_logs_url("not-an-arn") == "" + + +class TestGetStudioUrl: + @patch("sagemaker.train.common_utils.metrics_visualizer._get_studio_base_url") + @patch("sagemaker.core.utils.utils.SageMakerClient") + def test_with_training_job_object(self, mock_client_cls, mock_base_url): + from sagemaker.train.common_utils.metrics_visualizer import get_studio_url + mock_client_cls.return_value.region_name = "us-west-2" + mock_base_url.return_value = "https://studio-d-abc.studio.us-west-2.sagemaker.aws" + + mock_job = Mock() + mock_job.training_job_name = "my-job" + + url = get_studio_url(mock_job) + assert url == "https://studio-d-abc.studio.us-west-2.sagemaker.aws/jobs/train/my-job" + mock_base_url.assert_called_once_with("us-west-2") + + @patch("sagemaker.train.common_utils.metrics_visualizer._get_studio_base_url") + def test_with_arn_string(self, mock_base_url): + from sagemaker.train.common_utils.metrics_visualizer import get_studio_url + mock_base_url.return_value = "https://studio-d-abc.studio.us-west-2.sagemaker.aws" + + url = get_studio_url("arn:aws:sagemaker:us-west-2:123456789012:training-job/my-job") + assert url == "https://studio-d-abc.studio.us-west-2.sagemaker.aws/jobs/train/my-job" + mock_base_url.assert_called_once_with("us-west-2") + + @patch("sagemaker.train.common_utils.metrics_visualizer._get_studio_base_url") + @patch("sagemaker.core.utils.utils.SageMakerClient") + @patch("sagemaker.train.common_utils.metrics_visualizer.TrainingJob") + def test_with_job_name_string(self, mock_tj_cls, mock_client_cls, mock_base_url): + from sagemaker.train.common_utils.metrics_visualizer import get_studio_url + mock_client_cls.return_value.region_name = "us-west-2" + mock_base_url.return_value = "https://studio-d-abc.studio.us-west-2.sagemaker.aws" + mock_tj_cls.get.return_value.training_job_name = "my-job" + + url = get_studio_url("my-job") + assert url == "https://studio-d-abc.studio.us-west-2.sagemaker.aws/jobs/train/my-job" + + @patch("sagemaker.train.common_utils.metrics_visualizer._get_studio_base_url") + @patch("sagemaker.core.utils.utils.SageMakerClient") + def test_returns_empty_when_no_domain(self, mock_client_cls, mock_base_url): + from sagemaker.train.common_utils.metrics_visualizer import get_studio_url + mock_client_cls.return_value.region_name = "us-west-2" + mock_base_url.return_value = "" + + url = get_studio_url(Mock(training_job_name="my-job")) + assert url == "" + + +class TestGetAvailableMetrics: + @patch("sagemaker.train.common_utils.metrics_visualizer.TrainingJob") + def test_returns_empty_when_no_mlflow_config(self, _): + from sagemaker.train.common_utils.metrics_visualizer import get_available_metrics + mock_job = Mock(spec=[]) # no mlflow_config attribute + assert get_available_metrics(mock_job) == [] + + @patch("sagemaker.train.common_utils.metrics_visualizer.TrainingJob") + def test_returns_empty_when_mlflow_config_falsy(self, _): + from sagemaker.train.common_utils.metrics_visualizer import get_available_metrics + mock_job = Mock() + mock_job.mlflow_config = None + assert get_available_metrics(mock_job) == [] + + @patch("mlflow.get_run") + @patch("mlflow.set_tracking_uri") + def test_returns_metric_names(self, mock_set_uri, mock_get_run): + from sagemaker.train.common_utils.metrics_visualizer import get_available_metrics + mock_job = Mock() + mock_job.mlflow_config.mlflow_resource_arn = "arn:aws:sagemaker:us-west-2:123:mlflow-tracking/abc" + mock_job.mlflow_details.mlflow_run_id = "run-123" + mock_get_run.return_value.data.metrics = {"loss": 0.5, "accuracy": 0.9} + + result = get_available_metrics(mock_job) + assert set(result) == {"loss", "accuracy"} diff --git a/sagemaker-train/tests/unit/train/evaluate/test_pipeline_templates.py b/sagemaker-train/tests/unit/train/evaluate/test_pipeline_templates.py index 136ac022cd..36092fbeb3 100644 --- a/sagemaker-train/tests/unit/train/evaluate/test_pipeline_templates.py +++ b/sagemaker-train/tests/unit/train/evaluate/test_pipeline_templates.py @@ -121,7 +121,6 @@ def test_deterministic_template_with_optional_mlflow_params(self): pipeline_def = json.loads(rendered) assert pipeline_def["MlflowConfig"]["MlflowExperimentName"] == "test-experiment" - assert pipeline_def["MlflowConfig"]["MlflowRunName"] == "test-run" def test_deterministic_template_with_all_hyperparameters(self): """Test DETERMINISTIC_TEMPLATE with all optional hyperparameters.""" diff --git a/v3-examples/model-customization-examples/sm-studio-nova-training-job-sample-notebook.ipynb b/v3-examples/model-customization-examples/sm-studio-nova-training-job-sample-notebook.ipynb index 4e49266323..6645f2b7d2 100644 --- a/v3-examples/model-customization-examples/sm-studio-nova-training-job-sample-notebook.ipynb +++ b/v3-examples/model-customization-examples/sm-studio-nova-training-job-sample-notebook.ipynb @@ -1055,7 +1055,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": "py3.10.14", "language": "python", "name": "python3" }, @@ -1069,7 +1069,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.11" + "version": "3.10.14" } }, "nbformat": 4,