Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(x-goog-spanner-request-id): implement Request-ID #1264

Open
wants to merge 16 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 36 additions & 1 deletion google/cloud/spanner_v1/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,11 @@
from google.cloud.spanner_v1 import TypeCode
from google.cloud.spanner_v1 import ExecuteSqlRequest
from google.cloud.spanner_v1 import JsonObject
from google.cloud.spanner_v1.request_id_header import with_request_id
from google.cloud.spanner_v1.request_id_header import REQ_ID_HEADER_KEY, with_request_id
from google.rpc.error_details_pb2 import RetryInfo

import random
from typing import Callable

# Validation error messages
NUMERIC_MAX_SCALE_ERR_MSG = (
Expand Down Expand Up @@ -641,6 +642,40 @@ def __radd__(self, n):
"""
return self.__add__(n)

def reset(self):
with self.__lock:
self.__value = 0


def _metadata_with_request_id(*args, **kwargs):
return with_request_id(*args, **kwargs)


class InterceptingHeaderInjector:
def __init__(self, original_callable: Callable):
self._original_callable = original_callable

def __call__(self, *args, **kwargs):
metadata = kwargs.get("metadata", [])
# Find all the headers that match the x-goog-spanner-request-id
# header an on each retry increment the value.
all_metadata = []
for key, value in metadata:
if key is REQ_ID_HEADER_KEY:
# Otherwise now increment the count for the attempt number.
splits = value.split(".")
attempt_plus_one = int(splits[-1]) + 1
splits[-1] = str(attempt_plus_one)
value_before = value
value = ".".join(splits)
print("incrementing value on retry from", value_before, "to", value)

all_metadata.append(
(
key,
value,
)
)

kwargs["metadata"] = all_metadata
return self._original_callable(*args, **kwargs)
21 changes: 19 additions & 2 deletions google/cloud/spanner_v1/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from google.cloud.spanner_v1._helpers import (
_metadata_with_prefix,
_metadata_with_leader_aware_routing,
AtomicCounter,
)
from google.cloud.spanner_v1._opentelemetry_tracing import trace_call
from google.cloud.spanner_v1 import RequestOptions
Expand Down Expand Up @@ -227,11 +228,20 @@ def commit(
trace_attributes,
observability_options=observability_options,
):
attempt = AtomicCounter(0)
next_nth_request = database._next_nth_request

all_metadata = database.metadata_with_request_id(
next_nth_request,
attempt.increment(),
metadata,
)
method = functools.partial(
api.commit,
request=request,
metadata=metadata,
metadata=all_metadata,
)

deadline = time.time() + kwargs.get(
"timeout_secs", DEFAULT_RETRY_TIMEOUT_SECS
)
Expand Down Expand Up @@ -349,11 +359,18 @@ def batch_write(self, request_options=None, exclude_txn_from_change_streams=Fals
trace_attributes,
observability_options=observability_options,
):
next_nth_request = database._next_nth_request
all_metadata = database.metadata_with_request_id(
next_nth_request,
0,
metadata,
)
method = functools.partial(
api.batch_write,
request=request,
metadata=metadata,
metadata=all_metadata,
)

response = _retry(
method,
allowed_exceptions={
Expand Down
9 changes: 9 additions & 0 deletions google/cloud/spanner_v1/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
from google.cloud.spanner_v1._helpers import _merge_query_options
from google.cloud.spanner_v1._helpers import _metadata_with_prefix
from google.cloud.spanner_v1.instance import Instance
from google.cloud.spanner_v1._helpers import AtomicCounter

_CLIENT_INFO = client_info.ClientInfo(client_library_version=__version__)
EMULATOR_ENV_VAR = "SPANNER_EMULATOR_HOST"
Expand Down Expand Up @@ -147,6 +148,8 @@ class Client(ClientWithProject):
SCOPE = (SPANNER_ADMIN_SCOPE,)
"""The scopes required for Google Cloud Spanner."""

NTH_CLIENT = AtomicCounter()

def __init__(
self,
project=None,
Expand Down Expand Up @@ -199,6 +202,12 @@ def __init__(
self._route_to_leader_enabled = route_to_leader_enabled
self._directed_read_options = directed_read_options
self._observability_options = observability_options
self._nth_client_id = Client.NTH_CLIENT.increment()
self._nth_request = AtomicCounter(0)

@property
def _next_nth_request(self):
return self._nth_request.increment()

@property
def credentials(self):
Expand Down
Loading