Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,20 @@ def mock_import_bigframes(is_replay_mode):

@pytest.fixture
def mock_generate_multimodal_dataset_display_name():
with mock.patch.object(
with mock.patch.object(
_datasets_utils, "generate_multimodal_dataset_display_name"
) as mock_generate:
mock_generate.return_value = "test-generated-name"
yield mock_generate
mock_generate.return_value = "test-generated-name"
yield mock_generate


@pytest.fixture
def mock_get_batch_job_unique_name():
with mock.patch.object(
_datasets_utils, "get_batch_job_unique_name"
) as mock_unique_name:
mock_unique_name.return_value = "12345678901234_abcde"
yield mock_unique_name


def test_create_dataset(client):
Expand Down Expand Up @@ -169,43 +178,43 @@ def test_create_dataset_from_pandas(client, is_replay_mode):
)
@pytest.mark.usefixtures("mock_bigquery_client", "mock_import_bigframes")
def test_create_dataset_from_bigframes(client, is_replay_mode):
import bigframes.pandas
import bigframes.pandas

dataframe = pd.DataFrame(
dataframe = pd.DataFrame(
{
"col1": ["col1"],
"col2": ["col2"],
}
)
if is_replay_mode:
bf_dataframe = mock.MagicMock()
bf_dataframe.to_gbq.return_value = "temp_table_id"
else:
bf_dataframe = bigframes.pandas.DataFrame(dataframe)
if is_replay_mode:
bf_dataframe = mock.MagicMock()
bf_dataframe.to_gbq.return_value = "temp_table_id"
else:
bf_dataframe = bigframes.pandas.DataFrame(dataframe)

dataset = client.datasets.create_from_bigframes(
dataset = client.datasets.create_from_bigframes(
dataframe=bf_dataframe,
target_table_id=BIGQUERY_TABLE_NAME,
multimodal_dataset={
"display_name": "test-from-bigframes",
},
)

assert isinstance(dataset, types.MultimodalDataset)
assert dataset.display_name == "test-from-bigframes"
assert dataset.metadata.input_config.bigquery_source.uri == (
assert isinstance(dataset, types.MultimodalDataset)
assert dataset.display_name == "test-from-bigframes"
assert dataset.metadata.input_config.bigquery_source.uri == (
f"bq://{BIGQUERY_TABLE_NAME}"
)
if not is_replay_mode:
bigquery_client = bigquery.Client(
if not is_replay_mode:
bigquery_client = bigquery.Client(
project=client._api_client.project,
location=client._api_client.location,
credentials=client._api_client._credentials,
)
rows = bigquery_client.list_rows(
rows = bigquery_client.list_rows(
dataset.metadata.input_config.bigquery_source.uri[5:]
)
pd.testing.assert_frame_equal(
pd.testing.assert_frame_equal(
rows.to_dataframe(), dataframe, check_index_type=False
)

Expand Down
137 changes: 84 additions & 53 deletions tests/unit/vertexai/genai/test_multimodal_datasets_genai.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,140 +23,171 @@

@pytest.fixture
def mock_import_bigframes():
with mock.patch.object(
with mock.patch.object(
_datasets_utils, "_try_import_bigframes"
) as mock_import_bigframes:
mock_read_gbq_table_result = mock.MagicMock()
mock_read_gbq_table_result.sql = "SELECT * FROM `project.dataset.table`"
mock_read_gbq_table_result = mock.MagicMock()
mock_read_gbq_table_result.sql = "SELECT * FROM `project.dataset.table`"

bigframes = mock.MagicMock()
bigframes.pandas.read_gbq_table.return_value = mock_read_gbq_table_result
bigframes = mock.MagicMock()
bigframes.pandas.read_gbq_table.return_value = mock_read_gbq_table_result

mock_import_bigframes.return_value = bigframes
yield mock_import_bigframes
mock_import_bigframes.return_value = bigframes
yield mock_import_bigframes


@pytest.fixture
def mock_get_batch_job_unique_name():
with mock.patch.object(
_datasets_utils, "get_batch_job_unique_name"
) as mock_unique_name:
mock_unique_name.return_value = "12345678901234_abcde"
yield mock_unique_name


class TestMultimodalDataset:

def test_read_config(self):
dataset = types.MultimodalDataset(
def test_read_config(self):
dataset = types.MultimodalDataset(
metadata={
"gemini_request_read_config": {
"assembled_request_column_name": "test_column",
},
},
)

assert isinstance(dataset.read_config, types.GeminiRequestReadConfig)
assert dataset.read_config.assembled_request_column_name == "test_column"
assert isinstance(dataset.read_config, types.GeminiRequestReadConfig)
assert dataset.read_config.assembled_request_column_name == "test_column"

def test_read_config_empty(self):
dataset = types.MultimodalDataset()
assert dataset.read_config is None
def test_read_config_empty(self):
dataset = types.MultimodalDataset()
assert dataset.read_config is None

def test_set_read_config(self):
dataset = types.MultimodalDataset()
def test_set_read_config(self):
dataset = types.MultimodalDataset()

dataset.set_read_config(
dataset.set_read_config(
read_config={
"assembled_request_column_name": "test_column",
},
)

assert isinstance(dataset, types.MultimodalDataset)
assert (
assert isinstance(dataset, types.MultimodalDataset)
assert (
dataset.metadata.gemini_request_read_config.assembled_request_column_name
== "test_column"
)

def test_set_read_config_preserves_other_fields(self):
dataset = types.MultimodalDataset(
def test_set_read_config_preserves_other_fields(self):
dataset = types.MultimodalDataset(
metadata={
"inputConfig": {
"bigquerySource": {"uri": "bq://test_table"},
},
},
)

dataset.set_read_config(
dataset.set_read_config(
read_config={
"assembled_request_column_name": "test_column",
},
)

assert isinstance(dataset, types.MultimodalDataset)
assert (
assert isinstance(dataset, types.MultimodalDataset)
assert (
dataset.metadata.gemini_request_read_config.assembled_request_column_name
== "test_column"
)
assert dataset.metadata.input_config.bigquery_source.uri == "bq://test_table"
assert dataset.metadata.input_config.bigquery_source.uri == "bq://test_table"

def test_bigquery_uri(self):
dataset = types.MultimodalDataset(
def test_bigquery_uri(self):
dataset = types.MultimodalDataset(
metadata={
"inputConfig": {
"bigquerySource": {"uri": "bq://project.dataset.table"},
},
},
)

assert dataset.bigquery_uri == "bq://project.dataset.table"
assert dataset.bigquery_uri == "bq://project.dataset.table"

def test_bigquery_uri_empty(self):
dataset = types.MultimodalDataset()
assert dataset.bigquery_uri is None
def test_bigquery_uri_empty(self):
dataset = types.MultimodalDataset()
assert dataset.bigquery_uri is None

def test_set_bigquery_uri(self):
dataset = types.MultimodalDataset()
def test_set_bigquery_uri(self):
dataset = types.MultimodalDataset()

dataset.set_bigquery_uri("bq://project.dataset.table")
dataset.set_bigquery_uri("bq://project.dataset.table")

assert isinstance(dataset, types.MultimodalDataset)
assert (
assert isinstance(dataset, types.MultimodalDataset)
assert (
dataset.metadata.input_config.bigquery_source.uri
== "bq://project.dataset.table"
)

def test_set_bigquery_uri_without_prefix(self):
dataset = types.MultimodalDataset()
def test_set_bigquery_uri_without_prefix(self):
dataset = types.MultimodalDataset()

dataset.set_bigquery_uri("project.dataset.table")
dataset.set_bigquery_uri("project.dataset.table")

assert isinstance(dataset, types.MultimodalDataset)
assert (
assert isinstance(dataset, types.MultimodalDataset)
assert (
dataset.metadata.input_config.bigquery_source.uri
== "bq://project.dataset.table"
)

def test_set_bigquery_uri_preserves_other_fields(self):
dataset = types.MultimodalDataset(
def test_set_bigquery_uri_preserves_other_fields(self):
dataset = types.MultimodalDataset(
metadata={
"gemini_request_read_config": {
"assembled_request_column_name": "test_column",
},
},
)

dataset.set_bigquery_uri("bq://test_table")
dataset.set_bigquery_uri("bq://test_table")

assert isinstance(dataset, types.MultimodalDataset)
assert dataset.metadata.input_config.bigquery_source.uri == "bq://test_table"
assert (
assert isinstance(dataset, types.MultimodalDataset)
assert dataset.metadata.input_config.bigquery_source.uri == "bq://test_table"
assert (
dataset.metadata.gemini_request_read_config.assembled_request_column_name
== "test_column"
)

def test_to_bigframes(self, mock_import_bigframes):
dataset = types.MultimodalDataset()
dataset.set_bigquery_uri("bq://project.dataset.table")
def test_to_bigframes(self, mock_import_bigframes):
dataset = types.MultimodalDataset()
dataset.set_bigquery_uri("bq://project.dataset.table")

df = dataset.to_bigframes()
df = dataset.to_bigframes()

assert "project.dataset.table" in df.sql
mock_import_bigframes.return_value.pandas.read_gbq_table.assert_called_once_with(
assert "project.dataset.table" in df.sql
mock_import_bigframes.return_value.pandas.read_gbq_table.assert_called_once_with(
"project.dataset.table"
)

def test_get_batch_job_destination(self, mock_get_batch_job_unique_name):
dataset = types.MultimodalDataset(
name="projects/vertex-sdk-dev/locations/us-central1/datasets/12345",
display_name="test_multimodal_dataset",
metadata={
"inputConfig": {
"bigquerySource": {
"uri": "bq://target_project.target_dataset.target_table"
},
},
},
)
destination = dataset.get_batch_job_destination()
assert (
destination.vertex_dataset.display_name
== "test_multimodal_dataset_batch_output_12345678901234_abcde"
)
assert (
destination.vertex_dataset.bigquery_destination
== "bq://target_project.target_dataset.target_table_batch_output_12345678901234_abcde"
)


class TestGeminiRequestReadConfig:
def test_single_turn_template(self):
Expand Down
11 changes: 9 additions & 2 deletions vertexai/_genai/_datasets_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,8 +229,15 @@ def _generate_target_table_id(dataset_id: str) -> str:


def generate_multimodal_dataset_display_name() -> str:
"""Generates a display name with a timestamp."""
return f"MultimodalDataset {datetime.datetime.now().isoformat(sep=' ')}"
"""Generates a display name with a timestamp."""
return f"MultimodalDataset {datetime.datetime.now().isoformat(sep=' ')}"


def get_batch_job_unique_name() -> str:
"""Generates a unique name suffix for a batch job destination."""
timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
unique_id = uuid.uuid4().hex[0:5]
return f"{timestamp}_{unique_id}"


def save_dataframe_to_bigquery(
Expand Down
22 changes: 22 additions & 0 deletions vertexai/_genai/types/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -14094,6 +14094,28 @@ def to_bigframes(
raise ValueError("Multimodal dataset bigquery source uri is not set.")
return bigframes.pandas.read_gbq_table(self.bigquery_uri.removeprefix("bq://"))

def to_batch_job_source(self) -> "genai_types.BatchJobSource":
"""Converts the dataset to a BatchJobSource."""
return genai_types.BatchJobSource(
vertex_dataset_name=self.name,
)

def get_batch_job_destination(self) -> "genai_types.BatchJobDestination":
"""Converts the dataset to a BatchJobDestination."""
from .. import _datasets_utils

unique_name = _datasets_utils.get_batch_job_unique_name()
bigquery_uri = self.bigquery_uri
if bigquery_uri is None:
raise ValueError("Multimodal dataset bigquery source uri is not set.")
curr_display_name = self.display_name or "genai_batch_job"
return genai_types.BatchJobDestination(
vertex_dataset=genai_types.VertexMultimodalDatasetDestination(
display_name=f"{curr_display_name}_batch_output_{unique_name}",
bigquery_destination=f"{bigquery_uri}_batch_output_{unique_name}",
)
)


class MultimodalDatasetDict(TypedDict, total=False):
"""Represents a multimodal dataset."""
Expand Down
Loading