Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions google/cloud/spanner_v1/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

"""Context manager for Cloud Spanner batched writes."""

import functools
from typing import List, Optional

Expand Down Expand Up @@ -242,6 +243,8 @@ def commit(
observability_options=getattr(database, "observability_options", None),
metadata=metadata,
) as span, MetricsCapture():
nth_request = getattr(database, "_next_nth_request", 0)
attempt = AtomicCounter(0)

def wrapped_method():
commit_request = CommitRequest(
Expand All @@ -256,8 +259,8 @@ def wrapped_method():
# should be increased. attempt can only be increased if
# we encounter UNAVAILABLE or INTERNAL.
call_metadata, error_augmenter = database.with_error_augmentation(
getattr(database, "_next_nth_request", 0),
1,
nth_request,
attempt.increment(),
Comment on lines -259 to +263
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that the existing code was correct. There are two different things that can be retried in Spanner:

  1. Aborted transactions: When a read/write transaction is aborted, then the entire transaction is retried. This should not cause attempt to be increased, even in this case, where the entire transaction is just a single Commit call.
  2. Unavailable: A single RPC can fail due to network errors, server temporarily being down etc. This is normally retried by Gax. In this case, only a single RPC (so not the entire transaction) is retried. It is only in these cases that attempt should be increased.

metadata,
span,
)
Expand Down
62 changes: 43 additions & 19 deletions google/cloud/spanner_v1/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,12 @@
* a :class:`~google.cloud.spanner_v1.instance.Instance` owns a
:class:`~google.cloud.spanner_v1.database.Database`
"""

import grpc
import os
import logging
import warnings
import threading

from google.api_core.gapic_v1 import client_info
from google.auth.credentials import AnonymousCredentials
Expand Down Expand Up @@ -99,11 +101,50 @@ def _get_spanner_optimizer_statistics_package():

log = logging.getLogger(__name__)

_metrics_monitor_initialized = False
_metrics_monitor_lock = threading.Lock()


def _get_spanner_enable_builtin_metrics_env():
return os.getenv(SPANNER_DISABLE_BUILTIN_METRICS_ENV_VAR) != "true"


def _initialize_metrics(project, credentials):
"""
Initializes the Spanner built-in metrics.

This function sets up the OpenTelemetry MeterProvider and the SpannerMetricsTracerFactory.
It uses a lock to ensure that initialization happens only once.
"""
global _metrics_monitor_initialized
if not _metrics_monitor_initialized:
with _metrics_monitor_lock:
if not _metrics_monitor_initialized:
meter_provider = metrics.NoOpMeterProvider()
try:
if not _get_spanner_emulator_host():
meter_provider = MeterProvider(
metric_readers=[
PeriodicExportingMetricReader(
CloudMonitoringMetricsExporter(
project_id=project,
credentials=credentials,
),
export_interval_millis=METRIC_EXPORT_INTERVAL_MS,
),
]
)
metrics.set_meter_provider(meter_provider)
SpannerMetricsTracerFactory()
_metrics_monitor_initialized = True
except Exception as e:
# log is already defined at module level
log.warning(
"Failed to initialize Spanner built-in metrics. Error: %s",
e,
)


class Client(ClientWithProject):
"""Client for interacting with Cloud Spanner API.

Expand Down Expand Up @@ -252,30 +293,13 @@ def __init__(
):
warnings.warn(_EMULATOR_HOST_HTTP_SCHEME)
# Check flag to enable Spanner builtin metrics
global _metrics_monitor_initialized
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this still needed here?

if (
_get_spanner_enable_builtin_metrics_env()
and not disable_builtin_metrics
and HAS_GOOGLE_CLOUD_MONITORING_INSTALLED
):
meter_provider = metrics.NoOpMeterProvider()
try:
if not _get_spanner_emulator_host():
meter_provider = MeterProvider(
metric_readers=[
PeriodicExportingMetricReader(
CloudMonitoringMetricsExporter(
project_id=project, credentials=credentials
),
export_interval_millis=METRIC_EXPORT_INTERVAL_MS,
),
]
)
metrics.set_meter_provider(meter_provider)
SpannerMetricsTracerFactory()
except Exception as e:
log.warning(
"Failed to initialize Spanner built-in metrics. Error: %s", e
)
_initialize_metrics(project, credentials)
else:
SpannerMetricsTracerFactory(enabled=False)

Expand Down
24 changes: 18 additions & 6 deletions google/cloud/spanner_v1/metrics/metrics_capture.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,18 @@
from .spanner_metrics_tracer_factory import SpannerMetricsTracerFactory


from contextvars import Token
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: group this with the other import



class MetricsCapture:
"""Context manager for capturing metrics in Cloud Spanner operations.

This class provides a context manager interface to automatically handle
the start and completion of metrics tracing for a given operation.
"""

_token: Token
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is this? Could we add a small comment for it?


def __enter__(self):
"""Enter the runtime context related to this object.

Expand All @@ -45,11 +50,13 @@ def __enter__(self):
return self

# Define a new metrics tracer for the new operation
SpannerMetricsTracerFactory.current_metrics_tracer = (
factory.create_metrics_tracer()
# Set the context var and keep the token for reset
tracer = factory.create_metrics_tracer()
self._token = SpannerMetricsTracerFactory._current_metrics_tracer_ctx.set(
tracer
)
if SpannerMetricsTracerFactory.current_metrics_tracer:
SpannerMetricsTracerFactory.current_metrics_tracer.record_operation_start()
if tracer:
tracer.record_operation_start()
return self

def __exit__(self, exc_type, exc_value, traceback):
Expand All @@ -70,6 +77,11 @@ def __exit__(self, exc_type, exc_value, traceback):
if not SpannerMetricsTracerFactory().enabled:
return False

if SpannerMetricsTracerFactory.current_metrics_tracer:
SpannerMetricsTracerFactory.current_metrics_tracer.record_operation_completion()
tracer = SpannerMetricsTracerFactory._current_metrics_tracer_ctx.get()
if tracer:
tracer.record_operation_completion()

# Reset the context var using the token
if getattr(self, "_token", None):
SpannerMetricsTracerFactory._current_metrics_tracer_ctx.reset(self._token)
return False # Propagate the exception if any
33 changes: 12 additions & 21 deletions google/cloud/spanner_v1/metrics/metrics_interceptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,22 +97,17 @@ def _set_metrics_tracer_attributes(self, resources: Dict[str, str]) -> None:
Args:
resources (Dict[str, str]): A dictionary containing project, instance, and database information.
"""
if SpannerMetricsTracerFactory.current_metrics_tracer is None:
tracer = SpannerMetricsTracerFactory.get_current_tracer()
if tracer is None:
return

if resources:
if "project" in resources:
SpannerMetricsTracerFactory.current_metrics_tracer.set_project(
resources["project"]
)
tracer.set_project(resources["project"])
if "instance" in resources:
SpannerMetricsTracerFactory.current_metrics_tracer.set_instance(
resources["instance"]
)
tracer.set_instance(resources["instance"])
if "database" in resources:
SpannerMetricsTracerFactory.current_metrics_tracer.set_database(
resources["database"]
)
tracer.set_database(resources["database"])

def intercept(self, invoked_method, request_or_iterator, call_details):
"""Intercept gRPC calls to collect metrics.
Expand All @@ -126,10 +121,8 @@ def intercept(self, invoked_method, request_or_iterator, call_details):
The RPC response
"""
factory = SpannerMetricsTracerFactory()
if (
SpannerMetricsTracerFactory.current_metrics_tracer is None
or not factory.enabled
):
tracer = SpannerMetricsTracerFactory.get_current_tracer()
if tracer is None or not factory.enabled:
return invoked_method(request_or_iterator, call_details)

# Setup Metric Tracer attributes from call details
Expand All @@ -142,15 +135,13 @@ def intercept(self, invoked_method, request_or_iterator, call_details):
call_details.method, SPANNER_METHOD_PREFIX
).replace("/", ".")

