diff --git a/vertexai/_genai/_datasets_utils.py b/vertexai/_genai/_datasets_utils.py index e063e6802a..bf2ffd7cf2 100644 --- a/vertexai/_genai/_datasets_utils.py +++ b/vertexai/_genai/_datasets_utils.py @@ -21,7 +21,7 @@ import google.auth.credentials from vertexai._genai.types import common -from pydantic import BaseModel +from google.genai import _common METADATA_SCHEMA_URI = ( @@ -31,18 +31,27 @@ _DEFAULT_BQ_DATASET_PREFIX = "vertex_datasets" _DEFAULT_BQ_TABLE_PREFIX = "multimodal_dataset" -T = TypeVar("T", bound=BaseModel) +T = TypeVar("T", bound=_common.BaseModel) -def create_from_response(model_type: Type[T], response: dict[str, Any]) -> T: +def create_from_response( + model_type: Type[T], + response: dict[str, Any], + config: Any | None = None, +) -> T: """Creates a model from a response.""" - model_field_names = model_type.model_fields.keys() - filtered_response = {} - for key, value in response.items(): - snake_key = common.camel_to_snake(key) - if snake_key in model_field_names: - filtered_response[snake_key] = value - return model_type(**filtered_response) + kwargs = ( + { + "config": { + "response_schema": getattr(config, "response_schema", None), + "response_json_schema": getattr(config, "response_json_schema", None), + "include_all_fields": getattr(config, "include_all_fields", None), + } + } + if config + else {} + ) + return model_type._from_response(response=response, kwargs=kwargs) def validate_multimodal_dataset_bigquery_uri( diff --git a/vertexai/_genai/datasets.py b/vertexai/_genai/datasets.py index 046803edf0..87c5fdf542 100644 --- a/vertexai/_genai/datasets.py +++ b/vertexai/_genai/datasets.py @@ -963,7 +963,9 @@ def create_from_bigquery( operation=multimodal_dataset_operation, timeout_seconds=config.timeout, ) - return _datasets_utils.create_from_response(types.MultimodalDataset, response) + return _datasets_utils.create_from_response( + types.MultimodalDataset, response, config + ) def create_from_pandas( self, @@ -1302,6 +1304,7 @@ def assess_tuning_resources( return _datasets_utils.create_from_response( types.TuningResourceUsageAssessmentResult, response["tuningResourceUsageAssessmentResult"], + config, ) def assess_tuning_validity( @@ -1368,6 +1371,7 @@ def assess_tuning_validity( return _datasets_utils.create_from_response( types.TuningValidationAssessmentResult, response["tuningValidationAssessmentResult"], + config, ) def assess_batch_prediction_resources( @@ -1430,7 +1434,7 @@ def assess_batch_prediction_resources( ) result = response["batchPredictionResourceUsageAssessmentResult"] return _datasets_utils.create_from_response( - types.BatchPredictionResourceUsageAssessmentResult, result + types.BatchPredictionResourceUsageAssessmentResult, result, config ) def assess_batch_prediction_validity( @@ -1493,7 +1497,7 @@ def assess_batch_prediction_validity( ) result = response["batchPredictionValidationAssessmentResult"] return _datasets_utils.create_from_response( - types.BatchPredictionValidationAssessmentResult, result + types.BatchPredictionValidationAssessmentResult, result, config ) @@ -2231,7 +2235,9 @@ async def create_from_bigquery( operation=multimodal_dataset_operation, timeout_seconds=config.timeout, ) - return _datasets_utils.create_from_response(types.MultimodalDataset, response) + return _datasets_utils.create_from_response( + types.MultimodalDataset, response, config + ) async def create_from_pandas( self, @@ -2568,6 +2574,7 @@ async def assess_tuning_resources( return _datasets_utils.create_from_response( types.TuningResourceUsageAssessmentResult, response["tuningResourceUsageAssessmentResult"], + config, ) async def assess_tuning_validity( @@ -2634,6 +2641,7 @@ async def assess_tuning_validity( return _datasets_utils.create_from_response( types.TuningValidationAssessmentResult, response["tuningValidationAssessmentResult"], + config, ) async def assess_batch_prediction_resources( @@ -2696,7 +2704,7 @@ async def assess_batch_prediction_resources( ) result = response["batchPredictionResourceUsageAssessmentResult"] return _datasets_utils.create_from_response( - types.BatchPredictionResourceUsageAssessmentResult, result + types.BatchPredictionResourceUsageAssessmentResult, result, config ) async def assess_batch_prediction_validity( @@ -2759,5 +2767,5 @@ async def assess_batch_prediction_validity( ) result = response["batchPredictionValidationAssessmentResult"] return _datasets_utils.create_from_response( - types.BatchPredictionValidationAssessmentResult, result + types.BatchPredictionValidationAssessmentResult, result, config )