Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@
"pyyaml>=5.3.1,<7",
]
datasets_extra_require = [
"pyarrow >= 3.0.0, < 8.0.0; python_version<'3.11'",
"pyarrow >= 3.0.0, < 8.0.0; python_version<'3.10'",
"pyarrow >= 10.0.1; python_version=='3.10'",
"pyarrow >= 10.0.1; python_version=='3.11'",
"pyarrow >= 14.0.0; python_version>='3.12'",
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,17 @@
#
# pylint: disable=protected-access,bad-continuation,missing-function-docstring

import sys
from unittest import mock

from google.cloud import bigquery
from tests.unit.vertexai.genai.replays import pytest_helper
from vertexai._genai import _datasets_utils
from vertexai._genai import types
import pandas as pd
import pytest


METADATA_SCHEMA_URI = (
"gs://google-cloud-aiplatform/schema/dataset/metadata/multimodal_1.0.0.yaml"
)
Expand Down Expand Up @@ -156,6 +159,52 @@ def test_create_dataset_from_pandas(client, is_replay_mode):
pd.testing.assert_frame_equal(rows.to_dataframe(), dataframe)


@pytest.mark.skipif(
sys.version_info < (3, 10), reason="bigframes requires python 3.10 or higher"
)
@pytest.mark.usefixtures("mock_bigquery_client", "mock_import_bigframes")
def test_create_dataset_from_bigframes(client, is_replay_mode):
import bigframes.pandas

dataframe = pd.DataFrame(
{
"col1": ["col1"],
"col2": ["col2"],
}
)
if is_replay_mode:
bf_dataframe = mock.MagicMock()
bf_dataframe.to_gbq.return_value = "temp_table_id"
else:
bf_dataframe = bigframes.pandas.DataFrame(dataframe)

dataset = client.datasets.create_from_bigframes(
dataframe=bf_dataframe,
target_table_id=BIGQUERY_TABLE_NAME,
multimodal_dataset={
"display_name": "test-from-bigframes",
},
)

assert isinstance(dataset, types.MultimodalDataset)
assert dataset.display_name == "test-from-bigframes"
assert dataset.metadata.input_config.bigquery_source.uri == (
f"bq://{BIGQUERY_TABLE_NAME}"
)
if not is_replay_mode:
bigquery_client = bigquery.Client(
project=client._api_client.project,
location=client._api_client.location,
credentials=client._api_client._credentials,
)
rows = bigquery_client.list_rows(
dataset.metadata.input_config.bigquery_source.uri[5:]
)
pd.testing.assert_frame_equal(
rows.to_dataframe(), dataframe, check_index_type=False
)


pytestmark = pytest_helper.setup(
file=__file__,
globals_for_file=globals(),
Expand Down Expand Up @@ -274,3 +323,50 @@ async def test_create_dataset_from_pandas_async(client, is_replay_mode):
dataset.metadata.input_config.bigquery_source.uri[5:]
)
pd.testing.assert_frame_equal(rows.to_dataframe(), dataframe)


@pytest.mark.skipif(
sys.version_info < (3, 10), reason="bigframes requires python 3.10 or higher"
)
@pytest.mark.asyncio
@pytest.mark.usefixtures("mock_bigquery_client", "mock_import_bigframes")
async def test_create_dataset_from_bigframes_async(client, is_replay_mode):
import bigframes.pandas

dataframe = pd.DataFrame(
{
"col1": ["col1"],
"col2": ["col2"],
}
)
if is_replay_mode:
bf_dataframe = mock.MagicMock()
bf_dataframe.to_gbq.return_value = "temp_table_id"
else:
bf_dataframe = bigframes.pandas.DataFrame(dataframe)

dataset = await client.aio.datasets.create_from_bigframes(
dataframe=bf_dataframe,
target_table_id=BIGQUERY_TABLE_NAME,
multimodal_dataset={
"display_name": "test-from-bigframes",
},
)

assert isinstance(dataset, types.MultimodalDataset)
assert dataset.display_name == "test-from-bigframes"
assert dataset.metadata.input_config.bigquery_source.uri == (
f"bq://{BIGQUERY_TABLE_NAME}"
)
if not is_replay_mode:
bigquery_client = bigquery.Client(
project=client._api_client.project,
location=client._api_client.location,
credentials=client._api_client._credentials,
)
rows = bigquery_client.list_rows(
dataset.metadata.input_config.bigquery_source.uri[5:]
)
pd.testing.assert_frame_equal(
rows.to_dataframe(), dataframe, check_index_type=False
)
90 changes: 88 additions & 2 deletions vertexai/_genai/_datasets_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,15 @@
#
"""Utility functions for multimodal dataset."""

import asyncio
from typing import Any, Type, TypeVar
import uuid

import google.auth.credentials
from vertexai._genai.types import common
from pydantic import BaseModel


METADATA_SCHEMA_URI = (
"gs://google-cloud-aiplatform/schema/dataset/metadata/multimodal_1.0.0.yaml"
)
Expand Down Expand Up @@ -169,14 +171,48 @@ def _normalize_and_validate_table_id(
return f"{table_ref.project}.{table_ref.dataset_id}.{table_ref.table_id}"


async def _normalize_and_validate_table_id_async(
*,
table_id: str,
project: str,
location: str,
credentials: google.auth.credentials.Credentials,
) -> str:
bigquery = _try_import_bigquery()

table_ref = bigquery.TableReference.from_string(table_id, default_project=project)
if table_ref.project != project:
raise ValueError(
"The BigQuery table "
f"`{table_ref.project}.{table_ref.dataset_id}.{table_ref.table_id}`"
" must be in the same project as the multimodal dataset."
f" The multimodal dataset is in `{project}`, but the BigQuery table"
f" is in `{table_ref.project}`."
)

dataset_ref = bigquery.DatasetReference(
project=table_ref.project, dataset_id=table_ref.dataset_id
)
client = bigquery.Client(project=project, credentials=credentials)
bq_dataset = await asyncio.to_thread(client.get_dataset, dataset_ref=dataset_ref)
if not _bq_dataset_location_allowed(location, bq_dataset.location):
raise ValueError(
"The BigQuery dataset"
f" `{dataset_ref.project}.{dataset_ref.dataset_id}` must be in the"
" same location as the multimodal dataset. The multimodal dataset"
f" is in `{location}`, but the BigQuery dataset is in"
f" `{bq_dataset.location}`."
)
return f"{table_ref.project}.{table_ref.dataset_id}.{table_ref.table_id}"


def _create_default_bigquery_dataset_if_not_exists(
*,
project: str,
location: str,
credentials: google.auth.credentials.Credentials,
) -> str:
# Loading bigquery lazily to avoid auto-loading it when importing vertexai
from google.cloud import bigquery # pylint: disable=g-import-not-at-top
bigquery = _try_import_bigquery()

bigquery_client = bigquery.Client(project=project, credentials=credentials)
location_str = location.lower().replace("-", "_")
Expand All @@ -189,5 +225,55 @@ def _create_default_bigquery_dataset_if_not_exists(
return f"{dataset_id.project}.{dataset_id.dataset_id}"


async def _create_default_bigquery_dataset_if_not_exists_async(
*,
project: str,
location: str,
credentials: google.auth.credentials.Credentials,
) -> str:
bigquery = _try_import_bigquery()

bigquery_client = bigquery.Client(project=project, credentials=credentials)
location_str = location.lower().replace("-", "_")
dataset_id = bigquery.DatasetReference(
project, f"{_DEFAULT_BQ_DATASET_PREFIX}_{location_str}"
)
dataset = bigquery.Dataset(dataset_ref=dataset_id)
dataset.location = location
await asyncio.to_thread(bigquery_client.create_dataset, dataset, exists_ok=True)
return f"{dataset_id.project}.{dataset_id.dataset_id}"


def _generate_target_table_id(dataset_id: str) -> str:
return f"{dataset_id}.{_DEFAULT_BQ_TABLE_PREFIX}_{str(uuid.uuid4())}"


def save_dataframe_to_bigquery(
dataframe: "bigframes.pandas.DataFrame", # type: ignore # noqa: F821
target_table_id: str,
bq_client: "bigquery.Client", # type: ignore # noqa: F821
) -> None:
# `to_gbq` does not support cross-region use cases. We use `copy_table` as a workaround.
temp_table_id = dataframe.to_gbq()
copy_job = bq_client.copy_table(
sources=temp_table_id,
destination=target_table_id,
)
copy_job.result()
bq_client.delete_table(temp_table_id)


async def save_dataframe_to_bigquery_async(
dataframe: "bigframes.pandas.DataFrame", # type: ignore # noqa: F821
target_table_id: str,
bq_client: "bigquery.Client", # type: ignore # noqa: F821
) -> None:
# `to_gbq` does not support cross-region use cases. We use `copy_table` as a workaround.
temp_table_id = await asyncio.to_thread(dataframe.to_gbq)
copy_job = await asyncio.to_thread(
bq_client.copy_table,
sources=temp_table_id,
destination=target_table_id,
)
await asyncio.to_thread(copy_job.result)
await asyncio.to_thread(bq_client.delete_table, temp_table_id)
Loading
Loading