blob: 621fcf2174fbd26122b739f914cb7c27e1f1d4f4 [file] [log] [blame]
# Copyright 2016, Google Inc.
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are
# met:
#
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above
# copyright notice, this list of conditions and the following disclaimer
# in the documentation and/or other materials provided with the
# distribution.
# * Neither the name of Google Inc. nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Translates gRPC's client-side API into gRPC's client-side Beta API."""
import grpc
from grpc._cython import cygrpc
from grpc.beta import interfaces
from grpc.framework.common import cardinality
from grpc.framework.foundation import future
from grpc.framework.interfaces.face import face
_STATUS_CODE_TO_ABORTION_KIND_AND_ABORTION_ERROR_CLASS = {
grpc.StatusCode.CANCELLED: (
face.Abortion.Kind.CANCELLED, face.CancellationError),
grpc.StatusCode.UNKNOWN: (
face.Abortion.Kind.REMOTE_FAILURE, face.RemoteError),
grpc.StatusCode.DEADLINE_EXCEEDED: (
face.Abortion.Kind.EXPIRED, face.ExpirationError),
grpc.StatusCode.UNIMPLEMENTED: (
face.Abortion.Kind.LOCAL_FAILURE, face.LocalError),
}
def _fully_qualified_method(group, method):
return b'/{}/{}'.format(group, method)
def _effective_metadata(metadata, metadata_transformer):
non_none_metadata = () if metadata is None else metadata
if metadata_transformer is None:
return non_none_metadata
else:
return metadata_transformer(non_none_metadata)
def _credentials(grpc_call_options):
return None if grpc_call_options is None else grpc_call_options.credentials
def _abortion(rpc_error_call):
code = rpc_error_call.code()
pair = _STATUS_CODE_TO_ABORTION_KIND_AND_ABORTION_ERROR_CLASS.get(code)
error_kind = face.Abortion.Kind.LOCAL_FAILURE if pair is None else pair[0]
return face.Abortion(
error_kind, rpc_error_call.initial_metadata(),
rpc_error_call.trailing_metadata(), code, rpc_error_code.details())
def _abortion_error(rpc_error_call):
code = rpc_error_call.code()
pair = _STATUS_CODE_TO_ABORTION_KIND_AND_ABORTION_ERROR_CLASS.get(code)
exception_class = face.AbortionError if pair is None else pair[1]
return exception_class(
rpc_error_call.initial_metadata(), rpc_error_call.trailing_metadata(),
code, rpc_error_call.details())
class _InvocationProtocolContext(interfaces.GRPCInvocationContext):
def disable_next_request_compression(self):
pass # TODO(https://github.com/grpc/grpc/issues/4078): design, implement.
class _Rendezvous(future.Future, face.Call):
def __init__(self, response_future, response_iterator, call):
self._future = response_future
self._iterator = response_iterator
self._call = call
def cancel(self):
return self._call.cancel()
def cancelled(self):
return self._future.cancelled()
def running(self):
return self._future.running()
def done(self):
return self._future.done()
def result(self, timeout=None):
try:
return self._future.result(timeout=timeout)
except grpc.RpcError as rpc_error_call:
raise _abortion_error(rpc_error_call)
except grpc.FutureTimeoutError:
raise future.TimeoutError()
except grpc.FutureCancelledError:
raise future.CancelledError()
def exception(self, timeout=None):
try:
rpc_error_call = self._future.exception(timeout=timeout)
return _abortion_error(rpc_error_call)
except grpc.FutureTimeoutError:
raise future.TimeoutError()
except grpc.FutureCancelledError:
raise future.CancelledError()
def traceback(self, timeout=None):
try:
return self._future.traceback(timeout=timeout)
except grpc.FutureTimeoutError:
raise future.TimeoutError()
except grpc.FutureCancelledError:
raise future.CancelledError()
def add_done_callback(self, fn):
self._future.add_done_callback(lambda ignored_callback: fn(self))
def __iter__(self):
return self
def _next(self):
try:
return next(self._iterator)
except grpc.RpcError as rpc_error_call:
raise _abortion_error(rpc_error_call)
def __next__(self):
return self._next()
def next(self):
return self._next()
def is_active(self):
return self._call.is_active()
def time_remaining(self):
return self._call.time_remaining()
def add_abortion_callback(self, abortion_callback):
registered = self._call.add_callback(
lambda: abortion_callback(_abortion(self._call)))
return None if registered else _abortion(self._call)
def protocol_context(self):
return _InvocationProtocolContext()
def initial_metadata(self):
return self._call.initial_metadata()
def terminal_metadata(self):
return self._call.terminal_metadata()
def code(self):
return self._call.code()
def details(self):
return self._call.details()
def _blocking_unary_unary(
channel, group, method, timeout, with_call, protocol_options, metadata,
metadata_transformer, request, request_serializer, response_deserializer):
try:
multi_callable = channel.unary_unary(
_fully_qualified_method(group, method),
request_serializer=request_serializer,
response_deserializer=response_deserializer)
effective_metadata = _effective_metadata(metadata, metadata_transformer)
if with_call:
response, call = multi_callable(
request, timeout=timeout, metadata=effective_metadata,
credentials=_credentials(protocol_options), with_call=True)
return response, _Rendezvous(None, None, call)
else:
return multi_callable(
request, timeout=timeout, metadata=effective_metadata,
credentials=_credentials(protocol_options))
except grpc.RpcError as rpc_error_call:
raise _abortion_error(rpc_error_call)
def _future_unary_unary(
channel, group, method, timeout, protocol_options, metadata,
metadata_transformer, request, request_serializer, response_deserializer):
multi_callable = channel.unary_unary(
_fully_qualified_method(group, method),
request_serializer=request_serializer,
response_deserializer=response_deserializer)
effective_metadata = _effective_metadata(metadata, metadata_transformer)
response_future = multi_callable.future(
request, timeout=timeout, metadata=effective_metadata,
credentials=_credentials(protocol_options))
return _Rendezvous(response_future, None, response_future)
def _unary_stream(
channel, group, method, timeout, protocol_options, metadata,
metadata_transformer, request, request_serializer, response_deserializer):
multi_callable = channel.unary_stream(
_fully_qualified_method(group, method),
request_serializer=request_serializer,
response_deserializer=response_deserializer)
effective_metadata = _effective_metadata(metadata, metadata_transformer)
response_iterator = multi_callable(
request, timeout=timeout, metadata=effective_metadata,
credentials=_credentials(protocol_options))
return _Rendezvous(None, response_iterator, response_iterator)
def _blocking_stream_unary(
channel, group, method, timeout, with_call, protocol_options, metadata,
metadata_transformer, request_iterator, request_serializer,
response_deserializer):
try:
multi_callable = channel.stream_unary(
_fully_qualified_method(group, method),
request_serializer=request_serializer,
response_deserializer=response_deserializer)
effective_metadata = _effective_metadata(metadata, metadata_transformer)
if with_call:
response, call = multi_callable(
request_iterator, timeout=timeout, metadata=effective_metadata,
credentials=_credentials(protocol_options), with_call=True)
return response, _Rendezvous(None, None, call)
else:
return multi_callable(
request_iterator, timeout=timeout, metadata=effective_metadata,
credentials=_credentials(protocol_options))
except grpc.RpcError as rpc_error_call:
raise _abortion_error(rpc_error_call)
def _future_stream_unary(
channel, group, method, timeout, protocol_options, metadata,
metadata_transformer, request_iterator, request_serializer,
response_deserializer):
multi_callable = channel.stream_unary(
_fully_qualified_method(group, method),
request_serializer=request_serializer,
response_deserializer=response_deserializer)
effective_metadata = _effective_metadata(metadata, metadata_transformer)
response_future = multi_callable.future(
request_iterator, timeout=timeout, metadata=effective_metadata,
credentials=_credentials(protocol_options))
return _Rendezvous(response_future, None, response_future)
def _stream_stream(
channel, group, method, timeout, protocol_options, metadata,
metadata_transformer, request_iterator, request_serializer,
response_deserializer):
multi_callable = channel.stream_stream(
_fully_qualified_method(group, method),
request_serializer=request_serializer,
response_deserializer=response_deserializer)
effective_metadata = _effective_metadata(metadata, metadata_transformer)
response_iterator = multi_callable(
request_iterator, timeout=timeout, metadata=effective_metadata,
credentials=_credentials(protocol_options))
return _Rendezvous(None, response_iterator, response_iterator)
class _UnaryUnaryMultiCallable(face.UnaryUnaryMultiCallable):
def __init__(
self, channel, group, method, metadata_transformer, request_serializer,
response_deserializer):
self._channel = channel
self._group = group
self._method = method
self._metadata_transformer = metadata_transformer
self._request_serializer = request_serializer
self._response_deserializer = response_deserializer
def __call__(
self, request, timeout, metadata=None, with_call=False,
protocol_options=None):
return _blocking_unary_unary(
self._channel, self._group, self._method, timeout, with_call,
protocol_options, metadata, self._metadata_transformer, request,
self._request_serializer, self._response_deserializer)
def future(self, request, timeout, metadata=None, protocol_options=None):
return _future_unary_unary(
self._channel, self._group, self._method, timeout, protocol_options,
metadata, self._metadata_transformer, request, self._request_serializer,
self._response_deserializer)
def event(
self, request, receiver, abortion_callback, timeout,
metadata=None, protocol_options=None):
raise NotImplementedError()
class _UnaryStreamMultiCallable(face.UnaryStreamMultiCallable):
def __init__(
self, channel, group, method, metadata_transformer, request_serializer,
response_deserializer):
self._channel = channel
self._group = group
self._method = method
self._metadata_transformer = metadata_transformer
self._request_serializer = request_serializer
self._response_deserializer = response_deserializer
def __call__(self, request, timeout, metadata=None, protocol_options=None):
return _unary_stream(
self._channel, self._group, self._method, timeout, protocol_options,
metadata, self._metadata_transformer, request, self._request_serializer,
self._response_deserializer)
def event(
self, request, receiver, abortion_callback, timeout,
metadata=None, protocol_options=None):
raise NotImplementedError()
class _StreamUnaryMultiCallable(face.StreamUnaryMultiCallable):
def __init__(
self, channel, group, method, metadata_transformer, request_serializer,
response_deserializer):
self._channel = channel
self._group = group
self._method = method
self._metadata_transformer = metadata_transformer
self._request_serializer = request_serializer
self._response_deserializer = response_deserializer
def __call__(
self, request_iterator, timeout, metadata=None, with_call=False,
protocol_options=None):
return _blocking_stream_unary(
self._channel, self._group, self._method, timeout, with_call,
protocol_options, metadata, self._metadata_transformer,
request_iterator, self._request_serializer, self._response_deserializer)
def future(
self, request_iterator, timeout, metadata=None, protocol_options=None):
return _future_stream_unary(
self._channel, self._group, self._method, timeout, protocol_options,
metadata, self._metadata_transformer, request_iterator,
self._request_serializer, self._response_deserializer)
def event(
self, receiver, abortion_callback, timeout, metadata=None,
protocol_options=None):
raise NotImplementedError()
class _StreamStreamMultiCallable(face.StreamStreamMultiCallable):
def __init__(
self, channel, group, method, metadata_transformer, request_serializer,
response_deserializer):
self._channel = channel
self._group = group
self._method = method
self._metadata_transformer = metadata_transformer
self._request_serializer = request_serializer
self._response_deserializer = response_deserializer
def __call__(
self, request_iterator, timeout, metadata=None, protocol_options=None):
return _stream_stream(
self._channel, self._group, self._method, timeout, protocol_options,
metadata, self._metadata_transformer, request_iterator,
self._request_serializer, self._response_deserializer)
def event(
self, receiver, abortion_callback, timeout, metadata=None,
protocol_options=None):
raise NotImplementedError()
class _GenericStub(face.GenericStub):
def __init__(
self, channel, metadata_transformer, request_serializers,
response_deserializers):
self._channel = channel
self._metadata_transformer = metadata_transformer
self._request_serializers = request_serializers or {}
self._response_deserializers = response_deserializers or {}
def blocking_unary_unary(
self, group, method, request, timeout, metadata=None,
with_call=None, protocol_options=None):
request_serializer = self._request_serializers.get((group, method,))
response_deserializer = self._response_deserializers.get((group, method,))
return _blocking_unary_unary(
self._channel, group, method, timeout, with_call, protocol_options,
metadata, self._metadata_transformer, request, request_serializer,
response_deserializer)
def future_unary_unary(
self, group, method, request, timeout, metadata=None,
protocol_options=None):
request_serializer = self._request_serializers.get((group, method,))
response_deserializer = self._response_deserializers.get((group, method,))
return _future_unary_unary(
self._channel, group, method, timeout, protocol_options, metadata,
self._metadata_transformer, request, request_serializer,
response_deserializer)
def inline_unary_stream(
self, group, method, request, timeout, metadata=None,
protocol_options=None):
request_serializer = self._request_serializers.get((group, method,))
response_deserializer = self._response_deserializers.get((group, method,))
return _unary_stream(
self._channel, group, method, timeout, protocol_options, metadata,
self._metadata_transformer, request, request_serializer,
response_deserializer)
def blocking_stream_unary(
self, group, method, request_iterator, timeout, metadata=None,
with_call=None, protocol_options=None):
request_serializer = self._request_serializers.get((group, method,))
response_deserializer = self._response_deserializers.get((group, method,))
return _blocking_stream_unary(
self._channel, group, method, timeout, with_call, protocol_options,
metadata, self._metadata_transformer, request_iterator,
request_serializer, response_deserializer)
def future_stream_unary(
self, group, method, request_iterator, timeout, metadata=None,
protocol_options=None):
request_serializer = self._request_serializers.get((group, method,))
response_deserializer = self._response_deserializers.get((group, method,))
return _future_stream_unary(
self._channel, group, method, timeout, protocol_options, metadata,
self._metadata_transformer, request_iterator, request_serializer,
response_deserializer)
def inline_stream_stream(
self, group, method, request_iterator, timeout, metadata=None,
protocol_options=None):
request_serializer = self._request_serializers.get((group, method,))
response_deserializer = self._response_deserializers.get((group, method,))
return _stream_stream(
self._channel, group, method, timeout, protocol_options, metadata,
self._metadata_transformer, request_iterator, request_serializer,
response_deserializer)
def event_unary_unary(
self, group, method, request, receiver, abortion_callback, timeout,
metadata=None, protocol_options=None):
raise NotImplementedError()
def event_unary_stream(
self, group, method, request, receiver, abortion_callback, timeout,
metadata=None, protocol_options=None):
raise NotImplementedError()
def event_stream_unary(
self, group, method, receiver, abortion_callback, timeout,
metadata=None, protocol_options=None):
raise NotImplementedError()
def event_stream_stream(
self, group, method, receiver, abortion_callback, timeout,
metadata=None, protocol_options=None):
raise NotImplementedError()
def unary_unary(self, group, method):
request_serializer = self._request_serializers.get((group, method,))
response_deserializer = self._response_deserializers.get((group, method,))
return _UnaryUnaryMultiCallable(
self._channel, group, method, self._metadata_transformer,
request_serializer, response_deserializer)
def unary_stream(self, group, method):
request_serializer = self._request_serializers.get((group, method,))
response_deserializer = self._response_deserializers.get((group, method,))
return _UnaryStreamMultiCallable(
self._channel, group, method, self._metadata_transformer,
request_serializer, response_deserializer)
def stream_unary(self, group, method):
request_serializer = self._request_serializers.get((group, method,))
response_deserializer = self._response_deserializers.get((group, method,))
return _StreamUnaryMultiCallable(
self._channel, group, method, self._metadata_transformer,
request_serializer, response_deserializer)
def stream_stream(self, group, method):
request_serializer = self._request_serializers.get((group, method,))
response_deserializer = self._response_deserializers.get((group, method,))
return _StreamStreamMultiCallable(
self._channel, group, method, self._metadata_transformer,
request_serializer, response_deserializer)
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
return False
class _DynamicStub(face.DynamicStub):
def __init__(self, generic_stub, group, cardinalities):
self._generic_stub = generic_stub
self._group = group
self._cardinalities = cardinalities
def __getattr__(self, attr):
method_cardinality = self._cardinalities.get(attr)
if method_cardinality is cardinality.Cardinality.UNARY_UNARY:
return self._generic_stub.unary_unary(self._group, attr)
elif method_cardinality is cardinality.Cardinality.UNARY_STREAM:
return self._generic_stub.unary_stream(self._group, attr)
elif method_cardinality is cardinality.Cardinality.STREAM_UNARY:
return self._generic_stub.stream_unary(self._group, attr)
elif method_cardinality is cardinality.Cardinality.STREAM_STREAM:
return self._generic_stub.stream_stream(self._group, attr)
else:
raise AttributeError('_DynamicStub object has no attribute "%s"!' % attr)
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
return False
def generic_stub(
channel, host, metadata_transformer, request_serializers,
response_deserializers):
return _GenericStub(
channel, metadata_transformer, request_serializers,
response_deserializers)
def dynamic_stub(
channel, service, cardinalities, host, metadata_transformer,
request_serializers, response_deserializers):
return _DynamicStub(
_GenericStub(
channel, metadata_transformer, request_serializers,
response_deserializers),
service, cardinalities)