Skip to content

Commit

Permalink
More updates to monkey patched retryer
Browse files Browse the repository at this point in the history
  • Loading branch information
odeke-em committed Jan 15, 2025
1 parent 17a5896 commit 3c82d4e
Show file tree
Hide file tree
Showing 5 changed files with 88 additions and 121 deletions.
46 changes: 20 additions & 26 deletions google/cloud/spanner_v1/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,18 +231,16 @@ def commit(
attempt = AtomicCounter(0)
next_nth_request = database._next_nth_request

def wrapped_method(*args, **kwargs):
all_metadata = database.metadata_with_request_id(
next_nth_request,
attempt.increment(),
metadata,
)
method = functools.partial(
api.commit,
request=request,
metadata=all_metadata,
)
return method(*args, **kwargs)
all_metadata = database.metadata_with_request_id(
next_nth_request,
attempt.increment(),
metadata,
)
method = functools.partial(
api.commit,
request=request,
metadata=all_metadata,
)

deadline = time.time() + kwargs.get(
"timeout_secs", DEFAULT_RETRY_TIMEOUT_SECS
Expand Down Expand Up @@ -361,21 +359,17 @@ def batch_write(self, request_options=None, exclude_txn_from_change_streams=Fals
trace_attributes,
observability_options=observability_options,
):
attempt = AtomicCounter(0)
next_nth_request = database._next_nth_request

def wrapped_method(*args, **kwargs):
all_metadata = database.metadata_with_request_id(
next_nth_request,
attempt.increment(),
metadata,
)
method = functools.partial(
api.batch_write,
request=request,
metadata=all_metadata,
)
return method(*args, **kwargs)
all_metadata = database.metadata_with_request_id(
next_nth_request,
0,
metadata,
)
method = functools.partial(
api.batch_write,
request=request,
metadata=all_metadata,
)

response = _retry(
method,
Expand Down
21 changes: 8 additions & 13 deletions google/cloud/spanner_v1/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -795,12 +795,11 @@ def execute_pdml():
) as span:
with SessionCheckout(self._pool) as session:
add_span_event(span, "Starting BeginTransaction")
begin_txn_attempt.increment()
txn = api.begin_transaction(
session=session.name,
options=txn_options,
metadata=self.metadata_with_request_id(
begin_txn_nth_request, begin_txn_attempt.value, metadata
begin_txn_nth_request, begin_txn_attempt.increment(), metadata
),
)

Expand All @@ -815,23 +814,19 @@ def execute_pdml():
request_options=request_options,
)

def wrapped_method(*args, **kwargs):
partial_attempt.increment()
method = functools.partial(
api.execute_streaming_sql,
metadata=self.metadata_with_request_id(
partial_nth_request, partial_attempt.value, metadata
),
)
return method(*args, **kwargs)
method = functools.partial(
api.execute_streaming_sql,
metadata=self.metadata_with_request_id(
partial_nth_request, partial_attempt.increment(), metadata
),
)

iterator = _restart_on_unavailable(
method=wrapped_method,
method=method,
trace_name="CloudSpanner.ExecuteStreamingSql",
request=request,
transaction_selector=txn_selector,
observability_options=self.observability_options,
attempt=begin_txn_attempt,
)

result_set = StreamedResultSet(iterator)
Expand Down
1 change: 1 addition & 0 deletions google/cloud/spanner_v1/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,6 +563,7 @@ def bind(self, database):
"CloudSpanner.PingingPool.BatchCreateSessions",
observability_options=observability_options,
) as span:
created_session_count = 0
while created_session_count < self.size:
nth_req = database._next_nth_request

Expand Down
140 changes: 58 additions & 82 deletions google/cloud/spanner_v1/snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,22 +326,17 @@ def read(
)

nth_request = getattr(database, "_next_nth_request", 0)
attempt = AtomicCounter(0)

def wrapped_restart(*args, **kwargs):
attempt.increment()
all_metadata = database.metadata_with_request_id(
nth_request, attempt.value, metadata
)
all_metadata = database.metadata_with_request_id(
nth_request, 1, metadata
)

restart = functools.partial(
api.streaming_read,
request=request,
metadata=all_metadata,
retry=retry,
timeout=timeout,
)
return restart(*args, **kwargs)
restart = functools.partial(
api.streaming_read,
request=request,
metadata=all_metadata,
retry=retry,
timeout=timeout,
)

trace_attributes = {"table_id": table, "columns": columns}
observability_options = getattr(database, "observability_options", None)
Expand All @@ -350,14 +345,13 @@ def wrapped_restart(*args, **kwargs):
# lock is added to handle the inline begin for first rpc
with self._lock:
iterator = _restart_on_unavailable(
wrapped_restart,
restart,
request,
f"CloudSpanner.{type(self).__name__}.read",
self._session,
trace_attributes,
transaction=self,
observability_options=observability_options,
attempt=attempt,
)
self._read_request_count += 1
if self._multi_use:
Expand All @@ -373,14 +367,13 @@ def wrapped_restart(*args, **kwargs):
)
else:
iterator = _restart_on_unavailable(
wrapped_restart,
restart,
request,
f"CloudSpanner.{type(self).__name__}.read",
self._session,
trace_attributes,
transaction=self,
observability_options=observability_options,
attempt=attempt,
)

