diff --git a/src/google/adk/plugins/bigquery_agent_analytics_plugin.py b/src/google/adk/plugins/bigquery_agent_analytics_plugin.py index 7cbf931ca9..4ad20f3318 100644 --- a/src/google/adk/plugins/bigquery_agent_analytics_plugin.py +++ b/src/google/adk/plugins/bigquery_agent_analytics_plugin.py @@ -72,6 +72,27 @@ ) +def _safe_callback(func): + """Decorator that catches and logs exceptions in plugin callbacks. + + Prevents plugin errors from propagating to the runner and crashing + the agent run. All callback exceptions are logged and swallowed. + """ + + @functools.wraps(func) + async def wrapper(self, **kwargs): + try: + return await func(self, **kwargs) + except Exception: + logger.exception( + "BigQuery analytics plugin error in %s; skipping.", + func.__name__, + ) + return None + + return wrapper + + # gRPC Error Codes _GRPC_DEADLINE_EXCEEDED = 4 _GRPC_INTERNAL = 13 @@ -423,31 +444,44 @@ class BigQueryLoggerConfig: _root_agent_name_ctx = contextvars.ContextVar( "_bq_analytics_root_agent_name", default=None ) -_span_stack_ctx: contextvars.ContextVar[list[trace.Span]] = ( - contextvars.ContextVar("_bq_analytics_span_stack", default=None) -) -_span_token_stack_ctx: contextvars.ContextVar[list[trace.Token]] = ( - contextvars.ContextVar("_bq_analytics_span_token_stack", default=None) -) -_span_first_token_times_ctx: contextvars.ContextVar[dict[str, float]] = ( - contextvars.ContextVar("_bq_analytics_span_first_token_times", default=None) -) -_span_map_ctx: contextvars.ContextVar[dict[str, trace.Span]] = ( - contextvars.ContextVar("_bq_analytics_span_map", default=None) -) -_span_id_stack_ctx: contextvars.ContextVar[list[str]] = contextvars.ContextVar( - "_bq_analytics_span_id_stack", default=None -) -_span_start_time_ctx: contextvars.ContextVar[dict[str, int]] = ( - contextvars.ContextVar("_bq_analytics_span_start_time", default=None) -) -_span_ownership_stack_ctx: contextvars.ContextVar[list[bool]] = ( - contextvars.ContextVar("_bq_analytics_span_ownership_stack", default=None) + + +@dataclass +class _SpanRecord: + """A single record on the unified span stack. + + Consolidates span, token, id, ownership, and timing into one object + so all stacks stay in sync by construction. + """ + + span: trace.Span + token: Any # opentelemetry context token + span_id: str + owns_span: bool + start_time_ns: int + first_token_time: Optional[float] = None + + +_span_records_ctx: contextvars.ContextVar[list[_SpanRecord]] = ( + contextvars.ContextVar("_bq_analytics_span_records", default=None) ) class TraceManager: - """Manages OpenTelemetry-style trace and span context using contextvars.""" + """Manages OpenTelemetry-style trace and span context using contextvars. + + Uses a single stack of _SpanRecord objects to keep span, token, ID, + ownership, and timing in sync by construction. + """ + + @staticmethod + def _get_records() -> list[_SpanRecord]: + """Returns the current records stack, initializing if needed.""" + records = _span_records_ctx.get() + if records is None: + records = [] + _span_records_ctx.set(records) + return records @staticmethod def init_trace(callback_context: CallbackContext) -> None: @@ -458,29 +492,19 @@ def init_trace(callback_context: CallbackContext) -> None: except (AttributeError, ValueError): pass - if _span_first_token_times_ctx.get() is None: - _span_first_token_times_ctx.set({}) - - if _span_map_ctx.get() is None: - _span_map_ctx.set({}) - - if _span_start_time_ctx.get() is None: - _span_start_time_ctx.set({}) - - if _span_ownership_stack_ctx.get() is None: - _span_ownership_stack_ctx.set([]) + # Ensure records stack is initialized + TraceManager._get_records() @staticmethod def get_trace_id(callback_context: CallbackContext) -> Optional[str]: """Gets the trace ID from the current span or invocation_id.""" - # Prefer internal stack if available - stack = _span_stack_ctx.get() - if stack: - current_span = stack[-1] + records = _span_records_ctx.get() + if records: + current_span = records[-1].span if current_span.get_span_context().is_valid: return format(current_span.get_span_context().trace_id, "032x") - # Fallback to OTel context to satisfy "Trace Context Extraction" requirement + # Fallback to OTel context current_span = trace.get_current_span() if current_span.get_span_context().is_valid: return format(current_span.get_span_context().trace_id, "032x") @@ -496,45 +520,27 @@ def push_span( If OTel is not configured (returning non-recording spans), a UUID fallback is generated to ensure span_id and parent_span_id are populated in logs. """ - # Ensure init_trace logic (root agent name) runs if needed TraceManager.init_trace(callback_context) span = tracer.start_span(span_name) token = context.attach(trace.set_span_in_context(span)) - stack = _span_stack_ctx.get() or [] - new_stack = list(stack) + [span] - _span_stack_ctx.set(new_stack) - - token_stack = _span_token_stack_ctx.get() or [] - new_token_stack = list(token_stack) + [token] - _span_token_stack_ctx.set(new_token_stack) - if span.get_span_context().is_valid: span_id_str = format(span.get_span_context().span_id, "016x") else: - # Fallback: Generate a UUID-based ID if OTel span is invalid (NoOp) - # using 32-char hex to avoid collision, treated as string in BQ. span_id_str = uuid.uuid4().hex - id_stack = _span_id_stack_ctx.get() or [] - new_id_stack = list(id_stack) + [span_id_str] - _span_id_stack_ctx.set(new_id_stack) - - span_map = _span_map_ctx.get() or {} - new_span_map = span_map.copy() - new_span_map[span_id_str] = span - _span_map_ctx.set(new_span_map) - - # Record start time manually for fallback support (NoOpSpan lacks start_time) - start_times = _span_start_time_ctx.get() or {} - new_start_times = start_times.copy() - new_start_times[span_id_str] = time.time_ns() - _span_start_time_ctx.set(new_start_times) + record = _SpanRecord( + span=span, + token=token, + span_id=span_id_str, + owns_span=True, + start_time_ns=time.time_ns(), + ) - ownership_stack = _span_ownership_stack_ctx.get() or [] - new_ownership_stack = list(ownership_stack) + [True] - _span_ownership_stack_ctx.set(new_ownership_stack) + records = TraceManager._get_records() + new_records = list(records) + [record] + _span_records_ctx.set(new_records) return span_id_str @@ -545,137 +551,75 @@ def attach_current_span( """Attaches the current OTEL span to the stack without owning it.""" TraceManager.init_trace(callback_context) - # Get current span but don't start a new one span = trace.get_current_span() - # We still need to attach it to context to keep stacks symmetric with token token = context.attach(trace.set_span_in_context(span)) - stack = _span_stack_ctx.get() or [] - new_stack = list(stack) + [span] - _span_stack_ctx.set(new_stack) - - token_stack = _span_token_stack_ctx.get() or [] - new_token_stack = list(token_stack) + [token] - _span_token_stack_ctx.set(new_token_stack) - if span.get_span_context().is_valid: span_id_str = format(span.get_span_context().span_id, "016x") else: - # Fallback: Generate a UUID-based ID if OTel span is invalid (NoOp) span_id_str = uuid.uuid4().hex - id_stack = _span_id_stack_ctx.get() or [] - new_id_stack = list(id_stack) + [span_id_str] - _span_id_stack_ctx.set(new_id_stack) - - span_map = _span_map_ctx.get() or {} - new_span_map = span_map.copy() - new_span_map[span_id_str] = span - _span_map_ctx.set(new_span_map) + record = _SpanRecord( + span=span, + token=token, + span_id=span_id_str, + owns_span=False, + start_time_ns=time.time_ns(), + ) - ownership_stack = _span_ownership_stack_ctx.get() or [] - new_ownership_stack = list(ownership_stack) + [False] - _span_ownership_stack_ctx.set(new_ownership_stack) + records = TraceManager._get_records() + new_records = list(records) + [record] + _span_records_ctx.set(new_records) return span_id_str @staticmethod def pop_span() -> tuple[Optional[str], Optional[int]]: """Ends the current span and pops it from the stack.""" - stack = _span_stack_ctx.get() - token_stack = _span_token_stack_ctx.get() - - if not stack or not token_stack: + records = _span_records_ctx.get() + if not records: return None, None - new_stack = list(stack) - new_token_stack = list(token_stack) + new_records = list(records) + record = new_records.pop() + _span_records_ctx.set(new_records) - span = new_stack.pop() - token = new_token_stack.pop() + # Calculate duration + duration_ms = None + otel_start = getattr(record.span, "start_time", None) + if isinstance(otel_start, (int, float)) and otel_start: + duration_ms = int((time.time_ns() - otel_start) / 1_000_000) + else: + duration_ms = int((time.time_ns() - record.start_time_ns) / 1_000_000) - _span_stack_ctx.set(new_stack) - _span_token_stack_ctx.set(new_token_stack) + if record.owns_span: + record.span.end() - # Pop from ID stack regarding fallback support - id_stack = _span_id_stack_ctx.get() - if id_stack: - new_id_stack = list(id_stack) - span_id = new_id_stack.pop() - _span_id_stack_ctx.set(new_id_stack) - else: - # Should not happen if stacks are in sync, but robust fallback: - if span.get_span_context().is_valid: - span_id = format(span.get_span_context().span_id, "016x") - else: - span_id = "unknown-id" + context.detach(record.token) - duration_ms = None - # Try getting start time from OTel span first, then fallback to manual tracking - if hasattr(span, "start_time") and span.start_time: - duration_ms = int((time.time_ns() - span.start_time) / 1_000_000) - else: - start_times = _span_start_time_ctx.get() - if start_times and span_id in start_times: - start_ns = start_times[span_id] - duration_ms = int((time.time_ns() - start_ns) / 1_000_000) - - should_end = True - ownership_stack = _span_ownership_stack_ctx.get() - if ownership_stack: - new_ownership_stack = list(ownership_stack) - should_end = new_ownership_stack.pop() - _span_ownership_stack_ctx.set(new_ownership_stack) - - if should_end: - span.end() - - context.detach(token) - - first_tokens = _span_first_token_times_ctx.get() - if first_tokens: - # Copy to modify - new_first_tokens = first_tokens.copy() - new_first_tokens.pop(span_id, None) - _span_first_token_times_ctx.set(new_first_tokens) - - span_map = _span_map_ctx.get() - if span_map: - new_span_map = span_map.copy() - new_span_map.pop(span_id, None) - _span_map_ctx.set(new_span_map) - - start_times = _span_start_time_ctx.get() - if start_times: - new_start_times = start_times.copy() - new_start_times.pop(span_id, None) - _span_start_time_ctx.set(new_start_times) - - return span_id, duration_ms + return record.span_id, duration_ms @staticmethod def get_current_span_and_parent() -> tuple[Optional[str], Optional[str]]: - """Gets current span_id and parent span_id from OTEL context or fallback stack.""" - # Use internal ID stack for robust resolution (handling both OTel and fallback IDs) - id_stack = _span_id_stack_ctx.get() - if id_stack: - span_id = id_stack[-1] - parent_id = None - # Walk backwards to find a different span_id for parent - for i in range(len(id_stack) - 2, -1, -1): - if id_stack[i] != span_id: - parent_id = id_stack[i] - break - return span_id, parent_id + """Gets current span_id and parent span_id.""" + records = _span_records_ctx.get() + if not records: + return None, None - return None, None + span_id = records[-1].span_id + parent_id = None + for i in range(len(records) - 2, -1, -1): + if records[i].span_id != span_id: + parent_id = records[i].span_id + break + return span_id, parent_id @staticmethod def get_current_span_id() -> Optional[str]: - """Gets current span_id from OTEL context or fallback stack.""" - id_stack = _span_id_stack_ctx.get() - if id_stack: - return id_stack[-1] + """Gets current span_id.""" + records = _span_records_ctx.get() + if records: + return records[-1].span_id return None @staticmethod @@ -685,41 +629,43 @@ def get_root_agent_name() -> Optional[str]: @staticmethod def get_start_time(span_id: str) -> Optional[float]: """Gets start time of a span by ID.""" - # Try OTel Object first - span_map = _span_map_ctx.get() - if span_map: - span = span_map.get(span_id) - if ( - span - and span.get_span_context().is_valid - and hasattr(span, "start_time") - ): - return span.start_time / 1_000_000_000.0 - - # Fallback to manual start time - start_times = _span_start_time_ctx.get() - if start_times and span_id in start_times: - return start_times[span_id] / 1_000_000_000.0 - + records = _span_records_ctx.get() + if records: + for record in reversed(records): + if record.span_id == span_id: + # Try OTel span start_time first + otel_start = getattr(record.span, "start_time", None) + if ( + record.span.get_span_context().is_valid + and isinstance(otel_start, (int, float)) + and otel_start + ): + return otel_start / 1_000_000_000.0 + return record.start_time_ns / 1_000_000_000.0 return None @staticmethod def record_first_token(span_id: str) -> bool: """Records the current time as first token time if not already recorded.""" - first_tokens = _span_first_token_times_ctx.get() - - if span_id not in first_tokens: - new_first_tokens = first_tokens.copy() - new_first_tokens[span_id] = time.time() - _span_first_token_times_ctx.set(new_first_tokens) - return True + records = _span_records_ctx.get() + if records: + for record in reversed(records): + if record.span_id == span_id: + if record.first_token_time is None: + record.first_token_time = time.time() + return True + return False return False @staticmethod def get_first_token_time(span_id: str) -> Optional[float]: """Gets the recorded first token time.""" - first_tokens = _span_first_token_times_ctx.get() - return first_tokens.get(span_id) if first_tokens else None + records = _span_records_ctx.get() + if records: + for record in reversed(records): + if record.span_id == span_id: + return record.first_token_time + return None # ============================================================================== @@ -1566,6 +1512,22 @@ class _LoopState: batch_processor: BatchProcessor +@dataclass +class EventData: + """Typed container for structured fields passed to _log_event.""" + + span_id_override: Optional[str] = None + parent_span_id_override: Optional[str] = None + latency_ms: Optional[int] = None + time_to_first_token_ms: Optional[int] = None + model: Optional[str] = None + model_version: Optional[str] = None + usage_metadata: Any = None + status: str = "OK" + error_message: Optional[str] = None + extra_attributes: dict[str, Any] = field(default_factory=dict) + + class BigQueryAgentAnalyticsPlugin(BasePlugin): """BigQuery Agent Analytics Plugin (v2.0 using Write API). @@ -1619,35 +1581,23 @@ def __init__( self._schema = None self.arrow_schema = None - # API Compatibility: These attributes are statically defined as None to mask the - # dynamic properties from static analysis tools (preventing "breaking changes"), - # while __getattribute__ intercepts instance access to route to the logic. - batch_processor = None - write_client = None - write_stream = None - - def __getattribute__(self, name: str) -> Any: - """Intercepts attribute access to support API masking. - - Args: - name: The name of the attribute being accessed. - - Returns: - The value of the attribute. - """ - if name == "batch_processor": - return self._batch_processor_prop - if name == "write_client": - return self._write_client_prop - if name == "write_stream": - return self._write_stream_prop - return super().__getattribute__(name) + def _cleanup_stale_loop_states(self) -> None: + """Removes entries for event loops that have been closed.""" + stale = [loop for loop in self._loop_state_by_loop if loop.is_closed()] + for loop in stale: + logger.warning( + "Cleaning up stale loop state for closed loop %s (id=%s).", + loop, + id(loop), + ) + del self._loop_state_by_loop[loop] @property - def _batch_processor_prop(self) -> Optional["BatchProcessor"]: - """The batch processor for the current loop (backward compatibility).""" + def batch_processor(self) -> Optional["BatchProcessor"]: + """The batch processor for the current event loop.""" try: loop = asyncio.get_running_loop() + self._cleanup_stale_loop_states() if loop in self._loop_state_by_loop: return self._loop_state_by_loop[loop].batch_processor except RuntimeError: @@ -1655,8 +1605,8 @@ def _batch_processor_prop(self) -> Optional["BatchProcessor"]: return None @property - def _write_client_prop(self) -> Optional["BigQueryWriteAsyncClient"]: - """The write client for the current loop (backward compatibility).""" + def write_client(self) -> Optional["BigQueryWriteAsyncClient"]: + """The write client for the current event loop.""" try: loop = asyncio.get_running_loop() if loop in self._loop_state_by_loop: @@ -1666,9 +1616,9 @@ def _write_client_prop(self) -> Optional["BigQueryWriteAsyncClient"]: return None @property - def _write_stream_prop(self) -> Optional[str]: - """The write stream for the current loop (backward compatibility).""" - bp = self._batch_processor_prop + def write_stream(self) -> Optional[str]: + """The write stream for the current event loop.""" + bp = self.batch_processor return bp.write_stream if bp else None def _format_content_safely( @@ -1700,19 +1650,11 @@ async def _get_loop_state(self) -> _LoopState: The loop-specific state object containing clients and processors. """ loop = asyncio.get_running_loop() + self._cleanup_stale_loop_states() if loop in self._loop_state_by_loop: return self._loop_state_by_loop[loop] - # We DO NOT use the global client approach for multi-loop safety simpler - # or we must ensure _GLOBAL_WRITE_CLIENT usage is safe. - # The original code had a _GLOBAL_WRITE_CLIENT. - # If we want to reuse it, we must be careful. - # actually, _GLOBAL_WRITE_CLIENT is created in *A* loop. - # It cannot be shared across loops if it uses loop primitives. - # So strictly speaking, we should create a new client per loop. - # OR we assume the global client is thread-safe? - # grpc.aio clients are generally loop-bound. - # SAFE approach: Create one client per loop. + # grpc.aio clients are loop-bound, so we create one per event loop. def get_credentials(): creds, project_id = google.auth.default( @@ -1723,7 +1665,7 @@ def get_credentials(): creds, project_id = await loop.run_in_executor( self._executor, get_credentials ) - quota_project_id = getattr(creds, "quota_project_id", None) or project_id + quota_project_id = getattr(creds, "quota_project_id", None) options = ( client_options.ClientOptions(quota_project_id=quota_project_id) if quota_project_id @@ -1739,10 +1681,7 @@ def get_credentials(): client_options=options, ) - # Use the resolved write stream name if not self._write_stream_name: - # Should be set in _lazy_setup or we set it here if missing? - # _lazy_setup guarantees self.table_id etc are ready. self._write_stream_name = f"projects/{self.project_id}/datasets/{self.dataset_id}/tables/{self.table_id}/_default" batch_processor = BatchProcessor( @@ -1771,6 +1710,7 @@ async def flush(self) -> None: """ try: loop = asyncio.get_running_loop() + self._cleanup_stale_loop_states() if loop in self._loop_state_by_loop: await self._loop_state_by_loop[loop].batch_processor.flush() except RuntimeError: @@ -1825,62 +1765,36 @@ async def _lazy_setup(self, **kwargs) -> None: @staticmethod def _atexit_cleanup(batch_processor: "BatchProcessor") -> None: - """Clean up batch processor on script exit.""" + """Clean up batch processor on script exit. + Drains any remaining items from the queue and logs a warning. + Callers should use ``flush()`` before shutdown to ensure all + events are written; this handler only reports data that would + otherwise be silently lost. + """ try: if not batch_processor or batch_processor._shutdown: return except ReferenceError: return - # Emergency Flush: Rescue any logs remaining in the queue - remaining_items = [] + # Drain remaining items and warn — creating a new event loop and + # BQ client at interpreter exit is fragile and masks shutdown bugs. + remaining = 0 try: while True: - remaining_items.append(batch_processor._queue.get_nowait()) + batch_processor._queue.get_nowait() + remaining += 1 except (asyncio.QueueEmpty, AttributeError): pass - if remaining_items: - # We need a new loop and client to flush these - async def rescue_flush(): - client = None - try: - # Create a short-lived client just for this flush - try: - # Note: This relies on google.auth.default() working in this context. - # pylint: disable=g-import-not-at-top - from google.cloud.bigquery_storage_v1.services.big_query_write.async_client import BigQueryWriteAsyncClient - - # pylint: enable=g-import-not-at-top - client = BigQueryWriteAsyncClient() - except Exception as e: - logger.warning("Could not create rescue client: %s", e) - return - - # Patch batch_processor.write_client temporarily - old_client = batch_processor.write_client - batch_processor.write_client = client - try: - # Force a write - await batch_processor._write_rows_with_retry(remaining_items) - logger.info("Rescued logs flushed successfully.") - except Exception as e: - logger.error("Failed to flush rescued logs: %s", e) - finally: - batch_processor.write_client = old_client - except Exception as e: - logger.error("Rescue flush failed: %s", e) - finally: - if client: - await client.transport.close() - - try: - loop = asyncio.new_event_loop() - loop.run_until_complete(rescue_flush()) - loop.close() - except Exception as e: - logger.error("Failed to run rescue loop: %s", e) + if remaining: + logger.warning( + "%d analytics event(s) were still queued at interpreter exit " + "and could not be flushed. Call plugin.flush() before shutdown " + "to avoid data loss.", + remaining, + ) def _ensure_schema_exists(self) -> None: """Ensures the BigQuery table exists with the correct schema.""" @@ -1991,13 +1905,98 @@ async def _ensure_started(self, **kwargs) -> None: except Exception as e: logger.error("Failed to initialize BigQuery Plugin: %s", e) + @staticmethod + def _resolve_span_ids( + event_data: EventData, + ) -> tuple[str, str]: + """Reads span/parent overrides from EventData, falling back to TraceManager. + + Returns: + (span_id, parent_span_id) + """ + current_span_id, current_parent_span_id = ( + TraceManager.get_current_span_and_parent() + ) + + span_id = current_span_id + if event_data.span_id_override is not None: + span_id = event_data.span_id_override + + parent_span_id = current_parent_span_id + if event_data.parent_span_id_override is not None: + parent_span_id = event_data.parent_span_id_override + + return span_id, parent_span_id + + @staticmethod + def _extract_latency( + event_data: EventData, + ) -> dict[str, Any] | None: + """Reads latency fields from EventData and returns a latency dict (or None). + + Returns: + A dict with ``total_ms`` and/or ``time_to_first_token_ms``, or + *None* if neither was present. + """ + latency_json: dict[str, Any] = {} + if event_data.latency_ms is not None: + latency_json["total_ms"] = event_data.latency_ms + if event_data.time_to_first_token_ms is not None: + latency_json["time_to_first_token_ms"] = event_data.time_to_first_token_ms + return latency_json or None + + def _enrich_attributes( + self, + event_data: EventData, + callback_context: CallbackContext, + ) -> dict[str, Any]: + """Builds the attributes dict from EventData and enrichments. + + Reads ``model``, ``model_version``, and ``usage_metadata`` from + *event_data*, copies ``extra_attributes``, then adds session metadata + and custom tags. + + Returns: + A new dict ready for JSON serialization into the attributes column. + """ + attrs: dict[str, Any] = dict(event_data.extra_attributes) + + attrs["root_agent_name"] = TraceManager.get_root_agent_name() + if event_data.model: + attrs["model"] = event_data.model + if event_data.model_version: + attrs["model_version"] = event_data.model_version + if event_data.usage_metadata: + usage_dict, _ = _recursive_smart_truncate( + event_data.usage_metadata, self.config.max_content_length + ) + if isinstance(usage_dict, dict): + attrs["usage_metadata"] = usage_dict + else: + attrs["usage_metadata"] = event_data.usage_metadata + + if self.config.log_session_metadata and hasattr( + callback_context, "session" + ): + try: + metadata = getattr(callback_context.session, "metadata", None) + if metadata: + attrs["session_metadata"] = metadata + except Exception: + pass + + if self.config.custom_tags: + attrs["custom_tags"] = self.config.custom_tags + + return attrs + async def _log_event( self, event_type: str, callback_context: CallbackContext, raw_content: Any = None, is_truncated: bool = False, - **kwargs, + event_data: Optional[EventData] = None, ) -> None: """Logs an event to BigQuery. @@ -2006,7 +2005,8 @@ async def _log_event( callback_context: The callback context. raw_content: The raw content to log. is_truncated: Whether the content is already truncated. - **kwargs: Additional attributes to log. + event_data: Typed container for structured fields and extra + attributes. Defaults to ``EventData()`` when not provided. """ if not self.config.enabled or self._is_shutting_down: return @@ -2023,6 +2023,9 @@ async def _log_event( if not self._started: return + if event_data is None: + event_data = EventData() + timestamp = datetime.now(timezone.utc) if self.config.content_formatter: try: @@ -2031,98 +2034,24 @@ async def _log_event( logger.warning("Content formatter failed: %s", e) trace_id = TraceManager.get_trace_id(callback_context) - current_span_id, current_parent_span_id = ( - TraceManager.get_current_span_and_parent() - ) - - span_id = current_span_id - if "span_id_override" in kwargs: - val = kwargs.pop("span_id_override") - if val is not None: - span_id = val + span_id, parent_span_id = self._resolve_span_ids(event_data) - parent_span_id = current_parent_span_id - if "parent_span_id_override" in kwargs: - val = kwargs.pop("parent_span_id_override") - if val is not None: - parent_span_id = val - - # Use HybridContentParser if offloader is available, otherwise use default - # Re-initialize parser with current trace/span IDs for GCS pathing - self.parser = HybridContentParser( - self.offloader, - trace_id or "no_trace", - span_id or "no_span", - self.config.max_content_length, - connection_id=self.config.connection_id, - ) + # Update parser's trace/span IDs for GCS pathing (reuse instance) + self.parser.trace_id = trace_id or "no_trace" + self.parser.span_id = span_id or "no_span" content_json, content_parts, parser_truncated = await self.parser.parse( raw_content ) is_truncated = is_truncated or parser_truncated - total_latency = kwargs.get("latency_ms") - tfft = kwargs.get("time_to_first_token_ms") - latency_json = {} - if total_latency is not None: - latency_json["total_ms"] = total_latency - if tfft is not None: - latency_json["time_to_first_token_ms"] = tfft - kwargs.pop("latency_ms", None) - kwargs.pop("time_to_first_token_ms", None) - - # Check if content was truncated by the parser or explicitly passed - # (Already handled by parser_truncated above, but keeping for safety or if other logic added later) - - status = kwargs.pop("status", "OK") - error_message = kwargs.pop("error_message", None) - - # V2 Metadata Extensions - model = kwargs.pop("model", None) - model_version = kwargs.pop("model_version", None) - usage_metadata = kwargs.pop("usage_metadata", None) - - # Add new fields to attributes instead of columns - kwargs["root_agent_name"] = TraceManager.get_root_agent_name() - if model: - kwargs["model"] = model - if model_version: - kwargs["model_version"] = model_version - if usage_metadata: - # Use smart truncate to handle Pydantic, Dataclasses, and other objects - usage_dict, _ = _recursive_smart_truncate( - usage_metadata, self.config.max_content_length - ) - if isinstance(usage_dict, dict): - kwargs["usage_metadata"] = usage_dict - else: - # Fallback if it couldn't be converted to dict - kwargs["usage_metadata"] = usage_metadata - - # 6. Session Metadata - if self.config.log_session_metadata and hasattr( - callback_context, "session" - ): - try: - # Accessing session.metadata might trigger lazy loading or be a property - # Use getattr to safely check for metadata without raising AttributeError - metadata = getattr(callback_context.session, "metadata", None) - if metadata: - kwargs["session_metadata"] = metadata - except Exception: - # Ignore errors if metadata is missing or inaccessible - pass - - # 7. Custom Tags - if self.config.custom_tags: - kwargs["custom_tags"] = self.config.custom_tags + latency_json = self._extract_latency(event_data) + attributes = self._enrich_attributes(event_data, callback_context) - # Serialize remaining kwargs to JSON string for attributes + # Serialize attributes to JSON string try: - attributes_json = json.dumps(kwargs) + attributes_json = json.dumps(attributes) except (TypeError, ValueError): - # Fallback for non-serializable objects - attributes_json = json.dumps(kwargs, default=str) + attributes_json = json.dumps(attributes, default=str) row = { "timestamp": timestamp, @@ -2139,9 +2068,9 @@ async def _log_event( content_parts if self.config.log_multi_modal_content else [] ), "attributes": attributes_json, - "latency_ms": latency_json if latency_json else None, - "status": status, - "error_message": error_message, + "latency_ms": latency_json, + "status": event_data.status, + "error_message": event_data.error_message, "is_truncated": is_truncated, } @@ -2150,12 +2079,12 @@ async def _log_event( # --- UPDATED CALLBACKS FOR V1 PARITY --- + @_safe_callback async def on_user_message_callback( self, *, invocation_context: InvocationContext, user_message: types.Content, - **kwargs, ) -> None: """Parity with V1: Logs USER_MESSAGE_RECEIVED event. @@ -2169,29 +2098,36 @@ async def on_user_message_callback( raw_content=user_message, ) - async def on_state_change_callback( + @_safe_callback + async def on_event_callback( self, *, - callback_context: CallbackContext, - state_delta: dict[str, Any], - **kwargs, + invocation_context: InvocationContext, + event: "Event", ) -> None: - """Logs state changes (state_delta) to BigQuery. + """Logs state changes from events to BigQuery. + + Checks each event for a non-empty state_delta and logs it as a + STATE_DELTA event. This captures state changes from all sources + (tools, agents, LLM, manual), not just tool callbacks. Args: - callback_context: The callback context. - state_delta: The change in state to log. - **kwargs: Additional arguments. + invocation_context: The context for the current invocation. + event: The event raised by the runner. """ - await self._log_event( - "STATE_DELTA", - callback_context, - state_delta=state_delta, - **kwargs, - ) + if event.actions and event.actions.state_delta: + await self._log_event( + "STATE_DELTA", + CallbackContext(invocation_context), + event_data=EventData( + extra_attributes={"state_delta": dict(event.actions.state_delta)} + ), + ) + return None + @_safe_callback async def before_run_callback( - self, *, invocation_context: "InvocationContext", **kwargs + self, *, invocation_context: "InvocationContext" ) -> None: """Callback before the agent run starts. @@ -2200,11 +2136,13 @@ async def before_run_callback( """ await self._ensure_started() await self._log_event( - "INVOCATION_STARTING", CallbackContext(invocation_context) + "INVOCATION_STARTING", + CallbackContext(invocation_context), ) + @_safe_callback async def after_run_callback( - self, *, invocation_context: "InvocationContext", **kwargs + self, *, invocation_context: "InvocationContext" ) -> None: """Callback after the agent run completes. @@ -2212,13 +2150,15 @@ async def after_run_callback( invocation_context: The context of the current invocation. """ await self._log_event( - "INVOCATION_COMPLETED", CallbackContext(invocation_context) + "INVOCATION_COMPLETED", + CallbackContext(invocation_context), ) # Ensure all logs are flushed before the agent returns await self.flush() + @_safe_callback async def before_agent_callback( - self, *, agent: Any, callback_context: CallbackContext, **kwargs + self, *, agent: Any, callback_context: CallbackContext ) -> None: """Callback before an agent starts processing. @@ -2234,8 +2174,9 @@ async def before_agent_callback( raw_content=getattr(agent, "instruction", ""), ) + @_safe_callback async def after_agent_callback( - self, *, agent: Any, callback_context: CallbackContext, **kwargs + self, *, agent: Any, callback_context: CallbackContext ) -> None: """Callback after an agent completes processing. @@ -2252,17 +2193,19 @@ async def after_agent_callback( await self._log_event( "AGENT_COMPLETED", callback_context, - latency_ms=duration, - span_id_override=span_id, - parent_span_id_override=parent_span_id, + event_data=EventData( + latency_ms=duration, + span_id_override=span_id, + parent_span_id_override=parent_span_id, + ), ) + @_safe_callback async def before_model_callback( self, *, callback_context: CallbackContext, llm_request: LlmRequest, - **kwargs, ) -> None: """Callback before LLM call. @@ -2297,10 +2240,6 @@ async def before_model_callback( if val is not None: config_dict[field_name] = val - # Handle labels if present in config - if hasattr(llm_request.config, "labels") and llm_request.config.labels: - attributes["labels"] = llm_request.config.labels - if config_dict: attributes["llm_config"] = config_dict @@ -2310,24 +2249,23 @@ async def before_model_callback( if hasattr(llm_request, "tools_dict") and llm_request.tools_dict: attributes["tools"] = list(llm_request.tools_dict.keys()) - # Merge any additional kwargs into attributes - attributes.update(kwargs) - TraceManager.push_span(callback_context, "llm_request") await self._log_event( "LLM_REQUEST", callback_context, raw_content=llm_request, - model=llm_request.model, - **attributes, + event_data=EventData( + model=llm_request.model, + extra_attributes=attributes, + ), ) + @_safe_callback async def after_model_callback( self, *, callback_context: CallbackContext, llm_response: "LlmResponse", - **kwargs, ) -> None: """Callback after LLM call. @@ -2408,54 +2346,56 @@ async def after_model_callback( # Otherwise log_event will fetch current stack (which is parent). span_id = popped_span_id or span_id - extra_kwargs = {} - if tfft is not None: - extra_kwargs["time_to_first_token_ms"] = tfft - await self._log_event( "LLM_RESPONSE", callback_context, raw_content=content_str, is_truncated=is_truncated, - latency_ms=duration, - model_version=llm_response.model_version, - usage_metadata=llm_response.usage_metadata, - span_id_override=span_id if is_popped else None, - parent_span_id_override=parent_span_id - if is_popped - else None, # Use pre-pop state - **extra_kwargs, - **kwargs, + event_data=EventData( + latency_ms=duration, + time_to_first_token_ms=tfft, + model_version=llm_response.model_version, + usage_metadata=llm_response.usage_metadata, + span_id_override=span_id if is_popped else None, + parent_span_id_override=(parent_span_id if is_popped else None), + ), ) + @_safe_callback async def on_model_error_callback( - self, *, callback_context: CallbackContext, error: Exception, **kwargs + self, + *, + callback_context: CallbackContext, + llm_request: LlmRequest, + error: Exception, ) -> None: """Callback on LLM error. Args: callback_context: The callback context. + llm_request: The request that was sent to the model. error: The exception that occurred. - **kwargs: Additional arguments. """ span_id, duration = TraceManager.pop_span() parent_span_id, _ = TraceManager.get_current_span_and_parent() await self._log_event( "LLM_ERROR", callback_context, - error_message=str(error), - latency_ms=duration, - span_id_override=span_id, - parent_span_id_override=parent_span_id, + event_data=EventData( + error_message=str(error), + latency_ms=duration, + span_id_override=span_id, + parent_span_id_override=parent_span_id, + ), ) + @_safe_callback async def before_tool_callback( self, *, tool: BaseTool, tool_args: dict[str, Any], tool_context: ToolContext, - **kwargs, ) -> None: """Callback before tool execution. @@ -2476,6 +2416,7 @@ async def before_tool_callback( is_truncated=is_truncated, ) + @_safe_callback async def after_tool_callback( self, *, @@ -2483,7 +2424,6 @@ async def after_tool_callback( tool_args: dict[str, Any], tool_context: ToolContext, result: dict[str, Any], - **kwargs, ) -> None: """Callback after tool execution. @@ -2505,18 +2445,14 @@ async def after_tool_callback( tool_context, raw_content=content_dict, is_truncated=is_truncated, - latency_ms=duration, - span_id_override=span_id, - parent_span_id_override=parent_span_id, + event_data=EventData( + latency_ms=duration, + span_id_override=span_id, + parent_span_id_override=parent_span_id, + ), ) - if tool_context.actions.state_delta: - await self._log_event( - "STATE_DELTA", - tool_context, - state_delta=tool_context.actions.state_delta, - ) - + @_safe_callback async def on_tool_error_callback( self, *, @@ -2524,7 +2460,6 @@ async def on_tool_error_callback( tool_args: dict[str, Any], tool_context: ToolContext, error: Exception, - **kwargs, ) -> None: """Callback on tool error. @@ -2533,7 +2468,6 @@ async def on_tool_error_callback( tool_args: The arguments passed to the tool. tool_context: The tool context. error: The exception that occurred. - **kwargs: Additional arguments. """ args_truncated, is_truncated = _recursive_smart_truncate( tool_args, self.config.max_content_length @@ -2544,7 +2478,9 @@ async def on_tool_error_callback( "TOOL_ERROR", tool_context, raw_content=content_dict, - error_message=str(error), is_truncated=is_truncated, - latency_ms=duration, + event_data=EventData( + error_message=str(error), + latency_ms=duration, + ), ) diff --git a/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py b/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py index b11d5659dc..f637ff6b4d 100644 --- a/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py +++ b/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py @@ -22,6 +22,8 @@ from google.adk.agents import base_agent from google.adk.agents import callback_context as callback_context_lib from google.adk.agents import invocation_context as invocation_context_lib +from google.adk.events import event as event_lib +from google.adk.events import event_actions as event_actions_lib from google.adk.models import llm_request as llm_request_lib from google.adk.models import llm_response as llm_response_lib from google.adk.plugins import bigquery_agent_analytics_plugin @@ -1560,9 +1562,10 @@ async def test_after_tool_callback_logs_correctly( assert content_dict["result"] == {"res": "success"} @pytest.mark.asyncio - async def test_after_tool_callback_state_delta_logging( + async def test_after_tool_callback_no_state_delta_logging( self, bq_plugin_inst, mock_write_client, tool_context, dummy_arrow_schema ): + """State deltas are now logged via on_event_callback, not after_tool.""" mock_tool = mock.create_autospec( base_tool_lib.BaseTool, instance=True, spec_set=True ) @@ -1581,57 +1584,65 @@ async def test_after_tool_callback_state_delta_logging( ) await asyncio.sleep(0.01) - # We should have two events appended: TOOL_COMPLETED and STATE_DELTA - assert mock_write_client.append_rows.call_count >= 1 - - # Retrieve all flushed events + # Only TOOL_COMPLETED should be logged; STATE_DELTA is handled + # by on_event_callback now. rows = await _get_captured_rows_async(mock_write_client, dummy_arrow_schema) - assert len(rows) == 2 - - # Sort by event_type to reliably access them - rows.sort(key=lambda x: x["event_type"]) - - state_delta_event = ( - rows[0] if rows[0]["event_type"] == "STATE_DELTA" else rows[1] - ) - tool_event = ( - rows[1] if rows[1]["event_type"] == "TOOL_COMPLETED" else rows[0] - ) - - assert state_delta_event["event_type"] == "STATE_DELTA" - assert tool_event["event_type"] == "TOOL_COMPLETED" - - # Verify STATE_DELTA payload - attributes = json.loads(state_delta_event["attributes"]) - assert "state_delta" in attributes - assert attributes["state_delta"] == {"new_key": "new_value"} - assert state_delta_event["content"] is None + assert len(rows) == 1 + assert rows[0]["event_type"] == "TOOL_COMPLETED" @pytest.mark.asyncio - async def test_on_state_change_callback_logs_correctly( + async def test_on_event_callback_logs_state_delta( self, bq_plugin_inst, mock_write_client, - callback_context, + invocation_context, dummy_arrow_schema, ): + """on_event_callback logs STATE_DELTA for events with state changes.""" state_delta = {"key": "value", "new_key": 123} - bigquery_agent_analytics_plugin.TraceManager.push_span(callback_context) - await bq_plugin_inst.on_state_change_callback( - callback_context=callback_context, state_delta=state_delta + event = event_lib.Event( + author="test_agent", + actions=event_actions_lib.EventActions(state_delta=state_delta), + ) + + bigquery_agent_analytics_plugin.TraceManager.push_span(invocation_context) + result = await bq_plugin_inst.on_event_callback( + invocation_context=invocation_context, event=event ) + # Must return None to not modify the event + assert result is None + await asyncio.sleep(0.01) log_entry = await _get_captured_event_dict_async( mock_write_client, dummy_arrow_schema ) _assert_common_fields(log_entry, "STATE_DELTA") - # content should be None (as raw_content was not passed) assert log_entry["content"] is None - # state_delta should be in attributes attributes = json.loads(log_entry["attributes"]) assert attributes["state_delta"] == state_delta + @pytest.mark.asyncio + async def test_on_event_callback_ignores_empty_state_delta( + self, + bq_plugin_inst, + mock_write_client, + invocation_context, + ): + """on_event_callback should not log when state_delta is empty.""" + event = event_lib.Event( + author="test_agent", + actions=event_actions_lib.EventActions(state_delta={}), + ) + + result = await bq_plugin_inst.on_event_callback( + invocation_context=invocation_context, event=event + ) + assert result is None + + # No events should have been logged + mock_write_client.append_rows.assert_not_called() + @pytest.mark.asyncio async def test_log_event_with_session_metadata( self, @@ -1902,6 +1913,43 @@ async def test_quota_project_id_used_in_client( _, kwargs = mock_bq_write_cls.call_args assert kwargs["client_options"].quota_project_id == "quota-project" + @pytest.mark.asyncio + async def test_no_quota_project_when_creds_lack_it( + self, + mock_bq_client, + mock_to_arrow_schema, + mock_asyncio_to_thread, + ): + """Verify no quota_project_id is set when credentials don't provide one. + + This is critical for Workload Identity Federation flows where setting + quota_project_id on the client breaks auth token refresh (issue #4370). + """ + mock_creds = mock.create_autospec( + google.auth.credentials.Credentials, instance=True, spec_set=True + ) + mock_creds.quota_project_id = None + with mock.patch.object( + google.auth, + "default", + autospec=True, + return_value=(mock_creds, PROJECT_ID), + ): + with mock.patch.object( + bigquery_agent_analytics_plugin, + "BigQueryWriteAsyncClient", + autospec=True, + ) as mock_bq_write_cls: + async with managed_plugin( + project_id=PROJECT_ID, + dataset_id=DATASET_ID, + table_id=TABLE_ID, + ) as plugin: + await plugin._ensure_started() + mock_bq_write_cls.assert_called_once() + _, kwargs = mock_bq_write_cls.call_args + assert kwargs["client_options"] is None + @pytest.mark.asyncio async def test_pickle_safety(self, mock_auth_default, mock_bq_client): """Test that the plugin can be pickled safely.""" @@ -2246,3 +2294,1637 @@ async def test_generation_config_logging( if "labels" in gen_config_kwargs: assert attributes.get("labels") == gen_config_kwargs["labels"] + + +class TestSafeCallbackDecorator: + """Tests that _safe_callback prevents plugin errors from propagating.""" + + @pytest.mark.asyncio + async def test_callback_exception_does_not_propagate( + self, + bq_plugin_inst, + mock_write_client, + invocation_context, + ): + """A callback that throws should return None, not crash.""" + # Force _log_event to raise + with mock.patch.object( + bq_plugin_inst, + "_log_event", + side_effect=RuntimeError("BQ network timeout"), + ): + # Should NOT raise + result = await bq_plugin_inst.on_user_message_callback( + invocation_context=invocation_context, + user_message=types.Content(parts=[types.Part(text="Test")]), + ) + assert result is None + + @pytest.mark.asyncio + async def test_callback_exception_is_logged( + self, + bq_plugin_inst, + mock_write_client, + invocation_context, + ): + """The swallowed exception should be logged with exc_info.""" + with mock.patch.object( + bq_plugin_inst, + "_log_event", + side_effect=RuntimeError("BQ write failed"), + ): + with mock.patch( + "google.adk.plugins.bigquery_agent_analytics_plugin.logger" + ) as mock_logger: + await bq_plugin_inst.before_run_callback( + invocation_context=invocation_context, + ) + mock_logger.exception.assert_called_once_with( + "BigQuery analytics plugin error in %s; skipping.", + "before_run_callback", + ) + + @pytest.mark.asyncio + async def test_subsequent_callbacks_work_after_failure( + self, + bq_plugin_inst, + mock_write_client, + invocation_context, + dummy_arrow_schema, + ): + """After one callback fails, the next one should still work.""" + call_count = 0 + original_log_event = bq_plugin_inst._log_event + + async def fail_once(*args, **kwargs): + nonlocal call_count + call_count += 1 + if call_count == 1: + raise RuntimeError("Transient error") + return await original_log_event(*args, **kwargs) + + with mock.patch.object(bq_plugin_inst, "_log_event", side_effect=fail_once): + # First call fails silently + await bq_plugin_inst.on_user_message_callback( + invocation_context=invocation_context, + user_message=types.Content(parts=[types.Part(text="Fail")]), + ) + # Second call succeeds + bigquery_agent_analytics_plugin.TraceManager.push_span(invocation_context) + await bq_plugin_inst.before_run_callback( + invocation_context=invocation_context, + ) + await asyncio.sleep(0.01) + mock_write_client.append_rows.assert_called_once() + + @pytest.mark.asyncio + async def test_on_event_callback_exception_returns_none( + self, + bq_plugin_inst, + mock_write_client, + invocation_context, + ): + """on_event_callback should return None on error, not crash.""" + event = event_lib.Event( + author="test_agent", + actions=event_actions_lib.EventActions(state_delta={"key": "value"}), + ) + with mock.patch.object( + bq_plugin_inst, + "_log_event", + side_effect=Exception("serialize error"), + ): + result = await bq_plugin_inst.on_event_callback( + invocation_context=invocation_context, event=event + ) + assert result is None + + @pytest.mark.asyncio + async def test_tool_callback_exception_does_not_propagate( + self, + bq_plugin_inst, + mock_write_client, + tool_context, + ): + """Tool callbacks should not crash even if plugin errors.""" + mock_tool = mock.create_autospec( + base_tool_lib.BaseTool, instance=True, spec_set=True + ) + type(mock_tool).name = mock.PropertyMock(return_value="MyTool") + with mock.patch.object( + bq_plugin_inst, + "_log_event", + side_effect=RuntimeError("BQ down"), + ): + # before_tool_callback + result = await bq_plugin_inst.before_tool_callback( + tool=mock_tool, + tool_args={"p": "v"}, + tool_context=tool_context, + ) + assert result is None + + # after_tool_callback + result = await bq_plugin_inst.after_tool_callback( + tool=mock_tool, + tool_args={"p": "v"}, + tool_context=tool_context, + result={"r": "ok"}, + ) + assert result is None + + # on_tool_error_callback + result = await bq_plugin_inst.on_tool_error_callback( + tool=mock_tool, + tool_args={"p": "v"}, + tool_context=tool_context, + error=ValueError("tool broke"), + ) + assert result is None + + @pytest.mark.asyncio + async def test_model_callback_exception_does_not_propagate( + self, + bq_plugin_inst, + mock_write_client, + callback_context, + ): + """Model callbacks should not crash even if plugin errors.""" + with mock.patch.object( + bq_plugin_inst, + "_log_event", + side_effect=RuntimeError("BQ down"), + ): + llm_request = llm_request_lib.LlmRequest( + model="gemini-pro", + contents=[types.Content(parts=[types.Part(text="Hi")])], + ) + result = await bq_plugin_inst.before_model_callback( + callback_context=callback_context, llm_request=llm_request + ) + assert result is None + + llm_response = llm_response_lib.LlmResponse( + content=types.Content(parts=[types.Part(text="Hi")]), + ) + result = await bq_plugin_inst.after_model_callback( + callback_context=callback_context, llm_response=llm_response + ) + assert result is None + + result = await bq_plugin_inst.on_model_error_callback( + callback_context=callback_context, + llm_request=llm_request_lib.LlmRequest(model="gemini-pro"), + error=ValueError("llm error"), + ) + assert result is None + + +class TestParserReuse: + """Tests that HybridContentParser is reused, not recreated per event.""" + + @pytest.mark.asyncio + async def test_parser_instance_is_reused( + self, + bq_plugin_inst, + mock_write_client, + invocation_context, + ): + """The same parser instance should be reused across _log_event calls.""" + parser_after_init = bq_plugin_inst.parser + assert parser_after_init is not None + + bigquery_agent_analytics_plugin.TraceManager.push_span(invocation_context) + await bq_plugin_inst.on_user_message_callback( + invocation_context=invocation_context, + user_message=types.Content(parts=[types.Part(text="Hello")]), + ) + await asyncio.sleep(0.01) + + # Parser should be the same instance, not a new one + assert bq_plugin_inst.parser is parser_after_init + + @pytest.mark.asyncio + async def test_parser_trace_id_updated_per_call( + self, + bq_plugin_inst, + mock_write_client, + invocation_context, + dummy_arrow_schema, + ): + """trace_id and span_id on the parser should update per _log_event.""" + parser = bq_plugin_inst.parser + original_trace_id = parser.trace_id + + bigquery_agent_analytics_plugin.TraceManager.push_span(invocation_context) + await bq_plugin_inst.on_user_message_callback( + invocation_context=invocation_context, + user_message=types.Content(parts=[types.Part(text="Test")]), + ) + await asyncio.sleep(0.01) + + # After logging, trace_id/span_id should have been updated + # (they're derived from TraceManager, not the initial empty strings) + assert parser.span_id != "" + + @pytest.mark.asyncio + async def test_parser_not_recreated_with_constructor( + self, + bq_plugin_inst, + mock_write_client, + invocation_context, + ): + """HybridContentParser constructor should not be called in + _log_event.""" + with mock.patch.object( + bigquery_agent_analytics_plugin, + "HybridContentParser", + wraps=bigquery_agent_analytics_plugin.HybridContentParser, + ) as mock_parser_cls: + bigquery_agent_analytics_plugin.TraceManager.push_span(invocation_context) + await bq_plugin_inst.on_user_message_callback( + invocation_context=invocation_context, + user_message=types.Content(parts=[types.Part(text="Test")]), + ) + await asyncio.sleep(0.01) + # Constructor should NOT have been called during _log_event + mock_parser_cls.assert_not_called() + + +class TestPropertyAccessors: + """Tests that properties work correctly after __getattribute__ removal.""" + + @pytest.mark.asyncio + async def testbatch_processorerty_returns_processor(self, bq_plugin_inst): + """batch_processor property should return the processor for the + current loop.""" + bp = bq_plugin_inst.batch_processor + assert bp is not None + assert isinstance(bp, bigquery_agent_analytics_plugin.BatchProcessor) + + @pytest.mark.asyncio + async def test_write_client_property_returns_client(self, bq_plugin_inst): + """write_client property should return the client for the current + loop.""" + wc = bq_plugin_inst.write_client + assert wc is not None + + @pytest.mark.asyncio + async def test_write_stream_property_returns_stream(self, bq_plugin_inst): + """write_stream property should return the stream name.""" + ws = bq_plugin_inst.write_stream + assert ws is not None + assert ws == DEFAULT_STREAM_NAME + + @pytest.mark.asyncio + async def test_properties_return_none_when_no_loop_state(self): + """Properties should return None when no state exists for the + current loop.""" + plugin = bigquery_agent_analytics_plugin.BigQueryAgentAnalyticsPlugin( + project_id=PROJECT_ID, + dataset_id=DATASET_ID, + table_id=TABLE_ID, + ) + assert plugin.batch_processor is None + assert plugin.write_client is None + assert plugin.write_stream is None + + @pytest.mark.asyncio + async def test_regular_attributes_still_accessible(self, bq_plugin_inst): + """Regular instance attributes should still be accessible.""" + assert bq_plugin_inst.project_id == PROJECT_ID + assert bq_plugin_inst.dataset_id == DATASET_ID + assert bq_plugin_inst.table_id == TABLE_ID + assert bq_plugin_inst.config is not None + assert bq_plugin_inst._started is True + + def test_properties_without_running_loop(self): + """Properties should return None when no event loop is running.""" + plugin = bigquery_agent_analytics_plugin.BigQueryAgentAnalyticsPlugin( + project_id=PROJECT_ID, + dataset_id=DATASET_ID, + table_id=TABLE_ID, + ) + # No running loop → should return None, not crash + assert plugin.batch_processor is None + assert plugin.write_client is None + assert plugin.write_stream is None + + +class TestUnifiedSpanRecords: + """Tests for the unified _SpanRecord-based TraceManager.""" + + @pytest.mark.asyncio + async def test_push_pop_keeps_stacks_in_sync(self, callback_context): + """Push and pop should always leave the records stack consistent.""" + TM = bigquery_agent_analytics_plugin.TraceManager + TM.init_trace(callback_context) + + span_id_1 = TM.push_span(callback_context, "span-1") + span_id_2 = TM.push_span(callback_context, "span-2") + + # Both should be on the stack + assert TM.get_current_span_id() == span_id_2 + current, parent = TM.get_current_span_and_parent() + assert current == span_id_2 + assert parent == span_id_1 + + # Pop span-2 + popped_id, duration = TM.pop_span() + assert popped_id == span_id_2 + assert duration is not None + assert TM.get_current_span_id() == span_id_1 + + # Pop span-1 + popped_id, _ = TM.pop_span() + assert popped_id == span_id_1 + assert TM.get_current_span_id() is None + + @pytest.mark.asyncio + async def test_pop_empty_stack_returns_none(self, callback_context): + """Popping an empty stack should return (None, None).""" + TM = bigquery_agent_analytics_plugin.TraceManager + TM.init_trace(callback_context) + + span_id, duration = TM.pop_span() + assert span_id is None + assert duration is None + + @pytest.mark.asyncio + async def test_first_token_time_stored_in_record(self, callback_context): + """first_token_time should be stored on the span record.""" + TM = bigquery_agent_analytics_plugin.TraceManager + TM.init_trace(callback_context) + + span_id = TM.push_span(callback_context, "llm-span") + + # No first token yet + assert TM.get_first_token_time(span_id) is None + + # Record first token + assert TM.record_first_token(span_id) is True + ftt = TM.get_first_token_time(span_id) + assert ftt is not None + + # Second call should return False (already recorded) + assert TM.record_first_token(span_id) is False + + # Clean up + TM.pop_span() + + @pytest.mark.asyncio + async def test_start_time_accessible_by_span_id(self, callback_context): + """get_start_time should find the span by ID in the records.""" + TM = bigquery_agent_analytics_plugin.TraceManager + TM.init_trace(callback_context) + + span_id = TM.push_span(callback_context, "timed-span") + start = TM.get_start_time(span_id) + assert start is not None + assert start > 0 + + TM.pop_span() + + @pytest.mark.asyncio + async def test_attach_current_span_does_not_own(self, callback_context): + """attach_current_span should not end the span on pop.""" + TM = bigquery_agent_analytics_plugin.TraceManager + TM.init_trace(callback_context) + + mock_span = mock.Mock() + mock_ctx = mock.Mock() + mock_ctx.is_valid = False + mock_span.get_span_context.return_value = mock_ctx + + with mock.patch( + "opentelemetry.trace.get_current_span", return_value=mock_span + ): + span_id = TM.attach_current_span(callback_context) + assert span_id is not None + + TM.pop_span() + # Should NOT have called span.end() since we don't own it + mock_span.end.assert_not_called() + + @pytest.mark.asyncio + async def test_concurrent_tasks_have_isolated_stacks(self, callback_context): + """Concurrent async tasks should have isolated span stacks.""" + TM = bigquery_agent_analytics_plugin.TraceManager + TM.init_trace(callback_context) + + async def task_a(): + s = TM.push_span(callback_context, "task-a") + await asyncio.sleep(0.02) + assert TM.get_current_span_id() == s + TM.pop_span() + return s + + async def task_b(): + s = TM.push_span(callback_context, "task-b") + await asyncio.sleep(0.02) + assert TM.get_current_span_id() == s + TM.pop_span() + return s + + results = await asyncio.gather(task_a(), task_b()) + assert results[0] != results[1] + + @pytest.mark.asyncio + async def test_pop_cleans_up_record_completely(self, callback_context): + """After pop, the record should be fully removed from the stack.""" + TM = bigquery_agent_analytics_plugin.TraceManager + TM.init_trace(callback_context) + + span_id = TM.push_span(callback_context, "temp-span") + + # Record is on the stack + assert TM.get_current_span_id() == span_id + assert TM.get_start_time(span_id) is not None + + TM.pop_span() + + # Record is gone + assert TM.get_current_span_id() is None + assert TM.get_start_time(span_id) is None + assert TM.get_first_token_time(span_id) is None + + +class TestLoopStateValidation: + """Tests for loop state validation and stale loop cleanup.""" + + def _make_plugin(self): + """Creates a plugin instance without starting it.""" + return bigquery_agent_analytics_plugin.BigQueryAgentAnalyticsPlugin( + project_id=PROJECT_ID, + dataset_id=DATASET_ID, + table_id=TABLE_ID, + ) + + def _make_loop_state(self): + """Creates a mock _LoopState with batch_processor and write_client.""" + state = mock.MagicMock() + state.batch_processor = mock.MagicMock( + spec=bigquery_agent_analytics_plugin.BatchProcessor + ) + state.batch_processor.flush = mock.AsyncMock() + state.write_client = mock.MagicMock() + return state + + def test_cleanup_stale_loop_states_removes_closed_loops(self): + """Closed loops should be removed from _loop_state_by_loop.""" + plugin = self._make_plugin() + + closed_loop = mock.MagicMock(spec=asyncio.AbstractEventLoop) + closed_loop.is_closed.return_value = True + + plugin._loop_state_by_loop[closed_loop] = self._make_loop_state() + + plugin._cleanup_stale_loop_states() + + assert closed_loop not in plugin._loop_state_by_loop + + def test_cleanup_stale_loop_states_keeps_open_loops(self): + """Open loops should not be removed from _loop_state_by_loop.""" + plugin = self._make_plugin() + + open_loop = mock.MagicMock(spec=asyncio.AbstractEventLoop) + open_loop.is_closed.return_value = False + + plugin._loop_state_by_loop[open_loop] = self._make_loop_state() + + plugin._cleanup_stale_loop_states() + + assert open_loop in plugin._loop_state_by_loop + + def test_cleanup_removes_only_closed_loops(self): + """Only closed loops should be removed; open ones stay.""" + plugin = self._make_plugin() + + open_loop = mock.MagicMock(spec=asyncio.AbstractEventLoop) + open_loop.is_closed.return_value = False + closed_loop = mock.MagicMock(spec=asyncio.AbstractEventLoop) + closed_loop.is_closed.return_value = True + + plugin._loop_state_by_loop[open_loop] = self._make_loop_state() + plugin._loop_state_by_loop[closed_loop] = self._make_loop_state() + + plugin._cleanup_stale_loop_states() + + assert open_loop in plugin._loop_state_by_loop + assert closed_loop not in plugin._loop_state_by_loop + + @pytest.mark.asyncio + async def testbatch_processor_returns_processor_for_open_loop( + self, + ): + """batch_processor returns processor for the current loop.""" + plugin = self._make_plugin() + + loop = asyncio.get_running_loop() + state = self._make_loop_state() + plugin._loop_state_by_loop[loop] = state + + assert plugin.batch_processor is state.batch_processor + + # Clean up + del plugin._loop_state_by_loop[loop] + + @pytest.mark.asyncio + async def testbatch_processor_cleans_closed_loop_entry(self): + """Accessing batch_processor cleans up closed loop entries.""" + plugin = self._make_plugin() + + closed_loop = mock.MagicMock(spec=asyncio.AbstractEventLoop) + closed_loop.is_closed.return_value = True + plugin._loop_state_by_loop[closed_loop] = self._make_loop_state() + + # Accessing the prop should clean up the closed loop entry + _ = plugin.batch_processor + assert closed_loop not in plugin._loop_state_by_loop + + @pytest.mark.asyncio + async def test_flush_cleans_stale_states(self): + """flush() should clean up stale loop states before flushing.""" + plugin = self._make_plugin() + + closed_loop = mock.MagicMock(spec=asyncio.AbstractEventLoop) + closed_loop.is_closed.return_value = True + plugin._loop_state_by_loop[closed_loop] = self._make_loop_state() + + await plugin.flush() + + assert closed_loop not in plugin._loop_state_by_loop + + +class TestAtexitCleanup: + """Tests for the simplified _atexit_cleanup static method.""" + + def _make_batch_processor(self, queue_items=0): + bp = mock.MagicMock() + bp._shutdown = False + q = asyncio.Queue() + for i in range(queue_items): + q.put_nowait({"event": i}) + bp._queue = q + return bp + + def test_skips_none_processor(self): + """Should return immediately when batch_processor is None.""" + bigquery_agent_analytics_plugin.BigQueryAgentAnalyticsPlugin._atexit_cleanup( + None + ) + + def test_skips_already_shutdown(self): + """Should return immediately when batch_processor._shutdown is True.""" + bp = self._make_batch_processor() + bp._shutdown = True + bigquery_agent_analytics_plugin.BigQueryAgentAnalyticsPlugin._atexit_cleanup( + bp + ) + + def test_skips_reference_error(self): + """Should handle ReferenceError from weakref'd processor.""" + bp = mock.MagicMock() + type(bp)._shutdown = mock.PropertyMock(side_effect=ReferenceError) + bigquery_agent_analytics_plugin.BigQueryAgentAnalyticsPlugin._atexit_cleanup( + bp + ) + + def test_empty_queue_no_warning(self): + """Should not warn when queue is empty.""" + bp = self._make_batch_processor(queue_items=0) + with mock.patch.object( + bigquery_agent_analytics_plugin.logger, "warning" + ) as mock_warn: + bigquery_agent_analytics_plugin.BigQueryAgentAnalyticsPlugin._atexit_cleanup( + bp + ) + mock_warn.assert_not_called() + + def test_remaining_items_logs_warning(self): + """Should drain queue and log warning with count of lost items.""" + bp = self._make_batch_processor(queue_items=3) + with mock.patch.object( + bigquery_agent_analytics_plugin.logger, "warning" + ) as mock_warn: + bigquery_agent_analytics_plugin.BigQueryAgentAnalyticsPlugin._atexit_cleanup( + bp + ) + mock_warn.assert_called_once() + # Verify the warning mentions the count + call_args = mock_warn.call_args + assert "3" in str(call_args) + + def test_queue_is_drained(self): + """Should drain all items from the queue.""" + bp = self._make_batch_processor(queue_items=5) + bigquery_agent_analytics_plugin.BigQueryAgentAnalyticsPlugin._atexit_cleanup( + bp + ) + assert bp._queue.empty() + + +class TestDuplicateLabels: + """Tests that labels in before_model_callback are set exactly once.""" + + @pytest.mark.asyncio + async def test_labels_set_when_present( + self, + bq_plugin_inst, + mock_write_client, + callback_context, + dummy_arrow_schema, + ): + """Labels should appear in attributes when config has them.""" + llm_request = llm_request_lib.LlmRequest( + model="gemini-pro", + config=types.GenerateContentConfig( + labels={"env": "test"}, + ), + contents=[types.Content(role="user", parts=[types.Part(text="hi")])], + ) + bigquery_agent_analytics_plugin.TraceManager.push_span(callback_context) + await bq_plugin_inst.before_model_callback( + callback_context=callback_context, llm_request=llm_request + ) + await asyncio.sleep(0.01) + log_entry = await _get_captured_event_dict_async( + mock_write_client, dummy_arrow_schema + ) + attributes = json.loads(log_entry["attributes"]) + assert attributes["labels"] == {"env": "test"} + + @pytest.mark.asyncio + async def test_labels_absent_when_none( + self, + bq_plugin_inst, + mock_write_client, + callback_context, + dummy_arrow_schema, + ): + """Labels should not appear in attributes when config.labels is None.""" + llm_request = llm_request_lib.LlmRequest( + model="gemini-pro", + config=types.GenerateContentConfig( + temperature=0.5, + ), + contents=[types.Content(role="user", parts=[types.Part(text="hi")])], + ) + bigquery_agent_analytics_plugin.TraceManager.push_span(callback_context) + await bq_plugin_inst.before_model_callback( + callback_context=callback_context, llm_request=llm_request + ) + await asyncio.sleep(0.01) + log_entry = await _get_captured_event_dict_async( + mock_write_client, dummy_arrow_schema + ) + attributes = json.loads(log_entry["attributes"]) + assert "labels" not in attributes + + @pytest.mark.asyncio + async def test_no_config_no_labels( + self, + bq_plugin_inst, + mock_write_client, + callback_context, + dummy_arrow_schema, + ): + """Labels should not appear when llm_request has no config.""" + llm_request = llm_request_lib.LlmRequest( + model="gemini-pro", + contents=[types.Content(role="user", parts=[types.Part(text="hi")])], + ) + bigquery_agent_analytics_plugin.TraceManager.push_span(callback_context) + await bq_plugin_inst.before_model_callback( + callback_context=callback_context, llm_request=llm_request + ) + await asyncio.sleep(0.01) + log_entry = await _get_captured_event_dict_async( + mock_write_client, dummy_arrow_schema + ) + attributes = json.loads(log_entry["attributes"]) + assert "labels" not in attributes + + +class TestResolveSpanIds: + """Tests for the _resolve_span_ids static helper.""" + + def test_uses_trace_manager_defaults(self): + """Should use TraceManager values when no overrides provided.""" + ed = bigquery_agent_analytics_plugin.EventData( + extra_attributes={"some_key": "value"} + ) + with mock.patch.object( + bigquery_agent_analytics_plugin.TraceManager, + "get_current_span_and_parent", + return_value=("span-1", "parent-1"), + ): + span_id, parent_id = ( + bigquery_agent_analytics_plugin.BigQueryAgentAnalyticsPlugin._resolve_span_ids( + ed + ) + ) + assert span_id == "span-1" + assert parent_id == "parent-1" + + def test_span_id_override(self): + """Should use span_id_override from EventData.""" + ed = bigquery_agent_analytics_plugin.EventData( + span_id_override="custom-span" + ) + with mock.patch.object( + bigquery_agent_analytics_plugin.TraceManager, + "get_current_span_and_parent", + return_value=("span-1", "parent-1"), + ): + span_id, parent_id = ( + bigquery_agent_analytics_plugin.BigQueryAgentAnalyticsPlugin._resolve_span_ids( + ed + ) + ) + assert span_id == "custom-span" + assert parent_id == "parent-1" + + def test_parent_span_id_override(self): + """Should use parent_span_id_override from EventData.""" + ed = bigquery_agent_analytics_plugin.EventData( + parent_span_id_override="custom-parent" + ) + with mock.patch.object( + bigquery_agent_analytics_plugin.TraceManager, + "get_current_span_and_parent", + return_value=("span-1", "parent-1"), + ): + span_id, parent_id = ( + bigquery_agent_analytics_plugin.BigQueryAgentAnalyticsPlugin._resolve_span_ids( + ed + ) + ) + assert span_id == "span-1" + assert parent_id == "custom-parent" + + def test_none_override_keeps_default(self): + """None overrides should keep the TraceManager defaults.""" + ed = bigquery_agent_analytics_plugin.EventData( + span_id_override=None, parent_span_id_override=None + ) + with mock.patch.object( + bigquery_agent_analytics_plugin.TraceManager, + "get_current_span_and_parent", + return_value=("span-1", "parent-1"), + ): + span_id, parent_id = ( + bigquery_agent_analytics_plugin.BigQueryAgentAnalyticsPlugin._resolve_span_ids( + ed + ) + ) + assert span_id == "span-1" + assert parent_id == "parent-1" + + +class TestExtractLatency: + """Tests for the _extract_latency static helper.""" + + def test_no_latency_returns_none(self): + """Should return None when no latency fields present.""" + ed = bigquery_agent_analytics_plugin.EventData( + extra_attributes={"other": "val"} + ) + result = bigquery_agent_analytics_plugin.BigQueryAgentAnalyticsPlugin._extract_latency( + ed + ) + assert result is None + + def test_total_latency_only(self): + """Should extract latency_ms into total_ms.""" + ed = bigquery_agent_analytics_plugin.EventData(latency_ms=42.5) + result = bigquery_agent_analytics_plugin.BigQueryAgentAnalyticsPlugin._extract_latency( + ed + ) + assert result == {"total_ms": 42.5} + + def test_tfft_only(self): + """Should extract time_to_first_token_ms.""" + ed = bigquery_agent_analytics_plugin.EventData(time_to_first_token_ms=10.0) + result = bigquery_agent_analytics_plugin.BigQueryAgentAnalyticsPlugin._extract_latency( + ed + ) + assert result == {"time_to_first_token_ms": 10.0} + + def test_both_latencies(self): + """Should extract both latency fields.""" + ed = bigquery_agent_analytics_plugin.EventData( + latency_ms=100, time_to_first_token_ms=20 + ) + result = bigquery_agent_analytics_plugin.BigQueryAgentAnalyticsPlugin._extract_latency( + ed + ) + assert result == {"total_ms": 100, "time_to_first_token_ms": 20} + + +class TestEnrichAttributes: + """Tests for the _enrich_attributes helper.""" + + def _make_plugin(self): + with ( + mock.patch( + "google.auth.default", + return_value=(mock.Mock(), PROJECT_ID), + ), + mock.patch( + "google.cloud.bigquery.Client", + ), + ): + plugin = bigquery_agent_analytics_plugin.BigQueryAgentAnalyticsPlugin( + project_id=PROJECT_ID, + dataset_id=DATASET_ID, + ) + plugin.config.max_content_length = 10000 + plugin.config.log_session_metadata = False + plugin.config.custom_tags = None + return plugin + + def _make_callback_context(self): + ctx = mock.MagicMock() + ctx.session.metadata = {"env": "test"} + return ctx + + def test_adds_root_agent_name(self): + """Should always add root_agent_name.""" + plugin = self._make_plugin() + ed = bigquery_agent_analytics_plugin.EventData() + with mock.patch.object( + bigquery_agent_analytics_plugin.TraceManager, + "get_root_agent_name", + return_value="my-agent", + ): + attrs = plugin._enrich_attributes(ed, self._make_callback_context()) + assert attrs["root_agent_name"] == "my-agent" + + def test_includes_model(self): + """Should include model from EventData.""" + plugin = self._make_plugin() + ed = bigquery_agent_analytics_plugin.EventData(model="gemini-pro") + with mock.patch.object( + bigquery_agent_analytics_plugin.TraceManager, + "get_root_agent_name", + return_value="agent", + ): + attrs = plugin._enrich_attributes(ed, self._make_callback_context()) + assert attrs["model"] == "gemini-pro" + + def test_session_metadata_when_enabled(self): + """Should add session_metadata when log_session_metadata is True.""" + plugin = self._make_plugin() + plugin.config.log_session_metadata = True + ctx = self._make_callback_context() + ed = bigquery_agent_analytics_plugin.EventData() + with mock.patch.object( + bigquery_agent_analytics_plugin.TraceManager, + "get_root_agent_name", + return_value="agent", + ): + attrs = plugin._enrich_attributes(ed, ctx) + assert attrs["session_metadata"] == {"env": "test"} + + def test_session_metadata_when_disabled(self): + """Should not add session_metadata when log_session_metadata is False.""" + plugin = self._make_plugin() + plugin.config.log_session_metadata = False + ed = bigquery_agent_analytics_plugin.EventData() + with mock.patch.object( + bigquery_agent_analytics_plugin.TraceManager, + "get_root_agent_name", + return_value="agent", + ): + attrs = plugin._enrich_attributes(ed, self._make_callback_context()) + assert "session_metadata" not in attrs + + def test_custom_tags_added(self): + """Should add custom_tags when configured.""" + plugin = self._make_plugin() + plugin.config.custom_tags = {"team": "infra"} + ed = bigquery_agent_analytics_plugin.EventData() + with mock.patch.object( + bigquery_agent_analytics_plugin.TraceManager, + "get_root_agent_name", + return_value="agent", + ): + attrs = plugin._enrich_attributes(ed, self._make_callback_context()) + assert attrs["custom_tags"] == {"team": "infra"} + + def test_usage_metadata_truncated(self): + """Should smart-truncate usage_metadata.""" + plugin = self._make_plugin() + ed = bigquery_agent_analytics_plugin.EventData( + usage_metadata={"input_tokens": 100, "output_tokens": 50} + ) + with mock.patch.object( + bigquery_agent_analytics_plugin.TraceManager, + "get_root_agent_name", + return_value="agent", + ): + attrs = plugin._enrich_attributes(ed, self._make_callback_context()) + assert attrs["usage_metadata"] == { + "input_tokens": 100, + "output_tokens": 50, + } + + +class TestMultiSubagentToolLogging: + """Tests that tool events from different subagents are attributed correctly. + + Covers: + - Tool calls from different subagents have the correct `agent` field + - Multi-turn (different invocation_ids, same session) logs correctly + - Full callback sequence across multiple subagents in one turn + - Span hierarchy is maintained per-subagent + """ + + @staticmethod + def _make_invocation_context(agent_name, session, invocation_id="inv-001"): + """Create an InvocationContext with a specific agent name.""" + mock_a = mock.create_autospec( + base_agent.BaseAgent, instance=True, spec_set=True + ) + type(mock_a).name = mock.PropertyMock(return_value=agent_name) + type(mock_a).instruction = mock.PropertyMock( + return_value=f"{agent_name} instruction" + ) + mock_session_service = mock.create_autospec( + base_session_service_lib.BaseSessionService, + instance=True, + spec_set=True, + ) + mock_plugin_manager = mock.create_autospec( + plugin_manager_lib.PluginManager, + instance=True, + spec_set=True, + ) + return invocation_context_lib.InvocationContext( + agent=mock_a, + session=session, + invocation_id=invocation_id, + session_service=mock_session_service, + plugin_manager=mock_plugin_manager, + ) + + @staticmethod + def _make_session(session_id="session-multi", user_id="user-multi"): + mock_s = mock.create_autospec( + session_lib.Session, instance=True, spec_set=True + ) + type(mock_s).id = mock.PropertyMock(return_value=session_id) + type(mock_s).user_id = mock.PropertyMock(return_value=user_id) + type(mock_s).app_name = mock.PropertyMock(return_value="test_app") + type(mock_s).state = mock.PropertyMock(return_value={}) + return mock_s + + @staticmethod + def _make_tool(name): + mock_tool = mock.create_autospec( + base_tool_lib.BaseTool, instance=True, spec_set=True + ) + type(mock_tool).name = mock.PropertyMock(return_value=name) + type(mock_tool).description = mock.PropertyMock( + return_value=f"{name} description" + ) + return mock_tool + + @pytest.mark.asyncio + async def test_tool_calls_attributed_to_correct_subagent( + self, + mock_auth_default, + mock_bq_client, + mock_write_client, + mock_to_arrow_schema, + dummy_arrow_schema, + mock_asyncio_to_thread, + ): + """Tool events from different subagents carry the correct agent name.""" + session = self._make_session() + + async with managed_plugin( + PROJECT_ID, DATASET_ID, table_id=TABLE_ID + ) as plugin: + await plugin._ensure_started() + mock_write_client.append_rows.reset_mock() + + # --- Subagent A: schema_explorer calls list_datasets --- + inv_ctx_a = self._make_invocation_context("schema_explorer", session) + ctx_a = tool_context_lib.ToolContext(invocation_context=inv_ctx_a) + tool_a = self._make_tool("list_dataset_ids") + + bigquery_agent_analytics_plugin.TraceManager.push_span(ctx_a, "tool") + await plugin.before_tool_callback( + tool=tool_a, + tool_args={"project_id": "my-project"}, + tool_context=ctx_a, + ) + await asyncio.sleep(0.01) + + # --- Subagent B: image_describer calls describe_this_image --- + inv_ctx_b = self._make_invocation_context("image_describer", session) + ctx_b = tool_context_lib.ToolContext(invocation_context=inv_ctx_b) + tool_b = self._make_tool("describe_this_image") + + bigquery_agent_analytics_plugin.TraceManager.push_span(ctx_b, "tool") + await plugin.before_tool_callback( + tool=tool_b, + tool_args={"image_uri": "gs://bucket/image.jpg"}, + tool_context=ctx_b, + ) + await asyncio.sleep(0.01) + + rows = await _get_captured_rows_async( + mock_write_client, dummy_arrow_schema + ) + + assert len(rows) == 2 + + # First row: schema_explorer's tool + assert rows[0]["event_type"] == "TOOL_STARTING" + assert rows[0]["agent"] == "schema_explorer" + content_a = json.loads(rows[0]["content"]) + assert content_a["tool"] == "list_dataset_ids" + assert content_a["args"] == {"project_id": "my-project"} + + # Second row: image_describer's tool + assert rows[1]["event_type"] == "TOOL_STARTING" + assert rows[1]["agent"] == "image_describer" + content_b = json.loads(rows[1]["content"]) + assert content_b["tool"] == "describe_this_image" + assert content_b["args"] == {"image_uri": "gs://bucket/image.jpg"} + + @pytest.mark.asyncio + async def test_multi_turn_tool_calls_different_invocations( + self, + mock_auth_default, + mock_bq_client, + mock_write_client, + mock_to_arrow_schema, + dummy_arrow_schema, + mock_asyncio_to_thread, + ): + """Multi-turn: same session, different invocation IDs, tools logged.""" + session = self._make_session() + + async with managed_plugin( + PROJECT_ID, DATASET_ID, table_id=TABLE_ID + ) as plugin: + await plugin._ensure_started() + mock_write_client.append_rows.reset_mock() + + # --- Turn 1: schema_explorer calls list_dataset_ids --- + inv_ctx_1 = self._make_invocation_context( + "schema_explorer", session, invocation_id="inv-turn1" + ) + ctx_1 = tool_context_lib.ToolContext(invocation_context=inv_ctx_1) + tool_1 = self._make_tool("list_dataset_ids") + + bigquery_agent_analytics_plugin.TraceManager.push_span(ctx_1, "tool") + await plugin.before_tool_callback( + tool=tool_1, + tool_args={"project_id": "proj"}, + tool_context=ctx_1, + ) + await asyncio.sleep(0.01) + await plugin.after_tool_callback( + tool=tool_1, + tool_args={"project_id": "proj"}, + tool_context=ctx_1, + result={"datasets": ["ds1", "ds2"]}, + ) + await asyncio.sleep(0.01) + + # --- Turn 2: query_analyst calls execute_sql --- + inv_ctx_2 = self._make_invocation_context( + "query_analyst", session, invocation_id="inv-turn2" + ) + ctx_2 = tool_context_lib.ToolContext(invocation_context=inv_ctx_2) + tool_2 = self._make_tool("execute_sql") + + bigquery_agent_analytics_plugin.TraceManager.push_span(ctx_2, "tool") + await plugin.before_tool_callback( + tool=tool_2, + tool_args={"sql": "SELECT * FROM t"}, + tool_context=ctx_2, + ) + await asyncio.sleep(0.01) + await plugin.after_tool_callback( + tool=tool_2, + tool_args={"sql": "SELECT * FROM t"}, + tool_context=ctx_2, + result={"rows": [{"col": "val"}]}, + ) + await asyncio.sleep(0.01) + + rows = await _get_captured_rows_async( + mock_write_client, dummy_arrow_schema + ) + + assert len(rows) == 4 + + # Turn 1: TOOL_STARTING + TOOL_COMPLETED for schema_explorer + assert rows[0]["event_type"] == "TOOL_STARTING" + assert rows[0]["agent"] == "schema_explorer" + assert rows[0]["invocation_id"] == "inv-turn1" + assert rows[0]["session_id"] == "session-multi" + + assert rows[1]["event_type"] == "TOOL_COMPLETED" + assert rows[1]["agent"] == "schema_explorer" + assert rows[1]["invocation_id"] == "inv-turn1" + content_1 = json.loads(rows[1]["content"]) + assert content_1["tool"] == "list_dataset_ids" + assert content_1["result"] == {"datasets": ["ds1", "ds2"]} + + # Turn 2: TOOL_STARTING + TOOL_COMPLETED for query_analyst + assert rows[2]["event_type"] == "TOOL_STARTING" + assert rows[2]["agent"] == "query_analyst" + assert rows[2]["invocation_id"] == "inv-turn2" + + assert rows[3]["event_type"] == "TOOL_COMPLETED" + assert rows[3]["agent"] == "query_analyst" + assert rows[3]["invocation_id"] == "inv-turn2" + content_2 = json.loads(rows[3]["content"]) + assert content_2["tool"] == "execute_sql" + assert content_2["result"] == {"rows": [{"col": "val"}]} + + @pytest.mark.asyncio + async def test_full_subagent_callback_sequence( + self, + mock_auth_default, + mock_bq_client, + mock_write_client, + mock_to_arrow_schema, + dummy_arrow_schema, + mock_asyncio_to_thread, + ): + """Full lifecycle: agent_start → LLM → tool → tool_done → LLM → agent_done. + + Simulates a subagent that makes an LLM call, then a tool call, + then another LLM call, and completes. + """ + session = self._make_session() + inv_ctx = self._make_invocation_context("schema_explorer", session) + cb_ctx = callback_context_lib.CallbackContext(invocation_context=inv_ctx) + tool_ctx = tool_context_lib.ToolContext(invocation_context=inv_ctx) + mock_agent = inv_ctx.agent + tool = self._make_tool("get_table_info") + + async with managed_plugin( + PROJECT_ID, DATASET_ID, table_id=TABLE_ID + ) as plugin: + await plugin._ensure_started() + mock_write_client.append_rows.reset_mock() + + # 1. AGENT_STARTING + await plugin.before_agent_callback( + agent=mock_agent, callback_context=cb_ctx + ) + await asyncio.sleep(0.01) + + # 2. LLM_REQUEST (agent decides to call a tool) + llm_req = llm_request_lib.LlmRequest( + model="gemini-2.5-flash", + contents=[ + types.Content(parts=[types.Part(text="What tables exist?")]) + ], + ) + await plugin.before_model_callback( + callback_context=cb_ctx, llm_request=llm_req + ) + await asyncio.sleep(0.01) + + # 3. LLM_RESPONSE (function call) + llm_resp = llm_response_lib.LlmResponse( + content=types.Content( + parts=[ + types.Part( + function_call=types.FunctionCall( + name="get_table_info", + args={"table": "events"}, + ) + ) + ] + ) + ) + await plugin.after_model_callback( + callback_context=cb_ctx, llm_response=llm_resp + ) + await asyncio.sleep(0.01) + + # 4. TOOL_STARTING + bigquery_agent_analytics_plugin.TraceManager.push_span(tool_ctx, "tool") + await plugin.before_tool_callback( + tool=tool, + tool_args={"table": "events"}, + tool_context=tool_ctx, + ) + await asyncio.sleep(0.01) + + # 5. TOOL_COMPLETED + await plugin.after_tool_callback( + tool=tool, + tool_args={"table": "events"}, + tool_context=tool_ctx, + result={"schema": [{"name": "id", "type": "INT64"}]}, + ) + await asyncio.sleep(0.01) + + # 6. AGENT_COMPLETED + await plugin.after_agent_callback( + agent=mock_agent, callback_context=cb_ctx + ) + await asyncio.sleep(0.01) + + rows = await _get_captured_rows_async( + mock_write_client, dummy_arrow_schema + ) + + assert len(rows) == 6 + + expected_sequence = [ + "AGENT_STARTING", + "LLM_REQUEST", + "LLM_RESPONSE", + "TOOL_STARTING", + "TOOL_COMPLETED", + "AGENT_COMPLETED", + ] + for i, expected_type in enumerate(expected_sequence): + assert ( + rows[i]["event_type"] == expected_type + ), f"Row {i}: expected {expected_type}, got {rows[i]['event_type']}" + assert rows[i]["agent"] == "schema_explorer" + assert rows[i]["session_id"] == "session-multi" + + # TOOL rows have correct content + tool_start = json.loads(rows[3]["content"]) + assert tool_start["tool"] == "get_table_info" + assert tool_start["args"] == {"table": "events"} + + tool_done = json.loads(rows[4]["content"]) + assert tool_done["tool"] == "get_table_info" + assert tool_done["result"] == {"schema": [{"name": "id", "type": "INT64"}]} + + # AGENT_COMPLETED and TOOL_COMPLETED should have latency + assert rows[4]["latency_ms"] is not None # TOOL_COMPLETED + assert rows[5]["latency_ms"] is not None # AGENT_COMPLETED + + @pytest.mark.asyncio + async def test_tool_error_attributed_to_subagent( + self, + mock_auth_default, + mock_bq_client, + mock_write_client, + mock_to_arrow_schema, + dummy_arrow_schema, + mock_asyncio_to_thread, + ): + """TOOL_ERROR events carry the correct subagent name.""" + session = self._make_session() + inv_ctx = self._make_invocation_context("query_analyst", session) + tool_ctx = tool_context_lib.ToolContext(invocation_context=inv_ctx) + tool = self._make_tool("execute_sql") + + async with managed_plugin( + PROJECT_ID, DATASET_ID, table_id=TABLE_ID + ) as plugin: + await plugin._ensure_started() + mock_write_client.append_rows.reset_mock() + + bigquery_agent_analytics_plugin.TraceManager.push_span(tool_ctx, "tool") + await plugin.on_tool_error_callback( + tool=tool, + tool_args={"sql": "SELECT * FROM bad_table"}, + tool_context=tool_ctx, + error=RuntimeError("Table not found"), + ) + await asyncio.sleep(0.01) + + rows = await _get_captured_rows_async( + mock_write_client, dummy_arrow_schema + ) + + assert len(rows) == 1 + assert rows[0]["event_type"] == "TOOL_ERROR" + assert rows[0]["agent"] == "query_analyst" + assert rows[0]["error_message"] == "Table not found" + content = json.loads(rows[0]["content"]) + assert content["tool"] == "execute_sql" + assert content["args"] == {"sql": "SELECT * FROM bad_table"} + + @pytest.mark.asyncio + async def test_multi_subagent_interleaved_tool_calls( + self, + mock_auth_default, + mock_bq_client, + mock_write_client, + mock_to_arrow_schema, + dummy_arrow_schema, + mock_asyncio_to_thread, + ): + """Two subagents call tools in same invocation — agent field is correct. + + Simulates orchestrator delegating to schema_explorer first, then + image_describer, all within the same invocation. + """ + session = self._make_session() + + async with managed_plugin( + PROJECT_ID, DATASET_ID, table_id=TABLE_ID + ) as plugin: + await plugin._ensure_started() + mock_write_client.append_rows.reset_mock() + + # Subagent 1: schema_explorer — full tool cycle + inv_ctx_1 = self._make_invocation_context( + "schema_explorer", session, invocation_id="inv-shared" + ) + ctx_1 = tool_context_lib.ToolContext(invocation_context=inv_ctx_1) + tool_1 = self._make_tool("list_table_ids") + bigquery_agent_analytics_plugin.TraceManager.push_span(ctx_1, "tool") + await plugin.before_tool_callback( + tool=tool_1, + tool_args={"dataset": "analytics"}, + tool_context=ctx_1, + ) + await asyncio.sleep(0.01) + await plugin.after_tool_callback( + tool=tool_1, + tool_args={"dataset": "analytics"}, + tool_context=ctx_1, + result={"tables": ["events", "metrics"]}, + ) + await asyncio.sleep(0.01) + + # Subagent 2: image_describer — full tool cycle + inv_ctx_2 = self._make_invocation_context( + "image_describer", session, invocation_id="inv-shared" + ) + ctx_2 = tool_context_lib.ToolContext(invocation_context=inv_ctx_2) + tool_2 = self._make_tool("describe_this_image") + bigquery_agent_analytics_plugin.TraceManager.push_span(ctx_2, "tool") + await plugin.before_tool_callback( + tool=tool_2, + tool_args={"image_uri": "https://example.com/img.jpg"}, + tool_context=ctx_2, + ) + await asyncio.sleep(0.01) + await plugin.after_tool_callback( + tool=tool_2, + tool_args={"image_uri": "https://example.com/img.jpg"}, + tool_context=ctx_2, + result={"description": "A photo of scones"}, + ) + await asyncio.sleep(0.01) + + rows = await _get_captured_rows_async( + mock_write_client, dummy_arrow_schema + ) + + assert len(rows) == 4 + + # schema_explorer tool events + assert rows[0]["agent"] == "schema_explorer" + assert rows[0]["event_type"] == "TOOL_STARTING" + assert rows[0]["invocation_id"] == "inv-shared" + assert json.loads(rows[0]["content"])["tool"] == "list_table_ids" + + assert rows[1]["agent"] == "schema_explorer" + assert rows[1]["event_type"] == "TOOL_COMPLETED" + assert json.loads(rows[1]["content"])["result"]["tables"] == [ + "events", + "metrics", + ] + + # image_describer tool events + assert rows[2]["agent"] == "image_describer" + assert rows[2]["event_type"] == "TOOL_STARTING" + assert rows[2]["invocation_id"] == "inv-shared" + assert json.loads(rows[2]["content"])["tool"] == "describe_this_image" + + assert rows[3]["agent"] == "image_describer" + assert rows[3]["event_type"] == "TOOL_COMPLETED" + assert ( + json.loads(rows[3]["content"])["result"]["description"] + == "A photo of scones" + ) + + # All share the same session and invocation + for row in rows: + assert row["session_id"] == "session-multi" + assert row["invocation_id"] == "inv-shared" + + @pytest.mark.asyncio + async def test_multi_turn_multi_subagent_full_sequence( + self, + mock_auth_default, + mock_bq_client, + mock_write_client, + mock_to_arrow_schema, + dummy_arrow_schema, + mock_asyncio_to_thread, + ): + """Multi-turn + multi-subagent: two turns, each with different subagents. + + Turn 1: user asks about data → orchestrator → schema_explorer (tool) + Turn 2: user asks about image → orchestrator → image_describer (tool) + Verifies invocation_id changes, agent name changes, session stays same. + """ + session = self._make_session() + + async with managed_plugin( + PROJECT_ID, DATASET_ID, table_id=TABLE_ID + ) as plugin: + await plugin._ensure_started() + mock_write_client.append_rows.reset_mock() + + # ===== Turn 1: schema_explorer ===== + inv_ctx_t1_orch = self._make_invocation_context( + "orchestrator", session, invocation_id="inv-t1" + ) + cb_ctx_t1_orch = callback_context_lib.CallbackContext( + invocation_context=inv_ctx_t1_orch + ) + + # Orchestrator agent_starting + await plugin.before_agent_callback( + agent=inv_ctx_t1_orch.agent, + callback_context=cb_ctx_t1_orch, + ) + await asyncio.sleep(0.01) + + # Orchestrator delegates to schema_explorer + inv_ctx_t1_sub = self._make_invocation_context( + "schema_explorer", session, invocation_id="inv-t1" + ) + cb_ctx_t1_sub = callback_context_lib.CallbackContext( + invocation_context=inv_ctx_t1_sub + ) + tool_ctx_t1 = tool_context_lib.ToolContext( + invocation_context=inv_ctx_t1_sub + ) + + await plugin.before_agent_callback( + agent=inv_ctx_t1_sub.agent, + callback_context=cb_ctx_t1_sub, + ) + await asyncio.sleep(0.01) + + # schema_explorer calls tool + tool_1 = self._make_tool("list_dataset_ids") + bigquery_agent_analytics_plugin.TraceManager.push_span( + tool_ctx_t1, "tool" + ) + await plugin.before_tool_callback( + tool=tool_1, + tool_args={"project_id": "proj"}, + tool_context=tool_ctx_t1, + ) + await asyncio.sleep(0.01) + await plugin.after_tool_callback( + tool=tool_1, + tool_args={"project_id": "proj"}, + tool_context=tool_ctx_t1, + result={"datasets": ["ds1"]}, + ) + await asyncio.sleep(0.01) + + # schema_explorer done + await plugin.after_agent_callback( + agent=inv_ctx_t1_sub.agent, + callback_context=cb_ctx_t1_sub, + ) + await asyncio.sleep(0.01) + + # Orchestrator done + await plugin.after_agent_callback( + agent=inv_ctx_t1_orch.agent, + callback_context=cb_ctx_t1_orch, + ) + await asyncio.sleep(0.01) + + # ===== Turn 2: image_describer ===== + inv_ctx_t2_orch = self._make_invocation_context( + "orchestrator", session, invocation_id="inv-t2" + ) + cb_ctx_t2_orch = callback_context_lib.CallbackContext( + invocation_context=inv_ctx_t2_orch + ) + + await plugin.before_agent_callback( + agent=inv_ctx_t2_orch.agent, + callback_context=cb_ctx_t2_orch, + ) + await asyncio.sleep(0.01) + + # Orchestrator delegates to image_describer + inv_ctx_t2_sub = self._make_invocation_context( + "image_describer", session, invocation_id="inv-t2" + ) + cb_ctx_t2_sub = callback_context_lib.CallbackContext( + invocation_context=inv_ctx_t2_sub + ) + tool_ctx_t2 = tool_context_lib.ToolContext( + invocation_context=inv_ctx_t2_sub + ) + + await plugin.before_agent_callback( + agent=inv_ctx_t2_sub.agent, + callback_context=cb_ctx_t2_sub, + ) + await asyncio.sleep(0.01) + + # image_describer calls tool + tool_2 = self._make_tool("describe_this_image") + bigquery_agent_analytics_plugin.TraceManager.push_span( + tool_ctx_t2, "tool" + ) + await plugin.before_tool_callback( + tool=tool_2, + tool_args={"image_uri": "gs://b/img.jpg"}, + tool_context=tool_ctx_t2, + ) + await asyncio.sleep(0.01) + await plugin.after_tool_callback( + tool=tool_2, + tool_args={"image_uri": "gs://b/img.jpg"}, + tool_context=tool_ctx_t2, + result={"desc": "Scones on a table"}, + ) + await asyncio.sleep(0.01) + + # image_describer done + await plugin.after_agent_callback( + agent=inv_ctx_t2_sub.agent, + callback_context=cb_ctx_t2_sub, + ) + await asyncio.sleep(0.01) + + # Orchestrator done + await plugin.after_agent_callback( + agent=inv_ctx_t2_orch.agent, + callback_context=cb_ctx_t2_orch, + ) + await asyncio.sleep(0.01) + + rows = await _get_captured_rows_async( + mock_write_client, dummy_arrow_schema + ) + + # Turn 1: 6 rows (orch_start, sub_start, tool_start, tool_done, + # sub_done, orch_done) + # Turn 2: 6 rows (same pattern) + assert len(rows) == 12 + + # --- Turn 1 validation --- + t1_rows = [r for r in rows if r["invocation_id"] == "inv-t1"] + assert len(t1_rows) == 6 + + assert t1_rows[0]["event_type"] == "AGENT_STARTING" + assert t1_rows[0]["agent"] == "orchestrator" + + assert t1_rows[1]["event_type"] == "AGENT_STARTING" + assert t1_rows[1]["agent"] == "schema_explorer" + + assert t1_rows[2]["event_type"] == "TOOL_STARTING" + assert t1_rows[2]["agent"] == "schema_explorer" + assert json.loads(t1_rows[2]["content"])["tool"] == "list_dataset_ids" + + assert t1_rows[3]["event_type"] == "TOOL_COMPLETED" + assert t1_rows[3]["agent"] == "schema_explorer" + + assert t1_rows[4]["event_type"] == "AGENT_COMPLETED" + assert t1_rows[4]["agent"] == "schema_explorer" + + assert t1_rows[5]["event_type"] == "AGENT_COMPLETED" + assert t1_rows[5]["agent"] == "orchestrator" + + # --- Turn 2 validation --- + t2_rows = [r for r in rows if r["invocation_id"] == "inv-t2"] + assert len(t2_rows) == 6 + + assert t2_rows[0]["event_type"] == "AGENT_STARTING" + assert t2_rows[0]["agent"] == "orchestrator" + + assert t2_rows[1]["event_type"] == "AGENT_STARTING" + assert t2_rows[1]["agent"] == "image_describer" + + assert t2_rows[2]["event_type"] == "TOOL_STARTING" + assert t2_rows[2]["agent"] == "image_describer" + assert json.loads(t2_rows[2]["content"])["tool"] == "describe_this_image" + + assert t2_rows[3]["event_type"] == "TOOL_COMPLETED" + assert t2_rows[3]["agent"] == "image_describer" + + assert t2_rows[4]["event_type"] == "AGENT_COMPLETED" + assert t2_rows[4]["agent"] == "image_describer" + + assert t2_rows[5]["event_type"] == "AGENT_COMPLETED" + assert t2_rows[5]["agent"] == "orchestrator" + + # All rows share the same session + for row in rows: + assert row["session_id"] == "session-multi"