Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 19 additions & 10 deletions vertexai/_genai/_datasets_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

import google.auth.credentials
from vertexai._genai.types import common
from pydantic import BaseModel
from google.genai import _common


METADATA_SCHEMA_URI = (
Expand All @@ -31,18 +31,27 @@
_DEFAULT_BQ_DATASET_PREFIX = "vertex_datasets"
_DEFAULT_BQ_TABLE_PREFIX = "multimodal_dataset"

T = TypeVar("T", bound=BaseModel)
T = TypeVar("T", bound=_common.BaseModel)


def create_from_response(model_type: Type[T], response: dict[str, Any]) -> T:
def create_from_response(
model_type: Type[T],
response: dict[str, Any],
config: Any | None = None,
) -> T:
"""Creates a model from a response."""
model_field_names = model_type.model_fields.keys()
filtered_response = {}
for key, value in response.items():
snake_key = common.camel_to_snake(key)
if snake_key in model_field_names:
filtered_response[snake_key] = value
return model_type(**filtered_response)
kwargs = (
{
"config": {
"response_schema": getattr(config, "response_schema", None),
"response_json_schema": getattr(config, "response_json_schema", None),
"include_all_fields": getattr(config, "include_all_fields", None),
}
}
if config
else {}
)
return model_type._from_response(response=response, kwargs=kwargs)


def validate_multimodal_dataset_bigquery_uri(
Expand Down
20 changes: 14 additions & 6 deletions vertexai/_genai/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -963,7 +963,9 @@ def create_from_bigquery(
operation=multimodal_dataset_operation,
timeout_seconds=config.timeout,
)
return _datasets_utils.create_from_response(types.MultimodalDataset, response)
return _datasets_utils.create_from_response(
types.MultimodalDataset, response, config
)

def create_from_pandas(
self,
Expand Down Expand Up @@ -1302,6 +1304,7 @@ def assess_tuning_resources(
return _datasets_utils.create_from_response(
types.TuningResourceUsageAssessmentResult,
response["tuningResourceUsageAssessmentResult"],
config,
)

def assess_tuning_validity(
Expand Down Expand Up @@ -1368,6 +1371,7 @@ def assess_tuning_validity(
return _datasets_utils.create_from_response(
types.TuningValidationAssessmentResult,
response["tuningValidationAssessmentResult"],
config,
)

def assess_batch_prediction_resources(
Expand Down Expand Up @@ -1430,7 +1434,7 @@ def assess_batch_prediction_resources(
)
result = response["batchPredictionResourceUsageAssessmentResult"]
return _datasets_utils.create_from_response(
types.BatchPredictionResourceUsageAssessmentResult, result
types.BatchPredictionResourceUsageAssessmentResult, result, config
)

def assess_batch_prediction_validity(
Expand Down Expand Up @@ -1493,7 +1497,7 @@ def assess_batch_prediction_validity(
)
result = response["batchPredictionValidationAssessmentResult"]
return _datasets_utils.create_from_response(
types.BatchPredictionValidationAssessmentResult, result
types.BatchPredictionValidationAssessmentResult, result, config
)


Expand Down Expand Up @@ -2231,7 +2235,9 @@ async def create_from_bigquery(
operation=multimodal_dataset_operation,
timeout_seconds=config.timeout,
)
return _datasets_utils.create_from_response(types.MultimodalDataset, response)
return _datasets_utils.create_from_response(
types.MultimodalDataset, response, config
)

async def create_from_pandas(
self,
Expand Down Expand Up @@ -2568,6 +2574,7 @@ async def assess_tuning_resources(
return _datasets_utils.create_from_response(
types.TuningResourceUsageAssessmentResult,
response["tuningResourceUsageAssessmentResult"],
config,
)

async def assess_tuning_validity(
Expand Down Expand Up @@ -2634,6 +2641,7 @@ async def assess_tuning_validity(
return _datasets_utils.create_from_response(
types.TuningValidationAssessmentResult,
response["tuningValidationAssessmentResult"],
config,
)

async def assess_batch_prediction_resources(
Expand Down Expand Up @@ -2696,7 +2704,7 @@ async def assess_batch_prediction_resources(
)
result = response["batchPredictionResourceUsageAssessmentResult"]
return _datasets_utils.create_from_response(
types.BatchPredictionResourceUsageAssessmentResult, result
types.BatchPredictionResourceUsageAssessmentResult, result, config
)

async def assess_batch_prediction_validity(
Expand Down Expand Up @@ -2759,5 +2767,5 @@ async def assess_batch_prediction_validity(
)
result = response["batchPredictionValidationAssessmentResult"]
return _datasets_utils.create_from_response(
types.BatchPredictionValidationAssessmentResult, result
types.BatchPredictionValidationAssessmentResult, result, config
)
Loading