From 3c82d4e31dcacb8da9f4385aea5d4ac6f69a3a1e Mon Sep 17 00:00:00 2001 From: Emmanuel T Odeke Date: Wed, 15 Jan 2025 01:21:50 -0800 Subject: [PATCH] More updates to monkey patched retryer --- google/cloud/spanner_v1/batch.py | 46 ++++----- google/cloud/spanner_v1/database.py | 21 ++--- google/cloud/spanner_v1/pool.py | 1 + google/cloud/spanner_v1/snapshot.py | 140 ++++++++++++---------------- tests/unit/test_snapshot.py | 1 + 5 files changed, 88 insertions(+), 121 deletions(-) diff --git a/google/cloud/spanner_v1/batch.py b/google/cloud/spanner_v1/batch.py index 8ec84b75ab..094e2530db 100644 --- a/google/cloud/spanner_v1/batch.py +++ b/google/cloud/spanner_v1/batch.py @@ -231,18 +231,16 @@ def commit( attempt = AtomicCounter(0) next_nth_request = database._next_nth_request - def wrapped_method(*args, **kwargs): - all_metadata = database.metadata_with_request_id( - next_nth_request, - attempt.increment(), - metadata, - ) - method = functools.partial( - api.commit, - request=request, - metadata=all_metadata, - ) - return method(*args, **kwargs) + all_metadata = database.metadata_with_request_id( + next_nth_request, + attempt.increment(), + metadata, + ) + method = functools.partial( + api.commit, + request=request, + metadata=all_metadata, + ) deadline = time.time() + kwargs.get( "timeout_secs", DEFAULT_RETRY_TIMEOUT_SECS @@ -361,21 +359,17 @@ def batch_write(self, request_options=None, exclude_txn_from_change_streams=Fals trace_attributes, observability_options=observability_options, ): - attempt = AtomicCounter(0) next_nth_request = database._next_nth_request - - def wrapped_method(*args, **kwargs): - all_metadata = database.metadata_with_request_id( - next_nth_request, - attempt.increment(), - metadata, - ) - method = functools.partial( - api.batch_write, - request=request, - metadata=all_metadata, - ) - return method(*args, **kwargs) + all_metadata = database.metadata_with_request_id( + next_nth_request, + 0, + metadata, + ) + method = functools.partial( + api.batch_write, + request=request, + metadata=all_metadata, + ) response = _retry( method, diff --git a/google/cloud/spanner_v1/database.py b/google/cloud/spanner_v1/database.py index f1a1327b0f..8a53a6b05f 100644 --- a/google/cloud/spanner_v1/database.py +++ b/google/cloud/spanner_v1/database.py @@ -795,12 +795,11 @@ def execute_pdml(): ) as span: with SessionCheckout(self._pool) as session: add_span_event(span, "Starting BeginTransaction") - begin_txn_attempt.increment() txn = api.begin_transaction( session=session.name, options=txn_options, metadata=self.metadata_with_request_id( - begin_txn_nth_request, begin_txn_attempt.value, metadata + begin_txn_nth_request, begin_txn_attempt.increment(), metadata ), ) @@ -815,23 +814,19 @@ def execute_pdml(): request_options=request_options, ) - def wrapped_method(*args, **kwargs): - partial_attempt.increment() - method = functools.partial( - api.execute_streaming_sql, - metadata=self.metadata_with_request_id( - partial_nth_request, partial_attempt.value, metadata - ), - ) - return method(*args, **kwargs) + method = functools.partial( + api.execute_streaming_sql, + metadata=self.metadata_with_request_id( + partial_nth_request, partial_attempt.increment(), metadata + ), + ) iterator = _restart_on_unavailable( - method=wrapped_method, + method=method, trace_name="CloudSpanner.ExecuteStreamingSql", request=request, transaction_selector=txn_selector, observability_options=self.observability_options, - attempt=begin_txn_attempt, ) result_set = StreamedResultSet(iterator) diff --git a/google/cloud/spanner_v1/pool.py b/google/cloud/spanner_v1/pool.py index a6e0afd64c..401538b685 100644 --- a/google/cloud/spanner_v1/pool.py +++ b/google/cloud/spanner_v1/pool.py @@ -563,6 +563,7 @@ def bind(self, database): "CloudSpanner.PingingPool.BatchCreateSessions", observability_options=observability_options, ) as span: + created_session_count = 0 while created_session_count < self.size: nth_req = database._next_nth_request diff --git a/google/cloud/spanner_v1/snapshot.py b/google/cloud/spanner_v1/snapshot.py index 6743c9bd79..f68d73df95 100644 --- a/google/cloud/spanner_v1/snapshot.py +++ b/google/cloud/spanner_v1/snapshot.py @@ -326,22 +326,17 @@ def read( ) nth_request = getattr(database, "_next_nth_request", 0) - attempt = AtomicCounter(0) - - def wrapped_restart(*args, **kwargs): - attempt.increment() - all_metadata = database.metadata_with_request_id( - nth_request, attempt.value, metadata - ) + all_metadata = database.metadata_with_request_id( + nth_request, 1, metadata + ) - restart = functools.partial( - api.streaming_read, - request=request, - metadata=all_metadata, - retry=retry, - timeout=timeout, - ) - return restart(*args, **kwargs) + restart = functools.partial( + api.streaming_read, + request=request, + metadata=all_metadata, + retry=retry, + timeout=timeout, + ) trace_attributes = {"table_id": table, "columns": columns} observability_options = getattr(database, "observability_options", None) @@ -350,14 +345,13 @@ def wrapped_restart(*args, **kwargs): # lock is added to handle the inline begin for first rpc with self._lock: iterator = _restart_on_unavailable( - wrapped_restart, + restart, request, f"CloudSpanner.{type(self).__name__}.read", self._session, trace_attributes, transaction=self, observability_options=observability_options, - attempt=attempt, ) self._read_request_count += 1 if self._multi_use: @@ -373,14 +367,13 @@ def wrapped_restart(*args, **kwargs): ) else: iterator = _restart_on_unavailable( - wrapped_restart, + restart, request, f"CloudSpanner.{type(self).__name__}.read", self._session, trace_attributes, transaction=self, observability_options=observability_options, - attempt=attempt, ) self._read_request_count += 1 @@ -555,20 +548,18 @@ def execute_sql( ) nth_request = getattr(database, "_next_nth_request", 0) - attempt = AtomicCounter(0) - - def wrapped_restart(*args, **kwargs): - restart = functools.partial( - api.execute_streaming_sql, - request=request, - metadata=database.metadata_with_request_id( - nth_request, attempt.increment(), metadata - ), - retry=retry, - timeout=timeout, - ) - - return restart(*args, **kwargs) + if not isinstance(nth_request, int): + raise Exception(f"failed to get an integer back: {nth_request}") + + restart = functools.partial( + api.execute_streaming_sql, + request=request, + metadata=database.metadata_with_request_id( + nth_request, 1, metadata + ), + retry=retry, + timeout=timeout, + ) trace_attributes = {"db.statement": sql} observability_options = getattr(database, "observability_options", None) @@ -577,7 +568,7 @@ def wrapped_restart(*args, **kwargs): # lock is added to handle the inline begin for first rpc with self._lock: return self._get_streamed_result_set( - wrapped_restart, + restart, request, trace_attributes, column_info, @@ -586,7 +577,7 @@ def wrapped_restart(*args, **kwargs): ) else: return self._get_streamed_result_set( - wrapped_restart, + restart, request, trace_attributes, column_info, @@ -714,24 +705,19 @@ def partition_read( observability_options=getattr(database, "observability_options", None), ): nth_request = getattr(database, "_next_nth_request", 0) - attempt = AtomicCounter(0) - - def wrapped_method(*args, **kwargs): - attempt.increment() - all_metadata = database.metadata_with_request_id( - nth_request, attempt.value, metadata - ) - method = functools.partial( - api.partition_read, - request=request, - metadata=all_metadata, - retry=retry, - timeout=timeout, - ) - return method(*args, **kwargs) + all_metadata = database.metadata_with_request_id( + nth_request, 1, metadata + ) + method = functools.partial( + api.partition_read, + request=request, + metadata=all_metadata, + retry=retry, + timeout=timeout, + ) response = _retry( - wrapped_method, + method, allowed_exceptions={InternalServerError: _check_rst_stream_error}, ) @@ -827,24 +813,19 @@ def partition_query( observability_options=getattr(database, "observability_options", None), ): nth_request = getattr(database, "_next_nth_request", 0) - attempt = AtomicCounter(0) - - def wrapped_method(*args, **kwargs): - attempt.increment() - all_metadata = database.metadata_with_request_id( - nth_request, attempt.value, metadata - ) - method = functools.partial( - api.partition_query, - request=request, - metadata=all_metadata, - retry=retry, - timeout=timeout, - ) - return method(*args, **kwargs) + all_metadata = database.metadata_with_request_id( + nth_request, 1, metadata + ) + method = functools.partial( + api.partition_query, + request=request, + metadata=all_metadata, + retry=retry, + timeout=timeout, + ) response = _retry( - wrapped_method, + method, allowed_exceptions={InternalServerError: _check_rst_stream_error}, ) @@ -983,23 +964,18 @@ def begin(self): observability_options=getattr(database, "observability_options", None), ): nth_request = getattr(database, "_next_nth_request", 0) - attempt = AtomicCounter(0) - - def wrapped_method(*args, **kwargs): - attempt.increment() - all_metadata = database.metadata_with_request_id( - nth_request, attempt.value, metadata - ) - method = functools.partial( - api.begin_transaction, - session=self._session.name, - options=txn_selector.begin, - metadata=all_metadata, - ) - return method(*args, **kwargs) + all_metadata = database.metadata_with_request_id( + nth_request, 1, metadata + ) + method = functools.partial( + api.begin_transaction, + session=self._session.name, + options=txn_selector.begin, + metadata=all_metadata, + ) response = _retry( - wrapped_method, + method, allowed_exceptions={InternalServerError: _check_rst_stream_error}, ) self._transaction_id = response.id diff --git a/tests/unit/test_snapshot.py b/tests/unit/test_snapshot.py index e07c77d09f..1c1539ef84 100644 --- a/tests/unit/test_snapshot.py +++ b/tests/unit/test_snapshot.py @@ -1878,6 +1878,7 @@ def __init__(self, directed_read_options=None): def observability_options(self): return dict(db_name=self.name) + @property def _next_nth_request(self): return self._instance._client._next_nth_request