Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 35 additions & 6 deletions src/google/adk/agents/base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,9 +291,13 @@ async def run_async(
if ctx.end_invocation:
return

async with Aclosing(self._run_async_impl(ctx)) as agen:
async for event in agen:
yield event
try:
async with Aclosing(self._run_async_impl(ctx)) as agen:
async for event in agen:
yield event
except Exception as e:
await self._handle_agent_error_callback(ctx, e)
raise

if ctx.end_invocation:
return
Expand Down Expand Up @@ -323,9 +327,13 @@ async def run_live(
if ctx.end_invocation:
return

async with Aclosing(self._run_live_impl(ctx)) as agen:
async for event in agen:
yield event
try:
async with Aclosing(self._run_live_impl(ctx)) as agen:
async for event in agen:
yield event
except Exception as e:
await self._handle_agent_error_callback(ctx, e)
raise

if event := await self._handle_after_agent_callback(ctx):
yield event
Expand Down Expand Up @@ -545,6 +553,27 @@ async def _handle_after_agent_callback(
)
return None

async def _handle_agent_error_callback(
self,
invocation_context: InvocationContext,
error: Exception,
) -> None:
"""Runs the on_agent_error_callback for all plugins.

This is notification-only: the exception is always re-raised by
the caller after this method returns.

Args:
invocation_context: The invocation context for this agent.
error: The exception that escaped agent execution.
"""
callback_context = CallbackContext(invocation_context)
await invocation_context.plugin_manager.run_on_agent_error_callback(
agent=self,
callback_context=callback_context,
error=error,
)

@override
def model_post_init(self, __context: Any) -> None:
self.__set_parent_agent_for_sub_agents()
Expand Down
38 changes: 38 additions & 0 deletions src/google/adk/plugins/base_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,3 +370,41 @@ async def on_tool_error_callback(
allows the original error to be raised.
"""
pass

async def on_agent_error_callback(
self,
*,
agent: BaseAgent,
callback_context: CallbackContext,
error: Exception,
) -> None:
"""Callback executed when an unhandled exception escapes agent execution.

This is a notification-only callback. The exception is always re-raised
after all registered plugins have been notified. Plugins should NOT
suppress the exception.

Args:
agent: The agent instance that encountered the error.
callback_context: The callback context for the agent invocation.
error: The exception that was raised during agent execution.
"""
pass

async def on_run_error_callback(
self,
*,
invocation_context: InvocationContext,
error: Exception,
) -> None:
"""Callback executed when an unhandled exception escapes runner execution.

This is a notification-only callback. The exception is always re-raised
after all registered plugins have been notified. Plugins should NOT
suppress the exception.

Args:
invocation_context: The context for the entire invocation.
error: The exception that was raised during runner execution.
"""
pass
103 changes: 103 additions & 0 deletions src/google/adk/plugins/bigquery_agent_analytics_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import logging
import mimetypes
import os
import traceback as traceback_module

# Enable gRPC fork support so child processes created via os.fork()
# can safely create new gRPC channels. Must be set before grpc's
Expand Down Expand Up @@ -1870,8 +1871,15 @@ def _get_events_schema() -> list[bigquery.SchemaField]:
"AGENT_COMPLETED": [
"CAST(JSON_VALUE(latency_ms, '$.total_ms') AS INT64) AS total_ms",
],
"AGENT_ERROR": [
"CAST(JSON_VALUE(latency_ms, '$.total_ms') AS INT64) AS total_ms",
"JSON_VALUE(content, '$.error_traceback') AS error_traceback",
],
"INVOCATION_STARTING": [],
"INVOCATION_COMPLETED": [],
"INVOCATION_ERROR": [
"JSON_VALUE(content, '$.error_traceback') AS error_traceback",
],
"STATE_DELTA": [
"JSON_QUERY(attributes, '$.state_delta') AS state_delta",
],
Expand Down Expand Up @@ -3505,3 +3513,98 @@ async def on_tool_error_callback(
parent_span_id_override=parent_span_id,
),
)

@_safe_callback
async def on_agent_error_callback(
self,
*,
agent: Any,
callback_context: CallbackContext,
error: Exception,
) -> None:
"""Callback when an agent execution fails with an unhandled exception.

Emits an AGENT_ERROR event and pops the agent span from
TraceManager.

Args:
agent: The agent instance that failed.
callback_context: The callback context.
error: The exception that escaped agent execution.
"""
span_id, duration = TraceManager.pop_span()
parent_span_id, _ = TraceManager.get_current_span_and_parent()

error_tb = "".join(
traceback_module.format_exception(
type(error), error, error.__traceback__
)
)
max_len = self.config.max_content_length
if max_len > 0 and len(error_tb) > max_len:
error_tb = error_tb[:max_len] + "... [truncated]"

await self._log_event(
"AGENT_ERROR",
callback_context,
event_data=EventData(
status="ERROR",
error_message=str(error),
latency_ms=duration,
span_id_override=span_id,
parent_span_id_override=parent_span_id,
),
raw_content={"error_traceback": error_tb},
)

@_safe_callback
async def on_run_error_callback(
self,
*,
invocation_context: "InvocationContext",
error: Exception,
) -> None:
"""Callback when a runner execution fails with an unhandled exception.

Emits an INVOCATION_ERROR event and performs the cleanup that
after_run_callback would normally do.

Args:
invocation_context: The context of the current invocation.
error: The exception that escaped runner execution.
"""
try:
callback_ctx = CallbackContext(invocation_context)
trace_id = TraceManager.get_trace_id(callback_ctx)

span_id, duration = TraceManager.pop_span()
parent_span_id = TraceManager.get_current_span_id()

error_tb = "".join(
traceback_module.format_exception(
type(error), error, error.__traceback__
)
)
max_len = self.config.max_content_length
if max_len > 0 and len(error_tb) > max_len:
error_tb = error_tb[:max_len] + "... [truncated]"

await self._log_event(
"INVOCATION_ERROR",
callback_ctx,
event_data=EventData(
trace_id_override=trace_id,
status="ERROR",
error_message=str(error),
latency_ms=duration,
span_id_override=span_id,
parent_span_id_override=parent_span_id,
),
raw_content={"error_traceback": error_tb},
)
finally:
# Cleanup must run even if _log_event raises.
TraceManager.clear_stack()
_active_invocation_id_ctx.set(None)
_root_agent_name_ctx.set(None)
await self.flush()
57 changes: 57 additions & 0 deletions src/google/adk/plugins/plugin_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@
"after_model_callback",
"on_tool_error_callback",
"on_model_error_callback",
"on_agent_error_callback",
"on_run_error_callback",
]

logger = logging.getLogger("google_adk." + __name__)
Expand Down Expand Up @@ -306,6 +308,61 @@ async def _run_callbacks(

return None

async def run_on_agent_error_callback(
self,
*,
agent: BaseAgent,
callback_context: CallbackContext,
error: Exception,
) -> None:
"""Runs the `on_agent_error_callback` for all plugins."""
await self._run_notification_callbacks(
"on_agent_error_callback",
agent=agent,
callback_context=callback_context,
error=error,
)

async def run_on_run_error_callback(
self,
*,
invocation_context: InvocationContext,
error: Exception,
) -> None:
"""Runs the `on_run_error_callback` for all plugins."""
await self._run_notification_callbacks(
"on_run_error_callback",
invocation_context=invocation_context,
error=error,
)

async def _run_notification_callbacks(
self, callback_name: PluginCallbackName, **kwargs: Any
) -> None:
"""Executes a notification-only callback for all registered plugins.

Unlike ``_run_callbacks``, this method is best-effort: it always
iterates all plugins regardless of return values or exceptions.
If a plugin's callback raises, the error is logged and iteration
continues so that every plugin gets notified.

Args:
callback_name: The name of the callback method to execute.
**kwargs: Keyword arguments to be passed to the callback method.
"""
for plugin in self.plugins:
callback_method = getattr(plugin, callback_name)
try:
await callback_method(**kwargs)
except Exception as e:
logger.error(
"Error in plugin '%s' during '%s' callback: %s",
plugin.name,
callback_name,
e,
exc_info=True,
)

async def close(self) -> None:
"""Calls the close method on all registered plugins concurrently.

Expand Down
Loading