diff --git a/src/google/adk/plugins/bigquery_agent_analytics_plugin.py b/src/google/adk/plugins/bigquery_agent_analytics_plugin.py index 7cbf931ca9..c56ba30461 100644 --- a/src/google/adk/plugins/bigquery_agent_analytics_plugin.py +++ b/src/google/adk/plugins/bigquery_agent_analytics_plugin.py @@ -2162,11 +2162,13 @@ async def on_user_message_callback( Args: invocation_context: The context of the current invocation. user_message: The message content received from the user. + **kwargs: Additional keyword arguments (e.g., custom attributes). """ await self._log_event( "USER_MESSAGE_RECEIVED", CallbackContext(invocation_context), raw_content=user_message, + **kwargs, ) async def on_state_change_callback( @@ -2197,10 +2199,13 @@ async def before_run_callback( Args: invocation_context: The context of the current invocation. + **kwargs: Additional keyword arguments (e.g., custom attributes). """ await self._ensure_started() await self._log_event( - "INVOCATION_STARTING", CallbackContext(invocation_context) + "INVOCATION_STARTING", + CallbackContext(invocation_context), + **kwargs, ) async def after_run_callback( @@ -2210,9 +2215,12 @@ async def after_run_callback( Args: invocation_context: The context of the current invocation. + **kwargs: Additional keyword arguments (e.g., custom attributes). """ await self._log_event( - "INVOCATION_COMPLETED", CallbackContext(invocation_context) + "INVOCATION_COMPLETED", + CallbackContext(invocation_context), + **kwargs, ) # Ensure all logs are flushed before the agent returns await self.flush() @@ -2225,6 +2233,7 @@ async def before_agent_callback( Args: agent: The agent instance. callback_context: The callback context. + **kwargs: Additional keyword arguments (e.g., custom attributes). """ TraceManager.init_trace(callback_context) TraceManager.push_span(callback_context, "agent") @@ -2232,6 +2241,7 @@ async def before_agent_callback( "AGENT_STARTING", callback_context, raw_content=getattr(agent, "instruction", ""), + **kwargs, ) async def after_agent_callback( @@ -2242,6 +2252,7 @@ async def after_agent_callback( Args: agent: The agent instance. callback_context: The callback context. + **kwargs: Additional keyword arguments (e.g., custom attributes). """ span_id, duration = TraceManager.pop_span() # When popping, the current stack now points to parent. @@ -2255,6 +2266,7 @@ async def after_agent_callback( latency_ms=duration, span_id_override=span_id, parent_span_id_override=parent_span_id, + **kwargs, ) async def before_model_callback( @@ -2436,7 +2448,7 @@ async def on_model_error_callback( Args: callback_context: The callback context. error: The exception that occurred. - **kwargs: Additional arguments. + **kwargs: Additional keyword arguments (e.g., custom attributes). """ span_id, duration = TraceManager.pop_span() parent_span_id, _ = TraceManager.get_current_span_and_parent() @@ -2447,6 +2459,7 @@ async def on_model_error_callback( latency_ms=duration, span_id_override=span_id, parent_span_id_override=parent_span_id, + **kwargs, ) async def before_tool_callback( @@ -2463,6 +2476,7 @@ async def before_tool_callback( tool: The tool being executed. tool_args: The arguments passed to the tool. tool_context: The tool context. + **kwargs: Additional keyword arguments (e.g., custom attributes). """ args_truncated, is_truncated = _recursive_smart_truncate( tool_args, self.config.max_content_length @@ -2474,6 +2488,7 @@ async def before_tool_callback( tool_context, raw_content=content_dict, is_truncated=is_truncated, + **kwargs, ) async def after_tool_callback( @@ -2492,6 +2507,7 @@ async def after_tool_callback( tool_args: The arguments passed to the tool. tool_context: The tool context. result: The response from the tool. + **kwargs: Additional keyword arguments (e.g., custom attributes). """ resp_truncated, is_truncated = _recursive_smart_truncate( result, self.config.max_content_length @@ -2508,6 +2524,7 @@ async def after_tool_callback( latency_ms=duration, span_id_override=span_id, parent_span_id_override=parent_span_id, + **kwargs, ) if tool_context.actions.state_delta: @@ -2533,7 +2550,7 @@ async def on_tool_error_callback( tool_args: The arguments passed to the tool. tool_context: The tool context. error: The exception that occurred. - **kwargs: Additional arguments. + **kwargs: Additional keyword arguments (e.g., custom attributes). """ args_truncated, is_truncated = _recursive_smart_truncate( tool_args, self.config.max_content_length @@ -2547,4 +2564,5 @@ async def on_tool_error_callback( error_message=str(error), is_truncated=is_truncated, latency_ms=duration, + **kwargs, )