diff --git a/google/cloud/spanner_v1/metrics/metrics_interceptor.py b/google/cloud/spanner_v1/metrics/metrics_interceptor.py index 4b55056dab..a3a3d1de60 100644 --- a/google/cloud/spanner_v1/metrics/metrics_interceptor.py +++ b/google/cloud/spanner_v1/metrics/metrics_interceptor.py @@ -126,10 +126,8 @@ def intercept(self, invoked_method, request_or_iterator, call_details): The RPC response """ factory = SpannerMetricsTracerFactory() - if ( - SpannerMetricsTracerFactory.current_metrics_tracer is None - or not factory.enabled - ): + tracer = SpannerMetricsTracerFactory.current_metrics_tracer + if tracer is None or not factory.enabled: return invoked_method(request_or_iterator, call_details) # Setup Metric Tracer attributes from call details @@ -142,15 +140,13 @@ def intercept(self, invoked_method, request_or_iterator, call_details): call_details.method, SPANNER_METHOD_PREFIX ).replace("/", ".") - SpannerMetricsTracerFactory.current_metrics_tracer.set_method(method_name) - SpannerMetricsTracerFactory.current_metrics_tracer.record_attempt_start() + tracer.set_method(method_name) + tracer.record_attempt_start() response = invoked_method(request_or_iterator, call_details) - SpannerMetricsTracerFactory.current_metrics_tracer.record_attempt_completion() + tracer.record_attempt_completion() # Process and send GFE metrics if enabled - if SpannerMetricsTracerFactory.current_metrics_tracer.gfe_enabled: + if tracer.gfe_enabled: metadata = response.initial_metadata() - SpannerMetricsTracerFactory.current_metrics_trace.record_gfe_metrics( - metadata - ) + tracer.record_gfe_metrics(metadata) return response diff --git a/tests/unit/test_metrics_interceptor.py b/tests/unit/test_metrics_interceptor.py index e32003537f..dfe26a34f7 100644 --- a/tests/unit/test_metrics_interceptor.py +++ b/tests/unit/test_metrics_interceptor.py @@ -13,6 +13,8 @@ # limitations under the License. import pytest +import threading +import time from google.cloud.spanner_v1.metrics.metrics_interceptor import MetricsInterceptor from google.cloud.spanner_v1.metrics.spanner_metrics_tracer_factory import ( SpannerMetricsTracerFactory, @@ -102,6 +104,51 @@ def test_intercept_with_tracer(interceptor): mock_invoked_method.assert_called_once_with("request", call_details) +def test_intercept_thread_safety(interceptor): + # Regression test for race condition where current_metrics_tracer changes mid-call + + # Mock tracers + tracer_a = MagicMock() + tracer_a.gfe_enabled = False + tracer_b = MagicMock() + tracer_b.gfe_enabled = False + + call_details = MagicMock( + method="spanner.Commit", + metadata=[], + ) + + def mock_invoked_method(*args, **kwargs): + # Simulate network delay to allow thread switch + time.sleep(0.1) + return MagicMock() + + def thread_a_func(): + # Set Tracer A + SpannerMetricsTracerFactory.current_metrics_tracer = tracer_a + # Call intercept + interceptor.intercept(mock_invoked_method, None, call_details) + + def thread_b_func(): + time.sleep(0.05) # Wait for A to start + # Overwrite with Tracer B + SpannerMetricsTracerFactory.current_metrics_tracer = tracer_b + + t1 = threading.Thread(target=thread_a_func) + t2 = threading.Thread(target=thread_b_func) + + t1.start() + t2.start() + + t1.join() + t2.join() + + # Verify that Tracer A was used for completion, NOT Tracer B + # Because Thread A started with Tracer A, it should finish with Tracer A + tracer_a.record_attempt_completion.assert_called_once() + tracer_b.record_attempt_completion.assert_not_called() + + class MockMetricTracer: def __init__(self): self.project = None