From e2964f0105a7744c899aeb8c319e47fc284c7cc8 Mon Sep 17 00:00:00 2001 From: Haiyuan Cao Date: Thu, 26 Feb 2026 22:58:14 -0800 Subject: [PATCH 01/15] fix: add fork-safety and auto-create analytics views to BQ plugin Add PID tracking to detect post-fork broken gRPC channels (#4636) and automatically create per-event-type BigQuery views that unnest JSON columns into typed, queryable columns (#4639). Co-Authored-By: Claude Opus 4.6 --- .../bigquery_agent_analytics_plugin.py | 179 +++++++++++++++ .../test_bigquery_agent_analytics_plugin.py | 204 ++++++++++++++++++ 2 files changed, 383 insertions(+) diff --git a/src/google/adk/plugins/bigquery_agent_analytics_plugin.py b/src/google/adk/plugins/bigquery_agent_analytics_plugin.py index 70a17f4001..b7066c7280 100644 --- a/src/google/adk/plugins/bigquery_agent_analytics_plugin.py +++ b/src/google/adk/plugins/bigquery_agent_analytics_plugin.py @@ -27,6 +27,7 @@ import json import logging import mimetypes +import os import random import time from types import MappingProxyType @@ -498,6 +499,9 @@ class BigQueryLoggerConfig: # dropped or altered). Safe to leave enabled; a version label on the # table ensures the diff runs at most once per schema version. auto_schema_upgrade: bool = True + # Automatically create per-event-type BigQuery views that unnest + # JSON columns into typed, queryable columns. + create_views: bool = True # ============================================================================== @@ -1581,6 +1585,115 @@ def _get_events_schema() -> list[bigquery.SchemaField]: ] +# ============================================================================== +# ANALYTICS VIEW DEFINITIONS +# ============================================================================== + +# Columns included in every per-event-type view. +_VIEW_COMMON_COLUMNS = ( + "timestamp", + "event_type", + "agent", + "session_id", + "invocation_id", + "user_id", + "trace_id", + "span_id", + "parent_span_id", + "status", + "error_message", + "is_truncated", +) + +# Per-event-type column extractions. Each value is a list of +# ``"SQL_EXPR AS alias"`` strings that will be appended after the +# common columns in the view SELECT. +_EVENT_VIEW_DEFS: dict[str, list[str]] = { + "USER_MESSAGE_RECEIVED": [], + "LLM_REQUEST": [ + "JSON_VALUE(attributes, '$.model') AS model", + "content AS request_content", + "JSON_QUERY(attributes, '$.llm_config') AS llm_config", + "JSON_QUERY(attributes, '$.tools') AS tools", + ], + "LLM_RESPONSE": [ + "JSON_QUERY(content, '$.response') AS response", + ( + "CAST(JSON_VALUE(content, '$.usage.prompt')" + " AS INT64) AS usage_prompt_tokens" + ), + ( + "CAST(JSON_VALUE(content, '$.usage.completion')" + " AS INT64) AS usage_completion_tokens" + ), + ( + "CAST(JSON_VALUE(content, '$.usage.total')" + " AS INT64) AS usage_total_tokens" + ), + "CAST(JSON_VALUE(latency_ms, '$.total_ms') AS INT64) AS total_ms", + ( + "CAST(JSON_VALUE(latency_ms," + " '$.time_to_first_token_ms') AS INT64) AS ttft_ms" + ), + "JSON_VALUE(attributes, '$.model_version') AS model_version", + "JSON_QUERY(attributes, '$.usage_metadata') AS usage_metadata", + ], + "LLM_ERROR": [ + "CAST(JSON_VALUE(latency_ms, '$.total_ms') AS INT64) AS total_ms", + ], + "TOOL_STARTING": [ + "JSON_VALUE(content, '$.tool') AS tool_name", + "JSON_QUERY(content, '$.args') AS tool_args", + "JSON_VALUE(content, '$.tool_origin') AS tool_origin", + ], + "TOOL_COMPLETED": [ + "JSON_VALUE(content, '$.tool') AS tool_name", + "JSON_QUERY(content, '$.result') AS tool_result", + "JSON_VALUE(content, '$.tool_origin') AS tool_origin", + "CAST(JSON_VALUE(latency_ms, '$.total_ms') AS INT64) AS total_ms", + ], + "TOOL_ERROR": [ + "JSON_VALUE(content, '$.tool') AS tool_name", + "JSON_QUERY(content, '$.args') AS tool_args", + "JSON_VALUE(content, '$.tool_origin') AS tool_origin", + "CAST(JSON_VALUE(latency_ms, '$.total_ms') AS INT64) AS total_ms", + ], + "AGENT_STARTING": [ + "JSON_VALUE(content, '$.text_summary') AS agent_instruction", + ], + "AGENT_COMPLETED": [ + "CAST(JSON_VALUE(latency_ms, '$.total_ms') AS INT64) AS total_ms", + ], + "INVOCATION_STARTING": [], + "INVOCATION_COMPLETED": [], + "STATE_DELTA": [ + "JSON_QUERY(attributes, '$.state_delta') AS state_delta", + ], + "HITL_CREDENTIAL_REQUEST": [ + "JSON_VALUE(content, '$.tool') AS tool_name", + "JSON_QUERY(content, '$.args') AS tool_args", + ], + "HITL_CONFIRMATION_REQUEST": [ + "JSON_VALUE(content, '$.tool') AS tool_name", + "JSON_QUERY(content, '$.args') AS tool_args", + ], + "HITL_INPUT_REQUEST": [ + "JSON_VALUE(content, '$.tool') AS tool_name", + "JSON_QUERY(content, '$.args') AS tool_args", + ], +} + +_VIEW_SQL_TEMPLATE = """\ +CREATE OR REPLACE VIEW `{project}.{dataset}.{view_name}` AS +SELECT + {columns} +FROM + `{project}.{dataset}.{table}` +WHERE + event_type = '{event_type}' +""" + + # ============================================================================== # MAIN PLUGIN # ============================================================================== @@ -1660,6 +1773,7 @@ def __init__( self.parser: Optional[HybridContentParser] = None self._schema = None self.arrow_schema = None + self._init_pid = os.getpid() def _cleanup_stale_loop_states(self) -> None: """Removes entries for event loops that have been closed.""" @@ -1912,6 +2026,8 @@ def _ensure_schema_exists(self) -> None: existing_table = self.client.get_table(self.full_table_id) if self.config.auto_schema_upgrade: self._maybe_upgrade_schema(existing_table) + if self.config.create_views: + self._create_analytics_views() except cloud_exceptions.NotFound: logger.info("Table %s not found, creating table.", self.full_table_id) tbl = bigquery.Table(self.full_table_id, schema=self._schema) @@ -1932,6 +2048,8 @@ def _ensure_schema_exists(self) -> None: e, exc_info=True, ) + if self.config.create_views: + self._create_analytics_views() except Exception as e: logger.error( "Error checking for table %s: %s", @@ -1980,6 +2098,44 @@ def _maybe_upgrade_schema(self, existing_table: bigquery.Table) -> None: exc_info=True, ) + def _create_analytics_views(self) -> None: + """Creates per-event-type BigQuery views (idempotent). + + Each view filters the events table by ``event_type`` and + extracts JSON columns into typed, queryable columns. Uses + ``CREATE OR REPLACE VIEW`` so it is safe to call repeatedly. + Errors are logged but never raised. + """ + for event_type, extra_cols in _EVENT_VIEW_DEFS.items(): + view_name = "v_" + event_type.lower() + columns = ",\n ".join(list(_VIEW_COMMON_COLUMNS) + extra_cols) + sql = _VIEW_SQL_TEMPLATE.format( + project=self.project_id, + dataset=self.dataset_id, + view_name=view_name, + columns=columns, + table=self.table_id, + event_type=event_type, + ) + try: + self.client.query(sql).result() + except Exception as e: + logger.error( + "Failed to create view %s: %s", + view_name, + e, + exc_info=True, + ) + + async def create_analytics_views(self) -> None: + """Public async helper to (re-)create all analytics views. + + Useful when views need to be refreshed explicitly, for example + after a schema upgrade. + """ + loop = asyncio.get_running_loop() + await loop.run_in_executor(self._executor, self._create_analytics_views) + async def shutdown(self, timeout: float | None = None) -> None: """Shuts down the plugin and releases resources. @@ -2032,12 +2188,33 @@ def __getstate__(self): state["parser"] = None state["_started"] = False state["_is_shutting_down"] = False + state["_init_pid"] = 0 return state def __setstate__(self, state): """Custom unpickling to restore state.""" self.__dict__.update(state) + def _reset_runtime_state(self) -> None: + """Resets all runtime state after a fork. + + gRPC channels and asyncio locks are not safe to use after + ``os.fork()``. This method clears them so the next call to + ``_ensure_started()`` re-initializes everything in the child + process. Pure-data fields like ``_schema`` and + ``arrow_schema`` are kept because they are safe across fork. + """ + self._setup_lock = None + self.client = None + self._loop_state_by_loop = {} + self._write_stream_name = None + self._executor = None + self.offloader = None + self.parser = None + self._started = False + self._is_shutting_down = False + self._init_pid = os.getpid() + async def __aenter__(self) -> BigQueryAgentAnalyticsPlugin: await self._ensure_started() return self @@ -2047,6 +2224,8 @@ async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: async def _ensure_started(self, **kwargs) -> None: """Ensures that the plugin is started and initialized.""" + if os.getpid() != self._init_pid: + self._reset_runtime_state() if not self._started: # Kept original lock name as it was not explicitly changed. if self._setup_lock is None: diff --git a/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py b/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py index 549263fbae..48abd636cd 100644 --- a/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py +++ b/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py @@ -4665,3 +4665,207 @@ def regular_tool() -> str: ), f"Expected no HITL events for regular tool, got {hitl_events}" await bq_plugin.shutdown() + + +# ============================================================================== +# Fork-Safety Tests +# ============================================================================== +class TestForkSafety: + """Tests for fork-safety via PID tracking.""" + + def _make_plugin(self): + config = bigquery_agent_analytics_plugin.BigQueryLoggerConfig() + plugin = bigquery_agent_analytics_plugin.BigQueryAgentAnalyticsPlugin( + project_id=PROJECT_ID, + dataset_id=DATASET_ID, + table_id=TABLE_ID, + config=config, + ) + return plugin + + @pytest.mark.asyncio + async def test_pid_change_triggers_reinit( + self, mock_auth_default, mock_bq_client, mock_write_client + ): + """Simulating a fork by changing _init_pid forces re-init.""" + plugin = self._make_plugin() + await plugin._ensure_started() + assert plugin._started is True + + # Simulate a fork: set _init_pid to a stale value + plugin._init_pid = -1 + assert plugin._started is True # still True before check + + # _ensure_started should detect PID mismatch and reset + await plugin._ensure_started() + # After reset + re-init, _init_pid should match current + import os + + assert plugin._init_pid == os.getpid() + assert plugin._started is True + await plugin.shutdown() + + @pytest.mark.asyncio + async def test_pid_unchanged_skips_reset( + self, mock_auth_default, mock_bq_client, mock_write_client + ): + """Same PID should not trigger a reset.""" + plugin = self._make_plugin() + await plugin._ensure_started() + + # Save references to verify they are not recreated + original_client = plugin.client + original_parser = plugin.parser + + await plugin._ensure_started() + assert plugin.client is original_client + assert plugin.parser is original_parser + await plugin.shutdown() + + def test_reset_runtime_state_clears_fields(self): + """_reset_runtime_state clears all runtime fields.""" + plugin = self._make_plugin() + # Fake some runtime state + plugin._started = True + plugin._is_shutting_down = True + plugin.client = mock.MagicMock() + plugin._loop_state_by_loop = {"fake": "state"} + plugin._write_stream_name = "some/stream" + plugin._executor = mock.MagicMock() + plugin.offloader = mock.MagicMock() + plugin.parser = mock.MagicMock() + plugin._setup_lock = mock.MagicMock() + # Keep pure-data fields + plugin._schema = ["kept"] + plugin.arrow_schema = "kept_arrow" + + plugin._reset_runtime_state() + + assert plugin._started is False + assert plugin._is_shutting_down is False + assert plugin.client is None + assert plugin._loop_state_by_loop == {} + assert plugin._write_stream_name is None + assert plugin._executor is None + assert plugin.offloader is None + assert plugin.parser is None + assert plugin._setup_lock is None + # Pure-data fields are preserved + assert plugin._schema == ["kept"] + assert plugin.arrow_schema == "kept_arrow" + + import os + + assert plugin._init_pid == os.getpid() + + def test_getstate_resets_pid(self): + """Pickle state should have _init_pid = 0 to force re-init.""" + plugin = self._make_plugin() + state = plugin.__getstate__() + assert state["_init_pid"] == 0 + assert state["_started"] is False + + +# ============================================================================== +# Analytics Views Tests +# ============================================================================== +class TestAnalyticsViews: + """Tests for auto-created per-event-type BigQuery views.""" + + def _make_plugin(self, create_views=True): + config = bigquery_agent_analytics_plugin.BigQueryLoggerConfig( + create_views=create_views, + ) + plugin = bigquery_agent_analytics_plugin.BigQueryAgentAnalyticsPlugin( + project_id=PROJECT_ID, + dataset_id=DATASET_ID, + table_id=TABLE_ID, + config=config, + ) + plugin.client = mock.MagicMock() + plugin.full_table_id = f"{PROJECT_ID}.{DATASET_ID}.{TABLE_ID}" + plugin._schema = bigquery_agent_analytics_plugin._get_events_schema() + return plugin + + def test_views_created_on_new_table(self): + """NotFound path creates all views.""" + plugin = self._make_plugin(create_views=True) + plugin.client.get_table.side_effect = cloud_exceptions.NotFound("not found") + mock_query_job = mock.MagicMock() + plugin.client.query.return_value = mock_query_job + + plugin._ensure_schema_exists() + + expected_count = len(bigquery_agent_analytics_plugin._EVENT_VIEW_DEFS) + assert plugin.client.query.call_count == expected_count + + def test_views_created_for_existing_table(self): + """Existing table path also creates views.""" + plugin = self._make_plugin(create_views=True) + existing = mock.MagicMock(spec=bigquery.Table) + existing.schema = plugin._schema + existing.labels = { + bigquery_agent_analytics_plugin._SCHEMA_VERSION_LABEL_KEY: ( + bigquery_agent_analytics_plugin._SCHEMA_VERSION + ), + } + plugin.client.get_table.return_value = existing + mock_query_job = mock.MagicMock() + plugin.client.query.return_value = mock_query_job + + plugin._ensure_schema_exists() + + expected_count = len(bigquery_agent_analytics_plugin._EVENT_VIEW_DEFS) + assert plugin.client.query.call_count == expected_count + + def test_views_not_created_when_disabled(self): + """create_views=False skips view creation.""" + plugin = self._make_plugin(create_views=False) + plugin.client.get_table.side_effect = cloud_exceptions.NotFound("not found") + + plugin._ensure_schema_exists() + + plugin.client.query.assert_not_called() + + def test_view_creation_error_logged_not_raised(self): + """Errors during view creation don't crash the plugin.""" + plugin = self._make_plugin(create_views=True) + plugin.client.get_table.side_effect = cloud_exceptions.NotFound("not found") + plugin.client.query.side_effect = Exception("BQ error") + + # Should not raise + plugin._ensure_schema_exists() + + # Verify it tried to create views (and failed gracefully) + assert plugin.client.query.call_count > 0 + + def test_view_sql_contains_correct_event_filter(self): + """Each SQL has correct WHERE clause and view name.""" + plugin = self._make_plugin(create_views=True) + plugin.client.get_table.side_effect = cloud_exceptions.NotFound("not found") + mock_query_job = mock.MagicMock() + plugin.client.query.return_value = mock_query_job + + plugin._ensure_schema_exists() + + calls = plugin.client.query.call_args_list + for call in calls: + sql = call[0][0] + # Each SQL should have CREATE OR REPLACE VIEW + assert "CREATE OR REPLACE VIEW" in sql + # Each SQL should filter by event_type + assert "WHERE" in sql + assert "event_type = " in sql + # View name should start with v_ + assert ".v_" in sql + + # Verify specific views exist + all_sql = " ".join(c[0][0] for c in calls) + for event_type in bigquery_agent_analytics_plugin._EVENT_VIEW_DEFS: + view_name = "v_" + event_type.lower() + assert view_name in all_sql, f"View {view_name} not found in SQL" + + def test_config_create_views_default_true(self): + """Config create_views defaults to True.""" + config = bigquery_agent_analytics_plugin.BigQueryLoggerConfig() + assert config.create_views is True From 5fbfc45c82932fcb97aacd095146b56787f9eef4 Mon Sep 17 00:00:00 2001 From: Haiyuan Cao Date: Thu, 26 Feb 2026 23:12:04 -0800 Subject: [PATCH 02/15] fix: address review findings for BQ plugin fork-safety and views - Backfill _init_pid in __setstate__ for legacy pickle compatibility - Gate view creation on successful table creation (skip after failure) - Ensure plugin is started in public create_analytics_views() - Add tests for all three edge cases Co-Authored-By: Claude Opus 4.6 --- .../bigquery_agent_analytics_plugin.py | 14 ++++- .../test_bigquery_agent_analytics_plugin.py | 58 +++++++++++++++++++ 2 files changed, 69 insertions(+), 3 deletions(-) diff --git a/src/google/adk/plugins/bigquery_agent_analytics_plugin.py b/src/google/adk/plugins/bigquery_agent_analytics_plugin.py index b7066c7280..849dfee35e 100644 --- a/src/google/adk/plugins/bigquery_agent_analytics_plugin.py +++ b/src/google/adk/plugins/bigquery_agent_analytics_plugin.py @@ -2037,10 +2037,13 @@ def _ensure_schema_exists(self) -> None: ) tbl.clustering_fields = self.config.clustering_fields tbl.labels = {_SCHEMA_VERSION_LABEL_KEY: _SCHEMA_VERSION} + table_ready = False try: self.client.create_table(tbl) + table_ready = True except cloud_exceptions.Conflict: - pass + # Another process created it concurrently — still usable. + table_ready = True except Exception as e: logger.error( "Could not create table %s: %s", @@ -2048,7 +2051,7 @@ def _ensure_schema_exists(self) -> None: e, exc_info=True, ) - if self.config.create_views: + if table_ready and self.config.create_views: self._create_analytics_views() except Exception as e: logger.error( @@ -2131,8 +2134,10 @@ async def create_analytics_views(self) -> None: """Public async helper to (re-)create all analytics views. Useful when views need to be refreshed explicitly, for example - after a schema upgrade. + after a schema upgrade. Ensures the plugin is initialized + before attempting view creation. """ + await self._ensure_started() loop = asyncio.get_running_loop() await loop.run_in_executor(self._executor, self._create_analytics_views) @@ -2193,6 +2198,9 @@ def __getstate__(self): def __setstate__(self, state): """Custom unpickling to restore state.""" + # Backfill keys that may be absent in pickled state from older + # code versions so _ensure_started does not raise AttributeError. + state.setdefault("_init_pid", 0) self.__dict__.update(state) def _reset_runtime_state(self) -> None: diff --git a/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py b/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py index 48abd636cd..3512c27037 100644 --- a/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py +++ b/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py @@ -4765,6 +4765,30 @@ def test_getstate_resets_pid(self): assert state["_init_pid"] == 0 assert state["_started"] is False + @pytest.mark.asyncio + async def test_unpickle_legacy_state_missing_init_pid( + self, mock_auth_default, mock_bq_client, mock_write_client + ): + """Unpickling state from older code without _init_pid should not crash.""" + plugin = self._make_plugin() + state = plugin.__getstate__() + # Simulate legacy pickle state that lacks _init_pid entirely + del state["_init_pid"] + + new_plugin = ( + bigquery_agent_analytics_plugin.BigQueryAgentAnalyticsPlugin.__new__( + bigquery_agent_analytics_plugin.BigQueryAgentAnalyticsPlugin + ) + ) + new_plugin.__setstate__(state) + + # _init_pid should be backfilled to 0, triggering re-init + assert new_plugin._init_pid == 0 + # _ensure_started should not raise AttributeError + await new_plugin._ensure_started() + assert new_plugin._started is True + await new_plugin.shutdown() + # ============================================================================== # Analytics Views Tests @@ -4869,3 +4893,37 @@ def test_config_create_views_default_true(self): """Config create_views defaults to True.""" config = bigquery_agent_analytics_plugin.BigQueryLoggerConfig() assert config.create_views is True + + @pytest.mark.asyncio + async def test_create_analytics_views_ensures_started( + self, mock_auth_default, mock_bq_client, mock_write_client + ): + """Public create_analytics_views() initializes plugin first.""" + plugin = bigquery_agent_analytics_plugin.BigQueryAgentAnalyticsPlugin( + project_id=PROJECT_ID, + dataset_id=DATASET_ID, + table_id=TABLE_ID, + ) + assert plugin._started is False + + await plugin.create_analytics_views() + + # Plugin should be started after the call + assert plugin._started is True + # Views should have been created (query called) + expected_count = len(bigquery_agent_analytics_plugin._EVENT_VIEW_DEFS) + # _ensure_schema_exists also creates views, so total calls + # = schema-creation views + explicit views + assert mock_bq_client.query.call_count >= expected_count + await plugin.shutdown() + + def test_views_not_created_after_table_creation_failure(self): + """View creation is skipped when create_table raises a non-Conflict error.""" + plugin = self._make_plugin(create_views=True) + plugin.client.get_table.side_effect = cloud_exceptions.NotFound("not found") + plugin.client.create_table.side_effect = RuntimeError("BQ down") + + plugin._ensure_schema_exists() + + # Views should NOT be attempted since table creation failed + plugin.client.query.assert_not_called() From 33f4e7e77adedfb406d9b02d5cdc0e240459dfe5 Mon Sep 17 00:00:00 2001 From: Haiyuan Cao Date: Thu, 26 Feb 2026 23:25:00 -0800 Subject: [PATCH 03/15] fix: fail fast in create_analytics_views() when startup fails create_analytics_views() now raises RuntimeError if _ensure_started() could not initialize the plugin, instead of silently proceeding with a None client and logging per-view errors. Co-Authored-By: Claude Opus 4.6 --- .../plugins/bigquery_agent_analytics_plugin.py | 4 ++++ .../test_bigquery_agent_analytics_plugin.py | 18 ++++++++++++++++++ 2 files changed, 22 insertions(+) diff --git a/src/google/adk/plugins/bigquery_agent_analytics_plugin.py b/src/google/adk/plugins/bigquery_agent_analytics_plugin.py index 849dfee35e..c5cad67a91 100644 --- a/src/google/adk/plugins/bigquery_agent_analytics_plugin.py +++ b/src/google/adk/plugins/bigquery_agent_analytics_plugin.py @@ -2138,6 +2138,10 @@ async def create_analytics_views(self) -> None: before attempting view creation. """ await self._ensure_started() + if not self._started: + raise RuntimeError( + "Plugin initialization failed; cannot create analytics views." + ) loop = asyncio.get_running_loop() await loop.run_in_executor(self._executor, self._create_analytics_views) diff --git a/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py b/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py index 3512c27037..76f8de4ea6 100644 --- a/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py +++ b/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py @@ -4927,3 +4927,21 @@ def test_views_not_created_after_table_creation_failure(self): # Views should NOT be attempted since table creation failed plugin.client.query.assert_not_called() + + @pytest.mark.asyncio + async def test_create_analytics_views_raises_on_startup_failure( + self, mock_auth_default, mock_write_client + ): + """create_analytics_views() raises if plugin init fails.""" + # Make the BQ Client constructor raise so _lazy_setup fails + # before _started is set to True. + with mock.patch.object( + bigquery, "Client", side_effect=Exception("client boom") + ): + plugin = bigquery_agent_analytics_plugin.BigQueryAgentAnalyticsPlugin( + project_id=PROJECT_ID, + dataset_id=DATASET_ID, + table_id=TABLE_ID, + ) + with pytest.raises(RuntimeError, match="Plugin initialization failed"): + await plugin.create_analytics_views() From b82b263245006194060cc62e73d4e7acd813982d Mon Sep 17 00:00:00 2001 From: Haiyuan Cao Date: Thu, 26 Feb 2026 23:30:35 -0800 Subject: [PATCH 04/15] fix: chain root cause in create_analytics_views() startup failure Persist the init exception as _startup_error and use `raise ... from self._startup_error` so callers get actionable context instead of a generic RuntimeError. Co-Authored-By: Claude Opus 4.6 --- src/google/adk/plugins/bigquery_agent_analytics_plugin.py | 7 ++++++- .../plugins/test_bigquery_agent_analytics_plugin.py | 7 ++++++- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/src/google/adk/plugins/bigquery_agent_analytics_plugin.py b/src/google/adk/plugins/bigquery_agent_analytics_plugin.py index c5cad67a91..636c916161 100644 --- a/src/google/adk/plugins/bigquery_agent_analytics_plugin.py +++ b/src/google/adk/plugins/bigquery_agent_analytics_plugin.py @@ -1763,6 +1763,7 @@ def __init__( self.location = location self._started = False + self._startup_error: Optional[Exception] = None self._is_shutting_down = False self._setup_lock = None self.client = None @@ -2141,7 +2142,7 @@ async def create_analytics_views(self) -> None: if not self._started: raise RuntimeError( "Plugin initialization failed; cannot create analytics views." - ) + ) from self._startup_error loop = asyncio.get_running_loop() await loop.run_in_executor(self._executor, self._create_analytics_views) @@ -2196,6 +2197,7 @@ def __getstate__(self): state["offloader"] = None state["parser"] = None state["_started"] = False + state["_startup_error"] = None state["_is_shutting_down"] = False state["_init_pid"] = 0 return state @@ -2224,6 +2226,7 @@ def _reset_runtime_state(self) -> None: self.offloader = None self.parser = None self._started = False + self._startup_error = None self._is_shutting_down = False self._init_pid = os.getpid() @@ -2247,7 +2250,9 @@ async def _ensure_started(self, **kwargs) -> None: try: await self._lazy_setup(**kwargs) self._started = True + self._startup_error = None except Exception as e: + self._startup_error = e logger.error("Failed to initialize BigQuery Plugin: %s", e) @staticmethod diff --git a/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py b/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py index 76f8de4ea6..d47367632e 100644 --- a/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py +++ b/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py @@ -4943,5 +4943,10 @@ async def test_create_analytics_views_raises_on_startup_failure( dataset_id=DATASET_ID, table_id=TABLE_ID, ) - with pytest.raises(RuntimeError, match="Plugin initialization failed"): + with pytest.raises( + RuntimeError, match="Plugin initialization failed" + ) as exc_info: await plugin.create_analytics_views() + # Root cause should be chained for debuggability + assert exc_info.value.__cause__ is not None + assert "client boom" in str(exc_info.value.__cause__) From 462e1b6869bfbb9857f7dde41f97902efe374722 Mon Sep 17 00:00:00 2001 From: Haiyuan Cao Date: Sat, 28 Feb 2026 00:25:21 -0800 Subject: [PATCH 05/15] fix: trace_id continuity and o11y alignment for BQ plugin (#4645) Fix trace_id fracture between INVOCATION_STARTING and AGENT_STARTING when no ambient OTel span exists (e.g. Agent Engine, custom runners). Also align BQ rows with Cloud Trace span IDs when o11y is active. Changes: - TraceManager.ensure_invocation_span(): ensures a root span is on the plugin stack before any events fire, preventing early events from falling back to invocation_id while later events get OTel hex IDs. - TraceManager.push_span(): create child spans under existing stack parent via trace.set_span_in_context() so all spans in an invocation share the same trace_id (#4645). - _resolve_ids() replaces _resolve_span_ids(): 3-layer priority for trace_id/span_id/parent_span_id resolution: 1. EventData overrides (post-pop callbacks) 2. Ambient OTel span (aligns with Cloud Trace when o11y is active) 3. Plugin's internal span stack (fallback for no-ambient paths) - EventData.trace_id_override: captured before pop in after_run_callback so INVOCATION_COMPLETED shares the same trace_id as earlier events. - on_user_message_callback / before_run_callback: call ensure_invocation_span() before logging. - after_run_callback: capture trace_id before pop, pass as override. Tests: - TestTraceIdContinuity: 5 tests covering no-ambient, ambient, cross-turn isolation, completion-event continuity, and full callback-level integration. - TestResolveIds: 6 tests replacing TestResolveSpanIds, adding ambient priority and override-beats-ambient coverage. Co-Authored-By: Claude Opus 4.6 --- .../bigquery_agent_analytics_plugin.py | 110 ++++- .../test_bigquery_agent_analytics_plugin.py | 443 ++++++++++++++++-- 2 files changed, 493 insertions(+), 60 deletions(-) diff --git a/src/google/adk/plugins/bigquery_agent_analytics_plugin.py b/src/google/adk/plugins/bigquery_agent_analytics_plugin.py index 636c916161..b1bbd2de40 100644 --- a/src/google/adk/plugins/bigquery_agent_analytics_plugin.py +++ b/src/google/adk/plugins/bigquery_agent_analytics_plugin.py @@ -604,7 +604,16 @@ def push_span( # Create the span without attaching it to the ambient context. # This avoids re-parenting framework spans like ``call_llm`` # or ``execute_tool``. See #4561. - span = tracer.start_span(span_name) + # + # If the internal stack already has a span, create the new span + # as a child so it shares the same trace_id. Without this, each + # ``start_span`` would be an independent root with its own + # trace_id — causing trace_id fracture (see #4645). + records = TraceManager._get_records() + parent_ctx = None + if records and records[-1].span.get_span_context().is_valid: + parent_ctx = trace.set_span_in_context(records[-1].span) + span = tracer.start_span(span_name, context=parent_ctx) if span.get_span_context().is_valid: span_id_str = format(span.get_span_context().span_id, "016x") @@ -618,7 +627,6 @@ def push_span( start_time_ns=time.time_ns(), ) - records = TraceManager._get_records() new_records = list(records) + [record] _span_records_ctx.set(new_records) @@ -655,6 +663,32 @@ def attach_current_span( return span_id_str + @staticmethod + def ensure_invocation_span( + callback_context: CallbackContext, + ) -> None: + """Ensures a root span exists on the plugin stack for this invocation. + + Must be called before any events are logged so that every event in + the invocation shares the same trace_id. + + * If the stack already has entries → no-op (already initialised). + * If the ambient OTel span is valid → ``attach_current_span`` (reuse + the runner's span without owning it). + * Otherwise → ``push_span("invocation")`` (create a new root span + that will be popped in ``after_run_callback``). + """ + records = _span_records_ctx.get() + if records: + return # Already initialised for this invocation. + + # Check for a valid ambient span (e.g. the Runner's invocation span). + ambient = trace.get_current_span() + if ambient.get_span_context().is_valid: + TraceManager.attach_current_span(callback_context) + else: + TraceManager.push_span(callback_context, "invocation") + @staticmethod def pop_span() -> tuple[Optional[str], Optional[int]]: """Ends the current span and pops it from the stack. @@ -1709,6 +1743,7 @@ class _LoopState: class EventData: """Typed container for structured fields passed to _log_event.""" + trace_id_override: Optional[str] = None span_id_override: Optional[str] = None parent_span_id_override: Optional[str] = None latency_ms: Optional[int] = None @@ -2256,27 +2291,50 @@ async def _ensure_started(self, **kwargs) -> None: logger.error("Failed to initialize BigQuery Plugin: %s", e) @staticmethod - def _resolve_span_ids( + def _resolve_ids( event_data: EventData, - ) -> tuple[str, str]: - """Reads span/parent overrides from EventData, falling back to TraceManager. + callback_context: CallbackContext, + ) -> tuple[Optional[str], Optional[str], Optional[str]]: + """Resolves trace_id, span_id, and parent_span_id for a log row. + + Priority order (highest first): + 1. Explicit ``EventData`` overrides (needed for post-pop callbacks). + 2. Ambient OTel span (the framework's ``start_as_current_span``). + When present this aligns BQ rows with Cloud Trace / o11y. + 3. Plugin's internal span stack (``TraceManager``). + 4. ``invocation_id`` fallback for trace_id. Returns: - (span_id, parent_span_id) + (trace_id, span_id, parent_span_id) """ - current_span_id, current_parent_span_id = ( + # --- Layer 3: plugin stack baseline --- + trace_id = TraceManager.get_trace_id(callback_context) + plugin_span_id, plugin_parent_span_id = ( TraceManager.get_current_span_and_parent() ) - - span_id = current_span_id + span_id = plugin_span_id + parent_span_id = plugin_parent_span_id + + # --- Layer 2: ambient OTel span --- + ambient = trace.get_current_span() + ambient_ctx = ambient.get_span_context() + if ambient_ctx.is_valid: + trace_id = format(ambient_ctx.trace_id, "032x") + span_id = format(ambient_ctx.span_id, "016x") + # SDK spans expose .parent; non-recording spans do not. + parent_ctx = getattr(ambient, "parent", None) + if parent_ctx is not None and parent_ctx.span_id: + parent_span_id = format(parent_ctx.span_id, "016x") + + # --- Layer 1: explicit EventData overrides --- + if event_data.trace_id_override is not None: + trace_id = event_data.trace_id_override if event_data.span_id_override is not None: span_id = event_data.span_id_override - - parent_span_id = current_parent_span_id if event_data.parent_span_id_override is not None: parent_span_id = event_data.parent_span_id_override - return span_id, parent_span_id + return trace_id, span_id, parent_span_id @staticmethod def _extract_latency( @@ -2389,8 +2447,9 @@ async def _log_event( except Exception as e: logger.warning("Content formatter failed: %s", e) - trace_id = TraceManager.get_trace_id(callback_context) - span_id, parent_span_id = self._resolve_span_ids(event_data) + trace_id, span_id, parent_span_id = self._resolve_ids( + event_data, callback_context + ) if not self.parser: logger.warning("Parser not initialized; skipping event %s.", event_type) @@ -2457,6 +2516,7 @@ async def on_user_message_callback( user_message: The message content received from the user. """ callback_ctx = CallbackContext(invocation_context) + TraceManager.ensure_invocation_span(callback_ctx) await self._log_event( "USER_MESSAGE_RECEIVED", callback_ctx, @@ -2591,9 +2651,11 @@ async def before_run_callback( invocation_context: The context of the current invocation. """ await self._ensure_started() + callback_ctx = CallbackContext(invocation_context) + TraceManager.ensure_invocation_span(callback_ctx) await self._log_event( "INVOCATION_STARTING", - CallbackContext(invocation_context), + callback_ctx, ) @_safe_callback @@ -2605,9 +2667,25 @@ async def after_run_callback( Args: invocation_context: The context of the current invocation. """ + # Capture trace_id BEFORE popping the invocation-root span so that + # INVOCATION_COMPLETED shares the same trace_id as all earlier events + # in this invocation (fixes #4645 completion-event fracture). + callback_ctx = CallbackContext(invocation_context) + trace_id = TraceManager.get_trace_id(callback_ctx) + + # Pop the invocation-root span pushed by ensure_invocation_span(). + span_id, duration = TraceManager.pop_span() + parent_span_id = TraceManager.get_current_span_id() + await self._log_event( "INVOCATION_COMPLETED", - CallbackContext(invocation_context), + callback_ctx, + event_data=EventData( + trace_id_override=trace_id, + latency_ms=duration, + span_id_override=span_id, + parent_span_id_override=parent_span_id, + ), ) # Ensure all logs are flushed before the agent returns await self.flush() diff --git a/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py b/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py index d47367632e..71b5b840e2 100644 --- a/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py +++ b/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py @@ -2152,7 +2152,7 @@ async def test_otel_integration( span_id = bigquery_agent_analytics_plugin.TraceManager.push_span( callback_context, "test_span" ) - mock_tracer.start_span.assert_called_with("test_span") + mock_tracer.start_span.assert_called_with("test_span", context=None) assert span_id == format(span_id_int, "016x") # Test get_trace_id # We need to mock trace.get_current_span() to return our mock span @@ -3018,81 +3018,149 @@ async def test_no_config_no_labels( assert "labels" not in attributes -class TestResolveSpanIds: - """Tests for the _resolve_span_ids static helper.""" +class TestResolveIds: + """Tests for the _resolve_ids static helper.""" - def test_uses_trace_manager_defaults(self): - """Should use TraceManager values when no overrides provided.""" + def _resolve(self, ed, callback_context): + return bigquery_agent_analytics_plugin.BigQueryAgentAnalyticsPlugin._resolve_ids( + ed, callback_context + ) + + def test_uses_trace_manager_defaults(self, callback_context): + """Should use TraceManager values when no overrides and no ambient.""" ed = bigquery_agent_analytics_plugin.EventData( extra_attributes={"some_key": "value"} ) - with mock.patch.object( - bigquery_agent_analytics_plugin.TraceManager, - "get_current_span_and_parent", - return_value=("span-1", "parent-1"), + with ( + mock.patch.object( + bigquery_agent_analytics_plugin.TraceManager, + "get_current_span_and_parent", + return_value=("span-1", "parent-1"), + ), + mock.patch.object( + bigquery_agent_analytics_plugin.TraceManager, + "get_trace_id", + return_value="trace-1", + ), ): - span_id, parent_id = ( - bigquery_agent_analytics_plugin.BigQueryAgentAnalyticsPlugin._resolve_span_ids( - ed - ) - ) + trace_id, span_id, parent_id = self._resolve(ed, callback_context) + assert trace_id == "trace-1" assert span_id == "span-1" assert parent_id == "parent-1" - def test_span_id_override(self): + def test_span_id_override(self, callback_context): """Should use span_id_override from EventData.""" ed = bigquery_agent_analytics_plugin.EventData( span_id_override="custom-span" ) - with mock.patch.object( - bigquery_agent_analytics_plugin.TraceManager, - "get_current_span_and_parent", - return_value=("span-1", "parent-1"), + with ( + mock.patch.object( + bigquery_agent_analytics_plugin.TraceManager, + "get_current_span_and_parent", + return_value=("span-1", "parent-1"), + ), + mock.patch.object( + bigquery_agent_analytics_plugin.TraceManager, + "get_trace_id", + return_value="trace-1", + ), ): - span_id, parent_id = ( - bigquery_agent_analytics_plugin.BigQueryAgentAnalyticsPlugin._resolve_span_ids( - ed - ) - ) + trace_id, span_id, parent_id = self._resolve(ed, callback_context) assert span_id == "custom-span" assert parent_id == "parent-1" - def test_parent_span_id_override(self): + def test_parent_span_id_override(self, callback_context): """Should use parent_span_id_override from EventData.""" ed = bigquery_agent_analytics_plugin.EventData( parent_span_id_override="custom-parent" ) - with mock.patch.object( - bigquery_agent_analytics_plugin.TraceManager, - "get_current_span_and_parent", - return_value=("span-1", "parent-1"), + with ( + mock.patch.object( + bigquery_agent_analytics_plugin.TraceManager, + "get_current_span_and_parent", + return_value=("span-1", "parent-1"), + ), + mock.patch.object( + bigquery_agent_analytics_plugin.TraceManager, + "get_trace_id", + return_value="trace-1", + ), ): - span_id, parent_id = ( - bigquery_agent_analytics_plugin.BigQueryAgentAnalyticsPlugin._resolve_span_ids( - ed - ) - ) + trace_id, span_id, parent_id = self._resolve(ed, callback_context) assert span_id == "span-1" assert parent_id == "custom-parent" - def test_none_override_keeps_default(self): + def test_none_override_keeps_default(self, callback_context): """None overrides should keep the TraceManager defaults.""" ed = bigquery_agent_analytics_plugin.EventData( span_id_override=None, parent_span_id_override=None ) - with mock.patch.object( - bigquery_agent_analytics_plugin.TraceManager, - "get_current_span_and_parent", - return_value=("span-1", "parent-1"), + with ( + mock.patch.object( + bigquery_agent_analytics_plugin.TraceManager, + "get_current_span_and_parent", + return_value=("span-1", "parent-1"), + ), + mock.patch.object( + bigquery_agent_analytics_plugin.TraceManager, + "get_trace_id", + return_value="trace-1", + ), ): - span_id, parent_id = ( - bigquery_agent_analytics_plugin.BigQueryAgentAnalyticsPlugin._resolve_span_ids( - ed - ) - ) + trace_id, span_id, parent_id = self._resolve(ed, callback_context) assert span_id == "span-1" assert parent_id == "parent-1" + def test_ambient_otel_span_takes_priority(self, callback_context): + """When an ambient OTel span is valid, its IDs take priority.""" + from opentelemetry.sdk.trace import TracerProvider as SdkProvider + from opentelemetry.sdk.trace.export import SimpleSpanProcessor + from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter + + provider = SdkProvider() + provider.add_span_processor(SimpleSpanProcessor(InMemorySpanExporter())) + real_tracer = provider.get_tracer("test") + + ed = bigquery_agent_analytics_plugin.EventData() + + with real_tracer.start_as_current_span("invocation") as parent_span: + with real_tracer.start_as_current_span("agent") as agent_span: + ambient_ctx = agent_span.get_span_context() + expected_trace = format(ambient_ctx.trace_id, "032x") + expected_span = format(ambient_ctx.span_id, "016x") + expected_parent = format(parent_span.get_span_context().span_id, "016x") + + trace_id, span_id, parent_id = self._resolve(ed, callback_context) + + assert trace_id == expected_trace + assert span_id == expected_span + assert parent_id == expected_parent + provider.shutdown() + + def test_override_beats_ambient(self, callback_context): + """EventData overrides take priority over ambient OTel span.""" + from opentelemetry.sdk.trace import TracerProvider as SdkProvider + from opentelemetry.sdk.trace.export import SimpleSpanProcessor + from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter + + provider = SdkProvider() + provider.add_span_processor(SimpleSpanProcessor(InMemorySpanExporter())) + real_tracer = provider.get_tracer("test") + + ed = bigquery_agent_analytics_plugin.EventData( + trace_id_override="forced-trace", + span_id_override="forced-span", + parent_span_id_override="forced-parent", + ) + + with real_tracer.start_as_current_span("invocation"): + trace_id, span_id, parent_id = self._resolve(ed, callback_context) + + assert trace_id == "forced-trace" + assert span_id == "forced-span" + assert parent_id == "forced-parent" + provider.shutdown() + class TestExtractLatency: """Tests for the _extract_latency static helper.""" @@ -4950,3 +5018,290 @@ async def test_create_analytics_views_raises_on_startup_failure( # Root cause should be chained for debuggability assert exc_info.value.__cause__ is not None assert "client boom" in str(exc_info.value.__cause__) + + +# ============================================================================== +# Trace-ID Continuity Tests (Issue #4645) +# ============================================================================== +class TestTraceIdContinuity: + """Tests for trace_id continuity across all events in an invocation. + + Regression tests for https://github.com/google/adk-python/issues/4645. + + When there is no ambient OTel span (e.g. Agent Engine, custom runners), + early events (USER_MESSAGE_RECEIVED, INVOCATION_STARTING) used to fall + back to ``invocation_id`` while AGENT_STARTING got a new OTel hex + trace_id from ``push_span()``. The ``ensure_invocation_span()`` fix + guarantees a root span is always on the stack before any events fire. + """ + + @pytest.mark.asyncio + async def test_trace_id_continuity_no_ambient_span(self, callback_context): + """All events share one trace_id when no ambient OTel span exists. + + Simulates the #4645 scenario: OTel IS configured (real TracerProvider) + but the Runner's ambient span is NOT present (e.g. Agent Engine, + custom runners). + """ + from opentelemetry.sdk.trace import TracerProvider as SdkProvider + from opentelemetry.sdk.trace.export import SimpleSpanProcessor + from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter + + TM = bigquery_agent_analytics_plugin.TraceManager + + # Create a real TracerProvider and patch the plugin's module-level + # tracer so push_span creates valid spans with proper trace_ids. + exporter = InMemorySpanExporter() + provider = SdkProvider() + provider.add_span_processor(SimpleSpanProcessor(exporter)) + real_tracer = provider.get_tracer("test-plugin") + + with mock.patch.object( + bigquery_agent_analytics_plugin, "tracer", real_tracer + ): + # Reset the span records contextvar for a clean invocation. + bigquery_agent_analytics_plugin._span_records_ctx.set(None) + + # No ambient OTel span — we do NOT start_as_current_span. + ambient = trace.get_current_span() + assert not ambient.get_span_context().is_valid + + # ensure_invocation_span should push a new span. + TM.ensure_invocation_span(callback_context) + trace_id_early = TM.get_trace_id(callback_context) + assert trace_id_early is not None + # Should NOT fall back to invocation_id — it should be + # a 32-char hex OTel trace_id. + assert trace_id_early != callback_context.invocation_id + assert len(trace_id_early) == 32 + + # Simulate agent callback: push_span("agent") + TM.push_span(callback_context, "agent") + trace_id_agent = TM.get_trace_id(callback_context) + + # Both trace_ids must be identical. + assert trace_id_early == trace_id_agent + + # Cleanup + TM.pop_span() # agent + TM.pop_span() # invocation + + provider.shutdown() + + @pytest.mark.asyncio + async def test_invocation_completed_trace_continuity_no_ambient( + self, callback_context + ): + """INVOCATION_COMPLETED must share trace_id with earlier events. + + Reproduces the completion-event fracture: after_run_callback pops + the invocation span, then _log_event would resolve trace_id via + the fallback to invocation_id. The trace_id_override ensures the + completion event keeps the same trace_id. + """ + from opentelemetry.sdk.trace import TracerProvider as SdkProvider + from opentelemetry.sdk.trace.export import SimpleSpanProcessor + from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter + + TM = bigquery_agent_analytics_plugin.TraceManager + + exporter = InMemorySpanExporter() + provider = SdkProvider() + provider.add_span_processor(SimpleSpanProcessor(exporter)) + real_tracer = provider.get_tracer("test-plugin") + + with mock.patch.object( + bigquery_agent_analytics_plugin, "tracer", real_tracer + ): + # Reset for a clean invocation; no ambient span. + bigquery_agent_analytics_plugin._span_records_ctx.set(None) + assert not trace.get_current_span().get_span_context().is_valid + + # --- Simulate the full callback lifecycle --- + # 1. before_run / on_user_message: ensure invocation span + TM.ensure_invocation_span(callback_context) + trace_id_start = TM.get_trace_id(callback_context) + + # 2. before_agent: push agent span + TM.push_span(callback_context, "agent") + assert TM.get_trace_id(callback_context) == trace_id_start + + # 3. after_agent: pop agent span + TM.pop_span() + + # 4. after_run: capture trace_id THEN pop invocation span + trace_id_before_pop = TM.get_trace_id(callback_context) + assert trace_id_before_pop == trace_id_start + + TM.pop_span() + + # After popping, get_trace_id falls back to invocation_id + trace_id_after_pop = TM.get_trace_id(callback_context) + assert trace_id_after_pop == callback_context.invocation_id + + # The trace_id_override preserves continuity + assert trace_id_before_pop == trace_id_start + assert trace_id_before_pop != trace_id_after_pop + + provider.shutdown() + + @pytest.mark.asyncio + async def test_callbacks_emit_same_trace_id_no_ambient( + self, + bq_plugin_inst, + mock_write_client, + invocation_context, + callback_context, + mock_agent, + dummy_arrow_schema, + ): + """Full callback path: all emitted rows share one trace_id. + + Exercises the real before_run → before_agent → after_agent → + after_run callback chain via the plugin instance, then checks + every emitted BQ row has the same trace_id. + """ + from opentelemetry.sdk.trace import TracerProvider as SdkProvider + from opentelemetry.sdk.trace.export import SimpleSpanProcessor + from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter + + exporter = InMemorySpanExporter() + provider = SdkProvider() + provider.add_span_processor(SimpleSpanProcessor(exporter)) + real_tracer = provider.get_tracer("test-plugin") + + with mock.patch.object( + bigquery_agent_analytics_plugin, "tracer", real_tracer + ): + # Reset span records for a clean invocation. + bigquery_agent_analytics_plugin._span_records_ctx.set(None) + + # No ambient span — simulates Agent Engine / custom runner. + assert not trace.get_current_span().get_span_context().is_valid + + # Run the full callback lifecycle. + await bq_plugin_inst.before_run_callback( + invocation_context=invocation_context + ) + await bq_plugin_inst.before_agent_callback( + agent=mock_agent, callback_context=callback_context + ) + await bq_plugin_inst.after_agent_callback( + agent=mock_agent, callback_context=callback_context + ) + await bq_plugin_inst.after_run_callback( + invocation_context=invocation_context + ) + await asyncio.sleep(0.01) + + # Collect all emitted rows. + rows = await _get_captured_rows_async( + mock_write_client, dummy_arrow_schema + ) + event_types = [r["event_type"] for r in rows] + assert "INVOCATION_STARTING" in event_types + assert "INVOCATION_COMPLETED" in event_types + + # Every row must share the same trace_id. + trace_ids = {r["trace_id"] for r in rows} + assert len(trace_ids) == 1, ( + "Expected 1 unique trace_id across all events, got" + f" {len(trace_ids)}: {trace_ids}" + ) + # Should be a 32-char hex OTel trace, not the invocation_id. + sole_trace_id = trace_ids.pop() + assert sole_trace_id != invocation_context.invocation_id + assert len(sole_trace_id) == 32 + + provider.shutdown() + + @pytest.mark.asyncio + async def test_trace_id_continuity_with_ambient_span(self, callback_context): + """All events share one trace_id when an ambient OTel span exists.""" + from opentelemetry.sdk.trace import TracerProvider as SdkProvider + from opentelemetry.sdk.trace.export import SimpleSpanProcessor + from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter + + TM = bigquery_agent_analytics_plugin.TraceManager + + # Set up a real OTel tracer. + exporter = InMemorySpanExporter() + provider = SdkProvider() + provider.add_span_processor(SimpleSpanProcessor(exporter)) + real_tracer = provider.get_tracer("test") + + with mock.patch.object( + bigquery_agent_analytics_plugin, "tracer", real_tracer + ): + # Reset the span records contextvar. + bigquery_agent_analytics_plugin._span_records_ctx.set(None) + + with real_tracer.start_as_current_span("runner_invocation"): + ambient = trace.get_current_span() + assert ambient.get_span_context().is_valid + ambient_trace_id = format(ambient.get_span_context().trace_id, "032x") + + # ensure_invocation_span should attach the ambient span. + TM.ensure_invocation_span(callback_context) + trace_id_early = TM.get_trace_id(callback_context) + assert trace_id_early == ambient_trace_id + + # Simulate agent callback: push_span("agent") + TM.push_span(callback_context, "agent") + trace_id_agent = TM.get_trace_id(callback_context) + assert trace_id_agent == ambient_trace_id + + # Cleanup + TM.pop_span() # agent + TM.pop_span() # invocation (attached, not owned) + + provider.shutdown() + + @pytest.mark.asyncio + async def test_invocation_root_span_isolated_across_turns( + self, callback_context + ): + """Each invocation gets its own root span; turns don't leak.""" + from opentelemetry.sdk.trace import TracerProvider as SdkProvider + from opentelemetry.sdk.trace.export import SimpleSpanProcessor + from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter + + TM = bigquery_agent_analytics_plugin.TraceManager + + exporter = InMemorySpanExporter() + provider = SdkProvider() + provider.add_span_processor(SimpleSpanProcessor(exporter)) + real_tracer = provider.get_tracer("test") + + with mock.patch.object( + bigquery_agent_analytics_plugin, "tracer", real_tracer + ): + # --- Turn 1 --- + bigquery_agent_analytics_plugin._span_records_ctx.set(None) + TM.ensure_invocation_span(callback_context) + trace_id_turn1 = TM.get_trace_id(callback_context) + + TM.push_span(callback_context, "agent") + assert TM.get_trace_id(callback_context) == trace_id_turn1 + TM.pop_span() # agent + TM.pop_span() # invocation + + # After popping, the stack should be empty. + records = bigquery_agent_analytics_plugin._span_records_ctx.get() + assert not records + + # --- Turn 2 --- + bigquery_agent_analytics_plugin._span_records_ctx.set(None) + TM.ensure_invocation_span(callback_context) + trace_id_turn2 = TM.get_trace_id(callback_context) + + TM.push_span(callback_context, "agent") + assert TM.get_trace_id(callback_context) == trace_id_turn2 + TM.pop_span() # agent + TM.pop_span() # invocation + + # The two turns must have DIFFERENT trace_ids (different + # root spans). + assert trace_id_turn1 != trace_id_turn2 + + provider.shutdown() From 33cc4238b5592bf9447ff72406ef33b4bd29dd6a Mon Sep 17 00:00:00 2001 From: Haiyuan Cao Date: Sat, 28 Feb 2026 00:30:54 -0800 Subject: [PATCH 06/15] fix: clear stale parent_span_id when ambient span is root When the ambient OTel span is a root (no parent), _resolve_ids() was leaving the stale plugin-stack parent_span_id in place, producing invalid self-parent lineage (span_id == parent_span_id). Reset parent_span_id to None when ambient takes over, then only set it if the ambient span has a valid parent. Co-Authored-By: Claude Opus 4.6 --- .../bigquery_agent_analytics_plugin.py | 3 ++ .../test_bigquery_agent_analytics_plugin.py | 36 +++++++++++++++++++ 2 files changed, 39 insertions(+) diff --git a/src/google/adk/plugins/bigquery_agent_analytics_plugin.py b/src/google/adk/plugins/bigquery_agent_analytics_plugin.py index b1bbd2de40..a586b36d52 100644 --- a/src/google/adk/plugins/bigquery_agent_analytics_plugin.py +++ b/src/google/adk/plugins/bigquery_agent_analytics_plugin.py @@ -2321,6 +2321,9 @@ def _resolve_ids( if ambient_ctx.is_valid: trace_id = format(ambient_ctx.trace_id, "032x") span_id = format(ambient_ctx.span_id, "016x") + # Reset parent — stale plugin-stack parent must not leak through + # when the ambient span is a root (no parent). + parent_span_id = None # SDK spans expose .parent; non-recording spans do not. parent_ctx = getattr(ambient, "parent", None) if parent_ctx is not None and parent_ctx.span_id: diff --git a/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py b/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py index 71b5b840e2..4b620628f5 100644 --- a/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py +++ b/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py @@ -3161,6 +3161,42 @@ def test_override_beats_ambient(self, callback_context): assert parent_id == "forced-parent" provider.shutdown() + def test_ambient_root_span_no_self_parent(self, callback_context): + """Ambient root span (no parent) must not produce self-parent.""" + from opentelemetry.sdk.trace import TracerProvider as SdkProvider + from opentelemetry.sdk.trace.export import SimpleSpanProcessor + from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter + + provider = SdkProvider() + provider.add_span_processor(SimpleSpanProcessor(InMemorySpanExporter())) + real_tracer = provider.get_tracer("test") + + # Seed the plugin stack with a span so there's a stale parent. + bigquery_agent_analytics_plugin._span_records_ctx.set(None) + with mock.patch.object( + bigquery_agent_analytics_plugin, "tracer", real_tracer + ): + bigquery_agent_analytics_plugin.TraceManager.push_span( + callback_context, "plugin-child" + ) + + ed = bigquery_agent_analytics_plugin.EventData() + + # Single root ambient span — no parent. + with real_tracer.start_as_current_span("root_invocation") as root: + trace_id, span_id, parent_id = self._resolve(ed, callback_context) + root_span_id = format(root.get_span_context().span_id, "016x") + + # span_id should be the ambient root's span_id + assert span_id == root_span_id + # parent must be None — not the stale plugin parent, not self + assert parent_id is None + assert span_id != parent_id + + # Cleanup + bigquery_agent_analytics_plugin.TraceManager.pop_span() + provider.shutdown() + class TestExtractLatency: """Tests for the _extract_latency static helper.""" From b5bdc30fa7e0800d6840f64246b7cf12d6352302 Mon Sep 17 00:00:00 2001 From: Haiyuan Cao Date: Sat, 28 Feb 2026 08:04:52 -0800 Subject: [PATCH 07/15] fix: move trace_id_override to end of EventData to preserve positional API The trace_id_override field was added at the top of EventData, which shifted all existing positional parameters and broke the public API surface. Move it after extra_attributes to preserve backward compatibility. Co-Authored-By: Claude Opus 4.6 --- src/google/adk/plugins/bigquery_agent_analytics_plugin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/google/adk/plugins/bigquery_agent_analytics_plugin.py b/src/google/adk/plugins/bigquery_agent_analytics_plugin.py index a586b36d52..ee75ccfe73 100644 --- a/src/google/adk/plugins/bigquery_agent_analytics_plugin.py +++ b/src/google/adk/plugins/bigquery_agent_analytics_plugin.py @@ -1743,7 +1743,6 @@ class _LoopState: class EventData: """Typed container for structured fields passed to _log_event.""" - trace_id_override: Optional[str] = None span_id_override: Optional[str] = None parent_span_id_override: Optional[str] = None latency_ms: Optional[int] = None @@ -1754,6 +1753,7 @@ class EventData: status: str = "OK" error_message: Optional[str] = None extra_attributes: dict[str, Any] = field(default_factory=dict) + trace_id_override: Optional[str] = None class BigQueryAgentAnalyticsPlugin(BasePlugin): From 5da0421a1e5178e833313b7047f105a046d5ff7d Mon Sep 17 00:00:00 2001 From: Haiyuan Cao Date: Sat, 28 Feb 2026 08:20:48 -0800 Subject: [PATCH 08/15] fix: make EventData kw_only to prevent positional parameter breakage Add kw_only=True to the EventData dataclass so that field ordering can never break callers. All existing usages already pass keyword arguments. Co-Authored-By: Claude Opus 4.6 --- src/google/adk/plugins/bigquery_agent_analytics_plugin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/google/adk/plugins/bigquery_agent_analytics_plugin.py b/src/google/adk/plugins/bigquery_agent_analytics_plugin.py index ee75ccfe73..bfcc51a09b 100644 --- a/src/google/adk/plugins/bigquery_agent_analytics_plugin.py +++ b/src/google/adk/plugins/bigquery_agent_analytics_plugin.py @@ -1739,7 +1739,7 @@ class _LoopState: batch_processor: BatchProcessor -@dataclass +@dataclass(kw_only=True) class EventData: """Typed container for structured fields passed to _log_event.""" From ee908ff3dfb2c752102d29ac1adcf5e5ca585c7e Mon Sep 17 00:00:00 2001 From: Haiyuan Cao Date: Sat, 28 Feb 2026 08:50:49 -0800 Subject: [PATCH 09/15] fix: span-ID consistency under ambient OTel and stack leak safety MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit P1 — Span-ID inconsistency: *_STARTING events used ambient framework span IDs via _resolve_ids Layer 2, while *_COMPLETED events bypassed Layer 2 by passing explicit span_id_override from the plugin's popped span. Now completion callbacks (after_agent, after_model, on_model_error, after_tool, on_tool_error, after_run) check for ambient OTel and pass None overrides when present, letting _resolve_ids use the framework's ambient span — keeping STARTING/COMPLETED pairs consistent. P2 — Stack leak on abnormal exit: Added TraceManager.clear_stack() to end owned spans and reset the stack. ensure_invocation_span() now clears stale records instead of no-op'ing, and after_run_callback calls clear_stack() as a safety net. Bonus: on_tool_error_callback previously discarded the span_id from pop_span(); now captures it for correct TOOL_ERROR span attribution. Co-Authored-By: Claude Opus 4.6 --- .../bigquery_agent_analytics_plugin.py | 78 +++- .../test_bigquery_agent_analytics_plugin.py | 424 ++++++++++++++++++ 2 files changed, 487 insertions(+), 15 deletions(-) diff --git a/src/google/adk/plugins/bigquery_agent_analytics_plugin.py b/src/google/adk/plugins/bigquery_agent_analytics_plugin.py index bfcc51a09b..f7348c269b 100644 --- a/src/google/adk/plugins/bigquery_agent_analytics_plugin.py +++ b/src/google/adk/plugins/bigquery_agent_analytics_plugin.py @@ -680,7 +680,14 @@ def ensure_invocation_span( """ records = _span_records_ctx.get() if records: - return # Already initialised for this invocation. + # Stale records from a previous invocation that wasn't cleaned + # up (e.g. exception skipped after_run_callback). Clear and + # re-init. + logger.debug( + "Clearing %d stale span records from previous invocation.", + len(records), + ) + TraceManager.clear_stack() # Check for a valid ambient span (e.g. the Runner's invocation span). ambient = trace.get_current_span() @@ -717,6 +724,17 @@ def pop_span() -> tuple[Optional[str], Optional[int]]: return record.span_id, duration_ms + @staticmethod + def clear_stack() -> None: + """Clears all span records. Safety net for cross-invocation cleanup.""" + records = _span_records_ctx.get() + if records: + # End any owned spans to avoid OTel resource leaks. + for record in reversed(records): + if record.owns_span: + record.span.end() + _span_records_ctx.set([]) + @staticmethod def get_current_span_and_parent() -> tuple[Optional[str], Optional[str]]: """Gets current span_id and parent span_id.""" @@ -2680,16 +2698,24 @@ async def after_run_callback( span_id, duration = TraceManager.pop_span() parent_span_id = TraceManager.get_current_span_id() + # Only override span IDs when no ambient OTel span exists. + # When ambient exists, _resolve_ids Layer 2 uses the framework's + # span IDs, keeping STARTING/COMPLETED pairs consistent. + has_ambient = trace.get_current_span().get_span_context().is_valid + await self._log_event( "INVOCATION_COMPLETED", callback_ctx, event_data=EventData( trace_id_override=trace_id, latency_ms=duration, - span_id_override=span_id, - parent_span_id_override=parent_span_id, + span_id_override=None if has_ambient else span_id, + parent_span_id_override=(None if has_ambient else parent_span_id), ), ) + # Safety net: clear any remaining stack entries from this + # invocation to prevent leaks into the next one. + TraceManager.clear_stack() # Ensure all logs are flushed before the agent returns await self.flush() @@ -2722,18 +2748,20 @@ async def after_agent_callback( callback_context: The callback context. """ span_id, duration = TraceManager.pop_span() - # When popping, the current stack now points to parent. - # The event we are logging ("AGENT_COMPLETED") belongs to the span we just popped. - # So we must override span_id to be the popped span, and parent to be current top of stack. parent_span_id, _ = TraceManager.get_current_span_and_parent() + # Only override span IDs when no ambient OTel span exists. + # When ambient exists, _resolve_ids Layer 2 uses the framework's + # span IDs, keeping STARTING/COMPLETED pairs consistent. + has_ambient = trace.get_current_span().get_span_context().is_valid + await self._log_event( "AGENT_COMPLETED", callback_context, event_data=EventData( latency_ms=duration, - span_id_override=span_id, - parent_span_id_override=parent_span_id, + span_id_override=None if has_ambient else span_id, + parent_span_id_override=(None if has_ambient else parent_span_id), ), ) @@ -2883,6 +2911,12 @@ async def after_model_callback( # Otherwise log_event will fetch current stack (which is parent). span_id = popped_span_id or span_id + # Only override span IDs when no ambient OTel span exists. + # When ambient exists, _resolve_ids Layer 2 uses the framework's + # span IDs, keeping LLM_REQUEST/LLM_RESPONSE pairs consistent. + has_ambient = trace.get_current_span().get_span_context().is_valid + use_override = is_popped and not has_ambient + await self._log_event( "LLM_RESPONSE", callback_context, @@ -2893,8 +2927,8 @@ async def after_model_callback( time_to_first_token_ms=tfft, model_version=llm_response.model_version, usage_metadata=llm_response.usage_metadata, - span_id_override=span_id if is_popped else None, - parent_span_id_override=(parent_span_id if is_popped else None), + span_id_override=span_id if use_override else None, + parent_span_id_override=(parent_span_id if use_override else None), ), ) @@ -2915,14 +2949,18 @@ async def on_model_error_callback( """ span_id, duration = TraceManager.pop_span() parent_span_id, _ = TraceManager.get_current_span_and_parent() + + # Only override span IDs when no ambient OTel span exists. + has_ambient = trace.get_current_span().get_span_context().is_valid + await self._log_event( "LLM_ERROR", callback_context, event_data=EventData( error_message=str(error), latency_ms=duration, - span_id_override=span_id, - parent_span_id_override=parent_span_id, + span_id_override=None if has_ambient else span_id, + parent_span_id_override=(None if has_ambient else parent_span_id), ), ) @@ -2987,10 +3025,13 @@ async def after_tool_callback( span_id, duration = TraceManager.pop_span() parent_span_id, _ = TraceManager.get_current_span_and_parent() + # Only override span IDs when no ambient OTel span exists. + has_ambient = trace.get_current_span().get_span_context().is_valid + event_data = EventData( latency_ms=duration, - span_id_override=span_id, - parent_span_id_override=parent_span_id, + span_id_override=None if has_ambient else span_id, + parent_span_id_override=(None if has_ambient else parent_span_id), ) await self._log_event( "TOOL_COMPLETED", @@ -3026,7 +3067,12 @@ async def on_tool_error_callback( "args": args_truncated, "tool_origin": tool_origin, } - _, duration = TraceManager.pop_span() + span_id, duration = TraceManager.pop_span() + parent_span_id, _ = TraceManager.get_current_span_and_parent() + + # Only override span IDs when no ambient OTel span exists. + has_ambient = trace.get_current_span().get_span_context().is_valid + await self._log_event( "TOOL_ERROR", tool_context, @@ -3035,5 +3081,7 @@ async def on_tool_error_callback( event_data=EventData( error_message=str(error), latency_ms=duration, + span_id_override=None if has_ambient else span_id, + parent_span_id_override=(None if has_ambient else parent_span_id), ), ) diff --git a/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py b/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py index 4b620628f5..1af576ac07 100644 --- a/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py +++ b/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py @@ -3197,6 +3197,42 @@ def test_ambient_root_span_no_self_parent(self, callback_context): bigquery_agent_analytics_plugin.TraceManager.pop_span() provider.shutdown() + def test_ambient_span_used_for_completed_event(self, callback_context): + """Completed event with overrides should use ambient when present. + + When an ambient OTel span is valid, passing None overrides lets + _resolve_ids Layer 2 pick the ambient span — matching the + STARTING event's span_id. + """ + from opentelemetry.sdk.trace import TracerProvider as SdkProvider + from opentelemetry.sdk.trace.export import SimpleSpanProcessor + from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter + + provider = SdkProvider() + provider.add_span_processor(SimpleSpanProcessor(InMemorySpanExporter())) + real_tracer = provider.get_tracer("test") + + with real_tracer.start_as_current_span("invoke_agent") as agent_span: + expected_span = format(agent_span.get_span_context().span_id, "016x") + + # Simulate STARTING: no overrides → ambient Layer 2 wins. + ed_starting = bigquery_agent_analytics_plugin.EventData() + _, span_starting, _ = self._resolve(ed_starting, callback_context) + + # Simulate COMPLETED: None overrides (ambient check passed). + ed_completed = bigquery_agent_analytics_plugin.EventData( + span_id_override=None, + parent_span_id_override=None, + latency_ms=42, + ) + _, span_completed, _ = self._resolve(ed_completed, callback_context) + + assert span_starting == expected_span + assert span_completed == expected_span + assert span_starting == span_completed + + provider.shutdown() + class TestExtractLatency: """Tests for the _extract_latency static helper.""" @@ -5341,3 +5377,391 @@ async def test_invocation_root_span_isolated_across_turns( assert trace_id_turn1 != trace_id_turn2 provider.shutdown() + + +class TestSpanIdConsistency: + """Tests that STARTING/COMPLETED event pairs share span IDs. + + Span-ID resolution contract: + - When OTel is active: BQ rows use the same trace/span/parent IDs as + Cloud Trace (ambient framework spans). STARTING and COMPLETED events + in the same lifecycle share the same span_id. + - When OTel is not active: BQ rows use the plugin's internal span + stack. STARTING gets the current top-of-stack; COMPLETED gets the + popped span. + """ + + @pytest.mark.asyncio + async def test_starting_completed_same_span_with_ambient( + self, + bq_plugin_inst, + mock_write_client, + invocation_context, + callback_context, + mock_agent, + dummy_arrow_schema, + ): + """With ambient OTel, STARTING and COMPLETED get the same span_id.""" + from opentelemetry.sdk.trace import TracerProvider as SdkProvider + from opentelemetry.sdk.trace.export import SimpleSpanProcessor + from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter + + provider = SdkProvider() + provider.add_span_processor(SimpleSpanProcessor(InMemorySpanExporter())) + real_tracer = provider.get_tracer("test") + + with mock.patch.object( + bigquery_agent_analytics_plugin, "tracer", real_tracer + ): + bigquery_agent_analytics_plugin._span_records_ctx.set(None) + + # Simulate the framework's ambient spans. + with real_tracer.start_as_current_span("invocation"): + await bq_plugin_inst.before_run_callback( + invocation_context=invocation_context + ) + with real_tracer.start_as_current_span("invoke_agent"): + await bq_plugin_inst.before_agent_callback( + agent=mock_agent, callback_context=callback_context + ) + await bq_plugin_inst.after_agent_callback( + agent=mock_agent, callback_context=callback_context + ) + await bq_plugin_inst.after_run_callback( + invocation_context=invocation_context + ) + + await asyncio.sleep(0.01) + + rows = await _get_captured_rows_async( + mock_write_client, dummy_arrow_schema + ) + agent_starting = [r for r in rows if r["event_type"] == "AGENT_STARTING"] + agent_completed = [ + r for r in rows if r["event_type"] == "AGENT_COMPLETED" + ] + + assert len(agent_starting) == 1 + assert len(agent_completed) == 1 + + # Both events must share the same span_id (the ambient + # invoke_agent span) — no plugin-synthetic override. + assert agent_starting[0]["span_id"] == agent_completed[0]["span_id"] + assert ( + agent_starting[0]["parent_span_id"] + == agent_completed[0]["parent_span_id"] + ) + + provider.shutdown() + + @pytest.mark.asyncio + async def test_starting_completed_use_plugin_span_without_ambient( + self, + bq_plugin_inst, + mock_write_client, + invocation_context, + callback_context, + mock_agent, + dummy_arrow_schema, + ): + """Without ambient OTel, COMPLETED gets the popped plugin span.""" + from opentelemetry.sdk.trace import TracerProvider as SdkProvider + from opentelemetry.sdk.trace.export import SimpleSpanProcessor + from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter + + provider = SdkProvider() + provider.add_span_processor(SimpleSpanProcessor(InMemorySpanExporter())) + real_tracer = provider.get_tracer("test") + + with mock.patch.object( + bigquery_agent_analytics_plugin, "tracer", real_tracer + ): + bigquery_agent_analytics_plugin._span_records_ctx.set(None) + + # No ambient OTel span. + assert not trace.get_current_span().get_span_context().is_valid + + await bq_plugin_inst.before_run_callback( + invocation_context=invocation_context + ) + await bq_plugin_inst.before_agent_callback( + agent=mock_agent, callback_context=callback_context + ) + await bq_plugin_inst.after_agent_callback( + agent=mock_agent, callback_context=callback_context + ) + await bq_plugin_inst.after_run_callback( + invocation_context=invocation_context + ) + + await asyncio.sleep(0.01) + + rows = await _get_captured_rows_async( + mock_write_client, dummy_arrow_schema + ) + agent_starting = [r for r in rows if r["event_type"] == "AGENT_STARTING"] + agent_completed = [ + r for r in rows if r["event_type"] == "AGENT_COMPLETED" + ] + + assert len(agent_starting) == 1 + assert len(agent_completed) == 1 + + # AGENT_STARTING gets the top-of-stack span; AGENT_COMPLETED + # gets the popped span via override — they should match. + assert agent_starting[0]["span_id"] == agent_completed[0]["span_id"] + + provider.shutdown() + + @pytest.mark.asyncio + async def test_tool_error_captures_span_id( + self, + bq_plugin_inst, + mock_write_client, + invocation_context, + callback_context, + dummy_arrow_schema, + ): + """on_tool_error_callback uses the popped span_id (bonus fix).""" + from opentelemetry.sdk.trace import TracerProvider as SdkProvider + from opentelemetry.sdk.trace.export import SimpleSpanProcessor + from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter + + provider = SdkProvider() + provider.add_span_processor(SimpleSpanProcessor(InMemorySpanExporter())) + real_tracer = provider.get_tracer("test") + + mock_tool = mock.create_autospec(base_tool_lib.BaseTool, instance=True) + type(mock_tool).name = mock.PropertyMock(return_value="my_tool") + tool_ctx = tool_context_lib.ToolContext( + invocation_context=invocation_context + ) + + with mock.patch.object( + bigquery_agent_analytics_plugin, "tracer", real_tracer + ): + bigquery_agent_analytics_plugin._span_records_ctx.set(None) + + # No ambient OTel — plugin span stack provides IDs. + assert not trace.get_current_span().get_span_context().is_valid + + await bq_plugin_inst.before_run_callback( + invocation_context=invocation_context + ) + # Push tool span via before_tool_callback + await bq_plugin_inst.before_tool_callback( + tool=mock_tool, + tool_args={"a": 1}, + tool_context=tool_ctx, + ) + # Error callback should pop the tool span and use its ID + await bq_plugin_inst.on_tool_error_callback( + tool=mock_tool, + tool_args={"a": 1}, + tool_context=tool_ctx, + error=RuntimeError("boom"), + ) + await bq_plugin_inst.after_run_callback( + invocation_context=invocation_context + ) + await asyncio.sleep(0.01) + + rows = await _get_captured_rows_async( + mock_write_client, dummy_arrow_schema + ) + tool_starting = [r for r in rows if r["event_type"] == "TOOL_STARTING"] + tool_error = [r for r in rows if r["event_type"] == "TOOL_ERROR"] + + assert len(tool_starting) == 1 + assert len(tool_error) == 1 + + # The TOOL_ERROR event must have the same span_id as + # TOOL_STARTING (both correspond to the same tool span). + assert tool_starting[0]["span_id"] == tool_error[0]["span_id"] + assert tool_error[0]["span_id"] is not None + + provider.shutdown() + + +class TestStackLeakSafety: + """Tests for stack leak safety (P2). + + Ensures the plugin's internal span stack doesn't leak records + across invocations when after_run_callback is skipped. + """ + + def test_ensure_invocation_span_clears_stale_records(self, callback_context): + """Pre-populated stack is cleared and re-initialized.""" + from opentelemetry.sdk.trace import TracerProvider as SdkProvider + from opentelemetry.sdk.trace.export import SimpleSpanProcessor + from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter + + TM = bigquery_agent_analytics_plugin.TraceManager + + provider = SdkProvider() + provider.add_span_processor(SimpleSpanProcessor(InMemorySpanExporter())) + real_tracer = provider.get_tracer("test") + + with mock.patch.object( + bigquery_agent_analytics_plugin, "tracer", real_tracer + ): + # Simulate stale records from incomplete previous invocation. + bigquery_agent_analytics_plugin._span_records_ctx.set(None) + TM.push_span(callback_context, "stale-invocation") + TM.push_span(callback_context, "stale-agent") + + stale_records = bigquery_agent_analytics_plugin._span_records_ctx.get() + assert len(stale_records) == 2 + + # ensure_invocation_span should clear stale and re-init. + TM.ensure_invocation_span(callback_context) + + records = bigquery_agent_analytics_plugin._span_records_ctx.get() + # Should have exactly 1 fresh entry (the new invocation span). + assert len(records) == 1 + # The fresh span should NOT be one of the stale ones. + assert records[0].span_id != stale_records[0].span_id + assert records[0].span_id != stale_records[1].span_id + + provider.shutdown() + + def test_clear_stack_ends_owned_spans(self, callback_context): + """clear_stack() ends all owned spans.""" + from opentelemetry.sdk.trace import TracerProvider as SdkProvider + from opentelemetry.sdk.trace.export import SimpleSpanProcessor + from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter + + TM = bigquery_agent_analytics_plugin.TraceManager + + provider = SdkProvider() + exporter = InMemorySpanExporter() + provider.add_span_processor(SimpleSpanProcessor(exporter)) + real_tracer = provider.get_tracer("test") + + with mock.patch.object( + bigquery_agent_analytics_plugin, "tracer", real_tracer + ): + bigquery_agent_analytics_plugin._span_records_ctx.set(None) + TM.push_span(callback_context, "span-a") + TM.push_span(callback_context, "span-b") + + records = list(bigquery_agent_analytics_plugin._span_records_ctx.get()) + assert all(r.owns_span for r in records) + + TM.clear_stack() + + # Stack must be empty after clear. + result = bigquery_agent_analytics_plugin._span_records_ctx.get() + assert result == [] + + # Both owned spans should have been ended (exported). + exported = exporter.get_finished_spans() + assert len(exported) == 2 + + provider.shutdown() + + @pytest.mark.asyncio + async def test_after_run_callback_clears_remaining_stack( + self, + bq_plugin_inst, + mock_write_client, + invocation_context, + callback_context, + mock_agent, + dummy_arrow_schema, + ): + """after_run_callback clears any leftover stack entries.""" + from opentelemetry.sdk.trace import TracerProvider as SdkProvider + from opentelemetry.sdk.trace.export import SimpleSpanProcessor + from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter + + TM = bigquery_agent_analytics_plugin.TraceManager + + provider = SdkProvider() + provider.add_span_processor(SimpleSpanProcessor(InMemorySpanExporter())) + real_tracer = provider.get_tracer("test") + + with mock.patch.object( + bigquery_agent_analytics_plugin, "tracer", real_tracer + ): + bigquery_agent_analytics_plugin._span_records_ctx.set(None) + + # No ambient span. + assert not trace.get_current_span().get_span_context().is_valid + + await bq_plugin_inst.before_run_callback( + invocation_context=invocation_context + ) + # Push an agent span but DON'T pop it (simulate missing + # after_agent_callback due to exception). + await bq_plugin_inst.before_agent_callback( + agent=mock_agent, callback_context=callback_context + ) + # Stack now has [invocation, agent]. + + # after_run_callback should pop invocation + clear remaining. + await bq_plugin_inst.after_run_callback( + invocation_context=invocation_context + ) + + # Stack must be empty. + records = bigquery_agent_analytics_plugin._span_records_ctx.get() + assert records == [] + + provider.shutdown() + + @pytest.mark.asyncio + async def test_next_invocation_clean_after_incomplete_previous( + self, + bq_plugin_inst, + mock_write_client, + invocation_context, + callback_context, + mock_agent, + dummy_arrow_schema, + ): + """Next invocation starts clean even if previous was incomplete.""" + from opentelemetry.sdk.trace import TracerProvider as SdkProvider + from opentelemetry.sdk.trace.export import SimpleSpanProcessor + from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter + + TM = bigquery_agent_analytics_plugin.TraceManager + + provider = SdkProvider() + provider.add_span_processor(SimpleSpanProcessor(InMemorySpanExporter())) + real_tracer = provider.get_tracer("test") + + with mock.patch.object( + bigquery_agent_analytics_plugin, "tracer", real_tracer + ): + bigquery_agent_analytics_plugin._span_records_ctx.set(None) + + # --- Incomplete invocation 1: no after_run_callback --- + await bq_plugin_inst.before_run_callback( + invocation_context=invocation_context + ) + await bq_plugin_inst.before_agent_callback( + agent=mock_agent, callback_context=callback_context + ) + # Skip after_agent and after_run — simulates exception. + + stale = bigquery_agent_analytics_plugin._span_records_ctx.get() + assert len(stale) >= 2 # invocation + agent + + # --- Invocation 2: before_run_callback triggers + # ensure_invocation_span which should clear stale state --- + mock_write_client.append_rows.reset_mock() + await bq_plugin_inst.before_run_callback( + invocation_context=invocation_context + ) + + records = bigquery_agent_analytics_plugin._span_records_ctx.get() + # Should have exactly 1 fresh invocation span. + assert len(records) == 1 + + # Cleanup + await bq_plugin_inst.after_run_callback( + invocation_context=invocation_context + ) + + provider.shutdown() From e228b35c5fa36a5de3035bb9718cb83923e07a25 Mon Sep 17 00:00:00 2001 From: Haiyuan Cao Date: Sat, 28 Feb 2026 09:04:28 -0800 Subject: [PATCH 10/15] fix: make ensure_invocation_span idempotent within same invocation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The previous clear-on-non-empty behavior broke trace continuity between on_user_message_callback and before_run_callback in the no-ambient OTel path: both call ensure_invocation_span(), and the second call would clear the stack and create a new root span with a different trace_id. Fix: track the active invocation_id in a contextvar. If the stack has records belonging to the current invocation, return early (idempotent). If the stack has records from a different invocation (stale leak), clear and re-init. Reset the active invocation_id in after_run_callback. Also updates the docstring to match the new behavior and adds regression tests for the USER_MESSAGE_RECEIVED → INVOCATION_STARTING boundary. Co-Authored-By: Claude Opus 4.6 --- .../bigquery_agent_analytics_plugin.py | 28 +++- .../test_bigquery_agent_analytics_plugin.py | 131 +++++++++++++++++- 2 files changed, 148 insertions(+), 11 deletions(-) diff --git a/src/google/adk/plugins/bigquery_agent_analytics_plugin.py b/src/google/adk/plugins/bigquery_agent_analytics_plugin.py index f7348c269b..18352344e3 100644 --- a/src/google/adk/plugins/bigquery_agent_analytics_plugin.py +++ b/src/google/adk/plugins/bigquery_agent_analytics_plugin.py @@ -512,6 +512,13 @@ class BigQueryLoggerConfig: "_bq_analytics_root_agent_name", default=None ) +# Tracks the invocation_id that owns the current span stack so that +# ensure_invocation_span() can distinguish "same invocation re-entry" +# (idempotent) from "stale records from a previous invocation" (clear). +_active_invocation_id_ctx: contextvars.ContextVar[Optional[str]] = ( + contextvars.ContextVar("_bq_analytics_active_invocation_id", default=None) +) + @dataclass class _SpanRecord: @@ -672,14 +679,22 @@ def ensure_invocation_span( Must be called before any events are logged so that every event in the invocation shares the same trace_id. - * If the stack already has entries → no-op (already initialised). - * If the ambient OTel span is valid → ``attach_current_span`` (reuse - the runner's span without owning it). - * Otherwise → ``push_span("invocation")`` (create a new root span - that will be popped in ``after_run_callback``). + * If the stack has entries for the *current* invocation → no-op + (idempotent within the same invocation). + * If the stack has entries from a *different* invocation → clear + stale records and re-initialise (safety net for abnormal exit). + * If the ambient OTel span is valid → ``attach_current_span`` + (reuse the runner's span without owning it). + * Otherwise → ``push_span("invocation")`` (create a new root + span that will be popped in ``after_run_callback``). """ + current_inv = callback_context.invocation_id + active_inv = _active_invocation_id_ctx.get() + records = _span_records_ctx.get() if records: + if active_inv == current_inv: + return # Already initialised for this invocation. # Stale records from a previous invocation that wasn't cleaned # up (e.g. exception skipped after_run_callback). Clear and # re-init. @@ -689,6 +704,8 @@ def ensure_invocation_span( ) TraceManager.clear_stack() + _active_invocation_id_ctx.set(current_inv) + # Check for a valid ambient span (e.g. the Runner's invocation span). ambient = trace.get_current_span() if ambient.get_span_context().is_valid: @@ -2716,6 +2733,7 @@ async def after_run_callback( # Safety net: clear any remaining stack entries from this # invocation to prevent leaks into the next one. TraceManager.clear_stack() + _active_invocation_id_ctx.set(None) # Ensure all logs are flushed before the agent returns await self.flush() diff --git a/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py b/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py index 1af576ac07..fe8d27aa1b 100644 --- a/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py +++ b/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py @@ -5591,7 +5591,7 @@ class TestStackLeakSafety: """ def test_ensure_invocation_span_clears_stale_records(self, callback_context): - """Pre-populated stack is cleared and re-initialized.""" + """Pre-populated stack from a different invocation is cleared.""" from opentelemetry.sdk.trace import TracerProvider as SdkProvider from opentelemetry.sdk.trace.export import SimpleSpanProcessor from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter @@ -5607,13 +5607,18 @@ def test_ensure_invocation_span_clears_stale_records(self, callback_context): ): # Simulate stale records from incomplete previous invocation. bigquery_agent_analytics_plugin._span_records_ctx.set(None) + # Mark the stale records as belonging to a different invocation. + bigquery_agent_analytics_plugin._active_invocation_id_ctx.set( + "old-inv-stale" + ) TM.push_span(callback_context, "stale-invocation") TM.push_span(callback_context, "stale-agent") stale_records = bigquery_agent_analytics_plugin._span_records_ctx.get() assert len(stale_records) == 2 - # ensure_invocation_span should clear stale and re-init. + # ensure_invocation_span with the *current* invocation_id should + # detect the mismatch, clear stale records, and re-init. TM.ensure_invocation_span(callback_context) records = bigquery_agent_analytics_plugin._span_records_ctx.get() @@ -5719,6 +5724,7 @@ async def test_next_invocation_clean_after_incomplete_previous( callback_context, mock_agent, dummy_arrow_schema, + mock_session, ): """Next invocation starts clean even if previous was incomplete.""" from opentelemetry.sdk.trace import TracerProvider as SdkProvider @@ -5735,6 +5741,7 @@ async def test_next_invocation_clean_after_incomplete_previous( bigquery_agent_analytics_plugin, "tracer", real_tracer ): bigquery_agent_analytics_plugin._span_records_ctx.set(None) + bigquery_agent_analytics_plugin._active_invocation_id_ctx.set(None) # --- Incomplete invocation 1: no after_run_callback --- await bq_plugin_inst.before_run_callback( @@ -5748,20 +5755,132 @@ async def test_next_invocation_clean_after_incomplete_previous( stale = bigquery_agent_analytics_plugin._span_records_ctx.get() assert len(stale) >= 2 # invocation + agent - # --- Invocation 2: before_run_callback triggers - # ensure_invocation_span which should clear stale state --- + # --- Invocation 2 with a different invocation_id --- mock_write_client.append_rows.reset_mock() - await bq_plugin_inst.before_run_callback( - invocation_context=invocation_context + inv_ctx_2 = invocation_context_lib.InvocationContext( + agent=mock_agent, + session=mock_session, + invocation_id="inv-NEW-002", + session_service=invocation_context.session_service, + plugin_manager=invocation_context.plugin_manager, ) + await bq_plugin_inst.before_run_callback(invocation_context=inv_ctx_2) records = bigquery_agent_analytics_plugin._span_records_ctx.get() # Should have exactly 1 fresh invocation span. assert len(records) == 1 # Cleanup + await bq_plugin_inst.after_run_callback(invocation_context=inv_ctx_2) + + provider.shutdown() + + def test_ensure_invocation_span_idempotent_same_invocation( + self, callback_context + ): + """Calling ensure_invocation_span twice in the same invocation is a no-op.""" + from opentelemetry.sdk.trace import TracerProvider as SdkProvider + from opentelemetry.sdk.trace.export import SimpleSpanProcessor + from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter + + TM = bigquery_agent_analytics_plugin.TraceManager + + provider = SdkProvider() + provider.add_span_processor(SimpleSpanProcessor(InMemorySpanExporter())) + real_tracer = provider.get_tracer("test") + + with mock.patch.object( + bigquery_agent_analytics_plugin, "tracer", real_tracer + ): + bigquery_agent_analytics_plugin._span_records_ctx.set(None) + bigquery_agent_analytics_plugin._active_invocation_id_ctx.set(None) + + # First call: creates invocation span. + TM.ensure_invocation_span(callback_context) + records_after_first = list( + bigquery_agent_analytics_plugin._span_records_ctx.get() + ) + assert len(records_after_first) == 1 + first_span_id = records_after_first[0].span_id + + # Second call (same invocation): must be a no-op. + TM.ensure_invocation_span(callback_context) + records_after_second = ( + bigquery_agent_analytics_plugin._span_records_ctx.get() + ) + assert len(records_after_second) == 1 + assert records_after_second[0].span_id == first_span_id + + # Cleanup + TM.pop_span() + + provider.shutdown() + + @pytest.mark.asyncio + async def test_user_message_then_before_run_same_trace_no_ambient( + self, + bq_plugin_inst, + mock_write_client, + invocation_context, + callback_context, + mock_agent, + dummy_arrow_schema, + ): + """Regression: on_user_message → before_run must share one trace_id. + + Without the invocation-ID guard, the second ensure_invocation_span() + call would clear the stack and create a new root span with a + different trace_id, fracturing USER_MESSAGE_RECEIVED from + INVOCATION_STARTING. + """ + from opentelemetry.sdk.trace import TracerProvider as SdkProvider + from opentelemetry.sdk.trace.export import SimpleSpanProcessor + from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter + + provider = SdkProvider() + provider.add_span_processor(SimpleSpanProcessor(InMemorySpanExporter())) + real_tracer = provider.get_tracer("test") + + with mock.patch.object( + bigquery_agent_analytics_plugin, "tracer", real_tracer + ): + bigquery_agent_analytics_plugin._span_records_ctx.set(None) + bigquery_agent_analytics_plugin._active_invocation_id_ctx.set(None) + + # No ambient span. + assert not trace.get_current_span().get_span_context().is_valid + + user_msg = types.Content(parts=[types.Part(text="hello")], role="user") + await bq_plugin_inst.on_user_message_callback( + invocation_context=invocation_context, + user_message=user_msg, + ) + await bq_plugin_inst.before_run_callback( + invocation_context=invocation_context + ) + await bq_plugin_inst.before_agent_callback( + agent=mock_agent, callback_context=callback_context + ) + await bq_plugin_inst.after_agent_callback( + agent=mock_agent, callback_context=callback_context + ) await bq_plugin_inst.after_run_callback( invocation_context=invocation_context ) + await asyncio.sleep(0.01) + + rows = await _get_captured_rows_async( + mock_write_client, dummy_arrow_schema + ) + event_types = [r["event_type"] for r in rows] + assert "USER_MESSAGE_RECEIVED" in event_types + assert "INVOCATION_STARTING" in event_types + + # Every row must share the same trace_id. + trace_ids = {r["trace_id"] for r in rows} + assert len(trace_ids) == 1, ( + "Expected 1 unique trace_id across all events, got" + f" {len(trace_ids)}: {trace_ids}" + ) provider.shutdown() From b652f819ddfb6ff51038e0724fe7fef9f9e72e77 Mon Sep 17 00:00:00 2001 From: Haiyuan Cao Date: Sat, 28 Feb 2026 09:19:11 -0800 Subject: [PATCH 11/15] fix: refresh root_agent_name on each invocation, not just first init_trace() previously only set _root_agent_name_ctx when it was None, so the second invocation with a different root agent would inherit the first's name. Now it sets unconditionally. after_run_callback also resets _root_agent_name_ctx alongside the other invocation cleanup. Also adds a NOTE comment acknowledging that trace contextvars are module-global (not plugin-instance-scoped) as a known limitation. Co-Authored-By: Claude Opus 4.6 --- .../bigquery_agent_analytics_plugin.py | 19 ++- .../test_bigquery_agent_analytics_plugin.py | 116 ++++++++++++++++++ 2 files changed, 129 insertions(+), 6 deletions(-) diff --git a/src/google/adk/plugins/bigquery_agent_analytics_plugin.py b/src/google/adk/plugins/bigquery_agent_analytics_plugin.py index 18352344e3..5c8f3592ec 100644 --- a/src/google/adk/plugins/bigquery_agent_analytics_plugin.py +++ b/src/google/adk/plugins/bigquery_agent_analytics_plugin.py @@ -507,6 +507,11 @@ class BigQueryLoggerConfig: # ============================================================================== # HELPER: TRACE MANAGER (Async-Safe with ContextVars) # ============================================================================== +# NOTE: These contextvars are module-global, not plugin-instance-scoped. +# Multiple BigQueryAgentAnalyticsPlugin instances in the same execution +# context will share trace state. This is acceptable for the expected +# single-plugin-per-process deployment, but should be revisited if +# multi-instance support is needed (e.g. scope by plugin instance ID). _root_agent_name_ctx = contextvars.ContextVar( "_bq_analytics_root_agent_name", default=None @@ -564,12 +569,13 @@ def _get_records() -> list[_SpanRecord]: @staticmethod def init_trace(callback_context: CallbackContext) -> None: - if _root_agent_name_ctx.get() is None: - try: - root_agent = callback_context._invocation_context.agent.root_agent - _root_agent_name_ctx.set(root_agent.name) - except (AttributeError, ValueError): - pass + # Always refresh root_agent_name — it can change between + # invocations (e.g. different root agents in the same task). + try: + root_agent = callback_context._invocation_context.agent.root_agent + _root_agent_name_ctx.set(root_agent.name) + except (AttributeError, ValueError): + pass # Ensure records stack is initialized TraceManager._get_records() @@ -2734,6 +2740,7 @@ async def after_run_callback( # invocation to prevent leaks into the next one. TraceManager.clear_stack() _active_invocation_id_ctx.set(None) + _root_agent_name_ctx.set(None) # Ensure all logs are flushed before the agent returns await self.flush() diff --git a/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py b/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py index fe8d27aa1b..62fa7beda4 100644 --- a/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py +++ b/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py @@ -5884,3 +5884,119 @@ async def test_user_message_then_before_run_same_trace_no_ambient( ) provider.shutdown() + + +class TestRootAgentNameAcrossInvocations: + """Regression: root_agent_name must refresh across invocations.""" + + @pytest.mark.asyncio + async def test_root_agent_name_updates_between_invocations( + self, + bq_plugin_inst, + mock_write_client, + mock_session, + dummy_arrow_schema, + ): + """Two invocations with different root agents must log correct names. + + Previously init_trace() only set _root_agent_name_ctx when it was + None, so the second invocation would inherit the first's root agent. + """ + from opentelemetry.sdk.trace import TracerProvider as SdkProvider + from opentelemetry.sdk.trace.export import SimpleSpanProcessor + from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter + + provider = SdkProvider() + provider.add_span_processor(SimpleSpanProcessor(InMemorySpanExporter())) + real_tracer = provider.get_tracer("test") + + mock_session_service = mock.create_autospec( + base_session_service_lib.BaseSessionService, + instance=True, + spec_set=True, + ) + mock_plugin_manager = mock.create_autospec( + plugin_manager_lib.PluginManager, + instance=True, + spec_set=True, + ) + + def _make_inv_ctx(agent_name, inv_id): + agent = mock.create_autospec( + base_agent.BaseAgent, instance=True, spec_set=True + ) + type(agent).name = mock.PropertyMock(return_value=agent_name) + type(agent).instruction = mock.PropertyMock(return_value="") + # root_agent returns itself (no parent). + agent.root_agent = agent + return invocation_context_lib.InvocationContext( + agent=agent, + session=mock_session, + invocation_id=inv_id, + session_service=mock_session_service, + plugin_manager=mock_plugin_manager, + ) + + with mock.patch.object( + bigquery_agent_analytics_plugin, "tracer", real_tracer + ): + # --- Invocation 1: root agent = "RootA" --- + bigquery_agent_analytics_plugin._span_records_ctx.set(None) + bigquery_agent_analytics_plugin._active_invocation_id_ctx.set(None) + bigquery_agent_analytics_plugin._root_agent_name_ctx.set(None) + + inv1 = _make_inv_ctx("RootA", "inv-001") + cb1 = callback_context_lib.CallbackContext(inv1) + await bq_plugin_inst.before_run_callback(invocation_context=inv1) + await bq_plugin_inst.before_agent_callback( + agent=inv1.agent, callback_context=cb1 + ) + await bq_plugin_inst.after_agent_callback( + agent=inv1.agent, callback_context=cb1 + ) + await bq_plugin_inst.after_run_callback(invocation_context=inv1) + await asyncio.sleep(0.01) + + rows_inv1 = await _get_captured_rows_async( + mock_write_client, dummy_arrow_schema + ) + + # --- Invocation 2: root agent = "RootB" --- + mock_write_client.append_rows.reset_mock() + + inv2 = _make_inv_ctx("RootB", "inv-002") + cb2 = callback_context_lib.CallbackContext(inv2) + await bq_plugin_inst.before_run_callback(invocation_context=inv2) + await bq_plugin_inst.before_agent_callback( + agent=inv2.agent, callback_context=cb2 + ) + await bq_plugin_inst.after_agent_callback( + agent=inv2.agent, callback_context=cb2 + ) + await bq_plugin_inst.after_run_callback(invocation_context=inv2) + await asyncio.sleep(0.01) + + rows_inv2 = await _get_captured_rows_async( + mock_write_client, dummy_arrow_schema + ) + + # Parse root_agent_name from the attributes JSON column. + def _get_root_names(rows): + names = set() + for r in rows: + attrs = r.get("attributes") + if attrs: + parsed = json.loads(attrs) if isinstance(attrs, str) else attrs + if "root_agent_name" in parsed: + names.add(parsed["root_agent_name"]) + return names + + names_inv1 = _get_root_names(rows_inv1) + names_inv2 = _get_root_names(rows_inv2) + + # Invocation 1 should only have "RootA". + assert names_inv1 == {"RootA"}, f"Expected {{'RootA'}}, got {names_inv1}" + # Invocation 2 must have "RootB", NOT stale "RootA". + assert names_inv2 == {"RootB"}, f"Expected {{'RootB'}}, got {names_inv2}" + + provider.shutdown() From cd0f5ac17a28cdb39dd0706d6a10b57134011a04 Mon Sep 17 00:00:00 2001 From: Haiyuan Cao Date: Sat, 28 Feb 2026 09:24:39 -0800 Subject: [PATCH 12/15] docs: clarify why module-global contextvars are safe in practice PluginManager enforces name-uniqueness (no duplicate BQ plugins on same Runner), and concurrent asyncio tasks get isolated contextvar copies. Updated the NOTE comment to explain both guards. Co-Authored-By: Claude Opus 4.6 --- .../adk/plugins/bigquery_agent_analytics_plugin.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/google/adk/plugins/bigquery_agent_analytics_plugin.py b/src/google/adk/plugins/bigquery_agent_analytics_plugin.py index 5c8f3592ec..a245850151 100644 --- a/src/google/adk/plugins/bigquery_agent_analytics_plugin.py +++ b/src/google/adk/plugins/bigquery_agent_analytics_plugin.py @@ -508,10 +508,14 @@ class BigQueryLoggerConfig: # HELPER: TRACE MANAGER (Async-Safe with ContextVars) # ============================================================================== # NOTE: These contextvars are module-global, not plugin-instance-scoped. -# Multiple BigQueryAgentAnalyticsPlugin instances in the same execution -# context will share trace state. This is acceptable for the expected -# single-plugin-per-process deployment, but should be revisited if -# multi-instance support is needed (e.g. scope by plugin instance ID). +# This is safe in practice for two reasons: +# 1. PluginManager enforces name-uniqueness, preventing two BQ plugin +# instances on the same Runner. +# 2. Concurrent asyncio tasks (e.g. two Runners in asyncio.gather) each +# get an isolated contextvar copy, so they don't interfere. +# The only problematic case would be two plugin instances interleaved +# within the *same* asyncio task without task boundaries — which the +# framework's PluginManager already prevents. _root_agent_name_ctx = contextvars.ContextVar( "_bq_analytics_root_agent_name", default=None From 199e0175c693b24bcdefbfa16248282236dee3b5 Mon Sep 17 00:00:00 2001 From: Haiyuan Cao Date: Sat, 28 Feb 2026 09:26:36 -0800 Subject: [PATCH 13/15] fix: make after_run_callback cleanup exception-safe with try/finally The cleanup code (clear_stack, reset _active_invocation_id_ctx and _root_agent_name_ctx) ran after _log_event, so a failure in _log_event would skip cleanup silently (_safe_callback swallows the exception). Wrapping with try/finally ensures invocation state is always reset. Co-Authored-By: Claude Opus 4.6 --- .../bigquery_agent_analytics_plugin.py | 60 +++++++++-------- .../test_bigquery_agent_analytics_plugin.py | 67 +++++++++++++++++++ 2 files changed, 98 insertions(+), 29 deletions(-) diff --git a/src/google/adk/plugins/bigquery_agent_analytics_plugin.py b/src/google/adk/plugins/bigquery_agent_analytics_plugin.py index a245850151..beb26c24d7 100644 --- a/src/google/adk/plugins/bigquery_agent_analytics_plugin.py +++ b/src/google/adk/plugins/bigquery_agent_analytics_plugin.py @@ -2715,38 +2715,40 @@ async def after_run_callback( Args: invocation_context: The context of the current invocation. """ - # Capture trace_id BEFORE popping the invocation-root span so that - # INVOCATION_COMPLETED shares the same trace_id as all earlier events - # in this invocation (fixes #4645 completion-event fracture). - callback_ctx = CallbackContext(invocation_context) - trace_id = TraceManager.get_trace_id(callback_ctx) + try: + # Capture trace_id BEFORE popping the invocation-root span so + # that INVOCATION_COMPLETED shares the same trace_id as all + # earlier events in this invocation (fixes #4645). + callback_ctx = CallbackContext(invocation_context) + trace_id = TraceManager.get_trace_id(callback_ctx) - # Pop the invocation-root span pushed by ensure_invocation_span(). - span_id, duration = TraceManager.pop_span() - parent_span_id = TraceManager.get_current_span_id() + # Pop the invocation-root span pushed by ensure_invocation_span. + span_id, duration = TraceManager.pop_span() + parent_span_id = TraceManager.get_current_span_id() - # Only override span IDs when no ambient OTel span exists. - # When ambient exists, _resolve_ids Layer 2 uses the framework's - # span IDs, keeping STARTING/COMPLETED pairs consistent. - has_ambient = trace.get_current_span().get_span_context().is_valid + # Only override span IDs when no ambient OTel span exists. + # When ambient exists, _resolve_ids Layer 2 uses the framework's + # span IDs, keeping STARTING/COMPLETED pairs consistent. + has_ambient = trace.get_current_span().get_span_context().is_valid - await self._log_event( - "INVOCATION_COMPLETED", - callback_ctx, - event_data=EventData( - trace_id_override=trace_id, - latency_ms=duration, - span_id_override=None if has_ambient else span_id, - parent_span_id_override=(None if has_ambient else parent_span_id), - ), - ) - # Safety net: clear any remaining stack entries from this - # invocation to prevent leaks into the next one. - TraceManager.clear_stack() - _active_invocation_id_ctx.set(None) - _root_agent_name_ctx.set(None) - # Ensure all logs are flushed before the agent returns - await self.flush() + await self._log_event( + "INVOCATION_COMPLETED", + callback_ctx, + event_data=EventData( + trace_id_override=trace_id, + latency_ms=duration, + span_id_override=None if has_ambient else span_id, + parent_span_id_override=(None if has_ambient else parent_span_id), + ), + ) + finally: + # Cleanup must run even if _log_event raises, otherwise + # stale invocation metadata leaks into the next invocation. + TraceManager.clear_stack() + _active_invocation_id_ctx.set(None) + _root_agent_name_ctx.set(None) + # Ensure all logs are flushed before the agent returns. + await self.flush() @_safe_callback async def before_agent_callback( diff --git a/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py b/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py index 62fa7beda4..f91a955ff1 100644 --- a/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py +++ b/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py @@ -6000,3 +6000,70 @@ def _get_root_names(rows): assert names_inv2 == {"RootB"}, f"Expected {{'RootB'}}, got {names_inv2}" provider.shutdown() + + +class TestAfterRunCleanupExceptionSafety: + """after_run_callback cleanup must execute even if _log_event fails.""" + + @pytest.mark.asyncio + async def test_cleanup_runs_when_log_event_raises( + self, + bq_plugin_inst, + mock_write_client, + invocation_context, + callback_context, + mock_agent, + ): + """Stale state is cleared even when _log_event raises.""" + from opentelemetry.sdk.trace import TracerProvider as SdkProvider + from opentelemetry.sdk.trace.export import SimpleSpanProcessor + from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter + + provider = SdkProvider() + provider.add_span_processor(SimpleSpanProcessor(InMemorySpanExporter())) + real_tracer = provider.get_tracer("test") + + with mock.patch.object( + bigquery_agent_analytics_plugin, "tracer", real_tracer + ): + bigquery_agent_analytics_plugin._span_records_ctx.set(None) + bigquery_agent_analytics_plugin._active_invocation_id_ctx.set(None) + bigquery_agent_analytics_plugin._root_agent_name_ctx.set(None) + + # Run a normal before_run to initialise state. + await bq_plugin_inst.before_run_callback( + invocation_context=invocation_context + ) + await bq_plugin_inst.before_agent_callback( + agent=mock_agent, callback_context=callback_context + ) + + # Verify state is populated. + assert bigquery_agent_analytics_plugin._span_records_ctx.get() + assert ( + bigquery_agent_analytics_plugin._active_invocation_id_ctx.get() + is not None + ) + + # Make _log_event raise inside after_run_callback. + with mock.patch.object( + bq_plugin_inst, + "_log_event", + side_effect=RuntimeError("boom"), + ): + # _safe_callback swallows the exception, but cleanup in + # the finally block must still execute. + await bq_plugin_inst.after_run_callback( + invocation_context=invocation_context + ) + + # All invocation state must be cleaned up despite the error. + records = bigquery_agent_analytics_plugin._span_records_ctx.get() + assert records == [] or records is None + assert ( + bigquery_agent_analytics_plugin._active_invocation_id_ctx.get() + is None + ) + assert bigquery_agent_analytics_plugin._root_agent_name_ctx.get() is None + + provider.shutdown() From 54b563fde766c8d3a342692939361dc57c8f55df Mon Sep 17 00:00:00 2001 From: Haiyuan Cao Date: Sat, 28 Feb 2026 10:19:07 -0800 Subject: [PATCH 14/15] fix: use direct class imports matching codebase convention in BQ plugin tests Import CallbackContext and InvocationContext directly from their modules instead of importing the module and using qualified names. This matches the import pattern used throughout the test suite. Co-Authored-By: Claude Opus 4.6 --- .../test_bigquery_agent_analytics_plugin.py | 38 +++++++------------ 1 file changed, 14 insertions(+), 24 deletions(-) diff --git a/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py b/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py index f91a955ff1..5d87a17cd9 100644 --- a/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py +++ b/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py @@ -20,8 +20,8 @@ from unittest import mock from google.adk.agents import base_agent -from google.adk.agents import callback_context as callback_context_lib -from google.adk.agents import invocation_context as invocation_context_lib +from google.adk.agents.callback_context import CallbackContext +from google.adk.agents.invocation_context import InvocationContext from google.adk.events import event as event_lib from google.adk.events import event_actions as event_actions_lib from google.adk.models import llm_request as llm_request_lib @@ -83,7 +83,7 @@ def invocation_context(mock_agent, mock_session): mock_plugin_manager = mock.create_autospec( plugin_manager_lib.PluginManager, instance=True, spec_set=True ) - return invocation_context_lib.InvocationContext( + return InvocationContext( agent=mock_agent, session=mock_session, invocation_id="inv-789", @@ -94,9 +94,7 @@ def invocation_context(mock_agent, mock_session): @pytest.fixture def callback_context(invocation_context): - return callback_context_lib.CallbackContext( - invocation_context=invocation_context - ) + return CallbackContext(invocation_context=invocation_context) @pytest.fixture @@ -3422,7 +3420,7 @@ def _make_invocation_context(agent_name, session, invocation_id="inv-001"): instance=True, spec_set=True, ) - return invocation_context_lib.InvocationContext( + return InvocationContext( agent=mock_a, session=session, invocation_id=invocation_id, @@ -3628,7 +3626,7 @@ async def test_full_subagent_callback_sequence( """ session = self._make_session() inv_ctx = self._make_invocation_context("schema_explorer", session) - cb_ctx = callback_context_lib.CallbackContext(invocation_context=inv_ctx) + cb_ctx = CallbackContext(invocation_context=inv_ctx) tool_ctx = tool_context_lib.ToolContext(invocation_context=inv_ctx) mock_agent = inv_ctx.agent tool = self._make_tool("get_table_info") @@ -3906,9 +3904,7 @@ async def test_multi_turn_multi_subagent_full_sequence( inv_ctx_t1_orch = self._make_invocation_context( "orchestrator", session, invocation_id="inv-t1" ) - cb_ctx_t1_orch = callback_context_lib.CallbackContext( - invocation_context=inv_ctx_t1_orch - ) + cb_ctx_t1_orch = CallbackContext(invocation_context=inv_ctx_t1_orch) # Orchestrator agent_starting await plugin.before_agent_callback( @@ -3921,9 +3917,7 @@ async def test_multi_turn_multi_subagent_full_sequence( inv_ctx_t1_sub = self._make_invocation_context( "schema_explorer", session, invocation_id="inv-t1" ) - cb_ctx_t1_sub = callback_context_lib.CallbackContext( - invocation_context=inv_ctx_t1_sub - ) + cb_ctx_t1_sub = CallbackContext(invocation_context=inv_ctx_t1_sub) tool_ctx_t1 = tool_context_lib.ToolContext( invocation_context=inv_ctx_t1_sub ) @@ -3971,9 +3965,7 @@ async def test_multi_turn_multi_subagent_full_sequence( inv_ctx_t2_orch = self._make_invocation_context( "orchestrator", session, invocation_id="inv-t2" ) - cb_ctx_t2_orch = callback_context_lib.CallbackContext( - invocation_context=inv_ctx_t2_orch - ) + cb_ctx_t2_orch = CallbackContext(invocation_context=inv_ctx_t2_orch) await plugin.before_agent_callback( agent=inv_ctx_t2_orch.agent, @@ -3985,9 +3977,7 @@ async def test_multi_turn_multi_subagent_full_sequence( inv_ctx_t2_sub = self._make_invocation_context( "image_describer", session, invocation_id="inv-t2" ) - cb_ctx_t2_sub = callback_context_lib.CallbackContext( - invocation_context=inv_ctx_t2_sub - ) + cb_ctx_t2_sub = CallbackContext(invocation_context=inv_ctx_t2_sub) tool_ctx_t2 = tool_context_lib.ToolContext( invocation_context=inv_ctx_t2_sub ) @@ -5757,7 +5747,7 @@ async def test_next_invocation_clean_after_incomplete_previous( # --- Invocation 2 with a different invocation_id --- mock_write_client.append_rows.reset_mock() - inv_ctx_2 = invocation_context_lib.InvocationContext( + inv_ctx_2 = InvocationContext( agent=mock_agent, session=mock_session, invocation_id="inv-NEW-002", @@ -5929,7 +5919,7 @@ def _make_inv_ctx(agent_name, inv_id): type(agent).instruction = mock.PropertyMock(return_value="") # root_agent returns itself (no parent). agent.root_agent = agent - return invocation_context_lib.InvocationContext( + return InvocationContext( agent=agent, session=mock_session, invocation_id=inv_id, @@ -5946,7 +5936,7 @@ def _make_inv_ctx(agent_name, inv_id): bigquery_agent_analytics_plugin._root_agent_name_ctx.set(None) inv1 = _make_inv_ctx("RootA", "inv-001") - cb1 = callback_context_lib.CallbackContext(inv1) + cb1 = CallbackContext(inv1) await bq_plugin_inst.before_run_callback(invocation_context=inv1) await bq_plugin_inst.before_agent_callback( agent=inv1.agent, callback_context=cb1 @@ -5965,7 +5955,7 @@ def _make_inv_ctx(agent_name, inv_id): mock_write_client.append_rows.reset_mock() inv2 = _make_inv_ctx("RootB", "inv-002") - cb2 = callback_context_lib.CallbackContext(inv2) + cb2 = CallbackContext(inv2) await bq_plugin_inst.before_run_callback(invocation_context=inv2) await bq_plugin_inst.before_agent_callback( agent=inv2.agent, callback_context=cb2 From 1c68990c154443ceed7171f9401cdfdea7a84555 Mon Sep 17 00:00:00 2001 From: Haiyuan Cao Date: Sat, 28 Feb 2026 10:29:39 -0800 Subject: [PATCH 15/15] refactor: convert relative imports to absolute in BQ analytics plugin Replace all relative imports (e.g. from ..agents.callback_context) with absolute imports (e.g. from google.adk.agents.callback_context) for clarity and consistency with the test file. Co-Authored-By: Claude Opus 4.6 --- .../plugins/bigquery_agent_analytics_plugin.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/src/google/adk/plugins/bigquery_agent_analytics_plugin.py b/src/google/adk/plugins/bigquery_agent_analytics_plugin.py index beb26c24d7..fabe6f278c 100644 --- a/src/google/adk/plugins/bigquery_agent_analytics_plugin.py +++ b/src/google/adk/plugins/bigquery_agent_analytics_plugin.py @@ -39,6 +39,13 @@ import uuid import weakref +from google.adk.agents.callback_context import CallbackContext +from google.adk.models.llm_request import LlmRequest +from google.adk.models.llm_response import LlmResponse +from google.adk.plugins.base_plugin import BasePlugin +from google.adk.tools.base_tool import BaseTool +from google.adk.tools.tool_context import ToolContext +from google.adk.version import __version__ from google.api_core import client_options from google.api_core.exceptions import InternalServerError from google.api_core.exceptions import ServiceUnavailable @@ -55,16 +62,8 @@ from opentelemetry import trace import pyarrow as pa -from ..agents.callback_context import CallbackContext -from ..models.llm_request import LlmRequest -from ..models.llm_response import LlmResponse -from ..tools.base_tool import BaseTool -from ..tools.tool_context import ToolContext -from ..version import __version__ -from .base_plugin import BasePlugin - if TYPE_CHECKING: - from ..agents.invocation_context import InvocationContext + from google.adk.agents.invocation_context import InvocationContext logger: logging.Logger = logging.getLogger("google_adk." + __name__) tracer = trace.get_tracer(