SpannerMetricsTracerFactory.current_metrics_tracer.set_method(method_name)
SpannerMetricsTracerFactory.current_metrics_tracer.record_attempt_start()
tracer.set_method(method_name)
tracer.record_attempt_start()
response = invoked_method(request_or_iterator, call_details)
SpannerMetricsTracerFactory.current_metrics_tracer.record_attempt_completion()
tracer.record_attempt_completion()

# Process and send GFE metrics if enabled
if SpannerMetricsTracerFactory.current_metrics_tracer.gfe_enabled:
if tracer.gfe_enabled:
metadata = response.initial_metadata()
SpannerMetricsTracerFactory.current_metrics_trace.record_gfe_metrics(
metadata
)
tracer.record_gfe_metrics(metadata)
return response
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import os
import logging
from .constants import SPANNER_SERVICE_NAME
import contextvars

try:
import mmh3
Expand All @@ -43,7 +44,9 @@ class SpannerMetricsTracerFactory(MetricsTracerFactory):
"""A factory for creating SpannerMetricsTracer instances."""

_metrics_tracer_factory: "SpannerMetricsTracerFactory" = None
current_metrics_tracer: MetricsTracer = None
_current_metrics_tracer_ctx = contextvars.ContextVar(
"current_metrics_tracer", default=None
)

def __new__(
cls, enabled: bool = True, gfe_enabled: bool = False
Expand Down Expand Up @@ -80,10 +83,18 @@ def __new__(
cls._metrics_tracer_factory.gfe_enabled = gfe_enabled

if cls._metrics_tracer_factory.enabled != enabled:
cls._metrics_tracer_factory.enabeld = enabled
cls._metrics_tracer_factory.enabled = enabled

return cls._metrics_tracer_factory

@staticmethod
def get_current_tracer() -> MetricsTracer:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: maybe add set_current_tracer and reset_current_tracer methods?

return SpannerMetricsTracerFactory._current_metrics_tracer_ctx.get()

@property
def current_metrics_tracer(self) -> MetricsTracer:
return SpannerMetricsTracerFactory._current_metrics_tracer_ctx.get()

Comment on lines +94 to +97

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

You've introduced both a static method get_current_tracer() and a property current_metrics_tracer that do the same thing: retrieve the tracer from the context variable.

The property current_metrics_tracer is problematic because it replaces a class attribute with an instance property. Any code that previously accessed SpannerMetricsTracerFactory.current_metrics_tracer will now get a property object instead of the tracer, which is a breaking change and could lead to subtle bugs.

Since all new code in this PR uses the clear and unambiguous static method get_current_tracer(), I recommend removing the redundant and potentially confusing current_metrics_tracer property. This will make the API cleaner and prevent accidental misuse.

@staticmethod
def _generate_client_uid() -> str:
"""Generate a client UID in the form of uuidv4@pid@hostname.
Expand Down
13 changes: 13 additions & 0 deletions tests/unit/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import pytest
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: add copyright header

from unittest.mock import patch


@pytest.fixture(autouse=True)
def mock_periodic_exporting_metric_reader():
"""Globally mock PeriodicExportingMetricReader to prevent real network calls."""
with patch(
"google.cloud.spanner_v1.client.PeriodicExportingMetricReader"
) as mock_client_reader, patch(
"opentelemetry.sdk.metrics.export.PeriodicExportingMetricReader"
):
yield mock_client_reader
Loading
Loading