diff --git a/tests/unit/vertexai/genai/replays/test_create_multimodal_datasets.py b/tests/unit/vertexai/genai/replays/test_create_multimodal_datasets.py index 9fa1711ac9..6280e2a671 100644 --- a/tests/unit/vertexai/genai/replays/test_create_multimodal_datasets.py +++ b/tests/unit/vertexai/genai/replays/test_create_multimodal_datasets.py @@ -210,6 +210,50 @@ def test_create_dataset_from_bigframes(client, is_replay_mode): ) +@pytest.mark.skipif( + sys.version_info < (3, 10), + reason="bigframes requires python 3.10 or higher", +) +@pytest.mark.usefixtures("mock_bigquery_client", "mock_import_bigframes") +def test_create_dataset_from_bigframes_preserves_other_metadata(client, is_replay_mode): + import bigframes.pandas + + dataframe = pd.DataFrame( + { + "col1": ["col1"], + "col2": ["col2"], + } + ) + if is_replay_mode: + bf_dataframe = mock.MagicMock() + bf_dataframe.to_gbq.return_value = "temp_table_id" + else: + bf_dataframe = bigframes.pandas.DataFrame(dataframe) + + dataset = client.datasets.create_from_bigframes( + dataframe=bf_dataframe, + target_table_id=BIGQUERY_TABLE_NAME, + multimodal_dataset={ + "display_name": "test-from-bigframes", + "metadata": { + "gemini_request_read_config": { + "assembled_request_column_name": "test_column" + } + }, + }, + ) + + assert isinstance(dataset, types.MultimodalDataset) + assert dataset.display_name == "test-from-bigframes" + assert ( + dataset.metadata.gemini_request_read_config.assembled_request_column_name + == "test_column" + ) + assert dataset.metadata.input_config.bigquery_source.uri == ( + f"bq://{BIGQUERY_TABLE_NAME}" + ) + + pytestmark = pytest_helper.setup( file=__file__, globals_for_file=globals(), @@ -371,3 +415,50 @@ async def test_create_dataset_from_bigframes_async(client, is_replay_mode): pd.testing.assert_frame_equal( rows.to_dataframe(), dataframe, check_index_type=False ) + + +@pytest.mark.skipif( + sys.version_info < (3, 10), + reason="bigframes requires python 3.10 or higher", +) +@pytest.mark.asyncio +@pytest.mark.usefixtures("mock_bigquery_client", "mock_import_bigframes") +async def test_create_dataset_from_bigframes_preserves_other_metadata_async( + client, is_replay_mode +): + import bigframes.pandas + + dataframe = pd.DataFrame( + { + "col1": ["col1"], + "col2": ["col2"], + } + ) + if is_replay_mode: + bf_dataframe = mock.MagicMock() + bf_dataframe.to_gbq.return_value = "temp_table_id" + else: + bf_dataframe = bigframes.pandas.DataFrame(dataframe) + + dataset = await client.aio.datasets.create_from_bigframes( + dataframe=bf_dataframe, + target_table_id=BIGQUERY_TABLE_NAME, + multimodal_dataset={ + "display_name": "test-from-bigframes", + "metadata": { + "gemini_request_read_config": { + "assembled_request_column_name": "test_column" + } + }, + }, + ) + + assert isinstance(dataset, types.MultimodalDataset) + assert dataset.display_name == "test-from-bigframes" + assert ( + dataset.metadata.gemini_request_read_config.assembled_request_column_name + == "test_column" + ) + assert dataset.metadata.input_config.bigquery_source.uri == ( + f"bq://{BIGQUERY_TABLE_NAME}" + ) diff --git a/vertexai/_genai/datasets.py b/vertexai/_genai/datasets.py index 87c5fdf542..9906d5e44a 100644 --- a/vertexai/_genai/datasets.py +++ b/vertexai/_genai/datasets.py @@ -1083,19 +1083,10 @@ def create_from_bigframes( client, ) + multimodal_dataset = multimodal_dataset.model_copy(deep=True) + multimodal_dataset.set_bigquery_uri(f"bq://{target_table_id}") return self.create_from_bigquery( - multimodal_dataset=multimodal_dataset.model_copy( - update={ - "metadata": types.SchemaTablesDatasetMetadata( - input_config=types.SchemaTablesDatasetMetadataInputConfig( - bigquery_source=types.SchemaTablesDatasetMetadataBigQuerySource( - uri=f"bq://{target_table_id}" - ) - ) - ) - } - ), - config=config, + multimodal_dataset=multimodal_dataset, config=config ) def update_multimodal_dataset( @@ -2357,19 +2348,10 @@ async def create_from_bigframes( client, ) + multimodal_dataset = multimodal_dataset.model_copy(deep=True) + multimodal_dataset.set_bigquery_uri(f"bq://{target_table_id}") return await self.create_from_bigquery( - multimodal_dataset=multimodal_dataset.model_copy( - update={ - "metadata": types.SchemaTablesDatasetMetadata( - input_config=types.SchemaTablesDatasetMetadataInputConfig( - bigquery_source=types.SchemaTablesDatasetMetadataBigQuerySource( - uri=f"bq://{target_table_id}" - ) - ) - ) - } - ), - config=config, + multimodal_dataset=multimodal_dataset, config=config ) async def update_multimodal_dataset(