| # Copyright 2017 gRPC authors. |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| """Implementation of gRPC Python interceptors.""" |
| |
| import collections |
| import sys |
| |
| import grpc |
| |
| |
| class _ServicePipeline(object): |
| |
| def __init__(self, interceptors): |
| self.interceptors = tuple(interceptors) |
| |
| def _continuation(self, thunk, index): |
| return lambda context: self._intercept_at(thunk, index, context) |
| |
| def _intercept_at(self, thunk, index, context): |
| if index < len(self.interceptors): |
| interceptor = self.interceptors[index] |
| thunk = self._continuation(thunk, index + 1) |
| return interceptor.intercept_service(thunk, context) |
| else: |
| return thunk(context) |
| |
| def execute(self, thunk, context): |
| return self._intercept_at(thunk, 0, context) |
| |
| |
| def service_pipeline(interceptors): |
| return _ServicePipeline(interceptors) if interceptors else None |
| |
| |
| class _ClientCallDetails( |
| collections.namedtuple( |
| '_ClientCallDetails', |
| ('method', 'timeout', 'metadata', 'credentials')), |
| grpc.ClientCallDetails): |
| pass |
| |
| |
| def _unwrap_client_call_details(call_details, default_details): |
| try: |
| method = call_details.method |
| except AttributeError: |
| method = default_details.method |
| |
| try: |
| timeout = call_details.timeout |
| except AttributeError: |
| timeout = default_details.timeout |
| |
| try: |
| metadata = call_details.metadata |
| except AttributeError: |
| metadata = default_details.metadata |
| |
| try: |
| credentials = call_details.credentials |
| except AttributeError: |
| credentials = default_details.credentials |
| |
| return method, timeout, metadata, credentials |
| |
| |
| class _FailureOutcome(grpc.RpcError, grpc.Future, grpc.Call): |
| |
| def __init__(self, exception, traceback): |
| super(_FailureOutcome, self).__init__() |
| self._exception = exception |
| self._traceback = traceback |
| |
| def initial_metadata(self): |
| return None |
| |
| def trailing_metadata(self): |
| return None |
| |
| def code(self): |
| return grpc.StatusCode.INTERNAL |
| |
| def details(self): |
| return 'Exception raised while intercepting the RPC' |
| |
| def cancel(self): |
| return False |
| |
| def cancelled(self): |
| return False |
| |
| def is_active(self): |
| return False |
| |
| def time_remaining(self): |
| return None |
| |
| def running(self): |
| return False |
| |
| def done(self): |
| return True |
| |
| def result(self, ignored_timeout=None): |
| raise self._exception |
| |
| def exception(self, ignored_timeout=None): |
| return self._exception |
| |
| def traceback(self, ignored_timeout=None): |
| return self._traceback |
| |
| def add_callback(self, callback): |
| return False |
| |
| def add_done_callback(self, fn): |
| fn(self) |
| |
| def __iter__(self): |
| return self |
| |
| def next(self): |
| raise self._exception |
| |
| |
| class _UnaryOutcome(grpc.Call, grpc.Future): |
| |
| def __init__(self, response, call): |
| self._response = response |
| self._call = call |
| |
| def initial_metadata(self): |
| return self._call.initial_metadata() |
| |
| def trailing_metadata(self): |
| return self._call.trailing_metadata() |
| |
| def code(self): |
| return self._call.code() |
| |
| def details(self): |
| return self._call.details() |
| |
| def is_active(self): |
| return self._call.is_active() |
| |
| def time_remaining(self): |
| return self._call.time_remaining() |
| |
| def cancel(self): |
| return self._call.cancel() |
| |
| def add_callback(self, callback): |
| return self._call.add_callback(callback) |
| |
| def cancelled(self): |
| return False |
| |
| def running(self): |
| return False |
| |
| def done(self): |
| return True |
| |
| def result(self, ignored_timeout=None): |
| return self._response |
| |
| def exception(self, ignored_timeout=None): |
| return None |
| |
| def traceback(self, ignored_timeout=None): |
| return None |
| |
| def add_done_callback(self, fn): |
| fn(self) |
| |
| |
| class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable): |
| |
| def __init__(self, thunk, method, interceptor): |
| self._thunk = thunk |
| self._method = method |
| self._interceptor = interceptor |
| |
| def __call__(self, request, timeout=None, metadata=None, credentials=None): |
| response, ignored_call = self._with_call( |
| request, |
| timeout=timeout, |
| metadata=metadata, |
| credentials=credentials) |
| return response |
| |
| def _with_call(self, request, timeout=None, metadata=None, |
| credentials=None): |
| client_call_details = _ClientCallDetails(self._method, timeout, |
| metadata, credentials) |
| |
| def continuation(new_details, request): |
| new_method, new_timeout, new_metadata, new_credentials = ( |
| _unwrap_client_call_details(new_details, client_call_details)) |
| try: |
| response, call = self._thunk(new_method).with_call( |
| request, |
| timeout=new_timeout, |
| metadata=new_metadata, |
| credentials=new_credentials) |
| return _UnaryOutcome(response, call) |
| except grpc.RpcError: |
| raise |
| except Exception as exception: # pylint:disable=broad-except |
| return _FailureOutcome(exception, sys.exc_info()[2]) |
| |
| call = self._interceptor.intercept_unary_unary( |
| continuation, client_call_details, request) |
| return call.result(), call |
| |
| def with_call(self, request, timeout=None, metadata=None, credentials=None): |
| return self._with_call( |
| request, |
| timeout=timeout, |
| metadata=metadata, |
| credentials=credentials) |
| |
| def future(self, request, timeout=None, metadata=None, credentials=None): |
| client_call_details = _ClientCallDetails(self._method, timeout, |
| metadata, credentials) |
| |
| def continuation(new_details, request): |
| new_method, new_timeout, new_metadata, new_credentials = ( |
| _unwrap_client_call_details(new_details, client_call_details)) |
| return self._thunk(new_method).future( |
| request, |
| timeout=new_timeout, |
| metadata=new_metadata, |
| credentials=new_credentials) |
| |
| try: |
| return self._interceptor.intercept_unary_unary( |
| continuation, client_call_details, request) |
| except Exception as exception: # pylint:disable=broad-except |
| return _FailureOutcome(exception, sys.exc_info()[2]) |
| |
| |
| class _UnaryStreamMultiCallable(grpc.UnaryStreamMultiCallable): |
| |
| def __init__(self, thunk, method, interceptor): |
| self._thunk = thunk |
| self._method = method |
| self._interceptor = interceptor |
| |
| def __call__(self, request, timeout=None, metadata=None, credentials=None): |
| client_call_details = _ClientCallDetails(self._method, timeout, |
| metadata, credentials) |
| |
| def continuation(new_details, request): |
| new_method, new_timeout, new_metadata, new_credentials = ( |
| _unwrap_client_call_details(new_details, client_call_details)) |
| return self._thunk(new_method)( |
| request, |
| timeout=new_timeout, |
| metadata=new_metadata, |
| credentials=new_credentials) |
| |
| try: |
| return self._interceptor.intercept_unary_stream( |
| continuation, client_call_details, request) |
| except Exception as exception: # pylint:disable=broad-except |
| return _FailureOutcome(exception, sys.exc_info()[2]) |
| |
| |
| class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable): |
| |
| def __init__(self, thunk, method, interceptor): |
| self._thunk = thunk |
| self._method = method |
| self._interceptor = interceptor |
| |
| def __call__(self, |
| request_iterator, |
| timeout=None, |
| metadata=None, |
| credentials=None): |
| response, ignored_call = self._with_call( |
| request_iterator, |
| timeout=timeout, |
| metadata=metadata, |
| credentials=credentials) |
| return response |
| |
| def _with_call(self, |
| request_iterator, |
| timeout=None, |
| metadata=None, |
| credentials=None): |
| client_call_details = _ClientCallDetails(self._method, timeout, |
| metadata, credentials) |
| |
| def continuation(new_details, request_iterator): |
| new_method, new_timeout, new_metadata, new_credentials = ( |
| _unwrap_client_call_details(new_details, client_call_details)) |
| try: |
| response, call = self._thunk(new_method).with_call( |
| request_iterator, |
| timeout=new_timeout, |
| metadata=new_metadata, |
| credentials=new_credentials) |
| return _UnaryOutcome(response, call) |
| except grpc.RpcError: |
| raise |
| except Exception as exception: # pylint:disable=broad-except |
| return _FailureOutcome(exception, sys.exc_info()[2]) |
| |
| call = self._interceptor.intercept_stream_unary( |
| continuation, client_call_details, request_iterator) |
| return call.result(), call |
| |
| def with_call(self, |
| request_iterator, |
| timeout=None, |
| metadata=None, |
| credentials=None): |
| return self._with_call( |
| request_iterator, |
| timeout=timeout, |
| metadata=metadata, |
| credentials=credentials) |
| |
| def future(self, |
| request_iterator, |
| timeout=None, |
| metadata=None, |
| credentials=None): |
| client_call_details = _ClientCallDetails(self._method, timeout, |
| metadata, credentials) |
| |
| def continuation(new_details, request_iterator): |
| new_method, new_timeout, new_metadata, new_credentials = ( |
| _unwrap_client_call_details(new_details, client_call_details)) |
| return self._thunk(new_method).future( |
| request_iterator, |
| timeout=new_timeout, |
| metadata=new_metadata, |
| credentials=new_credentials) |
| |
| try: |
| return self._interceptor.intercept_stream_unary( |
| continuation, client_call_details, request_iterator) |
| except Exception as exception: # pylint:disable=broad-except |
| return _FailureOutcome(exception, sys.exc_info()[2]) |
| |
| |
| class _StreamStreamMultiCallable(grpc.StreamStreamMultiCallable): |
| |
| def __init__(self, thunk, method, interceptor): |
| self._thunk = thunk |
| self._method = method |
| self._interceptor = interceptor |
| |
| def __call__(self, |
| request_iterator, |
| timeout=None, |
| metadata=None, |
| credentials=None): |
| client_call_details = _ClientCallDetails(self._method, timeout, |
| metadata, credentials) |
| |
| def continuation(new_details, request_iterator): |
| new_method, new_timeout, new_metadata, new_credentials = ( |
| _unwrap_client_call_details(new_details, client_call_details)) |
| return self._thunk(new_method)( |
| request_iterator, |
| timeout=new_timeout, |
| metadata=new_metadata, |
| credentials=new_credentials) |
| |
| try: |
| return self._interceptor.intercept_stream_stream( |
| continuation, client_call_details, request_iterator) |
| except Exception as exception: # pylint:disable=broad-except |
| return _FailureOutcome(exception, sys.exc_info()[2]) |
| |
| |
| class _Channel(grpc.Channel): |
| |
| def __init__(self, channel, interceptor): |
| self._channel = channel |
| self._interceptor = interceptor |
| |
| def subscribe(self, callback, try_to_connect=False): |
| self._channel.subscribe(callback, try_to_connect=try_to_connect) |
| |
| def unsubscribe(self, callback): |
| self._channel.unsubscribe(callback) |
| |
| def unary_unary(self, |
| method, |
| request_serializer=None, |
| response_deserializer=None): |
| thunk = lambda m: self._channel.unary_unary(m, request_serializer, response_deserializer) |
| if isinstance(self._interceptor, grpc.UnaryUnaryClientInterceptor): |
| return _UnaryUnaryMultiCallable(thunk, method, self._interceptor) |
| else: |
| return thunk(method) |
| |
| def unary_stream(self, |
| method, |
| request_serializer=None, |
| response_deserializer=None): |
| thunk = lambda m: self._channel.unary_stream(m, request_serializer, response_deserializer) |
| if isinstance(self._interceptor, grpc.UnaryStreamClientInterceptor): |
| return _UnaryStreamMultiCallable(thunk, method, self._interceptor) |
| else: |
| return thunk(method) |
| |
| def stream_unary(self, |
| method, |
| request_serializer=None, |
| response_deserializer=None): |
| thunk = lambda m: self._channel.stream_unary(m, request_serializer, response_deserializer) |
| if isinstance(self._interceptor, grpc.StreamUnaryClientInterceptor): |
| return _StreamUnaryMultiCallable(thunk, method, self._interceptor) |
| else: |
| return thunk(method) |
| |
| def stream_stream(self, |
| method, |
| request_serializer=None, |
| response_deserializer=None): |
| thunk = lambda m: self._channel.stream_stream(m, request_serializer, response_deserializer) |
| if isinstance(self._interceptor, grpc.StreamStreamClientInterceptor): |
| return _StreamStreamMultiCallable(thunk, method, self._interceptor) |
| else: |
| return thunk(method) |
| |
| def _close(self): |
| self._channel.close() |
| |
| def __enter__(self): |
| return self |
| |
| def __exit__(self, exc_type, exc_val, exc_tb): |
| self._close() |
| return False |
| |
| def close(self): |
| self._channel.close() |
| |
| |
| def intercept_channel(channel, *interceptors): |
| for interceptor in reversed(list(interceptors)): |
| if not isinstance(interceptor, grpc.UnaryUnaryClientInterceptor) and \ |
| not isinstance(interceptor, grpc.UnaryStreamClientInterceptor) and \ |
| not isinstance(interceptor, grpc.StreamUnaryClientInterceptor) and \ |
| not isinstance(interceptor, grpc.StreamStreamClientInterceptor): |
| raise TypeError('interceptor must be ' |
| 'grpc.UnaryUnaryClientInterceptor or ' |
| 'grpc.UnaryStreamClientInterceptor or ' |
| 'grpc.StreamUnaryClientInterceptor or ' |
| 'grpc.StreamStreamClientInterceptor or ') |
| channel = _Channel(channel, interceptor) |
| return channel |