diff --git a/google/cloud/spanner_v1/testing/database_test.py b/google/cloud/spanner_v1/testing/database_test.py index 54afda11e0..16ebd5fb77 100644 --- a/google/cloud/spanner_v1/testing/database_test.py +++ b/google/cloud/spanner_v1/testing/database_test.py @@ -25,6 +25,7 @@ from google.cloud.spanner_v1.testing.interceptors import ( MethodCountInterceptor, MethodAbortInterceptor, + XGoogRequestIDHeaderInterceptor, ) @@ -60,9 +61,11 @@ def __init__( self._method_count_interceptor = MethodCountInterceptor() self._method_abort_interceptor = MethodAbortInterceptor() + self._x_goog_request_id_interceptor = XGoogRequestIDHeaderInterceptor() self._interceptors = [ self._method_count_interceptor, self._method_abort_interceptor, + self._x_goog_request_id_interceptor, ] @property diff --git a/google/cloud/spanner_v1/testing/interceptors.py b/google/cloud/spanner_v1/testing/interceptors.py index a8b015a87d..8c26065f6a 100644 --- a/google/cloud/spanner_v1/testing/interceptors.py +++ b/google/cloud/spanner_v1/testing/interceptors.py @@ -13,6 +13,8 @@ # limitations under the License. from collections import defaultdict +import threading + from grpc_interceptor import ClientInterceptor from google.api_core.exceptions import Aborted @@ -63,3 +65,30 @@ def reset(self): self._method_to_abort = None self._count = 0 self._connection = None + + +class XGoogRequestIDHeaderInterceptor(ClientInterceptor): + def __init__(self): + self._unary_req_segments = [] + self._stream_req_segments = [] + self.__lock = threading.Lock() + + def intercept(self, method, request_or_iterator, call_details): + metadata = call_details.metadata + x_goog_request_id = None + for key, value in metadata: + if key == "x-goog-spanner-request-id": + x_goog_request_id = value + break + + if not x_goog_request_id: + raise Exception(f"Missing {x_goog_request_id}") + + streaming = hasattr(request_or_iterator, "__iter__", False) + with self.__lock: + if streaming: + self._stream_req_segments.append(x_goog_request_id) + else: + self._unary_req_segments.append(x_goog_request_id) + + return method(request_or_iterator, call_details)