diff --git a/agentplatform/__init__.py b/agentplatform/__init__.py new file mode 100644 index 0000000000..913bdcffb7 --- /dev/null +++ b/agentplatform/__init__.py @@ -0,0 +1,24 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""The agentplatform module.""" + +from google.cloud.aiplatform import init +from google.cloud.aiplatform import version as aiplatform_version + +__version__ = aiplatform_version.__version__ + +__all__ = [ + "init", +] diff --git a/agentplatform/batch_prediction/__init__.py b/agentplatform/batch_prediction/__init__.py new file mode 100644 index 0000000000..7d2817228b --- /dev/null +++ b/agentplatform/batch_prediction/__init__.py @@ -0,0 +1,25 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""Classes for batch prediction.""" + +# We just want to re-export certain classes +# pylint: disable=g-multiple-import,g-importing-member +from agentplatform.batch_prediction._batch_prediction import ( + BatchPredictionJob, +) + +__all__ = [ + "BatchPredictionJob", +] diff --git a/agentplatform/batch_prediction/_batch_prediction.py b/agentplatform/batch_prediction/_batch_prediction.py new file mode 100644 index 0000000000..7c9307d950 --- /dev/null +++ b/agentplatform/batch_prediction/_batch_prediction.py @@ -0,0 +1,428 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Class to support Batch Prediction with GenAI models.""" +# pylint: disable=protected-access + +import logging +import re +from typing import List, Optional, Union + +from google.cloud.aiplatform import base as aiplatform_base +from google.cloud.aiplatform import initializer as aiplatform_initializer +from google.cloud.aiplatform import jobs +from google.cloud.aiplatform import models +from google.cloud.aiplatform import utils as aiplatform_utils +from google.cloud.aiplatform_v1 import types as gca_types + +from google.rpc import status_pb2 + + +_LOGGER = aiplatform_base.Logger(__name__) + +_GEMINI_MODEL_PATTERN = r"publishers/google/models/gemini" +_LLAMA_MODEL_PATTERN = r"publishers/meta/models/llama" +_CLAUDE_MODEL_PATTERN = r"publishers/anthropic/models/claude" +_GPT_MODEL_PATTERN = r"publishers/openai/models/gpt" +_QWEN_MODEL_PATTERN = r"publishers/qwen/models/qwen" +_DEEPSEEK_MODEL_PATTERN = r"publishers/deepseek-ai/models/deepseek" +_E5_MODEL_PATTERN = r"publishers/intfloat/models/multilingual" +_GEMINI_TUNED_MODEL_PATTERN = r"^projects/[0-9]+?/locations/[0-9a-z-]+?/models/[0-9]+?$" + + +class BatchPredictionJob(aiplatform_base._VertexAiResourceNounPlus): + """Represents a BatchPredictionJob that runs with GenAI models.""" + + _resource_noun = "batchPredictionJobs" + _getter_method = "get_batch_prediction_job" + _list_method = "list_batch_prediction_jobs" + _delete_method = "delete_batch_prediction_job" + _job_type = "batch-predictions" + _parse_resource_name_method = "parse_batch_prediction_job_path" + _format_resource_name_method = "batch_prediction_job_path" + + client_class = aiplatform_utils.JobClientWithOverride + + def __init__(self, batch_prediction_job_name: str): + """Retrieves a BatchPredictionJob resource that runs with a GenAI model. + + Args: + batch_prediction_job_name (str): + Required. A fully-qualified BatchPredictionJob resource name or + ID. Example: "projects/.../locations/.../batchPredictionJobs/456" + or "456" when project and location are initialized. + + Raises: + ValueError: If batch_prediction_job_name represents a BatchPredictionJob + resource that runs with another type of model. + """ + super().__init__(resource_name=batch_prediction_job_name) + self._gca_resource = self._get_gca_resource( + resource_name=batch_prediction_job_name + ) + if not self._is_genai_model(self.model_name): + raise ValueError( + f"BatchPredictionJob '{batch_prediction_job_name}' " + f"runs with the model '{self.model_name}', " + "which is not a GenAI model." + ) + + @property + def model_name(self) -> str: + """Returns the model name used for this batch prediction job.""" + return self._gca_resource.model + + @property + def state(self) -> gca_types.JobState: + """Returns the state of this batch prediction job.""" + return self._gca_resource.state + + @property + def has_ended(self) -> bool: + """Returns true if this batch prediction job has ended.""" + return self.state in jobs._JOB_COMPLETE_STATES + + @property + def has_succeeded(self) -> bool: + """Returns true if this batch prediction job has succeeded.""" + return self.state == gca_types.JobState.JOB_STATE_SUCCEEDED + + @property + def error(self) -> Optional[status_pb2.Status]: + """Returns detailed error info for this Job resource.""" + return self._gca_resource.error + + @property + def output_location(self) -> str: + """Returns the output location of this batch prediction job.""" + return ( + self._gca_resource.output_info.gcs_output_directory + or self._gca_resource.output_info.bigquery_output_table + ) + + @classmethod + def submit( + cls, + source_model: str, + input_dataset: Union[str, List[str]], + *, + output_uri_prefix: Optional[str] = None, + job_display_name: Optional[str] = None, + machine_type: Optional[str] = None, + accelerator_type: Optional[str] = None, + accelerator_count: Optional[int] = None, + starting_replica_count: Optional[int] = None, + max_replica_count: Optional[int] = None, + ) -> "BatchPredictionJob": + """Submits a batch prediction job for a GenAI model. + + Args: + source_model (str): + A GenAI model name or a tuned model name for batch prediction. + Supported formats for model name: "gemini-1.0-pro", + "models/gemini-1.0-pro", and "publishers/google/models/gemini-1.0-pro" + Supported formats for tuned model name: "789" and + "projects/123/locations/456/models/789" + input_dataset (Union[str,List[str]]): + GCS URI(-s) or BigQuery URI to your input data to run batch + prediction on. Example: "gs://path/to/input/data.jsonl" or + "bq://projectId.bqDatasetId.bqTableId" + output_uri_prefix (str): + GCS or BigQuery URI prefix for the output predictions. Example: + "gs://path/to/output/data" or "bq://projectId.bqDatasetId" + If not specified, f"{STAGING_BUCKET}/gen-ai-batch-prediction" will + be used for GCS source and + f"bq://projectId.gen_ai_batch_prediction.predictions_{TIMESTAMP}" + will be used for BigQuery source. + job_display_name (str): + The user-defined name of the BatchPredictionJob. + The name can be up to 128 characters long and can be consist + of any UTF-8 characters. + machine_type (str): + The type of machine for running batch prediction job. + accelerator_type (str): + The type of accelerator for running batch prediction job. + accelerator_count (int): + The number of accelerators for running batch prediction job. + starting_replica_count (int): + The starting number of replica for running batch prediction job. + max_replica_count (int): + The maximum number of replica for running batch prediction job. + + Returns: + Instantiated BatchPredictionJob. + + Raises: + ValueError: If source_model is not a GenAI model. + Or if input_dataset or output_uri_prefix are not in supported formats. + Or if output_uri_prefix is not specified and staging_bucket is not + set in agentplatform.init(). + """ + # Handle model name + model_name = cls._reconcile_model_name(source_model) + if not cls._is_genai_model(model_name): + raise ValueError(f"Model '{model_name}' is not a Generative AI model.") + + # Handle input URI + gcs_source = None + bigquery_source = None + first_input_uri = ( + input_dataset if isinstance(input_dataset, str) else input_dataset[0] + ) + if first_input_uri.startswith("gs://"): + gcs_source = input_dataset + elif first_input_uri.startswith("bq://"): + if not isinstance(input_dataset, str): + raise ValueError("Multiple BigQuery input datasets are not supported.") + bigquery_source = input_dataset + else: + raise ValueError( + f"Unsupported input URI: {input_dataset}. " + "Supported formats: 'gs://path/to/input/data.jsonl' and " + "'bq://projectId.bqDatasetId.bqTableId'" + ) + + # Handle output URI + gcs_destination_prefix = None + bigquery_destination_prefix = None + if output_uri_prefix: + if output_uri_prefix.startswith("gs://"): + gcs_destination_prefix = output_uri_prefix + elif output_uri_prefix.startswith("bq://"): + # Temporarily handle this in SDK, will remove once b/338423462 is fixed. + bigquery_destination_prefix = cls._complete_bq_uri(output_uri_prefix) + else: + raise ValueError( + f"Unsupported output URI: {output_uri_prefix}. " + "Supported formats: 'gs://path/to/output/data' and " + "'bq://projectId.bqDatasetId'" + ) + else: + if first_input_uri.startswith("gs://"): + if not aiplatform_initializer.global_config.staging_bucket: + raise ValueError( + "Please either specify output_uri_prefix or " + "set staging_bucket in agentplatform.init()." + ) + gcs_destination_prefix = ( + aiplatform_initializer.global_config.staging_bucket.rstrip("/") + + "/gen-ai-batch-prediction" + ) + else: + bigquery_destination_prefix = cls._complete_bq_uri() + + # Reuse aiplatform class to submit the job (override _LOGGER) + logging.getLogger("google.cloud.aiplatform.jobs").disabled = True + try: + aiplatform_job = jobs.BatchPredictionJob.submit( + model_name=model_name, + job_display_name=job_display_name, + gcs_source=gcs_source, + bigquery_source=bigquery_source, + gcs_destination_prefix=gcs_destination_prefix, + bigquery_destination_prefix=bigquery_destination_prefix, + machine_type=machine_type, + accelerator_type=accelerator_type, + accelerator_count=accelerator_count, + starting_replica_count=starting_replica_count, + max_replica_count=max_replica_count, + ) + job = cls._empty_constructor() + job._gca_resource = aiplatform_job._gca_resource + + _LOGGER.log_create_complete( + cls, job._gca_resource, "job", module_name="batch_prediction" + ) + _LOGGER.info("View Batch Prediction Job:\n%s" % job._dashboard_uri()) + + return job + finally: + logging.getLogger("google.cloud.aiplatform.jobs").disabled = False + + def refresh(self) -> "BatchPredictionJob": + """Refreshes the batch prediction job from the service.""" + self._sync_gca_resource() + return self + + def cancel(self): + """Cancels this BatchPredictionJob. + + Success of cancellation is not guaranteed. Use `job.refresh()` and + `job.state` to verify if cancellation was successful. + """ + _LOGGER.log_action_start_against_resource("Cancelling", "run", self) + self.api_client.cancel_batch_prediction_job(name=self.resource_name) + + def delete(self): + """Deletes this BatchPredictionJob resource. + + WARNING: This deletion is permanent. + """ + self._delete() + + @classmethod + def list(cls, filter=None) -> List["BatchPredictionJob"]: + """Lists all BatchPredictionJob instances that run with GenAI models.""" + return cls._list( + cls_filter=lambda gca_resource: cls._is_genai_model(gca_resource.model), + filter=filter, + ) + + def _dashboard_uri(self) -> Optional[str]: + """Returns the Google Cloud console URL where job can be viewed.""" + fields = self._parse_resource_name(self.resource_name) + location = fields.pop("location") + project = fields.pop("project") + job = list(fields.values())[0] + return ( + "https://console.cloud.google.com/ai/platform/locations/" + f"{location}/{self._job_type}/{job}?project={project}" + ) + + @classmethod + def _reconcile_model_name(cls, model_name: str) -> str: + """Reconciles model name to a publisher model resource name or a tuned model resource name.""" + if not model_name: + raise ValueError("model_name must not be empty") + + if "/" not in model_name: + # model name (e.g., gemini-1.0-pro) + if model_name.startswith("gemini"): + return "publishers/google/models/" + model_name + else: + raise ValueError( + "Abbreviated model names are only supported for Gemini models. " + "Please provide the full publisher model name." + ) + elif model_name.startswith("models/"): + # publisher model name (e.g., models/gemini-1.0-pro) + return "publishers/google/" + model_name + elif ( + re.match( + r"^publishers/(?P[^/]+)/models/(?P[^@]+)@(?P[^@]+)$", + model_name, + ) + or model_name.startswith("publishers/google/models/") + or model_name.startswith("publishers/meta/models/") + or model_name.startswith("publishers/anthropic/models/") + or model_name.startswith("publishers/openai/models/") + or model_name.startswith("publishers/qwen/models/") + or model_name.startswith("publishers/deepseek-ai/models/") + or model_name.startswith("publishers/intfloat/models/") + or re.search(_GEMINI_TUNED_MODEL_PATTERN, model_name) + ): + return model_name + else: + raise ValueError(f"Invalid format for model name: {model_name}.") + + @classmethod + def _is_genai_model(cls, model_name: str) -> bool: + """Validates if a given model_name represents a GenAI model.""" + if re.search(_GEMINI_MODEL_PATTERN, model_name): + # Model is a Gemini model. + return True + + if re.search(_GEMINI_TUNED_MODEL_PATTERN, model_name): + model = models.Model(model_name) + if ( + model.gca_resource.model_source_info.source_type + == gca_types.model.ModelSourceInfo.ModelSourceType.GENIE + ): + # Model is a tuned Gemini model. + return True + + if re.search(_LLAMA_MODEL_PATTERN, model_name): + # Model is a Llama3 model. + return True + + if re.search(_CLAUDE_MODEL_PATTERN, model_name): + # Model is a claude model. + return True + + if re.search(_GPT_MODEL_PATTERN, model_name): + # Model is a GPT model. + return True + + if re.search(_QWEN_MODEL_PATTERN, model_name): + # Model is a Qwen model. + return True + + if re.search(_DEEPSEEK_MODEL_PATTERN, model_name): + # Model is a DeepSeek model. + return True + + if re.search(_E5_MODEL_PATTERN, model_name): + # Model is an E5 model. + return True + + if re.match( + r"^publishers/(?P[^/]+)/models/(?P[^@]+)@(?P[^@]+)$", + model_name, + ): + # Model is a self-hosted model. + return True + + return False + + @classmethod + def num_pending_jobs(cls) -> int: + """Returns the number of pending batch prediction jobs. + + The pending jobs are those defined in _JOB_PENDING_STATES from + google/cloud/aiplatform/jobs.py + e.g. JOB_STATE_QUEUED, JOB_STATE_PENDING, JOB_STATE_RUNNING, + JOB_STATE_CANCELLING, JOB_STATE_UPDATING. + It will be used to manage the number of concurrent batch that is limited + according to + https://cloud.google.com/vertex-ai/generative-ai/docs/quotas#concurrent-batch-requests + """ + return len( + cls._list( + cls_filter=lambda gca_resource: cls._is_genai_model(gca_resource.model), + filter=" OR ".join( + f'state="{pending_state.name}"' + for pending_state in jobs._JOB_PENDING_STATES + ), + ) + ) + + @classmethod + def _complete_bq_uri(cls, uri: Optional[str] = None): + """Completes a BigQuery uri to a BigQuery table uri.""" + uri_parts = uri.split(".") if uri else [] + uri_len = len(uri_parts) + if len(uri_parts) > 3: + raise ValueError( + f"Invalid URI: {uri}. " + "Supported formats: 'bq://projectId.bqDatasetId.bqTableId'" + ) + + schema_and_project = ( + uri_parts[0] + if uri_len >= 1 + else f"bq://{aiplatform_initializer.global_config.project}" + ) + if not schema_and_project.startswith("bq://"): + raise ValueError("URI must start with 'bq://'") + + dataset = uri_parts[1] if uri_len >= 2 else "gen_ai_batch_prediction" + + table = ( + uri_parts[2] + if uri_len >= 3 + else f"predictions_{aiplatform_utils.timestamped_unique_name()}" + ) + + return f"{schema_and_project}.{dataset}.{table}" diff --git a/setup.py b/setup.py index a1108e7181..b47b2e4869 100644 --- a/setup.py +++ b/setup.py @@ -36,7 +36,9 @@ packages = [ package for package in setuptools.PEP420PackageFinder.find() - if package.startswith("google") or package.startswith("vertexai") + if package.startswith("google") + or package.startswith("vertexai") + or package.startswith("agentplatform") ] # Add vertex_ray relative packages diff --git a/tests/unit/agentplatform/__init__.py b/tests/unit/agentplatform/__init__.py new file mode 100644 index 0000000000..046333aa26 --- /dev/null +++ b/tests/unit/agentplatform/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/tests/unit/agentplatform/conftest.py b/tests/unit/agentplatform/conftest.py new file mode 100644 index 0000000000..1480b0a7b6 --- /dev/null +++ b/tests/unit/agentplatform/conftest.py @@ -0,0 +1,158 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +import shutil +import tempfile +from typing import Any +from unittest import mock + +from google import auth +from google.auth import credentials as auth_credentials +from google.cloud import storage +from google.cloud.aiplatform import base as aiplatform_base +import pytest + + +_TEST_PROJECT = "test-project" +_TEST_LOCATION = "us-central1" +_TEST_PARENT = f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}" +_TEST_BUCKET_NAME = "gs://test-bucket" + + +@pytest.fixture(scope="module") +def google_auth_mock(): + with mock.patch.object(auth, "default") as auth_mock: + auth_mock.return_value = ( + auth_credentials.AnonymousCredentials(), + _TEST_PROJECT, + ) + yield auth_mock + + +@pytest.fixture +def generate_display_name_mock(): + with mock.patch.object( + aiplatform_base.VertexAiResourceNoun, "_generate_display_name" + ) as generate_display_name_mock: + generate_display_name_mock.return_value = "test-display-name" + yield generate_display_name_mock + + +@pytest.fixture +def mock_storage_blob(): + """Mocks the storage Blob API. + + Replaces the Blob factory method by a simpler method that records the + destination_file_uri and, instead of uploading the file to gcs, copying it + to the fake local file system. + """ + + class MockStorageBlob: + """Mocks storage.Blob.""" + + def __init__(self, destination_file_uri: str, client: Any): + del client + self.destination_file_uri = destination_file_uri + + @classmethod + def from_string(cls, destination_file_uri: str, client: Any): + if destination_file_uri.startswith("gs://"): + # Do not copy files to gs:// since it's not a valid path in the fake + # filesystem. + destination_file_uri = destination_file_uri.split("/")[-1] + return cls(destination_file_uri, client) + + @classmethod + def from_uri(cls, destination_file_uri: str, client: Any): + return cls.from_string(destination_file_uri, client) + + def upload_from_filename(self, filename: str): + shutil.copy(filename, self.destination_file_uri) + + def download_to_filename(self, filename: str): + """To be replaced by an implementation of testing needs.""" + raise NotImplementedError + + with mock.patch.object(storage, "Blob", new=MockStorageBlob) as storage_blob: + yield storage_blob + + +@pytest.fixture +def mock_storage_blob_tmp_dir(tmp_path): + """Mocks the storage Blob API. + + Replaces the Blob factory method by a simpler method that records the + destination_file_uri and, instead of uploading the file to gcs, copying it + to a temporaray path in the local file system. + """ + + class MockStorageBlob: + """Mocks storage.Blob.""" + + def __init__(self, destination_file_uri: str, client: Any): + del client + self.destination_file_uri = destination_file_uri + + @classmethod + def from_string(cls, destination_file_uri: str, client: Any): + if destination_file_uri.startswith("gs://"): + # Do not copy files to gs:// since it's not a valid path in the fake + # filesystem. + destination_file_uri = os.fspath( + tmp_path / destination_file_uri.split("/")[-1] + ) + return cls(destination_file_uri, client) + + @classmethod + def from_uri(cls, destination_file_uri: str, client: Any): + return cls.from_string(destination_file_uri, client) + + def upload_from_filename(self, filename: str): + shutil.copy(filename, self.destination_file_uri) + + def download_to_filename(self, filename: str): + """To be replaced by an implementation of testing needs.""" + raise NotImplementedError + + with mock.patch.object(storage, "Blob", new=MockStorageBlob) as storage_blob: + yield storage_blob + + +@pytest.fixture +def mock_gcs_upload(): + def fake_upload_to_gcs(local_filename: str, gcs_destination: str): + if gcs_destination.startswith("gs://") or gcs_destination.startswith("gcs/"): + raise ValueError("Please don't use the real gcs path with mock_gcs_upload.") + # instead of upload, just copy the file. + shutil.copyfile(local_filename, gcs_destination) + + with mock.patch( + "google.cloud.aiplatform.aiplatform.utils.gcs_utils.upload_to_gcs", + new=fake_upload_to_gcs, + ) as gcs_upload: + yield gcs_upload + + +@pytest.fixture +def mock_temp_dir(): + with mock.patch.object(tempfile, "TemporaryDirectory") as temp_dir_mock: + yield temp_dir_mock + + +@pytest.fixture +def mock_named_temp_file(): + with mock.patch.object(tempfile, "NamedTemporaryFile") as named_temp_file_mock: + yield named_temp_file_mock diff --git a/tests/unit/vertexai/test_batch_prediction.py b/tests/unit/agentplatform/test_batch_prediction.py similarity index 98% rename from tests/unit/vertexai/test_batch_prediction.py rename to tests/unit/agentplatform/test_batch_prediction.py index 0394683410..14ce86e4d4 100644 --- a/tests/unit/vertexai/test_batch_prediction.py +++ b/tests/unit/agentplatform/test_batch_prediction.py @@ -1,6 +1,4 @@ -# -*- coding: utf-8 -*- - -# Copyright 2024 Google LLC +# Copyright 2026 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -22,7 +20,7 @@ from unittest import mock from google.cloud import aiplatform -import vertexai +import agentplatform from google.cloud.aiplatform import base as aiplatform_base from google.cloud.aiplatform import initializer as aiplatform_initializer from google.cloud.aiplatform.compat.services import ( @@ -35,8 +33,7 @@ job_state as gca_job_state_compat, model as gca_model, ) -from vertexai.preview import batch_prediction -from vertexai.generative_models import GenerativeModel +from agentplatform import batch_prediction _TEST_PROJECT = "test-project" @@ -93,7 +90,6 @@ ) -# TODO(b/339230025) Mock the whole service instead of methods. @pytest.fixture def generate_display_name_mock(): with mock.patch.object( @@ -355,8 +351,8 @@ class TestBatchPredictionJob: def setup_method(self): importlib.reload(aiplatform_initializer) importlib.reload(aiplatform) - importlib.reload(vertexai) - vertexai.init( + importlib.reload(agentplatform) + agentplatform.init( project=_TEST_PROJECT, location=_TEST_LOCATION, ) @@ -564,10 +560,9 @@ def test_submit_batch_prediction_job_with_bq_input( def test_submit_batch_prediction_job_with_gcs_input_without_output_uri_prefix( self, create_batch_prediction_job_mock ): - vertexai.init(staging_bucket=_TEST_BUCKET) - model = GenerativeModel(_TEST_GEMINI_MODEL_NAME) + agentplatform.init(staging_bucket=_TEST_BUCKET) job = batch_prediction.BatchPredictionJob.submit( - source_model=model, + source_model=_TEST_GEMINI_MODEL_NAME, input_dataset=[_TEST_GCS_INPUT_URI, _TEST_GCS_INPUT_URI_2], ) @@ -598,9 +593,8 @@ def test_submit_batch_prediction_job_with_gcs_input_without_output_uri_prefix( def test_submit_batch_prediction_job_with_bq_input_without_output_uri_prefix( self, create_batch_prediction_job_mock ): - model = GenerativeModel(_TEST_GEMINI_MODEL_NAME) job = batch_prediction.BatchPredictionJob.submit( - source_model=model, + source_model=_TEST_GEMINI_MODEL_NAME, input_dataset=_TEST_BQ_INPUT_URI, ) @@ -970,7 +964,7 @@ def test_submit_batch_prediction_job_without_output_uri_prefix_and_bucket(self): ValueError, match=( "Please either specify output_uri_prefix or " - "set staging_bucket in vertexai.init()." + "set staging_bucket in agentplatform.init()." ), ): batch_prediction.BatchPredictionJob.submit( diff --git a/tests/unit/vertexai/genai/test_evals.py b/tests/unit/vertexai/genai/test_evals.py index c384cfd40a..1d400584f7 100644 --- a/tests/unit/vertexai/genai/test_evals.py +++ b/tests/unit/vertexai/genai/test_evals.py @@ -21,7 +21,6 @@ import statistics import sys import tempfile -import unittest from unittest import mock import google.auth.credentials @@ -8423,11 +8422,10 @@ def read_file_contents_side_effect(src: str) -> str: ) -class TestEvalsGenerateConversationScenarios(unittest.TestCase): +class TestEvalsGenerateConversationScenarios: """Unit tests for the Evals generate_conversation_scenarios method.""" - def setUp(self): - self.addCleanup(mock.patch.stopall) + def setup_method(self, method): self.mock_client = mock.MagicMock(spec=client.Client) self.mock_client.vertexai = True self.mock_api_client = mock.MagicMock() diff --git a/tests/unit/vertexai/test_vertexai_batch_prediction.py b/tests/unit/vertexai/test_vertexai_batch_prediction.py new file mode 100644 index 0000000000..ba010d9caa --- /dev/null +++ b/tests/unit/vertexai/test_vertexai_batch_prediction.py @@ -0,0 +1,1017 @@ +# -*- coding: utf-8 -*- + +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""Unit tests for generative model batch prediction.""" +# pylint: disable=protected-access + +import importlib +import pytest +from unittest import mock + +from google.cloud import aiplatform +import vertexai +from google.cloud.aiplatform import base as aiplatform_base +from google.cloud.aiplatform import initializer as aiplatform_initializer +from google.cloud.aiplatform.compat.services import ( + job_service_client, + model_service_client, +) +from google.cloud.aiplatform.compat.types import ( + batch_prediction_job as gca_batch_prediction_job_compat, + io as gca_io_compat, + job_state as gca_job_state_compat, + model as gca_model, +) +from vertexai.preview import batch_prediction +from vertexai.generative_models import GenerativeModel + + +_TEST_PROJECT = "test-project" +_TEST_LOCATION = "us-central1" +_TEST_BUCKET = "gs://test-bucket" +_TEST_PARENT = f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}" +_TEST_DISPLAY_NAME = "test-display-name" + +_TEST_GEMINI_MODEL_NAME = "gemini-1.0-pro" +_TEST_GEMINI_MODEL_RESOURCE_NAME = f"publishers/google/models/{_TEST_GEMINI_MODEL_NAME}" +_TEST_TUNED_GEMINI_MODEL_RESOURCE_NAME = "projects/123/locations/us-central1/models/456" +_TEST_PALM_MODEL_NAME = "text-bison" +_TEST_PALM_MODEL_RESOURCE_NAME = f"publishers/google/models/{_TEST_PALM_MODEL_NAME}" +_TEST_LLAMA_MODEL_NAME = "llama3-405b-instruct-maas" +_TEST_LLAMA_MODEL_RESOURCE_NAME = f"publishers/meta/models/{_TEST_LLAMA_MODEL_NAME}" +_TEST_CLAUDE_MODEL_NAME = "claude-3-opus" +_TEST_CLAUDE_MODEL_RESOURCE_NAME = ( + f"publishers/anthropic/models/{_TEST_CLAUDE_MODEL_NAME}" +) +_TEST_GPT_MODEL_NAME = "gpt-oss-120b-maas" +_TEST_GPT_MODEL_RESOURCE_NAME = f"publishers/openai/models/{_TEST_GPT_MODEL_NAME}" +_TEST_QWEN_MODEL_NAME = "qwen3-235b-a22b-instruct-2507-maas" +_TEST_QWEN_MODEL_RESOURCE_NAME = f"publishers/qwen/models/{_TEST_QWEN_MODEL_NAME}" +_TEST_DEEPSEEK_MODEL_NAME = "deepseek-r1-0528-maas" +_TEST_DEEPSEEK_MODEL_RESOURCE_NAME = ( + f"publishers/deepseek-ai/models/{_TEST_DEEPSEEK_MODEL_NAME}" +) +_TEST_E5_MODEL_NAME = "multilingual-e5-small-maas" +_TEST_E5_MODEL_RESOURCE_NAME = f"publishers/intfloat/models/{_TEST_E5_MODEL_NAME}" +_TEST_SELF_HOSTED_GEMMA_MODEL_RESOURCE_NAME = ( + "publishers/google/models/gemma@gemma-2b-it" +) + +_TEST_GCS_INPUT_URI = "gs://test-bucket/test-input.jsonl" +_TEST_GCS_INPUT_URI_2 = "gs://test-bucket/test-input-2.jsonl" +_TEST_GCS_OUTPUT_PREFIX = "gs://test-bucket/test-output" +_TEST_BQ_INPUT_URI = "bq://test-project.test-dataset.test-input" +_TEST_BQ_OUTPUT_PREFIX = "bq://test-project.test-dataset.test-output" +_TEST_INVALID_URI = "invalid-uri" + + +_TEST_BATCH_PREDICTION_JOB_ID = "123456789" +_TEST_BATCH_PREDICTION_JOB_NAME = ( + f"{_TEST_PARENT}/batchPredictionJobs/{_TEST_BATCH_PREDICTION_JOB_ID}" +) +_TEST_JOB_STATE_RUNNING = gca_job_state_compat.JobState(3) +_TEST_JOB_STATE_SUCCESS = gca_job_state_compat.JobState(4) + +_TEST_GAPIC_BATCH_PREDICTION_JOB = gca_batch_prediction_job_compat.BatchPredictionJob( + name=_TEST_BATCH_PREDICTION_JOB_NAME, + display_name=_TEST_DISPLAY_NAME, + model=_TEST_GEMINI_MODEL_RESOURCE_NAME, + state=_TEST_JOB_STATE_RUNNING, +) + + +# TODO(b/339230025) Mock the whole service instead of methods. +@pytest.fixture +def generate_display_name_mock(): + with mock.patch.object( + aiplatform_base.VertexAiResourceNoun, "_generate_display_name" + ) as generate_display_name_mock: + generate_display_name_mock.return_value = _TEST_DISPLAY_NAME + yield generate_display_name_mock + + +@pytest.fixture +def complete_bq_uri_mock(): + with mock.patch.object( + batch_prediction.BatchPredictionJob, "_complete_bq_uri" + ) as complete_bq_uri_mock: + complete_bq_uri_mock.return_value = _TEST_BQ_OUTPUT_PREFIX + yield complete_bq_uri_mock + + +@pytest.fixture +def get_batch_prediction_job_with_bq_output_mock(): + with mock.patch.object( + job_service_client.JobServiceClient, "get_batch_prediction_job" + ) as get_job_mock: + get_job_mock.return_value = gca_batch_prediction_job_compat.BatchPredictionJob( + name=_TEST_BATCH_PREDICTION_JOB_NAME, + display_name=_TEST_DISPLAY_NAME, + model=_TEST_GEMINI_MODEL_RESOURCE_NAME, + state=_TEST_JOB_STATE_SUCCESS, + output_info=gca_batch_prediction_job_compat.BatchPredictionJob.OutputInfo( + bigquery_output_table=_TEST_BQ_OUTPUT_PREFIX + ), + ) + yield get_job_mock + + +@pytest.fixture +def get_batch_prediction_job_with_gcs_output_mock(): + with mock.patch.object( + job_service_client.JobServiceClient, "get_batch_prediction_job" + ) as get_job_mock: + get_job_mock.return_value = gca_batch_prediction_job_compat.BatchPredictionJob( + name=_TEST_BATCH_PREDICTION_JOB_NAME, + display_name=_TEST_DISPLAY_NAME, + model=_TEST_GEMINI_MODEL_RESOURCE_NAME, + state=_TEST_JOB_STATE_SUCCESS, + output_info=gca_batch_prediction_job_compat.BatchPredictionJob.OutputInfo( + gcs_output_directory=_TEST_GCS_OUTPUT_PREFIX + ), + ) + yield get_job_mock + + +@pytest.fixture +def get_batch_prediction_job_with_llama_model_mock(): + with mock.patch.object( + job_service_client.JobServiceClient, "get_batch_prediction_job" + ) as get_job_mock: + get_job_mock.return_value = gca_batch_prediction_job_compat.BatchPredictionJob( + name=_TEST_BATCH_PREDICTION_JOB_NAME, + display_name=_TEST_DISPLAY_NAME, + model=_TEST_LLAMA_MODEL_RESOURCE_NAME, + state=_TEST_JOB_STATE_SUCCESS, + output_info=gca_batch_prediction_job_compat.BatchPredictionJob.OutputInfo( + gcs_output_directory=_TEST_GCS_OUTPUT_PREFIX + ), + ) + yield get_job_mock + + +@pytest.fixture +def get_batch_prediction_job_with_claude_model_mock(): + with mock.patch.object( + job_service_client.JobServiceClient, "get_batch_prediction_job" + ) as get_job_mock: + get_job_mock.return_value = gca_batch_prediction_job_compat.BatchPredictionJob( + name=_TEST_BATCH_PREDICTION_JOB_NAME, + display_name=_TEST_DISPLAY_NAME, + model=_TEST_CLAUDE_MODEL_RESOURCE_NAME, + state=_TEST_JOB_STATE_SUCCESS, + output_info=gca_batch_prediction_job_compat.BatchPredictionJob.OutputInfo( + gcs_output_directory=_TEST_GCS_OUTPUT_PREFIX + ), + ) + yield get_job_mock + + +@pytest.fixture +def get_batch_prediction_job_with_gpt_model_mock(): + with mock.patch.object( + job_service_client.JobServiceClient, "get_batch_prediction_job" + ) as get_job_mock: + get_job_mock.return_value = gca_batch_prediction_job_compat.BatchPredictionJob( + name=_TEST_BATCH_PREDICTION_JOB_NAME, + display_name=_TEST_DISPLAY_NAME, + model=_TEST_GPT_MODEL_RESOURCE_NAME, + state=_TEST_JOB_STATE_SUCCESS, + output_info=gca_batch_prediction_job_compat.BatchPredictionJob.OutputInfo( + gcs_output_directory=_TEST_GCS_OUTPUT_PREFIX + ), + ) + yield get_job_mock + + +@pytest.fixture +def get_batch_prediction_job_with_qwen_model_mock(): + with mock.patch.object( + job_service_client.JobServiceClient, "get_batch_prediction_job" + ) as get_job_mock: + get_job_mock.return_value = gca_batch_prediction_job_compat.BatchPredictionJob( + name=_TEST_BATCH_PREDICTION_JOB_NAME, + display_name=_TEST_DISPLAY_NAME, + model=_TEST_QWEN_MODEL_RESOURCE_NAME, + state=_TEST_JOB_STATE_SUCCESS, + output_info=gca_batch_prediction_job_compat.BatchPredictionJob.OutputInfo( + gcs_output_directory=_TEST_GCS_OUTPUT_PREFIX + ), + ) + yield get_job_mock + + +@pytest.fixture +def get_batch_prediction_job_with_deepseek_model_mock(): + with mock.patch.object( + job_service_client.JobServiceClient, "get_batch_prediction_job" + ) as get_job_mock: + get_job_mock.return_value = gca_batch_prediction_job_compat.BatchPredictionJob( + name=_TEST_BATCH_PREDICTION_JOB_NAME, + display_name=_TEST_DISPLAY_NAME, + model=_TEST_DEEPSEEK_MODEL_RESOURCE_NAME, + state=_TEST_JOB_STATE_SUCCESS, + output_info=gca_batch_prediction_job_compat.BatchPredictionJob.OutputInfo( + gcs_output_directory=_TEST_GCS_OUTPUT_PREFIX + ), + ) + yield get_job_mock + + +@pytest.fixture +def get_batch_prediction_job_with_e5_model_mock(): + with mock.patch.object( + job_service_client.JobServiceClient, "get_batch_prediction_job" + ) as get_job_mock: + get_job_mock.return_value = gca_batch_prediction_job_compat.BatchPredictionJob( + name=_TEST_BATCH_PREDICTION_JOB_NAME, + display_name=_TEST_DISPLAY_NAME, + model=_TEST_E5_MODEL_RESOURCE_NAME, + state=_TEST_JOB_STATE_SUCCESS, + output_info=gca_batch_prediction_job_compat.BatchPredictionJob.OutputInfo( + gcs_output_directory=_TEST_GCS_OUTPUT_PREFIX + ), + ) + yield get_job_mock + + +@pytest.fixture +def get_batch_prediction_job_with_tuned_gemini_model_mock(): + with mock.patch.object( + job_service_client.JobServiceClient, "get_batch_prediction_job" + ) as get_job_mock: + get_job_mock.return_value = gca_batch_prediction_job_compat.BatchPredictionJob( + name=_TEST_BATCH_PREDICTION_JOB_NAME, + display_name=_TEST_DISPLAY_NAME, + model=_TEST_TUNED_GEMINI_MODEL_RESOURCE_NAME, + state=_TEST_JOB_STATE_SUCCESS, + output_info=gca_batch_prediction_job_compat.BatchPredictionJob.OutputInfo( + gcs_output_directory=_TEST_GCS_OUTPUT_PREFIX + ), + ) + yield get_job_mock + + +@pytest.fixture +def get_gemini_model_mock(): + with mock.patch.object( + model_service_client.ModelServiceClient, "get_model" + ) as get_model_mock: + get_model_mock.return_value = gca_model.Model( + name=_TEST_TUNED_GEMINI_MODEL_RESOURCE_NAME, + model_source_info=gca_model.ModelSourceInfo( + source_type=gca_model.ModelSourceInfo.ModelSourceType.GENIE + ), + ) + yield get_model_mock + + +@pytest.fixture +def get_non_gemini_model_mock(): + with mock.patch.object( + model_service_client.ModelServiceClient, "get_model" + ) as get_model_mock: + get_model_mock.return_value = gca_model.Model( + name=_TEST_TUNED_GEMINI_MODEL_RESOURCE_NAME, + ) + yield get_model_mock + + +@pytest.fixture +def get_batch_prediction_job_invalid_model_mock(): + with mock.patch.object( + job_service_client.JobServiceClient, "get_batch_prediction_job" + ) as get_job_mock: + get_job_mock.return_value = gca_batch_prediction_job_compat.BatchPredictionJob( + name=_TEST_BATCH_PREDICTION_JOB_NAME, + display_name=_TEST_DISPLAY_NAME, + model=_TEST_PALM_MODEL_RESOURCE_NAME, + state=_TEST_JOB_STATE_SUCCESS, + ) + yield get_job_mock + + +@pytest.fixture +def create_batch_prediction_job_mock(): + with mock.patch.object( + job_service_client.JobServiceClient, "create_batch_prediction_job" + ) as create_job_mock: + create_job_mock.return_value = _TEST_GAPIC_BATCH_PREDICTION_JOB + yield create_job_mock + + +@pytest.fixture +def cancel_batch_prediction_job_mock(): + with mock.patch.object( + job_service_client.JobServiceClient, "cancel_batch_prediction_job" + ) as cancel_job_mock: + yield cancel_job_mock + + +@pytest.fixture +def delete_batch_prediction_job_mock(): + with mock.patch.object( + job_service_client.JobServiceClient, "delete_batch_prediction_job" + ) as delete_job_mock: + yield delete_job_mock + + +@pytest.fixture +def list_batch_prediction_jobs_mock(): + with mock.patch.object( + job_service_client.JobServiceClient, "list_batch_prediction_jobs" + ) as list_jobs_mock: + list_jobs_mock.return_value = [ + _TEST_GAPIC_BATCH_PREDICTION_JOB, + gca_batch_prediction_job_compat.BatchPredictionJob( + name=_TEST_BATCH_PREDICTION_JOB_NAME, + display_name=_TEST_DISPLAY_NAME, + model=_TEST_PALM_MODEL_RESOURCE_NAME, + state=_TEST_JOB_STATE_SUCCESS, + ), + ] + yield list_jobs_mock + + +@pytest.mark.usefixtures( + "google_auth_mock", "generate_display_name_mock", "complete_bq_uri_mock" +) +class TestVertexAIBatchPredictionJob: + """Unit tests for BatchPredictionJob.""" + + def setup_method(self): + importlib.reload(aiplatform_initializer) + importlib.reload(aiplatform) + importlib.reload(vertexai) + vertexai.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + ) + + def teardown_method(self): + aiplatform_initializer.global_pool.shutdown(wait=True) + + def test_init_batch_prediction_job( + self, get_batch_prediction_job_with_gcs_output_mock + ): + batch_prediction.BatchPredictionJob(_TEST_BATCH_PREDICTION_JOB_ID) + + get_batch_prediction_job_with_gcs_output_mock.assert_called_once_with( + name=_TEST_BATCH_PREDICTION_JOB_NAME, retry=aiplatform_base._DEFAULT_RETRY + ) + + def test_init_batch_prediction_job_with_llama_model( + self, + get_batch_prediction_job_with_llama_model_mock, + ): + batch_prediction.BatchPredictionJob(_TEST_BATCH_PREDICTION_JOB_ID) + + get_batch_prediction_job_with_llama_model_mock.assert_called_once_with( + name=_TEST_BATCH_PREDICTION_JOB_NAME, retry=aiplatform_base._DEFAULT_RETRY + ) + + def test_init_batch_prediction_job_with_claude_model( + self, + get_batch_prediction_job_with_claude_model_mock, + ): + batch_prediction.BatchPredictionJob(_TEST_BATCH_PREDICTION_JOB_ID) + + get_batch_prediction_job_with_claude_model_mock.assert_called_once_with( + name=_TEST_BATCH_PREDICTION_JOB_NAME, retry=aiplatform_base._DEFAULT_RETRY + ) + + def test_init_batch_prediction_job_with_gpt_model( + self, + get_batch_prediction_job_with_gpt_model_mock, + ): + batch_prediction.BatchPredictionJob(_TEST_BATCH_PREDICTION_JOB_ID) + + get_batch_prediction_job_with_gpt_model_mock.assert_called_once_with( + name=_TEST_BATCH_PREDICTION_JOB_NAME, retry=aiplatform_base._DEFAULT_RETRY + ) + + def test_init_batch_prediction_job_with_qwen_model( + self, + get_batch_prediction_job_with_qwen_model_mock, + ): + batch_prediction.BatchPredictionJob(_TEST_BATCH_PREDICTION_JOB_ID) + + get_batch_prediction_job_with_qwen_model_mock.assert_called_once_with( + name=_TEST_BATCH_PREDICTION_JOB_NAME, retry=aiplatform_base._DEFAULT_RETRY + ) + + def test_init_batch_prediction_job_with_deepseek_model( + self, + get_batch_prediction_job_with_deepseek_model_mock, + ): + batch_prediction.BatchPredictionJob(_TEST_BATCH_PREDICTION_JOB_ID) + + get_batch_prediction_job_with_deepseek_model_mock.assert_called_once_with( + name=_TEST_BATCH_PREDICTION_JOB_NAME, retry=aiplatform_base._DEFAULT_RETRY + ) + + def test_init_batch_prediction_job_with_e5_model( + self, + get_batch_prediction_job_with_e5_model_mock, + ): + batch_prediction.BatchPredictionJob(_TEST_BATCH_PREDICTION_JOB_ID) + + get_batch_prediction_job_with_e5_model_mock.assert_called_once_with( + name=_TEST_BATCH_PREDICTION_JOB_NAME, retry=aiplatform_base._DEFAULT_RETRY + ) + + def test_init_batch_prediction_job_with_tuned_gemini_model( + self, + get_batch_prediction_job_with_tuned_gemini_model_mock, + get_gemini_model_mock, + ): + batch_prediction.BatchPredictionJob(_TEST_BATCH_PREDICTION_JOB_ID) + + get_batch_prediction_job_with_tuned_gemini_model_mock.assert_called_once_with( + name=_TEST_BATCH_PREDICTION_JOB_NAME, retry=aiplatform_base._DEFAULT_RETRY + ) + get_gemini_model_mock.assert_called_once_with( + name=_TEST_TUNED_GEMINI_MODEL_RESOURCE_NAME, + retry=aiplatform_base._DEFAULT_RETRY, + ) + + @pytest.mark.usefixtures("get_batch_prediction_job_invalid_model_mock") + def test_init_batch_prediction_job_invalid_model(self): + with pytest.raises( + ValueError, + match=( + f"BatchPredictionJob '{_TEST_BATCH_PREDICTION_JOB_ID}' " + f"runs with the model '{_TEST_PALM_MODEL_RESOURCE_NAME}', " + "which is not a GenAI model." + ), + ): + batch_prediction.BatchPredictionJob(_TEST_BATCH_PREDICTION_JOB_ID) + + @pytest.mark.usefixtures( + "get_batch_prediction_job_with_tuned_gemini_model_mock", + "get_non_gemini_model_mock", + ) + def test_init_batch_prediction_job_with_invalid_tuned_model( + self, + ): + with pytest.raises( + ValueError, + match=( + f"BatchPredictionJob '{_TEST_BATCH_PREDICTION_JOB_ID}' " + f"runs with the model '{_TEST_TUNED_GEMINI_MODEL_RESOURCE_NAME}', " + "which is not a GenAI model." + ), + ): + batch_prediction.BatchPredictionJob(_TEST_BATCH_PREDICTION_JOB_ID) + + @pytest.mark.usefixtures("get_batch_prediction_job_with_gcs_output_mock") + def test_submit_batch_prediction_job_with_gcs_input( + self, create_batch_prediction_job_mock + ): + job = batch_prediction.BatchPredictionJob.submit( + source_model=_TEST_GEMINI_MODEL_NAME, + input_dataset=_TEST_GCS_INPUT_URI, + output_uri_prefix=_TEST_GCS_OUTPUT_PREFIX, + ) + + assert job.gca_resource == _TEST_GAPIC_BATCH_PREDICTION_JOB + assert job.state == _TEST_JOB_STATE_RUNNING + assert not job.has_ended + assert not job.has_succeeded + + job.refresh() + assert job.state == _TEST_JOB_STATE_SUCCESS + assert job.has_ended + assert job.has_succeeded + assert job.output_location == _TEST_GCS_OUTPUT_PREFIX + + expected_gapic_batch_prediction_job = gca_batch_prediction_job_compat.BatchPredictionJob( + display_name=_TEST_DISPLAY_NAME, + model=_TEST_GEMINI_MODEL_RESOURCE_NAME, + input_config=gca_batch_prediction_job_compat.BatchPredictionJob.InputConfig( + instances_format="jsonl", + gcs_source=gca_io_compat.GcsSource(uris=[_TEST_GCS_INPUT_URI]), + ), + output_config=gca_batch_prediction_job_compat.BatchPredictionJob.OutputConfig( + gcs_destination=gca_io_compat.GcsDestination( + output_uri_prefix=_TEST_GCS_OUTPUT_PREFIX + ), + predictions_format="jsonl", + ), + ) + create_batch_prediction_job_mock.assert_called_once_with( + parent=_TEST_PARENT, + batch_prediction_job=expected_gapic_batch_prediction_job, + timeout=None, + ) + + @pytest.mark.usefixtures("get_batch_prediction_job_with_bq_output_mock") + def test_submit_batch_prediction_job_with_bq_input( + self, create_batch_prediction_job_mock + ): + job = batch_prediction.BatchPredictionJob.submit( + source_model=_TEST_GEMINI_MODEL_NAME, + input_dataset=_TEST_BQ_INPUT_URI, + output_uri_prefix=_TEST_BQ_OUTPUT_PREFIX, + ) + + assert job.gca_resource == _TEST_GAPIC_BATCH_PREDICTION_JOB + assert job.state == _TEST_JOB_STATE_RUNNING + assert not job.has_ended + assert not job.has_succeeded + + job.refresh() + assert job.state == _TEST_JOB_STATE_SUCCESS + assert job.has_ended + assert job.has_succeeded + assert job.output_location == _TEST_BQ_OUTPUT_PREFIX + + expected_gapic_batch_prediction_job = gca_batch_prediction_job_compat.BatchPredictionJob( + display_name=_TEST_DISPLAY_NAME, + model=_TEST_GEMINI_MODEL_RESOURCE_NAME, + input_config=gca_batch_prediction_job_compat.BatchPredictionJob.InputConfig( + instances_format="bigquery", + bigquery_source=gca_io_compat.BigQuerySource( + input_uri=_TEST_BQ_INPUT_URI + ), + ), + output_config=gca_batch_prediction_job_compat.BatchPredictionJob.OutputConfig( + bigquery_destination=gca_io_compat.BigQueryDestination( + output_uri=_TEST_BQ_OUTPUT_PREFIX + ), + predictions_format="bigquery", + ), + ) + create_batch_prediction_job_mock.assert_called_once_with( + parent=_TEST_PARENT, + batch_prediction_job=expected_gapic_batch_prediction_job, + timeout=None, + ) + + def test_submit_batch_prediction_job_with_gcs_input_without_output_uri_prefix( + self, create_batch_prediction_job_mock + ): + vertexai.init(staging_bucket=_TEST_BUCKET) + model = GenerativeModel(_TEST_GEMINI_MODEL_NAME) + job = batch_prediction.BatchPredictionJob.submit( + source_model=model, + input_dataset=[_TEST_GCS_INPUT_URI, _TEST_GCS_INPUT_URI_2], + ) + + assert job.gca_resource == _TEST_GAPIC_BATCH_PREDICTION_JOB + + expected_gapic_batch_prediction_job = gca_batch_prediction_job_compat.BatchPredictionJob( + display_name=_TEST_DISPLAY_NAME, + model=_TEST_GEMINI_MODEL_RESOURCE_NAME, + input_config=gca_batch_prediction_job_compat.BatchPredictionJob.InputConfig( + instances_format="jsonl", + gcs_source=gca_io_compat.GcsSource( + uris=[_TEST_GCS_INPUT_URI, _TEST_GCS_INPUT_URI_2] + ), + ), + output_config=gca_batch_prediction_job_compat.BatchPredictionJob.OutputConfig( + gcs_destination=gca_io_compat.GcsDestination( + output_uri_prefix=f"{_TEST_BUCKET}/gen-ai-batch-prediction" + ), + predictions_format="jsonl", + ), + ) + create_batch_prediction_job_mock.assert_called_once_with( + parent=_TEST_PARENT, + batch_prediction_job=expected_gapic_batch_prediction_job, + timeout=None, + ) + + def test_submit_batch_prediction_job_with_bq_input_without_output_uri_prefix( + self, create_batch_prediction_job_mock + ): + model = GenerativeModel(_TEST_GEMINI_MODEL_NAME) + job = batch_prediction.BatchPredictionJob.submit( + source_model=model, + input_dataset=_TEST_BQ_INPUT_URI, + ) + + assert job.gca_resource == _TEST_GAPIC_BATCH_PREDICTION_JOB + + expected_gapic_batch_prediction_job = gca_batch_prediction_job_compat.BatchPredictionJob( + display_name=_TEST_DISPLAY_NAME, + model=_TEST_GEMINI_MODEL_RESOURCE_NAME, + input_config=gca_batch_prediction_job_compat.BatchPredictionJob.InputConfig( + instances_format="bigquery", + bigquery_source=gca_io_compat.BigQuerySource( + input_uri=_TEST_BQ_INPUT_URI + ), + ), + output_config=gca_batch_prediction_job_compat.BatchPredictionJob.OutputConfig( + bigquery_destination=gca_io_compat.BigQueryDestination( + output_uri=_TEST_BQ_OUTPUT_PREFIX + ), + predictions_format="bigquery", + ), + ) + create_batch_prediction_job_mock.assert_called_once_with( + parent=_TEST_PARENT, + batch_prediction_job=expected_gapic_batch_prediction_job, + timeout=None, + ) + + def test_submit_batch_prediction_job_with_llama_model( + self, + create_batch_prediction_job_mock, + ): + job = batch_prediction.BatchPredictionJob.submit( + source_model=_TEST_LLAMA_MODEL_RESOURCE_NAME, + input_dataset=_TEST_BQ_INPUT_URI, + ) + + assert job.gca_resource == _TEST_GAPIC_BATCH_PREDICTION_JOB + + expected_gapic_batch_prediction_job = gca_batch_prediction_job_compat.BatchPredictionJob( + display_name=_TEST_DISPLAY_NAME, + model=_TEST_LLAMA_MODEL_RESOURCE_NAME, + input_config=gca_batch_prediction_job_compat.BatchPredictionJob.InputConfig( + instances_format="bigquery", + bigquery_source=gca_io_compat.BigQuerySource( + input_uri=_TEST_BQ_INPUT_URI + ), + ), + output_config=gca_batch_prediction_job_compat.BatchPredictionJob.OutputConfig( + bigquery_destination=gca_io_compat.BigQueryDestination( + output_uri=_TEST_BQ_OUTPUT_PREFIX + ), + predictions_format="bigquery", + ), + ) + create_batch_prediction_job_mock.assert_called_once_with( + parent=_TEST_PARENT, + batch_prediction_job=expected_gapic_batch_prediction_job, + timeout=None, + ) + + def test_submit_batch_prediction_job_with_claude_model( + self, + create_batch_prediction_job_mock, + ): + job = batch_prediction.BatchPredictionJob.submit( + source_model=_TEST_CLAUDE_MODEL_RESOURCE_NAME, + input_dataset=_TEST_BQ_INPUT_URI, + ) + + assert job.gca_resource == _TEST_GAPIC_BATCH_PREDICTION_JOB + + expected_gapic_batch_prediction_job = gca_batch_prediction_job_compat.BatchPredictionJob( + display_name=_TEST_DISPLAY_NAME, + model=_TEST_CLAUDE_MODEL_RESOURCE_NAME, + input_config=gca_batch_prediction_job_compat.BatchPredictionJob.InputConfig( + instances_format="bigquery", + bigquery_source=gca_io_compat.BigQuerySource( + input_uri=_TEST_BQ_INPUT_URI + ), + ), + output_config=gca_batch_prediction_job_compat.BatchPredictionJob.OutputConfig( + bigquery_destination=gca_io_compat.BigQueryDestination( + output_uri=_TEST_BQ_OUTPUT_PREFIX + ), + predictions_format="bigquery", + ), + ) + create_batch_prediction_job_mock.assert_called_once_with( + parent=_TEST_PARENT, + batch_prediction_job=expected_gapic_batch_prediction_job, + timeout=None, + ) + + def test_submit_batch_prediction_job_with_gpt_model( + self, + create_batch_prediction_job_mock, + ): + job = batch_prediction.BatchPredictionJob.submit( + source_model=_TEST_GPT_MODEL_RESOURCE_NAME, + input_dataset=_TEST_BQ_INPUT_URI, + ) + + assert job.gca_resource == _TEST_GAPIC_BATCH_PREDICTION_JOB + + expected_gapic_batch_prediction_job = gca_batch_prediction_job_compat.BatchPredictionJob( + display_name=_TEST_DISPLAY_NAME, + model=_TEST_GPT_MODEL_RESOURCE_NAME, + input_config=gca_batch_prediction_job_compat.BatchPredictionJob.InputConfig( + instances_format="bigquery", + bigquery_source=gca_io_compat.BigQuerySource( + input_uri=_TEST_BQ_INPUT_URI + ), + ), + output_config=gca_batch_prediction_job_compat.BatchPredictionJob.OutputConfig( + bigquery_destination=gca_io_compat.BigQueryDestination( + output_uri=_TEST_BQ_OUTPUT_PREFIX + ), + predictions_format="bigquery", + ), + ) + create_batch_prediction_job_mock.assert_called_once_with( + parent=_TEST_PARENT, + batch_prediction_job=expected_gapic_batch_prediction_job, + timeout=None, + ) + + def test_submit_batch_prediction_job_with_qwen_model( + self, + create_batch_prediction_job_mock, + ): + job = batch_prediction.BatchPredictionJob.submit( + source_model=_TEST_QWEN_MODEL_RESOURCE_NAME, + input_dataset=_TEST_BQ_INPUT_URI, + ) + + assert job.gca_resource == _TEST_GAPIC_BATCH_PREDICTION_JOB + + expected_gapic_batch_prediction_job = gca_batch_prediction_job_compat.BatchPredictionJob( + display_name=_TEST_DISPLAY_NAME, + model=_TEST_QWEN_MODEL_RESOURCE_NAME, + input_config=gca_batch_prediction_job_compat.BatchPredictionJob.InputConfig( + instances_format="bigquery", + bigquery_source=gca_io_compat.BigQuerySource( + input_uri=_TEST_BQ_INPUT_URI + ), + ), + output_config=gca_batch_prediction_job_compat.BatchPredictionJob.OutputConfig( + bigquery_destination=gca_io_compat.BigQueryDestination( + output_uri=_TEST_BQ_OUTPUT_PREFIX + ), + predictions_format="bigquery", + ), + ) + create_batch_prediction_job_mock.assert_called_once_with( + parent=_TEST_PARENT, + batch_prediction_job=expected_gapic_batch_prediction_job, + timeout=None, + ) + + def test_submit_batch_prediction_job_with_deepseek_model( + self, + create_batch_prediction_job_mock, + ): + job = batch_prediction.BatchPredictionJob.submit( + source_model=_TEST_DEEPSEEK_MODEL_RESOURCE_NAME, + input_dataset=_TEST_BQ_INPUT_URI, + ) + + assert job.gca_resource == _TEST_GAPIC_BATCH_PREDICTION_JOB + + expected_gapic_batch_prediction_job = gca_batch_prediction_job_compat.BatchPredictionJob( + display_name=_TEST_DISPLAY_NAME, + model=_TEST_DEEPSEEK_MODEL_RESOURCE_NAME, + input_config=gca_batch_prediction_job_compat.BatchPredictionJob.InputConfig( + instances_format="bigquery", + bigquery_source=gca_io_compat.BigQuerySource( + input_uri=_TEST_BQ_INPUT_URI + ), + ), + output_config=gca_batch_prediction_job_compat.BatchPredictionJob.OutputConfig( + bigquery_destination=gca_io_compat.BigQueryDestination( + output_uri=_TEST_BQ_OUTPUT_PREFIX + ), + predictions_format="bigquery", + ), + ) + create_batch_prediction_job_mock.assert_called_once_with( + parent=_TEST_PARENT, + batch_prediction_job=expected_gapic_batch_prediction_job, + timeout=None, + ) + + def test_submit_batch_prediction_job_with_e5_model( + self, + create_batch_prediction_job_mock, + ): + job = batch_prediction.BatchPredictionJob.submit( + source_model=_TEST_E5_MODEL_RESOURCE_NAME, + input_dataset=_TEST_BQ_INPUT_URI, + ) + + assert job.gca_resource == _TEST_GAPIC_BATCH_PREDICTION_JOB + + expected_gapic_batch_prediction_job = gca_batch_prediction_job_compat.BatchPredictionJob( + display_name=_TEST_DISPLAY_NAME, + model=_TEST_E5_MODEL_RESOURCE_NAME, + input_config=gca_batch_prediction_job_compat.BatchPredictionJob.InputConfig( + instances_format="bigquery", + bigquery_source=gca_io_compat.BigQuerySource( + input_uri=_TEST_BQ_INPUT_URI + ), + ), + output_config=gca_batch_prediction_job_compat.BatchPredictionJob.OutputConfig( + bigquery_destination=gca_io_compat.BigQueryDestination( + output_uri=_TEST_BQ_OUTPUT_PREFIX + ), + predictions_format="bigquery", + ), + ) + create_batch_prediction_job_mock.assert_called_once_with( + parent=_TEST_PARENT, + batch_prediction_job=expected_gapic_batch_prediction_job, + timeout=None, + ) + + @pytest.mark.usefixtures("create_batch_prediction_job_mock") + def test_submit_batch_prediction_job_with_tuned_model( + self, + get_gemini_model_mock, + ): + job = batch_prediction.BatchPredictionJob.submit( + source_model=_TEST_TUNED_GEMINI_MODEL_RESOURCE_NAME, + input_dataset=_TEST_BQ_INPUT_URI, + ) + + assert job.gca_resource == _TEST_GAPIC_BATCH_PREDICTION_JOB + get_gemini_model_mock.assert_called_once_with( + name=_TEST_TUNED_GEMINI_MODEL_RESOURCE_NAME, + retry=aiplatform_base._DEFAULT_RETRY, + ) + + def test_submit_batch_prediction_job_with_self_hosted_gemma_model( + self, + create_batch_prediction_job_mock, + ): + job = batch_prediction.BatchPredictionJob.submit( + source_model=_TEST_SELF_HOSTED_GEMMA_MODEL_RESOURCE_NAME, + input_dataset=_TEST_BQ_INPUT_URI, + ) + + assert job.gca_resource == _TEST_GAPIC_BATCH_PREDICTION_JOB + + expected_gapic_batch_prediction_job = gca_batch_prediction_job_compat.BatchPredictionJob( + display_name=_TEST_DISPLAY_NAME, + model=_TEST_SELF_HOSTED_GEMMA_MODEL_RESOURCE_NAME, + input_config=gca_batch_prediction_job_compat.BatchPredictionJob.InputConfig( + instances_format="bigquery", + bigquery_source=gca_io_compat.BigQuerySource( + input_uri=_TEST_BQ_INPUT_URI + ), + ), + output_config=gca_batch_prediction_job_compat.BatchPredictionJob.OutputConfig( + bigquery_destination=gca_io_compat.BigQueryDestination( + output_uri=_TEST_BQ_OUTPUT_PREFIX + ), + predictions_format="bigquery", + ), + ) + create_batch_prediction_job_mock.assert_called_once_with( + parent=_TEST_PARENT, + batch_prediction_job=expected_gapic_batch_prediction_job, + timeout=None, + ) + + def test_submit_batch_prediction_job_with_invalid_source_model(self): + with pytest.raises( + ValueError, + match=( + "Abbreviated model names are only supported for Gemini models. " + "Please provide the full publisher model name." + ), + ): + batch_prediction.BatchPredictionJob.submit( + source_model=_TEST_PALM_MODEL_NAME, + input_dataset=_TEST_GCS_INPUT_URI, + ) + + def test_submit_batch_prediction_job_with_invalid_abbreviated_model_name(self): + with pytest.raises( + ValueError, + match=( + "Abbreviated model names are only supported for Gemini models. " + "Please provide the full publisher model name." + ), + ): + batch_prediction.BatchPredictionJob.submit( + source_model=_TEST_LLAMA_MODEL_NAME, + input_dataset=_TEST_GCS_INPUT_URI, + ) + + @pytest.mark.usefixtures("get_non_gemini_model_mock") + def test_submit_batch_prediction_job_with_non_gemini_tuned_model(self): + with pytest.raises( + ValueError, + match=( + f"Model '{_TEST_TUNED_GEMINI_MODEL_RESOURCE_NAME}' " + "is not a Generative AI model." + ), + ): + batch_prediction.BatchPredictionJob.submit( + source_model=_TEST_TUNED_GEMINI_MODEL_RESOURCE_NAME, + input_dataset=_TEST_GCS_INPUT_URI, + ) + + def test_submit_batch_prediction_job_with_invalid_model_name(self): + invalid_model_name = "invalid/model/name" + with pytest.raises( + ValueError, + match=(f"Invalid format for model name: {invalid_model_name}."), + ): + batch_prediction.BatchPredictionJob.submit( + source_model=invalid_model_name, + input_dataset=_TEST_GCS_INPUT_URI, + ) + + def test_submit_batch_prediction_job_with_invalid_input_dataset(self): + with pytest.raises( + ValueError, + match=( + f"Unsupported input URI: {_TEST_INVALID_URI}. " + "Supported formats: 'gs://path/to/input/data.jsonl' and " + "'bq://projectId.bqDatasetId.bqTableId'" + ), + ): + batch_prediction.BatchPredictionJob.submit( + source_model=_TEST_GEMINI_MODEL_NAME, + input_dataset=_TEST_INVALID_URI, + ) + + invalid_bq_uris = ["bq://projectId.dataset1", "bq://projectId.dataset2"] + with pytest.raises( + ValueError, + match=("Multiple BigQuery input datasets are not supported."), + ): + batch_prediction.BatchPredictionJob.submit( + source_model=_TEST_GEMINI_MODEL_NAME, + input_dataset=invalid_bq_uris, + ) + + def test_submit_batch_prediction_job_with_invalid_output_uri_prefix(self): + with pytest.raises( + ValueError, + match=( + f"Unsupported output URI: {_TEST_INVALID_URI}. " + "Supported formats: 'gs://path/to/output/data' and " + "'bq://projectId.bqDatasetId'" + ), + ): + batch_prediction.BatchPredictionJob.submit( + source_model=_TEST_GEMINI_MODEL_NAME, + input_dataset=_TEST_GCS_INPUT_URI, + output_uri_prefix=_TEST_INVALID_URI, + ) + + def test_submit_batch_prediction_job_without_output_uri_prefix_and_bucket(self): + with pytest.raises( + ValueError, + match=( + "Please either specify output_uri_prefix or " + "set staging_bucket in vertexai.init()." + ), + ): + batch_prediction.BatchPredictionJob.submit( + source_model=_TEST_GEMINI_MODEL_NAME, + input_dataset=_TEST_GCS_INPUT_URI, + ) + + @pytest.mark.usefixtures("create_batch_prediction_job_mock") + def test_cancel_batch_prediction_job(self, cancel_batch_prediction_job_mock): + job = batch_prediction.BatchPredictionJob.submit( + source_model=_TEST_GEMINI_MODEL_NAME, + input_dataset=_TEST_GCS_INPUT_URI, + output_uri_prefix=_TEST_GCS_OUTPUT_PREFIX, + ) + job.cancel() + + cancel_batch_prediction_job_mock.assert_called_once_with( + name=_TEST_BATCH_PREDICTION_JOB_NAME, + ) + + @pytest.mark.usefixtures("get_batch_prediction_job_with_gcs_output_mock") + def test_delete_batch_prediction_job(self, delete_batch_prediction_job_mock): + job = batch_prediction.BatchPredictionJob(_TEST_BATCH_PREDICTION_JOB_ID) + job.delete() + + delete_batch_prediction_job_mock.assert_called_once_with( + name=_TEST_BATCH_PREDICTION_JOB_NAME, + ) + + def tes_list_batch_prediction_jobs(self, list_batch_prediction_jobs_mock): + jobs = batch_prediction.BatchPredictionJob.list() + + assert len(jobs) == 1 + assert jobs[0].gca_resource == _TEST_GAPIC_BATCH_PREDICTION_JOB + + list_batch_prediction_jobs_mock.assert_called_once_with( + request={"parent": _TEST_PARENT} + ) + + def test_num_pending_jobs(self, list_batch_prediction_jobs_mock): + num_pending_jobs = batch_prediction.BatchPredictionJob.num_pending_jobs() + + assert num_pending_jobs == 1 + list_batch_prediction_jobs_mock.assert_called_once()