Skip to content

Commit

Permalink
Implement interceptor to wrap and increase x-goog-spanner-request-id …
Browse files Browse the repository at this point in the history
…attempts per retry

This monkey patches SpannerClient methods to have an interceptor that
increases the attempts per retry. The prelude though is that any
callers to it must pass in the attempt value 0 so that each pass through
will correctly increase the attempt field's value.
  • Loading branch information
odeke-em committed Jan 15, 2025
1 parent 4d13c9b commit f92c51d
Show file tree
Hide file tree
Showing 6 changed files with 160 additions and 147 deletions.
33 changes: 32 additions & 1 deletion google/cloud/spanner_v1/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,11 @@
from google.cloud.spanner_v1 import TypeCode
from google.cloud.spanner_v1 import ExecuteSqlRequest
from google.cloud.spanner_v1 import JsonObject
from google.cloud.spanner_v1.request_id_header import with_request_id
from google.cloud.spanner_v1.request_id_header import REQ_ID_HEADER_KEY, with_request_id
from google.rpc.error_details_pb2 import RetryInfo

import random
from typing import Callable

# Validation error messages
NUMERIC_MAX_SCALE_ERR_MSG = (
Expand Down Expand Up @@ -648,3 +649,33 @@ def reset(self):

def _metadata_with_request_id(*args, **kwargs):
return with_request_id(*args, **kwargs)


class InterceptingHeaderInjector:
def __init__(self, original_callable: Callable):
self._original_callable = original_callable

def __call__(self, *args, **kwargs):
metadata = kwargs.get("metadata", [])
# Find all the headers that match the x-goog-spanner-request-id
# header an on each retry increment the value.
all_metadata = []
for key, value in metadata:
if key is REQ_ID_HEADER_KEY:
# Otherwise now increment the count for the attempt number.
splits = value.split(".")
attempt_plus_one = int(splits[-1]) + 1
splits[-1] = str(attempt_plus_one)
value_before = value
value = ".".join(splits)
print("incrementing value on retry from", value_before, "to", value)

all_metadata.append(
(
key,
value,
)
)

kwargs["metadata"] = all_metadata
return self._original_callable(*args, **kwargs)
46 changes: 20 additions & 26 deletions google/cloud/spanner_v1/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,18 +231,16 @@ def commit(
attempt = AtomicCounter(0)
next_nth_request = database._next_nth_request

def wrapped_method(*args, **kwargs):
all_metadata = database.metadata_with_request_id(
next_nth_request,
attempt.increment(),
metadata,
)
method = functools.partial(
api.commit,
request=request,
metadata=all_metadata,
)
return method(*args, **kwargs)
all_metadata = database.metadata_with_request_id(
next_nth_request,
attempt.increment(),
metadata,
)
method = functools.partial(
api.commit,
request=request,
metadata=all_metadata,
)

deadline = time.time() + kwargs.get(
"timeout_secs", DEFAULT_RETRY_TIMEOUT_SECS
Expand Down Expand Up @@ -361,21 +359,17 @@ def batch_write(self, request_options=None, exclude_txn_from_change_streams=Fals
trace_attributes,
observability_options=observability_options,
):
attempt = AtomicCounter(0)
next_nth_request = database._next_nth_request

def wrapped_method(*args, **kwargs):
all_metadata = database.metadata_with_request_id(
next_nth_request,
attempt.increment(),
metadata,
)
method = functools.partial(
api.batch_write,
request=request,
metadata=all_metadata,
)
return method(*args, **kwargs)
all_metadata = database.metadata_with_request_id(
next_nth_request,
0,
metadata,
)
method = functools.partial(
api.batch_write,
request=request,
metadata=all_metadata,
)

response = _retry(
method,
Expand Down
74 changes: 48 additions & 26 deletions google/cloud/spanner_v1/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
_metadata_with_prefix,
_metadata_with_leader_aware_routing,
_metadata_with_request_id,
InterceptingHeaderInjector,
)
from google.cloud.spanner_v1.batch import Batch
from google.cloud.spanner_v1.batch import MutationGroups
Expand Down Expand Up @@ -427,6 +428,43 @@ def logger(self):

@property
def spanner_api(self):
"""Helper for session-related API calls."""
api = self._generate_spanner_api()
if not api:
return api

# Now wrap each method's __call__ method with our wrapped one.
# This is how to deal with the fact that there are no proper gRPC
# interceptors for Python hence the remedy is to replace callables
# with our custom wrapper.
attrs = dir(api)
for attr_name in attrs:
mangled = attr_name.startswith("__")
if mangled:
continue

non_public = attr_name.startswith("_")
if non_public:
continue

attr = getattr(api, attr_name)
callable_attr = callable(attr)
if callable_attr is None:
continue

# We should only be looking at bound methods to SpannerClient
# as those are the RPC invoking methods that need to be wrapped

is_method = type(attr).__name__ == "method"
if not is_method:
continue

print("attr_name", attr_name, "callable_attr", attr)
setattr(api, attr_name, InterceptingHeaderInjector(attr))

return api

def _generate_spanner_api(self):
"""Helper for session-related API calls."""
if self._spanner_api is None:
client_info = self._instance._client._client_info
Expand Down Expand Up @@ -757,11 +795,11 @@ def execute_pdml():
) as span:
with SessionCheckout(self._pool) as session:
add_span_event(span, "Starting BeginTransaction")
begin_txn_attempt.increment()
txn = api.begin_transaction(
session=session.name, options=txn_options,
session=session.name,
options=txn_options,
metadata=self.metadata_with_request_id(
begin_txn_nth_request, begin_txn_attempt.value, metadata
begin_txn_nth_request, begin_txn_attempt.increment(), metadata
),
)

Expand All @@ -776,36 +814,20 @@ def execute_pdml():
request_options=request_options,
)

def wrapped_method(*args, **kwargs):
partial_attempt.increment()
method = functools.partial(
api.execute_streaming_sql,
metadata=self.metadata_with_request_id(
partial_nth_request, partial_attempt.value, metadata
),
)
return method(*args, **kwargs)
method = functools.partial(
api.execute_streaming_sql,
metadata=self.metadata_with_request_id(
partial_nth_request, partial_attempt.increment(), metadata
),
)

iterator = _restart_on_unavailable(
method=wrapped_method,
method=method,
trace_name="CloudSpanner.ExecuteStreamingSql",
request=request,
transaction_selector=txn_selector,
observability_options=self.observability_options,
attempt=begin_txn_attempt,
)
<<<<<<< HEAD
=======
return method(*args, **kwargs)

iterator = _restart_on_unavailable(
method=wrapped_method,
trace_name="CloudSpanner.ExecuteStreamingSql",
request=request,
transaction_selector=txn_selector,
observability_options=self.observability_options,
)
>>>>>>> 54df502... Update tests

result_set = StreamedResultSet(iterator)
list(result_set) # consume all partials
Expand Down
13 changes: 1 addition & 12 deletions google/cloud/spanner_v1/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,11 +262,6 @@ def create_sessions(attempt):
return api.batch_create_sessions(
request=request,
metadata=all_metadata,
# Manually passing retry=None because otherwise any
# UNAVAILABLE retry will be retried without replenishing
# the metadata, hence this allows us to manually update
# the metadata using retry_on_unavailable.
retry=None,
)

resp = retry_on_unavailable(create_sessions)
Expand Down Expand Up @@ -568,6 +563,7 @@ def bind(self, database):
"CloudSpanner.PingingPool.BatchCreateSessions",
observability_options=observability_options,
) as span:
created_session_count = 0
while created_session_count < self.size:
nth_req = database._next_nth_request

Expand All @@ -578,13 +574,6 @@ def create_sessions(attempt):
return api.batch_create_sessions(
request=request,
metadata=all_metadata,
# Manually passing retry=None because otherwise any
# UNAVAILABLE retry will be retried without replenishing
# the metadata, hence this allows us to manually update
# the metadata using retry_on_unavailable.
# TODO: Figure out how to intercept and monkey patch the internals
# of the gRPC transport.
retry=None,
)

resp = retry_on_unavailable(create_sessions)
Expand Down
Loading

0 comments on commit f92c51d

Please sign in to comment.