diff --git a/packages/google-cloud-spanner/google/cloud/spanner_dbapi/connection.py b/packages/google-cloud-spanner/google/cloud/spanner_dbapi/connection.py index f21c1fea05fe..318ed281adce 100644 --- a/packages/google-cloud-spanner/google/cloud/spanner_dbapi/connection.py +++ b/packages/google-cloud-spanner/google/cloud/spanner_dbapi/connection.py @@ -91,10 +91,28 @@ class Connection: the read-only transaction is semantically the same, and only indicates that the read-only transaction should end a that a new one should be started when the next statement is executed. + :type retry_aborts_internally: bool + :param retry_aborts_internally: + (Optional) Default True. When True, the connection will automatically retry aborted + transactions by replaying all statements and validating checksums of read results. + This is the recommended setting for interactive use and ORMs like Django that build + transactions incrementally through individual cursor.execute() calls. + + Set to False when the application already implements its own transaction retry logic + (e.g. by wrapping the entire transaction in a callable and re-invoking it on abort, + similar to ``Session.run_in_transaction``). In this mode, ``Aborted`` errors from + ``commit()`` will be raised directly as ``RetryAborted`` without entering the + internal statement-replay loop. This avoids nested retry loops and the associated + contention amplification under concurrent writes. + + This is equivalent to ``RETRY_ABORTS_INTERNALLY`` in the Spanner JDBC driver. + **kwargs: Initial value for connection variables. """ - def __init__(self, instance, database=None, read_only=False, **kwargs): + def __init__( + self, instance, database=None, read_only=False, retry_aborts_internally=True, **kwargs + ): self._instance = instance self._database = database self._ddl_statements = [] @@ -110,6 +128,7 @@ def __init__(self, instance, database=None, read_only=False, **kwargs): # connection close self._own_pool = True self._read_only = read_only + self._retry_aborts_internally = retry_aborts_internally self._staleness = None self.request_priority = None self._transaction_begin_marked = False @@ -248,6 +267,33 @@ def read_only(self, value): ) self._read_only = value + @property + def retry_aborts_internally(self): + """Flag: whether the connection retries aborted transactions internally. + + Returns: + bool: + True if the connection will retry aborted transactions using + statement replay with checksum validation (default). False if + aborted transactions will raise ``RetryAborted`` directly. + """ + return self._retry_aborts_internally + + @retry_aborts_internally.setter + def retry_aborts_internally(self, value): + """``retry_aborts_internally`` flag setter. + + Args: + value (bool): True to enable internal retry (default), False to disable. + """ + if self._spanner_transaction_started: + raise ValueError( + "retry_aborts_internally can't be changed while a transaction " + "is in progress. Commit or rollback the current transaction " + "and try again." + ) + self._retry_aborts_internally = value + @property def request_options(self): """Options for the next SQL operations. @@ -491,9 +537,12 @@ def commit(self): try: if self._spanner_transaction_started and not self._read_only: self._transaction.commit() - except Aborted: - self._transaction_helper.retry_transaction() - self.commit() + except Aborted as exc: + if self._retry_aborts_internally: + self._transaction_helper.retry_transaction() + self.commit() + else: + raise RetryAborted(str(exc)) from exc finally: self._reset_post_commit_or_rollback() @@ -747,6 +796,7 @@ def connect( ca_certificate=None, client_certificate=None, client_key=None, + retry_aborts_internally=True, **kwargs, ): """Creates a connection to a Google Cloud Spanner database. @@ -822,6 +872,13 @@ def connect( :param client_key: (Optional) The path to the client key file used for mTLS connection. This is intended only for experimental host spanner endpoints. This is mandatory if the experimental_host requires an mTLS connection. + + :type retry_aborts_internally: bool + :param retry_aborts_internally: + (Optional) Default True. When True, the connection will automatically retry + aborted transactions internally by replaying all statements and validating + checksums. Set to False when the application manages its own transaction retry + logic. See ``Connection.retry_aborts_internally`` for details. """ if client is None: client_info = ClientInfo( @@ -866,7 +923,12 @@ def connect( database = instance.database( database_id, pool=pool, database_role=database_role, logger=logger ) - conn = Connection(instance, database, **kwargs) + conn = Connection( + instance, + database, + retry_aborts_internally=retry_aborts_internally, + **kwargs, + ) if pool is not None: conn._own_pool = False diff --git a/packages/google-cloud-spanner/tests/unit/spanner_dbapi/test_connection.py b/packages/google-cloud-spanner/tests/unit/spanner_dbapi/test_connection.py index ae90d14de29d..09e6c6b4f6ef 100644 --- a/packages/google-cloud-spanner/tests/unit/spanner_dbapi/test_connection.py +++ b/packages/google-cloud-spanner/tests/unit/spanner_dbapi/test_connection.py @@ -355,6 +355,61 @@ def test_commit_database_error(self): with pytest.raises(ValueError): connection.commit() + def test_retry_aborts_internally_defaults_true(self): + connection = self._make_connection() + self.assertTrue(connection.retry_aborts_internally) + + def test_retry_aborts_internally_set_false(self): + connection = self._make_connection(retry_aborts_internally=False) + self.assertFalse(connection.retry_aborts_internally) + + def test_retry_aborts_internally_setter(self): + connection = self._make_connection() + connection.retry_aborts_internally = False + self.assertFalse(connection.retry_aborts_internally) + + def test_retry_aborts_internally_setter_while_transaction_active(self): + connection = self._make_connection() + connection._spanner_transaction_started = True + with pytest.raises(ValueError, match="retry_aborts_internally can't be changed"): + connection.retry_aborts_internally = False + + def test_commit_retries_internally_when_enabled(self): + from google.api_core.exceptions import Aborted + + self._under_test._transaction = mock_transaction = mock.MagicMock() + self._under_test._spanner_transaction_started = True + mock_transaction.commit = mock.MagicMock( + side_effect=[Aborted("aborted"), None] + ) + self._under_test._retry_aborts_internally = True + + with mock.patch.object( + self._under_test._transaction_helper, "retry_transaction" + ) as mock_retry, mock.patch( + "google.cloud.spanner_dbapi.connection.Connection._release_session" + ): + self._under_test.commit() + + mock_retry.assert_called_once() + + def test_commit_raises_retry_aborted_when_internal_retry_disabled(self): + from google.api_core.exceptions import Aborted + from google.cloud.spanner_dbapi.exceptions import RetryAborted + + self._under_test._transaction = mock_transaction = mock.MagicMock() + self._under_test._spanner_transaction_started = True + mock_transaction.commit = mock.MagicMock( + side_effect=Aborted("aborted") + ) + self._under_test._retry_aborts_internally = False + + with mock.patch( + "google.cloud.spanner_dbapi.connection.Connection._release_session" + ): + with pytest.raises(RetryAborted, match="aborted"): + self._under_test.commit() + @mock.patch.object(warnings, "warn") def test_rollback_spanner_transaction_not_started(self, mock_warn): self._under_test._spanner_transaction_started = False @@ -882,6 +937,31 @@ def test_connection_wo_database(self): ) self.assertTrue(connection.database is None) + def test_connect_retry_aborts_internally_default(self): + from google.cloud.spanner_dbapi import connect + + connection = connect( + "test-instance", + "test-database", + project="test-project", + credentials=AnonymousCredentials(), + client_options={"api_endpoint": "none"}, + ) + self.assertTrue(connection.retry_aborts_internally) + + def test_connect_retry_aborts_internally_false(self): + from google.cloud.spanner_dbapi import connect + + connection = connect( + "test-instance", + "test-database", + project="test-project", + credentials=AnonymousCredentials(), + client_options={"api_endpoint": "none"}, + retry_aborts_internally=False, + ) + self.assertFalse(connection.retry_aborts_internally) + def exit_ctx_func(self, exc_type, exc_value, traceback): """Context __exit__ method mock."""