Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions src/google/adk/plugins/bigquery_agent_analytics_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,10 @@ class BigQueryLoggerConfig:
shutdown_timeout: Max time to wait for shutdown.
queue_max_size: Max size of the in-memory queue.
content_formatter: Optional custom formatter for content.
gcs_bucket_name: GCS bucket name for offloading large content.
connection_id: BigQuery connection ID for ObjectRef authorization.
default_attributes: Static key-value pairs included in every event's
attributes. Useful for service version, environment, etc.
"""

enabled: bool = True
Expand All @@ -409,6 +413,10 @@ class BigQueryLoggerConfig:
# If provided, this connection ID will be used as the authorizer for ObjectRef columns.
# Format: "location.connection_id" (e.g. "us.my-connection")
connection_id: Optional[str] = None
# If provided, these key-value pairs will be merged into every event's attributes.
# Useful for adding static metadata like service version, deployment environment, etc.
# Event-specific attributes will override these if there are conflicts.
default_attributes: Optional[dict[str, Any]] = None

# Toggle for session metadata (e.g. gchat thread-id)
log_session_metadata: bool = True
Expand Down Expand Up @@ -2117,12 +2125,18 @@ async def _log_event(
if self.config.custom_tags:
kwargs["custom_tags"] = self.config.custom_tags

# 8. Merge default_attributes first, then let event-specific kwargs override
if self.config.default_attributes:
merged_attributes = {**self.config.default_attributes, **kwargs}
else:
merged_attributes = kwargs

# Serialize remaining kwargs to JSON string for attributes
try:
attributes_json = json.dumps(kwargs)
attributes_json = json.dumps(merged_attributes)
except (TypeError, ValueError):
# Fallback for non-serializable objects
attributes_json = json.dumps(kwargs, default=str)
attributes_json = json.dumps(merged_attributes, default=str)

row = {
"timestamp": timestamp,
Expand Down
125 changes: 125 additions & 0 deletions tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from google.adk.models import llm_response as llm_response_lib
from google.adk.plugins import bigquery_agent_analytics_plugin
from google.adk.plugins import plugin_manager as plugin_manager_lib
from google.adk.plugins.bigquery_agent_analytics_plugin import BigQueryLoggerConfig
from google.adk.sessions import base_session_service as base_session_service_lib
from google.adk.sessions import session as session_lib
from google.adk.tools import base_tool as base_tool_lib
Expand Down Expand Up @@ -2175,6 +2176,130 @@ async def test_flush_mechanism(
)
assert log_entry["event_type"] == "INVOCATION_STARTING"

@pytest.mark.asyncio
async def test_default_attributes_included_in_events(
self,
mock_write_client,
invocation_context,
mock_auth_default,
mock_bq_client,
mock_to_arrow_schema,
dummy_arrow_schema,
mock_asyncio_to_thread,
):
"""Test that default_attributes are included in every logged event."""
default_attrs = {
"service_version": "1.2.3",
"environment": "production",
"deployment_id": "deploy-abc",
}
config = BigQueryLoggerConfig(default_attributes=default_attrs)
plugin = bigquery_agent_analytics_plugin.BigQueryAgentAnalyticsPlugin(
PROJECT_ID, DATASET_ID, table_id=TABLE_ID, config=config
)
await plugin._ensure_started()
mock_write_client.append_rows.reset_mock()

user_message = types.Content(parts=[types.Part(text="Hello")])
bigquery_agent_analytics_plugin.TraceManager.push_span(invocation_context)
await plugin.on_user_message_callback(
invocation_context=invocation_context, user_message=user_message
)
await asyncio.sleep(0.01)

mock_write_client.append_rows.assert_called_once()
log_entry = await _get_captured_event_dict_async(
mock_write_client, dummy_arrow_schema
)

# Verify default attributes are in the attributes field
attributes = json.loads(log_entry["attributes"])
assert attributes["service_version"] == "1.2.3"
assert attributes["environment"] == "production"
assert attributes["deployment_id"] == "deploy-abc"

@pytest.mark.asyncio
async def test_default_attributes_overridden_by_event_attributes(
self,
mock_write_client,
callback_context,
mock_auth_default,
mock_bq_client,
mock_to_arrow_schema,
dummy_arrow_schema,
mock_asyncio_to_thread,
):
"""Test that event-specific attributes override default_attributes."""
default_attrs = {
"service_version": "1.2.3",
"model": "default-model",
}
config = BigQueryLoggerConfig(default_attributes=default_attrs)
plugin = bigquery_agent_analytics_plugin.BigQueryAgentAnalyticsPlugin(
PROJECT_ID, DATASET_ID, table_id=TABLE_ID, config=config
)
await plugin._ensure_started()
mock_write_client.append_rows.reset_mock()

# LLM request will add its own "model" attribute which should override the default
llm_request = llm_request_lib.LlmRequest(
model="gemini-pro",
contents=[types.Content(parts=[types.Part(text="Hi")])],
)
bigquery_agent_analytics_plugin.TraceManager.push_span(callback_context)
await plugin.before_model_callback(
callback_context=callback_context, llm_request=llm_request
)
await asyncio.sleep(0.01)

mock_write_client.append_rows.assert_called_once()
log_entry = await _get_captured_event_dict_async(
mock_write_client, dummy_arrow_schema
)

attributes = json.loads(log_entry["attributes"])
# service_version should come from default_attributes
assert attributes["service_version"] == "1.2.3"
# model should be overridden by the event-specific value
assert attributes["model"] == "gemini-pro"

@pytest.mark.asyncio
async def test_default_attributes_none_does_not_affect_events(
self,
mock_write_client,
invocation_context,
mock_auth_default,
mock_bq_client,
mock_to_arrow_schema,
dummy_arrow_schema,
mock_asyncio_to_thread,
):
"""Test that when default_attributes is None, events work normally."""
config = BigQueryLoggerConfig(default_attributes=None)
plugin = bigquery_agent_analytics_plugin.BigQueryAgentAnalyticsPlugin(
PROJECT_ID, DATASET_ID, table_id=TABLE_ID, config=config
)
await plugin._ensure_started()
mock_write_client.append_rows.reset_mock()

user_message = types.Content(parts=[types.Part(text="Hello")])
bigquery_agent_analytics_plugin.TraceManager.push_span(invocation_context)
await plugin.on_user_message_callback(
invocation_context=invocation_context, user_message=user_message
)
await asyncio.sleep(0.01)

mock_write_client.append_rows.assert_called_once()
log_entry = await _get_captured_event_dict_async(
mock_write_client, dummy_arrow_schema
)

# Verify event was logged successfully with normal attributes
assert log_entry["event_type"] == "USER_MESSAGE_RECEIVED"
# Attributes should only contain root_agent_name (added by plugin)
attributes = json.loads(log_entry["attributes"])
assert "root_agent_name" in attributes

@pytest.mark.asyncio
@pytest.mark.parametrize(
"gen_config_kwargs, expected_llm_config",
Expand Down