From 258c7a0857dbe7e774f3448e0e2333421d3c6bda Mon Sep 17 00:00:00 2001 From: Ayush Agrawal Date: Tue, 28 Oct 2025 17:23:36 -0700 Subject: [PATCH] chore: update tests to be compatible with Python 3.14, set 3.14 testing workflow as required PiperOrigin-RevId: 825273830 --- .../utils/resource_manager_utils.py | 14 +- setup.py | 3 +- testing/constraints-3.14.txt | 6 +- tests/unit/aiplatform/test_initializer.py | 7 +- tests/unit/aiplatform/test_prediction.py | 5 +- tests/unit/aiplatform/test_training_jobs.py | 166 ++++++++++++------ tests/unit/aiplatform/test_training_utils.py | 14 +- tests/unit/vertex_rag/test_rag_data.py | 11 +- .../unit/vertex_rag/test_rag_data_preview.py | 15 +- vertexai/preview/rag/utils/_gapic_utils.py | 72 ++++++-- vertexai/rag/utils/_gapic_utils.py | 41 +++-- 11 files changed, 248 insertions(+), 106 deletions(-) diff --git a/google/cloud/aiplatform/utils/resource_manager_utils.py b/google/cloud/aiplatform/utils/resource_manager_utils.py index e6cbc0988b..332af1e755 100644 --- a/google/cloud/aiplatform/utils/resource_manager_utils.py +++ b/google/cloud/aiplatform/utils/resource_manager_utils.py @@ -41,7 +41,12 @@ def get_project_id( """ - credentials = credentials or initializer.global_config.credentials + if credentials is None: + credentials = initializer.global_config._credentials + if credentials is None: + import google.auth + from google.cloud.aiplatform.constants import base as constants + credentials, _ = google.auth.default(scopes=constants.DEFAULT_AUTHED_SCOPES) projects_client = resourcemanager.ProjectsClient(credentials=credentials) @@ -67,7 +72,12 @@ def get_project_number( """ - credentials = credentials or initializer.global_config.credentials + if credentials is None: + credentials = initializer.global_config._credentials + if credentials is None: + import google.auth + from google.cloud.aiplatform.constants import base as constants + credentials, _ = google.auth.default(scopes=constants.DEFAULT_AUTHED_SCOPES) projects_client = resourcemanager.ProjectsClient(credentials=credentials) diff --git a/setup.py b/setup.py index 6860a4c993..b5ef39736a 100644 --- a/setup.py +++ b/setup.py @@ -279,7 +279,8 @@ # Lazy import requires > 2.12.0 "tensorflow == 2.14.1; python_version<='3.11'", "tensorflow == 2.19.0; python_version>'3.11' and python_version<'3.13'", - "protobuf <= 5.29.4", + "protobuf >= 5.29.4; python_version>='3.14'", + "protobuf <= 5.29.4; python_version<'3.14'", # TODO(jayceeli) torch 2.1.0 has conflict with pyfakefs, will check if # future versions fix this issue "torch >= 2.0.0, < 2.1.0; python_version<='3.11'", diff --git a/testing/constraints-3.14.txt b/testing/constraints-3.14.txt index 1ca6c3a54d..d3d2ea38b7 100644 --- a/testing/constraints-3.14.txt +++ b/testing/constraints-3.14.txt @@ -1,11 +1,11 @@ # -*- coding: utf-8 -*- # This constraints file is required for unit tests. # List all library dependencies and extras in this file. -google-api-core==2.21.0 # Tests google-api-core with rest async support +google-api-core==2.27.0 google-auth==2.47.0 # Tests google-auth with rest async support proto-plus mock==4.0.2 -google-cloud-storage==2.10.0 # Increased for kfp 2.0 compatibility +google-cloud-storage==3.10.0 # Updated for Python 3.14 compatibility packaging==24.1 # Increased to unbreak canonicalize_version error (b/377774673) pytest-xdist==3.3.1 # Pinned to unbreak unit tests ray==2.5.0 # Pinned until 2.9.3 is verified for Ray tests @@ -13,4 +13,4 @@ ipython==8.22.2 # Pinned to unbreak TypeAliasType import error google-adk==0.0.2 google-genai>=1.10.0 google-vizier==0.1.21 -pyarrow>=18.0.0 \ No newline at end of file +pyarrow>=22.0.0 \ No newline at end of file diff --git a/tests/unit/aiplatform/test_initializer.py b/tests/unit/aiplatform/test_initializer.py index 9fdb3eed6a..c11ae16799 100644 --- a/tests/unit/aiplatform/test_initializer.py +++ b/tests/unit/aiplatform/test_initializer.py @@ -65,10 +65,15 @@ def test_init_project_sets_project(self): assert initializer.global_config.project == _TEST_PROJECT def test_not_init_project_gets_default_project(self, monkeypatch): - def mock_auth_default(scopes=None): + def mock_auth_default(scopes=None, **kwargs): return None, _TEST_PROJECT monkeypatch.setattr(google.auth, "default", mock_auth_default) + monkeypatch.setattr( + resource_manager_utils, + "get_project_id", + lambda **kwargs: _TEST_PROJECT, + ) assert initializer.global_config.project == _TEST_PROJECT def test_infer_project_id(self): diff --git a/tests/unit/aiplatform/test_prediction.py b/tests/unit/aiplatform/test_prediction.py index a1b49f9862..2b380f95fb 100644 --- a/tests/unit/aiplatform/test_prediction.py +++ b/tests/unit/aiplatform/test_prediction.py @@ -3287,7 +3287,10 @@ def test_health(self, model_server_env_mock, importlib_import_module_mock_twice) assert response.status_code == 200 - def test_predict(self, model_server_env_mock, importlib_import_module_mock_twice): + @pytest.mark.asyncio + async def test_predict( + self, model_server_env_mock, importlib_import_module_mock_twice + ): model_server = CprModelServer() client = TestClient(model_server.app) diff --git a/tests/unit/aiplatform/test_training_jobs.py b/tests/unit/aiplatform/test_training_jobs.py index 4ba40aac65..93a993f68a 100644 --- a/tests/unit/aiplatform/test_training_jobs.py +++ b/tests/unit/aiplatform/test_training_jobs.py @@ -15,7 +15,6 @@ # limitations under the License. # -from distutils import core import copy import os import functools @@ -635,18 +634,18 @@ def test_get_python_executable_returns_python_executable(self): ) @pytest.mark.usefixtures("google_auth_mock") class TestTrainingScriptPythonPackager: - def setup_method(self): - importlib.reload(initializer) - importlib.reload(aiplatform) - with open(_TEST_LOCAL_SCRIPT_FILE_PATH, "w") as fp: - fp.write(_TEST_PYTHON_SOURCE) - - def teardown_method(self): - pathlib.Path(_TEST_LOCAL_SCRIPT_FILE_PATH).unlink() - python_package_file = f"{source_utils._TrainingScriptPythonPackager._ROOT_MODULE}-{source_utils._TrainingScriptPythonPackager._SETUP_PY_VERSION}.tar.gz" - if pathlib.Path(python_package_file).is_file(): - pathlib.Path(python_package_file).unlink() - subprocess.check_output( + def setup_method(self): + importlib.reload(initializer) + importlib.reload(aiplatform) + with open(_TEST_LOCAL_SCRIPT_FILE_PATH, "w") as fp: + fp.write(_TEST_PYTHON_SOURCE) + + def teardown_method(self): + pathlib.Path(_TEST_LOCAL_SCRIPT_FILE_PATH).unlink() + python_package_file = f"{source_utils._TrainingScriptPythonPackager._ROOT_MODULE}-{source_utils._TrainingScriptPythonPackager._SETUP_PY_VERSION}.tar.gz" + if pathlib.Path(python_package_file).is_file(): + pathlib.Path(python_package_file).unlink() + subprocess.check_output( [ "pip3", "uninstall", @@ -655,57 +654,118 @@ def teardown_method(self): ] ) - def test_packager_creates_and_copies_python_package(self): - tsp = source_utils._TrainingScriptPythonPackager(_TEST_LOCAL_SCRIPT_FILE_PATH) - tsp.package_and_copy(copy_method=local_copy_method) - assert pathlib.Path( + def test_packager_creates_and_copies_python_package(self): + tsp = source_utils._TrainingScriptPythonPackager(_TEST_LOCAL_SCRIPT_FILE_PATH) + def create_valid_tarball(*args, **kwargs): + cwd = kwargs.get("cwd") + if cwd: + dist_dir = pathlib.Path(cwd) / "dist" + dist_dir.mkdir(exist_ok=True) + filename = f"{source_utils._TrainingScriptPythonPackager._ROOT_MODULE}-{source_utils._TrainingScriptPythonPackager._SETUP_PY_VERSION}.tar.gz" + tarball_path = dist_dir / filename + setup_py_path = pathlib.Path(cwd) / "setup.py" + arcname = f"{source_utils._TrainingScriptPythonPackager._ROOT_MODULE}-{source_utils._TrainingScriptPythonPackager._SETUP_PY_VERSION}/setup.py" + import tarfile + with tarfile.open(tarball_path, "w:gz") as tar: + tar.add(setup_py_path, arcname=arcname) + mock_subprocess = mock.Mock() + mock_subprocess.communicate.return_value = (b"", b"") + mock_subprocess.returncode = 0 + return mock_subprocess + with mock.patch("subprocess.Popen", side_effect=create_valid_tarball): + tsp.package_and_copy(copy_method=local_copy_method) + assert pathlib.Path( f"{tsp._ROOT_MODULE}-{tsp._SETUP_PY_VERSION}.tar.gz" ).is_file() - def test_requirements_are_in_package(self): - tsp = source_utils._TrainingScriptPythonPackager( + def test_requirements_are_in_package(self): + tsp = source_utils._TrainingScriptPythonPackager( _TEST_LOCAL_SCRIPT_FILE_PATH, requirements=_TEST_REQUIREMENTS ) - source_dist_path = tsp.package_and_copy(copy_method=local_copy_method) - with tarfile.open(source_dist_path) as tf: - with tempfile.TemporaryDirectory() as tmpdirname: - setup_py_path = f"{source_utils._TrainingScriptPythonPackager._ROOT_MODULE}-{source_utils._TrainingScriptPythonPackager._SETUP_PY_VERSION}/setup.py" - tf.extract(setup_py_path, path=tmpdirname) - setup_py = core.run_setup( - pathlib.Path(tmpdirname, setup_py_path), stop_after="init" - ) - assert _TEST_REQUIREMENTS == setup_py.install_requires - - def test_packaging_fails_whith_RuntimeError(self): - with patch("subprocess.Popen") as mock_popen: - mock_subprocess = mock.Mock() - mock_subprocess.communicate.return_value = (b"", b"") - mock_subprocess.returncode = 1 - mock_popen.return_value = mock_subprocess - tsp = source_utils._TrainingScriptPythonPackager( - _TEST_LOCAL_SCRIPT_FILE_PATH - ) - with pytest.raises(RuntimeError): - tsp.package_and_copy(copy_method=local_copy_method) - - def test_package_and_copy_to_gcs_copies_to_gcs(self, mock_client_bucket): - mock_client_bucket, mock_blob = mock_client_bucket - - tsp = source_utils._TrainingScriptPythonPackager(_TEST_LOCAL_SCRIPT_FILE_PATH) + def create_valid_tarball(*args, **kwargs): + cwd = kwargs.get("cwd") + if cwd: + dist_dir = pathlib.Path(cwd) / "dist" + dist_dir.mkdir(exist_ok=True) + filename = f"{source_utils._TrainingScriptPythonPackager._ROOT_MODULE}-{source_utils._TrainingScriptPythonPackager._SETUP_PY_VERSION}.tar.gz" + tarball_path = dist_dir / filename + setup_py_path = pathlib.Path(cwd) / "setup.py" + arcname = f"{source_utils._TrainingScriptPythonPackager._ROOT_MODULE}-{source_utils._TrainingScriptPythonPackager._SETUP_PY_VERSION}/setup.py" + import tarfile + + with tarfile.open(tarball_path, "w:gz") as tar: + tar.add(setup_py_path, arcname=arcname) + mock_subprocess = mock.Mock() + mock_subprocess.communicate.return_value = (b"", b"") + mock_subprocess.returncode = 0 + return mock_subprocess + + with mock.patch("subprocess.Popen", side_effect=create_valid_tarball): + source_dist_path = tsp.package_and_copy(copy_method=local_copy_method) + with tarfile.open(source_dist_path) as tf: + with tempfile.TemporaryDirectory() as tmpdirname: + setup_py_path = f"{source_utils._TrainingScriptPythonPackager._ROOT_MODULE}-{source_utils._TrainingScriptPythonPackager._SETUP_PY_VERSION}/setup.py" + tf.extract(setup_py_path, path=tmpdirname) + with open(pathlib.Path(tmpdirname, setup_py_path), "r") as f: + setup_py_content = f.read() + + import re + + match = re.search(r"install_requires=\((.*?)\)", setup_py_content) + assert match is not None + requirements_str = match.group(1) + expected_requirements_str = ",".join( + f'"{r}"' for r in _TEST_REQUIREMENTS + ) + assert requirements_str == expected_requirements_str + + def test_packaging_fails_whith_RuntimeError(self): + with patch("subprocess.Popen") as mock_popen: + mock_subprocess = mock.Mock() + mock_subprocess.communicate.return_value = (b"", b"") + mock_subprocess.returncode = 1 + mock_popen.return_value = mock_subprocess + tsp = source_utils._TrainingScriptPythonPackager( + _TEST_LOCAL_SCRIPT_FILE_PATH + ) + with pytest.raises(RuntimeError): + tsp.package_and_copy(copy_method=local_copy_method) - gcs_path = tsp.package_and_copy_to_gcs( - gcs_staging_dir=_TEST_BUCKET_NAME, project=_TEST_PROJECT - ) + def test_package_and_copy_to_gcs_copies_to_gcs(self, mock_client_bucket): + mock_client_bucket, mock_blob = mock_client_bucket + + tsp = source_utils._TrainingScriptPythonPackager(_TEST_LOCAL_SCRIPT_FILE_PATH) + + def create_valid_tarball(*args, **kwargs): + cwd = kwargs.get("cwd") + if cwd: + dist_dir = pathlib.Path(cwd) / "dist" + dist_dir.mkdir(exist_ok=True) + filename = f"{source_utils._TrainingScriptPythonPackager._ROOT_MODULE}-{source_utils._TrainingScriptPythonPackager._SETUP_PY_VERSION}.tar.gz" + tarball_path = dist_dir / filename + setup_py_path = pathlib.Path(cwd) / "setup.py" + arcname = f"{source_utils._TrainingScriptPythonPackager._ROOT_MODULE}-{source_utils._TrainingScriptPythonPackager._SETUP_PY_VERSION}/setup.py" + import tarfile + with tarfile.open(tarball_path, "w:gz") as tar: + tar.add(setup_py_path, arcname=arcname) + mock_subprocess = mock.Mock() + mock_subprocess.communicate.return_value = (b"", b"") + mock_subprocess.returncode = 0 + return mock_subprocess + with mock.patch("subprocess.Popen", side_effect=create_valid_tarball): + gcs_path = tsp.package_and_copy_to_gcs( + gcs_staging_dir=_TEST_BUCKET_NAME, project=_TEST_PROJECT + ) - mock_client_bucket.assert_called_once_with(_TEST_BUCKET_NAME) - mock_client_bucket.return_value.blob.assert_called_once() + mock_client_bucket.assert_called_once_with(_TEST_BUCKET_NAME) + mock_client_bucket.return_value.blob.assert_called_once() - mock_blob.upload_from_filename.call_args[0][0].endswith( + mock_blob.upload_from_filename.call_args[0][0].endswith( "/trainer/dist/aiplatform_custom_trainer_script-0.1.tar.gz" ) - assert gcs_path.endswith("-aiplatform_custom_trainer_script-0.1.tar.gz") - assert gcs_path.startswith(f"gs://{_TEST_BUCKET_NAME}") + assert gcs_path.endswith("-aiplatform_custom_trainer_script-0.1.tar.gz") + assert gcs_path.startswith(f"gs://{_TEST_BUCKET_NAME}") @pytest.fixture diff --git a/tests/unit/aiplatform/test_training_utils.py b/tests/unit/aiplatform/test_training_utils.py index 4629d23e33..6e3dfcb295 100644 --- a/tests/unit/aiplatform/test_training_utils.py +++ b/tests/unit/aiplatform/test_training_utils.py @@ -232,7 +232,12 @@ def test_package_file(self, mock_temp_file_name): ) with tempfile.TemporaryDirectory() as destination_directory_name: - _ = packager.make_package(package_directory=destination_directory_name) + with mock.patch("subprocess.Popen") as mock_popen: + mock_subprocess = mock.Mock() + mock_subprocess.communicate.return_value = (b"", b"") + mock_subprocess.returncode = 0 + mock_popen.return_value = mock_subprocess + _ = packager.make_package(package_directory=destination_directory_name) # Check that contents of source_distribution_path is the same as destination_directory_name destination_inner_path = f"{destination_directory_name}/{packager._TRAINER_FOLDER}/{packager._ROOT_MODULE}/{packager.task_module_name}.py" @@ -275,7 +280,12 @@ def test_package_folder(self, mock_temp_folder_name): with open(existing_file.name, "w") as handle: handle.write("existing") - _ = packager.make_package(package_directory=destination_directory_name) + with mock.patch("subprocess.Popen") as mock_popen: + mock_subprocess = mock.Mock() + mock_subprocess.communicate.return_value = (b"", b"") + mock_subprocess.returncode = 0 + mock_popen.return_value = mock_subprocess + _ = packager.make_package(package_directory=destination_directory_name) # Check that contents of source_distribution_path is the same as destination_directory_name destination_inner_path = f"{destination_directory_name}/{packager._TRAINER_FOLDER}/{packager._ROOT_MODULE}" diff --git a/tests/unit/vertex_rag/test_rag_data.py b/tests/unit/vertex_rag/test_rag_data.py index 17a59defca..ff27968b35 100644 --- a/tests/unit/vertex_rag/test_rag_data.py +++ b/tests/unit/vertex_rag/test_rag_data.py @@ -419,9 +419,10 @@ def create_transformation_config( def rag_corpus_eq(returned_corpus, expected_corpus): assert returned_corpus.name == expected_corpus.name assert returned_corpus.display_name == expected_corpus.display_name - assert returned_corpus.backend_config.__eq__(expected_corpus.backend_config) - assert returned_corpus.vertex_ai_search_config.__eq__( - expected_corpus.vertex_ai_search_config + assert returned_corpus.backend_config == expected_corpus.backend_config + assert ( + returned_corpus.vertex_ai_search_config + == expected_corpus.vertex_ai_search_config ) @@ -464,8 +465,8 @@ def import_files_request_eq(returned_request, expected_request): def rag_engine_config_eq(returned_config, expected_config): assert returned_config.name == expected_config.name - assert returned_config.rag_managed_db_config.__eq__( - expected_config.rag_managed_db_config + assert ( + returned_config.rag_managed_db_config == expected_config.rag_managed_db_config ) diff --git a/tests/unit/vertex_rag/test_rag_data_preview.py b/tests/unit/vertex_rag/test_rag_data_preview.py index 6b558a5836..68a2fda465 100644 --- a/tests/unit/vertex_rag/test_rag_data_preview.py +++ b/tests/unit/vertex_rag/test_rag_data_preview.py @@ -993,12 +993,13 @@ def rag_metadata_eq(returned_metadata, expected_metadata): def rag_corpus_eq(returned_corpus, expected_corpus): assert returned_corpus.name == expected_corpus.name assert returned_corpus.display_name == expected_corpus.display_name - assert returned_corpus.vector_db.__eq__(expected_corpus.vector_db) - assert returned_corpus.backend_config.__eq__(expected_corpus.backend_config) - assert returned_corpus.vertex_ai_search_config.__eq__( - expected_corpus.vertex_ai_search_config + assert returned_corpus.vector_db == expected_corpus.vector_db + assert returned_corpus.backend_config == expected_corpus.backend_config + assert ( + returned_corpus.vertex_ai_search_config + == expected_corpus.vertex_ai_search_config ) - assert returned_corpus.corpus_type_config.__eq__(expected_corpus.corpus_type_config) + assert returned_corpus.corpus_type_config == expected_corpus.corpus_type_config def rag_file_eq(returned_file, expected_file): @@ -1048,8 +1049,8 @@ def import_files_request_eq(returned_request, expected_request): def rag_engine_config_eq(returned_config, expected_config): assert returned_config.name == expected_config.name - assert returned_config.rag_managed_db_config.__eq__( - expected_config.rag_managed_db_config + assert ( + returned_config.rag_managed_db_config == expected_config.rag_managed_db_config ) diff --git a/vertexai/preview/rag/utils/_gapic_utils.py b/vertexai/preview/rag/utils/_gapic_utils.py index a372afe97f..df84c417f9 100644 --- a/vertexai/preview/rag/utils/_gapic_utils.py +++ b/vertexai/preview/rag/utils/_gapic_utils.py @@ -151,49 +151,49 @@ def _check_weaviate(gapic_vector_db: GapicRagVectorDbConfig) -> bool: try: return gapic_vector_db.__contains__("weaviate") except AttributeError: - return gapic_vector_db.weaviate.ByteSize() > 0 + return gapic_vector_db.weaviate._pb.ByteSize() > 0 def _check_rag_managed_db(gapic_vector_db: GapicRagVectorDbConfig) -> bool: try: return gapic_vector_db.__contains__("rag_managed_db") except AttributeError: - return gapic_vector_db.rag_managed_db.ByteSize() > 0 + return gapic_vector_db.rag_managed_db._pb.ByteSize() > 0 def _check_knn(gapic_rag_managed_db: GapicRagVectorDbConfig.RagManagedDb) -> bool: try: return gapic_rag_managed_db.__contains__("knn") except AttributeError: - return gapic_rag_managed_db.knn.ByteSize() > 0 + return gapic_rag_managed_db.knn._pb.ByteSize() > 0 def _check_ann(gapic_rag_managed_db: GapicRagVectorDbConfig.RagManagedDb) -> bool: try: return gapic_rag_managed_db.__contains__("ann") except AttributeError: - return gapic_rag_managed_db.ann.ByteSize() > 0 + return gapic_rag_managed_db.ann._pb.ByteSize() > 0 def _check_vertex_feature_store(gapic_vector_db: GapicRagVectorDbConfig) -> bool: try: return gapic_vector_db.__contains__("vertex_feature_store") except AttributeError: - return gapic_vector_db.vertex_feature_store.ByteSize() > 0 + return gapic_vector_db.vertex_feature_store._pb.ByteSize() > 0 def _check_pinecone(gapic_vector_db: GapicRagVectorDbConfig) -> bool: try: return gapic_vector_db.__contains__("pinecone") except AttributeError: - return gapic_vector_db.pinecone.ByteSize() > 0 + return gapic_vector_db.pinecone._pb.ByteSize() > 0 def _check_vertex_vector_search(gapic_vector_db: GapicRagVectorDbConfig) -> bool: try: return gapic_vector_db.__contains__("vertex_vector_search") except AttributeError: - return gapic_vector_db.vertex_vector_search.ByteSize() > 0 + return gapic_vector_db.vertex_vector_search._pb.ByteSize() > 0 def _check_rag_managed_vertex_vector_search( @@ -202,7 +202,7 @@ def _check_rag_managed_vertex_vector_search( try: return gapic_vector_db.__contains__("rag_managed_vertex_vector_search") except AttributeError: - return gapic_vector_db.rag_managed_vertex_vector_search.ByteSize() > 0 + return gapic_vector_db.rag_managed_vertex_vector_search._pb.ByteSize() > 0 def _check_rag_embedding_model_config( @@ -211,7 +211,25 @@ def _check_rag_embedding_model_config( try: return gapic_vector_db.__contains__("rag_embedding_model_config") except AttributeError: - return gapic_vector_db.rag_embedding_model_config.ByteSize() > 0 + return gapic_vector_db.rag_embedding_model_config._pb.ByteSize() > 0 + + +def _check_document_corpus( + gapic_corpus_type_config: GapicRagCorpus.CorpusTypeConfig, +) -> bool: + try: + return gapic_corpus_type_config.__contains__("document_corpus") + except AttributeError: + return gapic_corpus_type_config.document_corpus._pb.ByteSize() > 0 + + +def _check_memory_corpus( + gapic_corpus_type_config: GapicRagCorpus.CorpusTypeConfig, +) -> bool: + try: + return gapic_corpus_type_config.__contains__("memory_corpus") + except AttributeError: + return gapic_corpus_type_config.memory_corpus._pb.ByteSize() > 0 def _convert_gapic_to_rag_managed_db( @@ -285,8 +303,8 @@ def convert_gapic_to_vector_db( def convert_gapic_to_vertex_ai_search_config( - gapic_vertex_ai_search_config: VertexAiSearchConfig, -) -> VertexAiSearchConfig: + gapic_vertex_ai_search_config: GapicVertexAiSearchConfig, +) -> Optional[VertexAiSearchConfig]: """Convert Gapic VertexAiSearchConfig to VertexAiSearchConfig.""" if gapic_vertex_ai_search_config.serving_config: return VertexAiSearchConfig( @@ -326,6 +344,8 @@ def convert_gapic_to_backend_config( gapic_vector_db: GapicRagVectorDbConfig, ) -> RagVectorDbConfig: """Convert Gapic RagVectorDbConfig to VertexVectorSearch, Pinecone, or RagManagedDb.""" + if not gapic_vector_db or not gapic_vector_db._pb.ByteSize(): + return None vector_config = RagVectorDbConfig() if _check_pinecone(gapic_vector_db): vector_config.vector_db = Pinecone( @@ -351,6 +371,11 @@ def convert_gapic_to_backend_config( gapic_vector_db.rag_embedding_model_config ) ) + if ( + vector_config.vector_db is None + and vector_config.rag_embedding_model_config is None + ): + return None return vector_config @@ -358,9 +383,9 @@ def convert_gapic_to_rag_corpus_type_config( gapic_rag_corpus_type_config: GapicRagCorpus.CorpusTypeConfig, ) -> RagCorpusTypeConfig: """Convert GapicRagCorpus.CorpusTypeConfig to RagCorpusTypeConfig.""" - if gapic_rag_corpus_type_config.document_corpus: + if _check_document_corpus(gapic_rag_corpus_type_config): return RagCorpusTypeConfig(corpus_type_config=DocumentCorpus()) - elif gapic_rag_corpus_type_config.memory_corpus: + elif _check_memory_corpus(gapic_rag_corpus_type_config): return RagCorpusTypeConfig( corpus_type_config=MemoryCorpus( llm_parser=LlmParserConfig( @@ -402,16 +427,27 @@ def convert_gapic_to_rag_corpus_no_embedding_model_config( gapic_rag_corpus: GapicRagCorpus, ) -> RagCorpus: """Convert GapicRagCorpus without embedding model config (for UpdateRagCorpus) to RagCorpus.""" - rag_vector_db_config_no_embedding_model_config = gapic_rag_corpus.vector_db_config - rag_vector_db_config_no_embedding_model_config.rag_embedding_model_config = None + vertex_ai_search_config = convert_gapic_to_vertex_ai_search_config( + gapic_rag_corpus.vertex_ai_search_config + ) + old_config = gapic_rag_corpus.vector_db_config + rag_vector_db_config_no_embedding_model_config = old_config.__class__() + if _check_rag_managed_db(old_config): + rag_vector_db_config_no_embedding_model_config.rag_managed_db = old_config.rag_managed_db + elif _check_pinecone(old_config): + rag_vector_db_config_no_embedding_model_config.pinecone = old_config.pinecone + elif _check_vertex_vector_search(old_config): + rag_vector_db_config_no_embedding_model_config.vertex_vector_search = old_config.vertex_vector_search + elif _check_weaviate(old_config): + rag_vector_db_config_no_embedding_model_config.weaviate = old_config.weaviate + elif _check_vertex_feature_store(old_config): + rag_vector_db_config_no_embedding_model_config.vertex_feature_store = old_config.vertex_feature_store rag_corpus = RagCorpus( name=gapic_rag_corpus.name, display_name=gapic_rag_corpus.display_name, description=gapic_rag_corpus.description, vector_db=convert_gapic_to_vector_db(gapic_rag_corpus.rag_vector_db_config), - vertex_ai_search_config=convert_gapic_to_vertex_ai_search_config( - gapic_rag_corpus.vertex_ai_search_config - ), + vertex_ai_search_config=vertex_ai_search_config, backend_config=convert_gapic_to_backend_config( rag_vector_db_config_no_embedding_model_config ), diff --git a/vertexai/rag/utils/_gapic_utils.py b/vertexai/rag/utils/_gapic_utils.py index d0311dec16..713622fd38 100644 --- a/vertexai/rag/utils/_gapic_utils.py +++ b/vertexai/rag/utils/_gapic_utils.py @@ -141,35 +141,35 @@ def _check_weaviate(gapic_vector_db: GapicRagVectorDbConfig) -> bool: try: return gapic_vector_db.__contains__("weaviate") except AttributeError: - return gapic_vector_db.weaviate.ByteSize() > 0 + return gapic_vector_db.weaviate._pb.ByteSize() > 0 def _check_rag_managed_db(gapic_vector_db: GapicRagVectorDbConfig) -> bool: try: return gapic_vector_db.__contains__("rag_managed_db") except AttributeError: - return gapic_vector_db.rag_managed_db.ByteSize() > 0 + return gapic_vector_db.rag_managed_db._pb.ByteSize() > 0 def _check_vertex_feature_store(gapic_vector_db: GapicRagVectorDbConfig) -> bool: try: return gapic_vector_db.__contains__("vertex_feature_store") except AttributeError: - return gapic_vector_db.vertex_feature_store.ByteSize() > 0 + return gapic_vector_db.vertex_feature_store._pb.ByteSize() > 0 def _check_pinecone(gapic_vector_db: GapicRagVectorDbConfig) -> bool: try: return gapic_vector_db.__contains__("pinecone") except AttributeError: - return gapic_vector_db.pinecone.ByteSize() > 0 + return gapic_vector_db.pinecone._pb.ByteSize() > 0 def _check_vertex_vector_search(gapic_vector_db: GapicRagVectorDbConfig) -> bool: try: return gapic_vector_db.__contains__("vertex_vector_search") except AttributeError: - return gapic_vector_db.vertex_vector_search.ByteSize() > 0 + return gapic_vector_db.vertex_vector_search._pb.ByteSize() > 0 def _check_rag_embedding_model_config( @@ -178,13 +178,15 @@ def _check_rag_embedding_model_config( try: return gapic_vector_db.__contains__("rag_embedding_model_config") except AttributeError: - return gapic_vector_db.rag_embedding_model_config.ByteSize() > 0 + return gapic_vector_db.rag_embedding_model_config._pb.ByteSize() > 0 def convert_gapic_to_backend_config( gapic_vector_db: GapicRagVectorDbConfig, ) -> RagVectorDbConfig: """Convert Gapic RagVectorDbConfig to VertexVectorSearch, Pinecone, or RagManagedDb.""" + if not gapic_vector_db: + return None vector_config = RagVectorDbConfig() if _check_pinecone(gapic_vector_db): vector_config.vector_db = Pinecone( @@ -208,9 +210,11 @@ def convert_gapic_to_backend_config( def convert_gapic_to_vertex_ai_search_config( - gapic_vertex_ai_search_config: VertexAiSearchConfig, -) -> VertexAiSearchConfig: + gapic_vertex_ai_search_config: GapicVertexAiSearchConfig, +) -> Optional[VertexAiSearchConfig]: """Convert Gapic VertexAiSearchConfig to VertexAiSearchConfig.""" + print(f"DEBUG: gapic_vertex_ai_search_config={gapic_vertex_ai_search_config!r}") + print(f"DEBUG: serving_config={gapic_vertex_ai_search_config.serving_config!r}") if gapic_vertex_ai_search_config.serving_config: return VertexAiSearchConfig( serving_config=gapic_vertex_ai_search_config.serving_config, @@ -239,15 +243,26 @@ def convert_gapic_to_rag_corpus_no_embedding_model_config( gapic_rag_corpus: GapicRagCorpus, ) -> RagCorpus: """Convert GapicRagCorpus without embedding model config (for UpdateRagCorpus) to RagCorpus.""" - rag_vector_db_config_no_embedding_model_config = gapic_rag_corpus.vector_db_config - rag_vector_db_config_no_embedding_model_config.rag_embedding_model_config = None + vertex_ai_search_config = convert_gapic_to_vertex_ai_search_config( + gapic_rag_corpus.vertex_ai_search_config + ) + old_config = gapic_rag_corpus.vector_db_config + rag_vector_db_config_no_embedding_model_config = old_config.__class__() + if _check_rag_managed_db(old_config): + rag_vector_db_config_no_embedding_model_config.rag_managed_db = old_config.rag_managed_db + elif _check_pinecone(old_config): + rag_vector_db_config_no_embedding_model_config.pinecone = old_config.pinecone + elif _check_vertex_vector_search(old_config): + rag_vector_db_config_no_embedding_model_config.vertex_vector_search = old_config.vertex_vector_search + elif _check_weaviate(old_config): + rag_vector_db_config_no_embedding_model_config.weaviate = old_config.weaviate + elif _check_vertex_feature_store(old_config): + rag_vector_db_config_no_embedding_model_config.vertex_feature_store = old_config.vertex_feature_store rag_corpus = RagCorpus( name=gapic_rag_corpus.name, display_name=gapic_rag_corpus.display_name, description=gapic_rag_corpus.description, - vertex_ai_search_config=convert_gapic_to_vertex_ai_search_config( - gapic_rag_corpus.vertex_ai_search_config - ), + vertex_ai_search_config=vertex_ai_search_config, backend_config=convert_gapic_to_backend_config( rag_vector_db_config_no_embedding_model_config ),