self._read_request_count += 1
Expand Down Expand Up @@ -555,20 +548,18 @@ def execute_sql(
)

nth_request = getattr(database, "_next_nth_request", 0)
attempt = AtomicCounter(0)

def wrapped_restart(*args, **kwargs):
restart = functools.partial(
api.execute_streaming_sql,
request=request,
metadata=database.metadata_with_request_id(
nth_request, attempt.increment(), metadata
),
retry=retry,
timeout=timeout,
)

return restart(*args, **kwargs)
if not isinstance(nth_request, int):
raise Exception(f"failed to get an integer back: {nth_request}")

restart = functools.partial(
api.execute_streaming_sql,
request=request,
metadata=database.metadata_with_request_id(
nth_request, 1, metadata
),
retry=retry,
timeout=timeout,
)

trace_attributes = {"db.statement": sql}
observability_options = getattr(database, "observability_options", None)
Expand All @@ -577,7 +568,7 @@ def wrapped_restart(*args, **kwargs):
# lock is added to handle the inline begin for first rpc
with self._lock:
return self._get_streamed_result_set(
wrapped_restart,
restart,
request,
trace_attributes,
column_info,
Expand All @@ -586,7 +577,7 @@ def wrapped_restart(*args, **kwargs):
)
else:
return self._get_streamed_result_set(
wrapped_restart,
restart,
request,
trace_attributes,
column_info,
Expand Down Expand Up @@ -714,24 +705,19 @@ def partition_read(
observability_options=getattr(database, "observability_options", None),
):
nth_request = getattr(database, "_next_nth_request", 0)
attempt = AtomicCounter(0)

def wrapped_method(*args, **kwargs):
attempt.increment()
all_metadata = database.metadata_with_request_id(
nth_request, attempt.value, metadata
)
method = functools.partial(
api.partition_read,
request=request,
metadata=all_metadata,
retry=retry,
timeout=timeout,
)
return method(*args, **kwargs)
all_metadata = database.metadata_with_request_id(
nth_request, 1, metadata
)
method = functools.partial(
api.partition_read,
request=request,
metadata=all_metadata,
retry=retry,
timeout=timeout,
)

response = _retry(
wrapped_method,
method,
allowed_exceptions={InternalServerError: _check_rst_stream_error},
)

Expand Down Expand Up @@ -827,24 +813,19 @@ def partition_query(
observability_options=getattr(database, "observability_options", None),
):
nth_request = getattr(database, "_next_nth_request", 0)
attempt = AtomicCounter(0)

def wrapped_method(*args, **kwargs):
attempt.increment()
all_metadata = database.metadata_with_request_id(
nth_request, attempt.value, metadata
)
method = functools.partial(
api.partition_query,
request=request,
metadata=all_metadata,
retry=retry,
timeout=timeout,
)
return method(*args, **kwargs)
all_metadata = database.metadata_with_request_id(
nth_request, 1, metadata
)
method = functools.partial(
api.partition_query,
request=request,
metadata=all_metadata,
retry=retry,
timeout=timeout,
)

response = _retry(
wrapped_method,
method,
allowed_exceptions={InternalServerError: _check_rst_stream_error},
)

Expand Down Expand Up @@ -983,23 +964,18 @@ def begin(self):
observability_options=getattr(database, "observability_options", None),
):
nth_request = getattr(database, "_next_nth_request", 0)
attempt = AtomicCounter(0)

def wrapped_method(*args, **kwargs):
attempt.increment()
all_metadata = database.metadata_with_request_id(
nth_request, attempt.value, metadata
)
method = functools.partial(
api.begin_transaction,
session=self._session.name,
options=txn_selector.begin,
metadata=all_metadata,
)
return method(*args, **kwargs)
all_metadata = database.metadata_with_request_id(
nth_request, 1, metadata
)
method = functools.partial(
api.begin_transaction,
session=self._session.name,
options=txn_selector.begin,
metadata=all_metadata,
)

response = _retry(
wrapped_method,
method,
allowed_exceptions={InternalServerError: _check_rst_stream_error},
)
self._transaction_id = response.id
Expand Down
1 change: 1 addition & 0 deletions tests/unit/test_snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -1878,6 +1878,7 @@ def __init__(self, directed_read_options=None):
def observability_options(self):
return dict(db_name=self.name)

@property
def _next_nth_request(self):
return self._instance._client._next_nth_request

Expand Down

0 comments on commit 3c82d4e

Please sign in to comment.