From 00dd502a8aef68ea1fba7e14abce1aae3ae98b56 Mon Sep 17 00:00:00 2001 From: Jason Dai Date: Thu, 30 Apr 2026 17:20:13 -0700 Subject: [PATCH] feat: GenAI Client(evals) - Support N+1 Agent Engine inference via agent_data in run_inference() PiperOrigin-RevId: 908461295 --- .../genai/replays/test_run_inference.py | 159 +++++++ vertexai/_genai/_evals_common.py | 424 ++++++++++++++---- vertexai/_genai/evals.py | 28 +- 3 files changed, 523 insertions(+), 88 deletions(-) create mode 100644 tests/unit/vertexai/genai/replays/test_run_inference.py diff --git a/tests/unit/vertexai/genai/replays/test_run_inference.py b/tests/unit/vertexai/genai/replays/test_run_inference.py new file mode 100644 index 0000000000..a820d7ed9e --- /dev/null +++ b/tests/unit/vertexai/genai/replays/test_run_inference.py @@ -0,0 +1,159 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# pylint: disable=protected-access,bad-continuation,missing-function-docstring + +import pytest + +from tests.unit.vertexai.genai.replays import pytest_helper +from vertexai._genai import types +from google.genai import types as genai_types + +pytest.importorskip( + "google.adk", reason="google-adk not installed, skipping ADK agent tests" +) +from google.adk.agents import ( # noqa: E402 + LlmAgent, +) # pylint: disable=g-import-not-at-top,g-bad-import-order + + +def test_inference_with_eval_cases_multi_turn_agent_data(client): + """Tests run_inference with multi-turn agent_data in eval_cases. + + Verifies that run_inference() accepts an EvaluationDataset with + eval_cases containing agent_data (no eval_dataset_df). The agent_data + has 2 turns: turn 0 is a completed user+agent exchange (history), + turn 1 is a new user query. The agent should see the history and + respond to the final query in context. + """ + agent = LlmAgent( + name="test_agent", + model="gemini-2.5-flash", + instruction="You are a helpful assistant. Answer questions concisely.", + ) + + eval_case = types.EvalCase( + agent_data=types.evals.AgentData( + turns=[ + types.evals.ConversationTurn( + turn_index=0, + events=[ + types.evals.AgentEvent( + author="user", + content=genai_types.Content( + role="user", + parts=[genai_types.Part(text="My name is Alice.")], + ), + ), + types.evals.AgentEvent( + author="test_agent", + content=genai_types.Content( + role="model", + parts=[ + genai_types.Part( + text="Hello Alice! How can I help you?" + ) + ], + ), + ), + ], + ), + types.evals.ConversationTurn( + turn_index=1, + events=[ + types.evals.AgentEvent( + author="user", + content=genai_types.Content( + role="user", + parts=[genai_types.Part(text="What is my name?")], + ), + ), + ], + ), + ], + ), + ) + eval_dataset = types.EvaluationDataset(eval_cases=[eval_case]) + + inference_result = client.evals.run_inference( + agent=agent, + src=eval_dataset, + ) + assert isinstance(inference_result, types.EvaluationDataset) + assert inference_result.eval_dataset_df is not None + assert "agent_data" in inference_result.eval_dataset_df.columns + + +def test_inference_with_eval_cases_agent_engine_agent_data(client): + """Tests N+1 inference with agent_data via remote Agent Engine.""" + agent_engine = client.agent_engines.get( + name="projects/977012026409/locations/us-central1" + "/reasoningEngines/7188347537655332864" + ) + + eval_case = types.EvalCase( + agent_data=types.evals.AgentData( + turns=[ + types.evals.ConversationTurn( + turn_index=0, + events=[ + types.evals.AgentEvent( + author="user", + content=genai_types.Content( + role="user", + parts=[genai_types.Part(text="My name is Bob.")], + ), + ), + types.evals.AgentEvent( + author="model", + content=genai_types.Content( + role="model", + parts=[ + genai_types.Part(text="Hi Bob! Nice to meet you.") + ], + ), + ), + ], + ), + types.evals.ConversationTurn( + turn_index=1, + events=[ + types.evals.AgentEvent( + author="user", + content=genai_types.Content( + role="user", + parts=[genai_types.Part(text="What is my name?")], + ), + ), + ], + ), + ], + ), + ) + eval_dataset = types.EvaluationDataset(eval_cases=[eval_case]) + + inference_result = client.evals.run_inference( + agent=agent_engine, + src=eval_dataset, + ) + assert isinstance(inference_result, types.EvaluationDataset) + assert inference_result.eval_dataset_df is not None + assert "agent_data" in inference_result.eval_dataset_df.columns + + +pytestmark = pytest_helper.setup( + file=__file__, + globals_for_file=globals(), + test_method="evals.run_inference", +) diff --git a/vertexai/_genai/_evals_common.py b/vertexai/_genai/_evals_common.py index 2b37abf04e..04b6e20d8f 100644 --- a/vertexai/_genai/_evals_common.py +++ b/vertexai/_genai/_evals_common.py @@ -253,6 +253,119 @@ def _extract_contents_for_inference( return request_dict_or_raw_text +def _eval_cases_to_dataframe( + eval_cases: list[types.EvalCase], +) -> pd.DataFrame: + """Converts a list of EvalCase objects to a pandas DataFrame. + + Each EvalCase is converted to a row in the DataFrame. Structured fields + like ``agent_data`` are preserved as-is (not flattened) so that downstream + agent execution paths can consume them directly. + + Args: + eval_cases: The list of EvalCase objects to convert. + + Returns: + A DataFrame with one row per EvalCase. + """ + rows = [] + for case in eval_cases: + row: dict[str, Any] = {} + if case.prompt: + row[_evals_constant.PROMPT] = _evals_data_converters._get_content_text( + case.prompt + ) + + if case.responses and len(case.responses) > 0 and case.responses[0].response: + row[_evals_constant.RESPONSE] = _evals_data_converters._get_content_text( + case.responses[0].response + ) + + if case.reference and case.reference.response: + row[_evals_constant.REFERENCE] = _evals_data_converters._get_content_text( + case.reference.response + ) + + if case.agent_data: + row[AGENT_DATA] = case.agent_data + + if case.intermediate_events: + row[_evals_constant.INTERMEDIATE_EVENTS] = [ + {CONTENT: event.content} + for event in case.intermediate_events + if event.content + ] + + if case.conversation_history: + history_parts = [] + for msg in case.conversation_history: + if msg.content: + role = msg.content.role or "user" + text = _evals_data_converters._get_content_text(msg.content) + history_parts.append(f"{role}: {text}") + if history_parts: + row[_evals_constant.CONVERSATION_HISTORY] = "\n".join(history_parts) + + if case.user_scenario: + if case.user_scenario.starting_prompt: + row[_evals_constant.STARTING_PROMPT] = ( + case.user_scenario.starting_prompt + ) + if case.user_scenario.conversation_plan: + row[_evals_constant.CONVERSATION_PLAN] = ( + case.user_scenario.conversation_plan + ) + + rows.append(row) + return pd.DataFrame(rows) + + +def _extract_prompt_from_agent_data( + agent_data: types.evals.AgentData, +) -> tuple[genai_types.Content, list[types.evals.AgentEvent]]: + """Extracts the last user message and prior events from agent_data. + + The last event across all turns must be authored by ``"user"``; it is + treated as the current prompt that the agent should respond to. + Everything before it is returned as conversation history. + + Args: + agent_data: The AgentData containing conversation turns. + + Returns: + A tuple of ``(last_user_content, history_events)`` where + ``last_user_content`` is the ``Content`` of the final user event + and ``history_events`` is the ordered list of all prior + ``AgentEvent`` objects. + + Raises: + ValueError: If ``agent_data`` has no turns, no events, or the last + event is not a user event. + """ + if not agent_data.turns: + raise ValueError("agent_data must have at least one turn.") + + all_events: list[types.evals.AgentEvent] = [] + for turn in agent_data.turns: + if turn.events: + all_events.extend(turn.events) + + if not all_events: + raise ValueError("agent_data turns contain no events.") + + last_event = all_events[-1] + if last_event.author != USER_AUTHOR: + raise ValueError( + "agent_data must end with a user event, but the last event has" + f" author='{last_event.author}'." + ) + + if not last_event.content: + raise ValueError("The last user event in agent_data has no content.") + + return last_event.content, all_events[:-1] + + def _resolve_dataset( api_client: BaseApiClient, dataset: Union[types.EvaluationRunDataSource, types.EvaluationDataset], @@ -264,66 +377,7 @@ def _resolve_dataset( candidate_name = _get_candidate_name(dataset, parsed_agent_info) eval_df = dataset.eval_dataset_df if eval_df is None and dataset.eval_cases: - rows = [] - for case in dataset.eval_cases: - row: dict[str, Any] = {} - if case.prompt: - row[_evals_constant.PROMPT] = ( - _evals_data_converters._get_content_text(case.prompt) - ) - - if ( - case.responses - and len(case.responses) > 0 - and case.responses[0].response - ): - row[_evals_constant.RESPONSE] = ( - _evals_data_converters._get_content_text( - case.responses[0].response - ) - ) - - if case.reference and case.reference.response: - row[_evals_constant.REFERENCE] = ( - _evals_data_converters._get_content_text( - case.reference.response - ) - ) - - if case.agent_data: - row[AGENT_DATA] = case.agent_data - - if case.intermediate_events: - row[_evals_constant.INTERMEDIATE_EVENTS] = [ - {CONTENT: event.content} - for event in case.intermediate_events - if event.content - ] - - if case.conversation_history: - history_parts = [] - for msg in case.conversation_history: - if msg.content: - role = msg.content.role or "user" - text = _evals_data_converters._get_content_text(msg.content) - history_parts.append(f"{role}: {text}") - if history_parts: - row[_evals_constant.CONVERSATION_HISTORY] = "\n".join( - history_parts - ) - - if case.user_scenario: - if case.user_scenario.starting_prompt: - row[_evals_constant.STARTING_PROMPT] = ( - case.user_scenario.starting_prompt - ) - if case.user_scenario.conversation_plan: - row[_evals_constant.CONVERSATION_PLAN] = ( - case.user_scenario.conversation_plan - ) - - rows.append(row) - eval_df = pd.DataFrame(rows) + eval_df = _eval_cases_to_dataframe(dataset.eval_cases) eval_set = _create_evaluation_set_from_dataframe( api_client, @@ -500,25 +554,54 @@ def _execute_inference_concurrently( ] = [None] * len(prompt_dataset) tasks = [] - if "request" in prompt_dataset.columns: - primary_prompt_column = "request" - elif "prompt" in prompt_dataset.columns: - primary_prompt_column = "prompt" - elif "starting_prompt" in prompt_dataset.columns: - primary_prompt_column = "starting_prompt" - else: - raise ValueError( - "Dataset must contain either 'prompt', 'request', or 'starting_prompt'." - f" Found: {prompt_dataset.columns.tolist()}" - ) + # When running with an agent and agent_data is present, we extract the + # prompt from the structured agent_data rather than requiring a flat + # prompt/request column. + has_agent_data = ( + agent is not None or agent_engine is not None + ) and AGENT_DATA in prompt_dataset.columns + + primary_prompt_column: Optional[str] = None + if not has_agent_data: + if "request" in prompt_dataset.columns: + primary_prompt_column = "request" + elif "prompt" in prompt_dataset.columns: + primary_prompt_column = "prompt" + elif "starting_prompt" in prompt_dataset.columns: + primary_prompt_column = "starting_prompt" + else: + raise ValueError( + "Dataset must contain either 'prompt', 'request', or" + " 'starting_prompt'." + f" Found: {prompt_dataset.columns.tolist()}" + ) max_workers = AGENT_MAX_WORKERS if agent_engine or agent else MAX_WORKERS with tqdm(total=len(prompt_dataset), desc=progress_desc) as pbar: with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: for index, row in prompt_dataset.iterrows(): - request_dict_or_raw_text = row[primary_prompt_column] try: - contents = _extract_contents_for_inference(request_dict_or_raw_text) + if ( + has_agent_data + and AGENT_DATA in row.index + and row.get(AGENT_DATA) is not None + ): + agent_data_obj = row[AGENT_DATA] + if isinstance(agent_data_obj, dict): + agent_data_obj = types.evals.AgentData.model_validate( + agent_data_obj + ) + last_user_content, _ = _extract_prompt_from_agent_data( + agent_data_obj + ) + contents = _evals_data_converters._get_content_text( + last_user_content + ) + else: + request_dict_or_raw_text = row[primary_prompt_column] + contents = _extract_contents_for_inference( + request_dict_or_raw_text + ) except ValueError as e: error_message = ( f"Failed to extract contents for prompt at index {index}: {e}. " @@ -696,8 +779,7 @@ def _run_litellm_inference( ) -> list[Optional[dict[str, Any]]]: """Runs inference using LiteLLM with concurrency.""" logger.info( - "Generating responses for %d prompts using LiteLLM for third party" - " model: %s", + "Generating responses for %d prompts using LiteLLM for third party model: %s", len(prompt_dataset), model, ) @@ -1843,6 +1925,14 @@ def _create_agent_results_dataframe( prompt_dataset_indexed = prompt_dataset.reset_index(drop=True) results_df_responses_only_indexed = results_df_raw.reset_index(drop=True) + # Drop columns from input that will be overwritten by results to avoid + # duplicate columns after concatenation (e.g. agent_data). + overlap = prompt_dataset_indexed.columns.intersection( + results_df_responses_only_indexed.columns + ) + if not overlap.empty: + prompt_dataset_indexed = prompt_dataset_indexed.drop(columns=overlap) + results_df = pd.concat( [prompt_dataset_indexed, results_df_responses_only_indexed], axis=1 ) @@ -2046,9 +2136,100 @@ def _execute_agent_run_with_retry( max_retries: int = 3, ) -> Union[list[dict[str, Any]], dict[str, Any]]: """Executes agent run over agent engine for a single prompt.""" + # Agent data with conversation history — pre-populate the remote session + # with prior turns via the Sessions API, then query with the last user + # message only. + if AGENT_DATA in row.index and row.get(AGENT_DATA) is not None: + agent_data_obj = row[AGENT_DATA] + if isinstance(agent_data_obj, dict): + agent_data_obj = types.evals.AgentData.model_validate(agent_data_obj) + if isinstance(agent_data_obj, types.evals.AgentData) and agent_data_obj.turns: + try: + last_user_content, history_events = _extract_prompt_from_agent_data( + agent_data_obj + ) + except ValueError as e: + return {"error": f"Invalid agent_data for inference: {e}"} + + user_id = str(uuid.uuid4()) + session_state = None + if "session_inputs" in row.index and row.get("session_inputs") is not None: + si = _get_session_inputs(row) + user_id = si.user_id or user_id + session_state = si.state + + try: + session_id = _create_agent_engine_session( + agent_engine=agent_engine, + user_id=user_id, + session_state=session_state, + ) + except Exception as e: # pylint: disable=broad-exception-caught + return {"error": f"Failed to create session: {e}"} + + # Pre-populate remote session with history events. + if agent_engine.api_resource is None: + return {"error": "agent_engine.api_resource is None."} + if agent_engine.api_client is None: + return {"error": "agent_engine.api_client is None."} + session_name = f"{agent_engine.api_resource.name}/sessions/{session_id}" + # Use a fixed base timestamp for history events so that + # replay tests produce deterministic request bodies. + base_ts = datetime.datetime(2000, 1, 1, tzinfo=datetime.timezone.utc) + for i, ag_event in enumerate(history_events): + agent_engine.api_client.sessions.events.append( + name=session_name, + author=ag_event.author or "user", + invocation_id="history", + timestamp=base_ts + datetime.timedelta(seconds=i), + config=types.AppendAgentEngineSessionEventConfig( + content=ag_event.content, + ), + ) + + last_user_text = _evals_data_converters._get_content_text(last_user_content) + for attempt in range(max_retries): + try: + responses = [] + for event in agent_engine.stream_query( # type: ignore[attr-defined] + user_id=user_id, + session_id=session_id, + message=last_user_text, + ): + if event and CONTENT in event and PARTS in event[CONTENT]: + responses.append(event) + return responses + except api_exceptions.ResourceExhausted as e: + logger.warning( + "Resource Exhausted error on attempt %d/%d: %s." + " Retrying in %s seconds...", + attempt + 1, + max_retries, + e, + 2**attempt, + ) + if attempt == max_retries - 1: + return {"error": (f"Resource exhausted after retries: {e}")} + time.sleep(2**attempt) + except Exception as e: # pylint: disable=broad-exception-caught + logger.error( + "Unexpected error during agent engine run on attempt %d/%d: %s", + attempt + 1, + max_retries, + e, + ) + if attempt == max_retries - 1: + return {"error": f"Failed after retries: {e}"} + time.sleep(1) + return { + "error": ( + f"Failed to get agent run results after {max_retries} retries" + ) + } + try: session_inputs = _get_session_inputs(row) - user_id = session_inputs.user_id + user_id = session_inputs.user_id or str(uuid.uuid4()) session_state = session_inputs.state session_id = _create_agent_engine_session( agent_engine=agent_engine, @@ -2132,8 +2313,94 @@ async def _execute_local_agent_run_with_retry_async( logger.error("Multi-turn agent run with user simulation failed: %s", e) return {"error": f"Multi-turn agent run with user simulation failed: {e}"} + # Agent data with conversation history — pre-populate the session with + # prior turns so the agent sees them as context, then send the last user + # message as the new query. + if AGENT_DATA in row.index and row.get(AGENT_DATA) is not None: + from google.adk.events.event import Event as AdkEvent + + agent_data_obj = row[AGENT_DATA] + if isinstance(agent_data_obj, dict): + agent_data_obj = types.evals.AgentData.model_validate(agent_data_obj) + if isinstance(agent_data_obj, types.evals.AgentData) and agent_data_obj.turns: + try: + last_user_content, history_events = _extract_prompt_from_agent_data( + agent_data_obj + ) + except ValueError as e: + return {"error": f"Invalid agent_data for inference: {e}"} + + user_id = str(uuid.uuid4()) + session_id = str(uuid.uuid4()) + app_name = "local agent run" + if "session_inputs" in row.index and row.get("session_inputs") is not None: + session_inputs = _get_session_inputs(row) + user_id = session_inputs.user_id or user_id + app_name = session_inputs.app_name or app_name + session_service = InMemorySessionService() + await session_service.create_session( + app_name=app_name, user_id=user_id, session_id=session_id + ) + + # Pre-populate session with history events so the agent sees + # prior conversation context. + internal_session = session_service.sessions[app_name][user_id][session_id] + for ag_event in history_events: + adk_event = AdkEvent( + author=ag_event.author or "user", + content=ag_event.content, + invocation_id="history", + ) + internal_session.events.append(adk_event) + + agent_runner = Runner( + agent=agent, app_name=app_name, session_service=session_service + ) + + with _temp_logger_level("google_genai.types", logging.ERROR): + for attempt in range(max_retries): + try: + events = [] + async for event in agent_runner.run_async( + user_id=user_id, + session_id=session_id, + new_message=last_user_content, + ): + if event: + event = event.model_dump(exclude_none=True) + if event and CONTENT in event and PARTS in event[CONTENT]: + events.append(event) + return events + except api_exceptions.ResourceExhausted as e: + logger.warning( + "Resource Exhausted error on attempt %d/%d: %s." + " Retrying in %s seconds...", + attempt + 1, + max_retries, + e, + 2**attempt, + ) + if attempt == max_retries - 1: + return {"error": f"Resource exhausted after retries: {e}"} + await asyncio.sleep(2**attempt) + except Exception as e: # pylint: disable=broad-exception-caught + logger.error( + "Unexpected error during agent run on attempt %d/%d: %s", + attempt + 1, + max_retries, + e, + ) + if attempt == max_retries - 1: + return {"error": f"Failed after retries: {e}"} + await asyncio.sleep(1) + return { + "error": ( + f"Failed to get agent run results after {max_retries} retries" + ) + } + session_inputs = _get_session_inputs(row) - user_id = session_inputs.user_id + user_id = session_inputs.user_id or str(uuid.uuid4()) session_id = str(uuid.uuid4()) app_name = session_inputs.app_name or "local agent run" # TODO: Enable user to set up session service. @@ -2613,8 +2880,7 @@ def _get_content(row: dict[str, Any], column: str) -> Optional[genai_types.Conte return cast(genai_types.Content, row[column]) else: raise ValueError( - f"{column} must be a string or a Content object. Got" - f" {type(row[column])}." + f"{column} must be a string or a Content object. Got {type(row[column])}." ) diff --git a/vertexai/_genai/evals.py b/vertexai/_genai/evals.py index 5567036989..2e455f116b 100644 --- a/vertexai/_genai/evals.py +++ b/vertexai/_genai/evals.py @@ -2096,9 +2096,11 @@ def run_inference( Args: src: The source of the dataset. Can be a string (path to a local file, a GCS path, or a BigQuery table), a Pandas DataFrame, or an - EvaluationDataset object. If an Evalu - ationDataset is provided, - it must have `eval_dataset_df` populated. + EvaluationDataset object. An EvaluationDataset may have either + ``eval_dataset_df`` or ``eval_cases`` populated. When + ``eval_cases`` with ``agent_data`` is provided, the last user + event in the turns is used as the current prompt and prior + events are replayed as session history for local ADK agents. model: Optional type is experimental and may change in future versions. The model to use for inference, optional for agent evaluations. - For Google Gemini models, provide the model name string (e.g., "gemini-2.5-flash"). @@ -2134,11 +2136,15 @@ def run_inference( config = types.EvalRunInferenceConfig.model_validate(config) if isinstance(src, types.EvaluationDataset): - if src.eval_dataset_df is None: + if src.eval_dataset_df is not None: + src = src.eval_dataset_df + elif src.eval_cases: + src = _evals_common._eval_cases_to_dataframe(src.eval_cases) + else: raise ValueError( - "EvaluationDataset must have eval_dataset_df populated." + "EvaluationDataset must have eval_dataset_df or eval_cases" + " populated." ) - src = src.eval_dataset_df agent_engine_instance = None agent_instance = None @@ -2373,11 +2379,15 @@ def generate_rubrics( {rubric_group_name: [list[Rubric]]}. """ if isinstance(src, types.EvaluationDataset): - if src.eval_dataset_df is None: + if src.eval_dataset_df is not None: + prompts_df = src.eval_dataset_df + elif src.eval_cases: + prompts_df = _evals_common._eval_cases_to_dataframe(src.eval_cases) + else: raise ValueError( - "EvaluationDataset must have eval_dataset_df populated." + "EvaluationDataset must have eval_dataset_df or eval_cases" + " populated." ) - prompts_df = src.eval_dataset_df elif isinstance(src, (str, pd.DataFrame)): try: prompts_df = _evals_common._load_dataframe(self._api_client, src)