diff --git a/tests/unit/vertexai/genai/replays/test_run_inference.py b/tests/unit/vertexai/genai/replays/test_run_inference.py index 5b7de3f9f3..c090e68272 100644 --- a/tests/unit/vertexai/genai/replays/test_run_inference.py +++ b/tests/unit/vertexai/genai/replays/test_run_inference.py @@ -95,6 +95,211 @@ def test_inference_with_eval_cases_multi_turn_agent_data(client): assert "agent_data" in inference_result.eval_dataset_df.columns +def test_inference_with_eval_cases_agent_engine_agent_data(client): + """Tests N+1 inference with agent_data via remote Agent Engine.""" + agent_engine = client.agent_engines.get( + name="projects/977012026409/locations/us-central1" + "/reasoningEngines/7188347537655332864" + ) + + eval_case = types.EvalCase( + agent_data=types.evals.AgentData( + turns=[ + types.evals.ConversationTurn( + turn_index=0, + events=[ + types.evals.AgentEvent( + author="user", + content=genai_types.Content( + role="user", + parts=[genai_types.Part(text="My name is Bob.")], + ), + ), + types.evals.AgentEvent( + author="model", + content=genai_types.Content( + role="model", + parts=[ + genai_types.Part(text="Hi Bob! Nice to meet you.") + ], + ), + ), + ], + ), + types.evals.ConversationTurn( + turn_index=1, + events=[ + types.evals.AgentEvent( + author="user", + content=genai_types.Content( + role="user", + parts=[genai_types.Part(text="What is my name?")], + ), + ), + ], + ), + ], + ), + ) + eval_dataset = types.EvaluationDataset(eval_cases=[eval_case]) + + inference_result = client.evals.run_inference( + agent=agent_engine, + src=eval_dataset, + ) + assert isinstance(inference_result, types.EvaluationDataset) + assert inference_result.eval_dataset_df is not None + assert "agent_data" in inference_result.eval_dataset_df.columns + + +def test_inference_with_prompt_column_local_agent(client): + """Tests run_inference with a prompt column and a local ADK agent. + + Verifies the existing prompt-based inference path: a DataFrame with + a 'prompt' column is passed alongside a local LlmAgent. The agent + should respond to the prompt normally. + """ + import pandas as pd + + agent = LlmAgent( + name="prompt_agent", + model="gemini-2.5-flash", + instruction="You are a helpful assistant. Answer questions concisely.", + ) + + prompt_df = pd.DataFrame( + { + "prompt": ["What is the capital of France?"], + } + ) + eval_dataset = types.EvaluationDataset(eval_dataset_df=prompt_df) + + inference_result = client.evals.run_inference( + agent=agent, + src=eval_dataset, + ) + assert isinstance(inference_result, types.EvaluationDataset) + result_df = inference_result.eval_dataset_df + assert result_df is not None + assert "response" in result_df.columns + # The response should be a non-empty string (actual model answer). + response_val = result_df["response"].iloc[0] + assert response_val is not None + assert isinstance(response_val, str) + assert len(response_val) > 0 + + +def test_inference_with_completed_and_incomplete_agent_data(client): + """Tests run_inference with a mix of completed and N+1 agent traces. + + Provides two eval_cases: + - Row 0: completed trace (last event from agent) — BYOD, no inference. + - Row 1: incomplete trace (last event from user) — N+1 inference. + + The completed row should return the existing agent response without + making any API calls. The N+1 row should run inference normally. + """ + agent = LlmAgent( + name="mixed_agent", + model="gemini-2.5-flash", + instruction="You are a helpful assistant. Answer questions concisely.", + ) + + # Row 0: Completed trace — last event is from the agent. + completed_case = types.EvalCase( + agent_data=types.evals.AgentData( + turns=[ + types.evals.ConversationTurn( + turn_index=0, + events=[ + types.evals.AgentEvent( + author="user", + content=genai_types.Content( + role="user", + parts=[genai_types.Part(text="What color is the sky?")], + ), + ), + types.evals.AgentEvent( + author="mixed_agent", + content=genai_types.Content( + role="model", + parts=[genai_types.Part(text="The sky is blue.")], + ), + ), + ], + ), + ], + ), + ) + + # Row 1: N+1 trace — last event is from the user. + n_plus_1_case = types.EvalCase( + agent_data=types.evals.AgentData( + turns=[ + types.evals.ConversationTurn( + turn_index=0, + events=[ + types.evals.AgentEvent( + author="user", + content=genai_types.Content( + role="user", + parts=[ + genai_types.Part(text="My favorite number is 7.") + ], + ), + ), + types.evals.AgentEvent( + author="mixed_agent", + content=genai_types.Content( + role="model", + parts=[genai_types.Part(text="Got it, 7!")], + ), + ), + ], + ), + types.evals.ConversationTurn( + turn_index=1, + events=[ + types.evals.AgentEvent( + author="user", + content=genai_types.Content( + role="user", + parts=[ + genai_types.Part(text="What is my favorite number?") + ], + ), + ), + ], + ), + ], + ), + ) + + eval_dataset = types.EvaluationDataset(eval_cases=[completed_case, n_plus_1_case]) + + inference_result = client.evals.run_inference( + agent=agent, + src=eval_dataset, + ) + assert isinstance(inference_result, types.EvaluationDataset) + result_df = inference_result.eval_dataset_df + assert result_df is not None + assert len(result_df) == 2 + + # Row 0 (completed trace): response should contain the existing + # agent answer — "The sky is blue." + row0_response = result_df["response"].iloc[0] + assert row0_response is not None + assert "blue" in row0_response.lower() + + # Row 1 (N+1 inference): response should be a non-empty string + # from the model (actual inference). + row1_response = result_df["response"].iloc[1] + assert row1_response is not None + assert isinstance(row1_response, str) + assert len(row1_response) > 0 + + pytestmark = pytest_helper.setup( file=__file__, globals_for_file=globals(), diff --git a/vertexai/_genai/_evals_common.py b/vertexai/_genai/_evals_common.py index c5cfb352e8..196e7c70b0 100644 --- a/vertexai/_genai/_evals_common.py +++ b/vertexai/_genai/_evals_common.py @@ -366,6 +366,60 @@ def _extract_prompt_from_agent_data( return last_event.content, all_events[:-1] +def _is_n_plus_1_inference( + agent_data: Union[types.evals.AgentData, dict[str, Any]], +) -> bool: + """Returns True if agent_data represents an N+1 inference case. + + An N+1 case means the trace is incomplete: N prior conversation turns + exist plus 1 final user query that the agent should respond to. This + is detected by checking whether the very last event across all turns + is authored by ``"user"``. + + Returns ``False`` for completed traces (last event from the agent), + empty traces, or invalid data. + """ + if isinstance(agent_data, dict): + try: + agent_data = types.evals.AgentData.model_validate(agent_data) + except Exception: # pylint: disable=broad-exception-caught + return False + if not isinstance(agent_data, types.evals.AgentData): + return False + if not agent_data.turns: + return False + all_events: list[types.evals.AgentEvent] = [] + for turn in agent_data.turns or []: + if turn.events: + all_events.extend(turn.events) + if not all_events: + return False + return all_events[-1].author == USER_AUTHOR + + +def _extract_response_from_completed_trace( + agent_data: types.evals.AgentData, +) -> list[dict[str, Any]]: + """Extracts all events from a completed agent trace as event dicts. + + For BYOD (bring-your-own-data) use cases where the agent trace is + already complete, this returns all events formatted as a list of + dicts compatible with ``_process_single_turn_agent_response``. The + last element is the final agent response; preceding elements become + intermediate events. + """ + event_dicts: list[dict[str, Any]] = [] + for turn in agent_data.turns or []: + if not turn.events: + continue + for event in turn.events: + d: dict[str, Any] = {"author": event.author or "agent"} + if event.content: + d[CONTENT] = event.content.model_dump(exclude_none=True) + event_dicts.append(d) + return event_dicts + + def _resolve_dataset( api_client: BaseApiClient, dataset: Union[types.EvaluationRunDataSource, types.EvaluationDataset], @@ -562,19 +616,18 @@ def _execute_inference_concurrently( ) and AGENT_DATA in prompt_dataset.columns primary_prompt_column: Optional[str] = None - if not has_agent_data: - if "request" in prompt_dataset.columns: - primary_prompt_column = "request" - elif "prompt" in prompt_dataset.columns: - primary_prompt_column = "prompt" - elif "starting_prompt" in prompt_dataset.columns: - primary_prompt_column = "starting_prompt" - else: - raise ValueError( - "Dataset must contain either 'prompt', 'request', or" - " 'starting_prompt'." - f" Found: {prompt_dataset.columns.tolist()}" - ) + if "request" in prompt_dataset.columns: + primary_prompt_column = "request" + elif "prompt" in prompt_dataset.columns: + primary_prompt_column = "prompt" + elif "starting_prompt" in prompt_dataset.columns: + primary_prompt_column = "starting_prompt" + elif not has_agent_data: + raise ValueError( + "Dataset must contain either 'prompt', 'request', or" + " 'starting_prompt'." + f" Found: {prompt_dataset.columns.tolist()}" + ) max_workers = AGENT_MAX_WORKERS if agent_engine or agent else MAX_WORKERS with tqdm(total=len(prompt_dataset), desc=progress_desc) as pbar: @@ -591,21 +644,41 @@ def _execute_inference_concurrently( agent_data_obj = types.evals.AgentData.model_validate( agent_data_obj ) - last_user_content, _ = _extract_prompt_from_agent_data( - agent_data_obj - ) - contents = _evals_data_converters._get_content_text( - last_user_content - ) + if _is_n_plus_1_inference(agent_data_obj): + last_user_content, _ = _extract_prompt_from_agent_data( + agent_data_obj + ) + contents = _evals_data_converters._get_content_text( + last_user_content + ) + else: + logger.info( + "Row %s has a completed agent trace" + " (last event is not from user)." + " Skipping inference and using existing" + " agent response.", + index, + ) + responses[index] = _extract_response_from_completed_trace( + agent_data_obj + ) + pbar.update(1) + continue else: + if primary_prompt_column is None: + raise ValueError( + "Row has no agent_data and dataset has no" + " 'prompt', 'request', or 'starting_prompt'" + " column." + ) request_dict_or_raw_text = row[primary_prompt_column] contents = _extract_contents_for_inference( request_dict_or_raw_text ) except ValueError as e: error_message = ( - f"Failed to extract contents for prompt at index {index}: {e}. " - "Skipping prompt." + f"Failed to extract contents for prompt at index" + f" {index}: {e}. Skipping prompt." ) logger.error(error_message) responses[index] = {"error": error_message} @@ -2136,29 +2209,55 @@ def _execute_agent_run_with_retry( max_retries: int = 3, ) -> Union[list[dict[str, Any]], dict[str, Any]]: """Executes agent run over agent engine for a single prompt.""" - # TODO(b/507976585): Support agent_data history replay for Agent Engine - # sessions. Requires appending history events to the remote session via - # the Sessions API before calling stream_query. - if AGENT_DATA in row.index and row.get(AGENT_DATA) is not None: - raise NotImplementedError( - "Conversation history replay from agent_data is not yet supported" - " for remote Agent Engine inference. Use a local ADK agent" - " (LlmAgent) instead, or provide a DataFrame with a 'prompt'" - " column." - ) try: - session_inputs = _get_session_inputs(row) - user_id = session_inputs.user_id - session_state = session_inputs.state + if "session_inputs" in row.index and row.get("session_inputs") is not None: + session_inputs = _get_session_inputs(row) + user_id = session_inputs.user_id or str(uuid.uuid4()) + session_state = session_inputs.state + else: + user_id = str(uuid.uuid4()) + session_state = None + except KeyError as e: + return {"error": f"Failed to get all required agent engine inputs: {e}"} + + try: session_id = _create_agent_engine_session( agent_engine=agent_engine, user_id=user_id, session_state=session_state, ) - except KeyError as e: - return {"error": f"Failed to get all required agent engine inputs: {e}"} - except Exception as e: - return {"error": f"Failed to create a new session : {e}"} + except Exception as e: # pylint: disable=broad-exception-caught + return {"error": f"Failed to create a new session: {e}"} + + # Pre-populate remote session with agent_data history (N+1 case only). + if ( + AGENT_DATA in row.index + and row.get(AGENT_DATA) is not None + and _is_n_plus_1_inference(row[AGENT_DATA]) + ): + agent_data_obj = row[AGENT_DATA] + if isinstance(agent_data_obj, dict): + agent_data_obj = types.evals.AgentData.model_validate(agent_data_obj) + _, history_events = _extract_prompt_from_agent_data(agent_data_obj) + + if agent_engine.api_resource is None: + return {"error": "agent_engine.api_resource is None."} + if agent_engine.api_client is None: + return {"error": "agent_engine.api_client is None."} + session_name = f"{agent_engine.api_resource.name}/sessions/{session_id}" + base_ts = datetime.datetime(2000, 1, 1, tzinfo=datetime.timezone.utc) + for i, ag_event in enumerate(history_events): + agent_engine.api_client.sessions.events.append( + name=session_name, + author=ag_event.author or "user", + invocation_id="history", + timestamp=base_ts + datetime.timedelta(seconds=i), + config=types.AppendAgentEngineSessionEventConfig( + content=ag_event.content, + ), + ) + + # stream_query retry loop (shared for both agent_data and prompt paths). for attempt in range(max_retries): try: responses = [] @@ -2184,12 +2283,11 @@ def _execute_agent_run_with_retry( time.sleep(2**attempt) except Exception as e: # pylint: disable=broad-exception-caught logger.error( - "Unexpected error during generate_content on attempt %d/%d: %s", + "Unexpected error during agent engine run on attempt %d/%d: %s", attempt + 1, max_retries, e, ) - if attempt == max_retries - 1: return {"error": f"Failed after retries: {e}"} time.sleep(1) @@ -2232,104 +2330,48 @@ async def _execute_local_agent_run_with_retry_async( logger.error("Multi-turn agent run with user simulation failed: %s", e) return {"error": f"Multi-turn agent run with user simulation failed: {e}"} - # Agent data with conversation history — pre-populate the session with - # prior turns so the agent sees them as context, then send the last user - # message as the new query. - if AGENT_DATA in row.index and row.get(AGENT_DATA) is not None: + if "session_inputs" in row.index and row.get("session_inputs") is not None: + session_inputs = _get_session_inputs(row) + user_id = session_inputs.user_id or str(uuid.uuid4()) + app_name = session_inputs.app_name or "local agent run" + else: + user_id = str(uuid.uuid4()) + app_name = "local agent run" + session_id = str(uuid.uuid4()) + + session_service = InMemorySessionService() + await session_service.create_session( + app_name=app_name, user_id=user_id, session_id=session_id + ) + + # Pre-populate session with agent_data history (N+1 case only). + if ( + AGENT_DATA in row.index + and row.get(AGENT_DATA) is not None + and _is_n_plus_1_inference(row[AGENT_DATA]) + ): from google.adk.events.event import Event as AdkEvent agent_data_obj = row[AGENT_DATA] if isinstance(agent_data_obj, dict): agent_data_obj = types.evals.AgentData.model_validate(agent_data_obj) - if isinstance(agent_data_obj, types.evals.AgentData) and agent_data_obj.turns: - try: - last_user_content, history_events = _extract_prompt_from_agent_data( - agent_data_obj - ) - except ValueError as e: - return {"error": f"Invalid agent_data for inference: {e}"} - - user_id = str(uuid.uuid4()) - session_id = str(uuid.uuid4()) - app_name = "local agent run" - if "session_inputs" in row.index and row.get("session_inputs") is not None: - session_inputs = _get_session_inputs(row) - user_id = session_inputs.user_id or user_id - app_name = session_inputs.app_name or app_name - session_service = InMemorySessionService() - await session_service.create_session( - app_name=app_name, user_id=user_id, session_id=session_id - ) - - # Pre-populate session with history events so the agent sees - # prior conversation context. - internal_session = session_service.sessions[app_name][user_id][session_id] - for ag_event in history_events: - adk_event = AdkEvent( - author=ag_event.author or "user", - content=ag_event.content, - invocation_id="history", - ) - internal_session.events.append(adk_event) - - agent_runner = Runner( - agent=agent, app_name=app_name, session_service=session_service + _, history_events = _extract_prompt_from_agent_data(agent_data_obj) + internal_session = session_service.sessions[app_name][user_id][session_id] + for ag_event in history_events: + adk_event = AdkEvent( + author=ag_event.author or "user", + content=ag_event.content, + invocation_id="history", ) + internal_session.events.append(adk_event) - with _temp_logger_level("google_genai.types", logging.ERROR): - for attempt in range(max_retries): - try: - events = [] - async for event in agent_runner.run_async( - user_id=user_id, - session_id=session_id, - new_message=last_user_content, - ): - if event: - event = event.model_dump(exclude_none=True) - if event and CONTENT in event and PARTS in event[CONTENT]: - events.append(event) - return events - except api_exceptions.ResourceExhausted as e: - logger.warning( - "Resource Exhausted error on attempt %d/%d: %s." - " Retrying in %s seconds...", - attempt + 1, - max_retries, - e, - 2**attempt, - ) - if attempt == max_retries - 1: - return {"error": f"Resource exhausted after retries: {e}"} - await asyncio.sleep(2**attempt) - except Exception as e: # pylint: disable=broad-exception-caught - logger.error( - "Unexpected error during agent run on attempt %d/%d: %s", - attempt + 1, - max_retries, - e, - ) - if attempt == max_retries - 1: - return {"error": f"Failed after retries: {e}"} - await asyncio.sleep(1) - return { - "error": ( - f"Failed to get agent run results after {max_retries} retries" - ) - } - - session_inputs = _get_session_inputs(row) - user_id = session_inputs.user_id or str(uuid.uuid4()) - session_id = str(uuid.uuid4()) - app_name = session_inputs.app_name or "local agent run" - # TODO: Enable user to set up session service. - session_service = InMemorySessionService() - await session_service.create_session( - app_name=app_name, user_id=user_id, session_id=session_id - ) agent_runner = Runner( agent=agent, app_name=app_name, session_service=session_service ) + new_message_content = genai_types.Content( + role=USER_AUTHOR, + parts=[genai_types.Part(text=contents)], + ) # Avoid printing out warning from agent_runner.run() # WARNING:google_genai.types:Warning: there are non-text parts in the # response: ['function_call'], returning concatenated text result from @@ -2340,10 +2382,6 @@ async def _execute_local_agent_run_with_retry_async( for attempt in range(max_retries): try: events = [] - new_message_content = genai_types.Content( - role=USER_AUTHOR, - parts=[genai_types.Part(text=contents)], - ) async for event in agent_runner.run_async( user_id=user_id, session_id=session_id, @@ -2356,8 +2394,8 @@ async def _execute_local_agent_run_with_retry_async( return events except api_exceptions.ResourceExhausted as e: logger.warning( - "Resource Exhausted error on attempt %d/%d: %s. Retrying in %s" - " seconds...", + "Resource Exhausted error on attempt %d/%d: %s. Retrying" + " in %s seconds...", attempt + 1, max_retries, e, @@ -2368,12 +2406,11 @@ async def _execute_local_agent_run_with_retry_async( await asyncio.sleep(2**attempt) except Exception as e: # pylint: disable=broad-exception-caught logger.error( - "Unexpected error during generate_content on attempt %d/%d: %s", + "Unexpected error during agent run on attempt %d/%d: %s", attempt + 1, max_retries, e, ) - if attempt == max_retries - 1: return {"error": f"Failed after retries: {e}"} await asyncio.sleep(1)