Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions sdk/ml/azure-ai-ml/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

- Fixed default deployment template check to verify `asset_id` is not None before logging template information.
- Skip _list_secrets for identity-based datastores to prevent noisy telemetry traces.
- Deployment templates `allowed_instance_types` now accepts a list instead of string.

### Other Changes

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from marshmallow import fields, post_load

from azure.ai.ml._schema.assets.environment import AnonymousEnvironmentSchema, EnvironmentSchema
from azure.ai.ml.constants._common import AzureMLResourceType
from azure.ai.ml._schema.core.fields import (
ArmVersionedStr,
NestedField,
Expand All @@ -21,6 +20,7 @@
VersionField,
)
from azure.ai.ml._utils._experimental import experimental
from azure.ai.ml.constants._common import AzureMLResourceType

from .probe_settings_schema import ProbeSettingsSchema
from .request_settings_schema import RequestSettingsSchema
Expand All @@ -44,7 +44,7 @@ class DeploymentTemplateSchema(PathAwareSchema):
readiness_probe = NestedField(ProbeSettingsSchema)
instance_count = fields.Int()
model_mount_path = fields.Str()
allowed_instance_types = fields.Str()
allowed_instance_types = fields.List(fields.Str())
default_instance_type = fields.Str()
environment = UnionField(
[
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,12 @@

from os import PathLike
from pathlib import Path
from typing import Any, Dict, Optional, Union, IO, AnyStr
from typing import IO, Any, AnyStr, Dict, List, Optional, Union

from azure.ai.ml._utils._experimental import experimental
from azure.ai.ml.entities._mixins import RestTranslatableMixin
from azure.ai.ml.entities._assets import Environment

from azure.ai.ml.entities._deployment.deployment_template_settings import OnlineRequestSettings, ProbeSettings
from azure.ai.ml.entities._mixins import RestTranslatableMixin
from azure.ai.ml.entities._resource import Resource


Expand Down Expand Up @@ -69,7 +68,7 @@ def __init__( # pylint: disable=too-many-locals
code_configuration: Optional[Dict[str, Any]] = None,
environment_variables: Optional[Dict[str, str]] = None,
app_insights_enabled: Optional[bool] = None,
allowed_instance_types: Optional[str] = None,
allowed_instance_types: Optional[List[str]] = None,
default_instance_type: Optional[str] = None, # Handle default instance type
scoring_port: Optional[int] = None,
scoring_path: Optional[str] = None,
Expand Down Expand Up @@ -99,6 +98,10 @@ def __init__( # pylint: disable=too-many-locals
self.code_configuration = code_configuration
self.environment_variables = environment_variables
self.app_insights_enabled = app_insights_enabled
if allowed_instance_types is not None and not isinstance(allowed_instance_types, list):
raise TypeError(
"allowed_instance_types must be a list of strings, e.g. ['Standard_DS3_v2', 'Standard_DS4_v2']."
)
self.allowed_instance_types = allowed_instance_types
self.default_instance_type = default_instance_type
self.scoring_port = scoring_port
Expand Down Expand Up @@ -372,8 +375,8 @@ def get_value(source, key, default=None):
type_field = get_value(properties, "type") or get_value(obj, "type")

# Handle string representations from properties - they come as JSON strings
import json
import ast
import json

# Parse tags if it's a string
if isinstance(tags, str):
Expand Down Expand Up @@ -565,16 +568,9 @@ def _to_rest_object(self) -> dict:
if hasattr(self, "app_insights_enabled") and self.app_insights_enabled is not None:
result["appInsightsEnabled"] = self.app_insights_enabled # type: ignore

# Handle allowed instance types - convert string to array format for API
# Handle allowed instance types
if hasattr(self, "allowed_instance_types") and self.allowed_instance_types:
if isinstance(self.allowed_instance_types, str):
# Convert space-separated string to array
instance_types_array = self.allowed_instance_types.split()
elif isinstance(self.allowed_instance_types, list):
instance_types_array = self.allowed_instance_types
else:
instance_types_array = [str(self.allowed_instance_types)]
result["allowedInstanceTypes"] = instance_types_array # type: ignore[assignment]
result["allowedInstanceTypes"] = self.allowed_instance_types # type: ignore[assignment]

return result

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@

from typing import Any, Dict, Iterable, Optional, cast

from azure.ai.ml._scope_dependent_operations import OperationScope, OperationConfig, _ScopeDependentOperations
from azure.ai.ml._scope_dependent_operations import OperationConfig, OperationScope, _ScopeDependentOperations
from azure.ai.ml._telemetry import ActivityType, monitor_with_telemetry_mixin
from azure.ai.ml._utils._experimental import experimental
from azure.ai.ml._utils._logger_utils import OpsLogger
from azure.ai.ml.entities import DeploymentTemplate
from azure.core.tracing.decorator import distributed_trace
from azure.core.exceptions import ResourceNotFoundError
from azure.core.tracing.decorator import distributed_trace

ops_logger = OpsLogger(__name__)
module_logger = ops_logger.module_logger
Expand Down Expand Up @@ -51,10 +51,10 @@ def _get_registry_endpoint(self) -> str:
"""
try:
# Import here to avoid circular dependencies
from azure.ai.ml.operations import RegistryOperations
from azure.ai.ml._restclient.v2022_10_01_preview import (
AzureMachineLearningWorkspaces as ServiceClient102022,
)
from azure.ai.ml.operations import RegistryOperations

# Try to get credential from service client or operation config
credential = None
Expand Down Expand Up @@ -140,9 +140,6 @@ def get_field_value(data: dict, primary_name: str, alt_name: str = None, default

# Handle field name variations for constructor parameters
allowed_instance_types = get_field_value(data, "allowed_instance_types", "allowedInstanceTypes")
if isinstance(allowed_instance_types, str):
# Convert space-separated string to list
allowed_instance_types = allowed_instance_types.split()

default_instance_type = get_field_value(data, "default_instance_type", "defaultInstanceType")
deployment_template_type = get_field_value(data, "deployment_template_type", "deploymentTemplateType")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------

import pytest
from unittest.mock import Mock, patch

import pytest

from azure.ai.ml.entities._deployment.deployment_template import DeploymentTemplate


Expand Down Expand Up @@ -31,7 +33,7 @@ def test_deployment_template_full_init(self):
instance_type="Standard_DS3_v2",
type="deployment_template",
deployment_template_type="model_deployment",
allowed_instance_types="Standard_DS2_v2,Standard_DS3_v2",
allowed_instance_types=["Standard_DS2_v2", "Standard_DS3_v2"],
)

assert template.name == "test-template"
Expand All @@ -44,7 +46,7 @@ def test_deployment_template_full_init(self):
assert template.instance_type == "Standard_DS3_v2"
assert template.type == "deployment_template"
assert template.deployment_template_type == "model_deployment"
assert template.allowed_instance_types == "Standard_DS2_v2,Standard_DS3_v2"
assert template.allowed_instance_types == ["Standard_DS2_v2", "Standard_DS3_v2"]

def test_deployment_template_type_fields(self):
"""Test handling of 'type' and 'deployment_template_type' fields."""
Expand Down Expand Up @@ -230,6 +232,15 @@ def test_deployment_template_empty_values(self):
assert template.properties == {}
assert template.environment_variables == {}

def test_deployment_template_allowed_instance_types_rejects_string(self):
"""Test that allowed_instance_types raises TypeError when given a string."""
with pytest.raises(TypeError, match="allowed_instance_types must be a list of strings"):
DeploymentTemplate(
name="test-template",
version="1.0",
allowed_instance_types="Standard_DS2_v2,Standard_DS3_v2",
)

def test_deployment_template_from_rest_object_none(self):
"""Test _from_rest_object with None input."""
result = DeploymentTemplate._from_rest_object(None)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -680,21 +680,9 @@ def test_convert_dict_to_deployment_template_string_to_int_conversion(self, depl
assert result.instance_count == 3
assert result.scoring_port == 8080

def test_convert_dict_to_deployment_template_space_separated_instance_types(self, deployment_template_ops):
"""Test _convert_dict_to_deployment_template with space-separated allowed_instance_types."""
dict_data = {
"name": "test-template",
"version": "1.0",
"environment": "azureml:test-env:1",
"allowed_instance_types": "Standard_DS2_v2 Standard_DS3_v2 Standard_DS4_v2",
}

result = deployment_template_ops._convert_dict_to_deployment_template(dict_data)

assert result.allowed_instance_types == ["Standard_DS2_v2", "Standard_DS3_v2", "Standard_DS4_v2"]

def test_convert_dict_to_deployment_template_all_fields(self, deployment_template_ops):
"""Test _convert_dict_to_deployment_template with all possible fields."""

dict_data = {
"name": "full-template",
"version": "2.0",
Expand Down
Loading