Fixes to stub and server lifecycle
Context management is implemented.
Stub deletion now cancels all RPCs immediately.
diff --git a/src/python/grpcio/grpc/beta/_server.py b/src/python/grpcio/grpc/beta/_server.py
index 05b954d..4f45443 100644
--- a/src/python/grpcio/grpc/beta/_server.py
+++ b/src/python/grpcio/grpc/beta/_server.py
@@ -44,6 +44,12 @@
_MAXIMUM_TIMEOUT = 24 * 60 * 60
+def _set_event():
+ event = threading.Event()
+ event.set()
+ return event
+
+
class _GRPCServicer(base.Servicer):
def __init__(self, delegate):
@@ -61,86 +67,143 @@
raise
-def _disassemble(grpc_link, end_link, pool, event, grace):
- grpc_link.begin_stop()
- end_link.stop(grace).wait()
- grpc_link.end_stop()
- grpc_link.join_link(utilities.NULL_LINK)
- end_link.join_link(utilities.NULL_LINK)
- if pool is not None:
- pool.shutdown(wait=True)
- event.set()
+class _Server(interfaces.Server):
-
-class Server(interfaces.Server):
-
- def __init__(self, grpc_link, end_link, pool):
+ def __init__(
+ self, implementations, multi_implementation, pool, pool_size,
+ default_timeout, maximum_timeout, grpc_link):
+ self._lock = threading.Lock()
+ self._implementations = implementations
+ self._multi_implementation = multi_implementation
+ self._customer_pool = pool
+ self._pool_size = pool_size
+ self._default_timeout = default_timeout
+ self._maximum_timeout = maximum_timeout
self._grpc_link = grpc_link
- self._end_link = end_link
- self._pool = pool
- def add_insecure_port(self, address):
- return self._grpc_link.add_port(address, None)
-
- def add_secure_port(self, address, server_credentials):
- return self._grpc_link.add_port(
- address, server_credentials._intermediary_low_credentials) # pylint: disable=protected-access
+ self._end_link = None
+ self._stop_events = None
+ self._pool = None
def _start(self):
- self._grpc_link.join_link(self._end_link)
- self._end_link.join_link(self._grpc_link)
- self._grpc_link.start()
- self._end_link.start()
+ with self._lock:
+ if self._end_link is not None:
+ raise ValueError('Cannot start already-started server!')
- def _stop(self, grace):
- stop_event = threading.Event()
- if 0 < grace:
- disassembly_thread = threading.Thread(
- target=_disassemble,
- args=(
- self._grpc_link, self._end_link, self._pool, stop_event, grace,))
- disassembly_thread.start()
- return stop_event
- else:
- _disassemble(self._grpc_link, self._end_link, self._pool, stop_event, 0)
- return stop_event
+ if self._customer_pool is None:
+ self._pool = logging_pool.pool(self._pool_size)
+ assembly_pool = self._pool
+ else:
+ assembly_pool = self._customer_pool
+
+ servicer = _GRPCServicer(
+ _crust_implementations.servicer(
+ self._implementations, self._multi_implementation, assembly_pool))
+
+ self._end_link = _core_implementations.service_end_link(
+ servicer, self._default_timeout, self._maximum_timeout)
+
+ self._grpc_link.join_link(self._end_link)
+ self._end_link.join_link(self._grpc_link)
+ self._grpc_link.start()
+ self._end_link.start()
+
+ def _dissociate_links_and_shut_down_pool(self):
+ self._grpc_link.end_stop()
+ self._grpc_link.join_link(utilities.NULL_LINK)
+ self._end_link.join_link(utilities.NULL_LINK)
+ self._end_link = None
+ if self._pool is not None:
+ self._pool.shutdown(wait=True)
+ self._pool = None
+
+ def _stop_stopping(self):
+ self._dissociate_links_and_shut_down_pool()
+ for stop_event in self._stop_events:
+ stop_event.set()
+ self._stop_events = None
+
+ def _stop_started(self):
+ self._grpc_link.begin_stop()
+ self._end_link.stop(0).wait()
+ self._dissociate_links_and_shut_down_pool()
+
+ def _foreign_thread_stop(self, end_stop_event, stop_events):
+ end_stop_event.wait()
+ with self._lock:
+ if self._stop_events is stop_events:
+ self._stop_stopping()
+
+ def _schedule_stop(self, grace):
+ with self._lock:
+ if self._end_link is None:
+ return _set_event()
+ server_stop_event = threading.Event()
+ if self._stop_events is None:
+ self._stop_events = [server_stop_event]
+ self._grpc_link.begin_stop()
+ else:
+ self._stop_events.append(server_stop_event)
+ end_stop_event = self._end_link.stop(grace)
+ end_stop_thread = threading.Thread(
+ target=self._foreign_thread_stop,
+ args=(end_stop_event, self._stop_events))
+ end_stop_thread.start()
+ return server_stop_event
+
+ def _stop_now(self):
+ with self._lock:
+ if self._end_link is not None:
+ if self._stop_events is None:
+ self._stop_started()
+ else:
+ self._stop_stopping()
+
+ def add_insecure_port(self, address):
+ with self._lock:
+ if self._end_link is None:
+ return self._grpc_link.add_port(address, None)
+ else:
+ raise ValueError('Can\'t add port to serving server!')
+
+ def add_secure_port(self, address, server_credentials):
+ with self._lock:
+ if self._end_link is None:
+ return self._grpc_link.add_port(
+ address, server_credentials._intermediary_low_credentials) # pylint: disable=protected-access
+ else:
+ raise ValueError('Can\'t add port to serving server!')
def start(self):
self._start()
def stop(self, grace):
- return self._stop(grace)
+ if 0 < grace:
+ return self._schedule_stop(grace)
+ else:
+ self._stop_now()
+ return _set_event()
def __enter__(self):
self._start()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
- self._stop(0).wait()
+ self._stop_now()
return False
+ def __del__(self):
+ self._stop_now()
+
def server(
implementations, multi_implementation, request_deserializers,
response_serializers, thread_pool, thread_pool_size, default_timeout,
maximum_timeout):
- if thread_pool is None:
- service_thread_pool = logging_pool.pool(
- _DEFAULT_POOL_SIZE if thread_pool_size is None else thread_pool_size)
- assembly_thread_pool = service_thread_pool
- else:
- service_thread_pool = thread_pool
- assembly_thread_pool = None
-
- servicer = _GRPCServicer(
- _crust_implementations.servicer(
- implementations, multi_implementation, service_thread_pool))
-
grpc_link = service.service_link(request_deserializers, response_serializers)
-
- end_link = _core_implementations.service_end_link(
- servicer,
+ return _Server(
+ implementations, multi_implementation, thread_pool,
+ _DEFAULT_POOL_SIZE if thread_pool_size is None else thread_pool_size,
_DEFAULT_TIMEOUT if default_timeout is None else default_timeout,
- _MAXIMUM_TIMEOUT if maximum_timeout is None else maximum_timeout)
-
- return Server(grpc_link, end_link, assembly_thread_pool)
+ _MAXIMUM_TIMEOUT if maximum_timeout is None else maximum_timeout,
+ grpc_link)
diff --git a/src/python/grpcio/grpc/beta/_stub.py b/src/python/grpcio/grpc/beta/_stub.py
index 11dab88..2af0193 100644
--- a/src/python/grpcio/grpc/beta/_stub.py
+++ b/src/python/grpcio/grpc/beta/_stub.py
@@ -42,76 +42,114 @@
class _AutoIntermediary(object):
- def __init__(self, delegate, on_deletion):
+ def __init__(self, up, down, delegate):
+ self._lock = threading.Lock()
+ self._up = up
+ self._down = down
+ self._in_context = False
self._delegate = delegate
- self._on_deletion = on_deletion
def __getattr__(self, attr):
- return getattr(self._delegate, attr)
+ with self._lock:
+ if self._delegate is None:
+ raise AttributeError('No useful attributes out of context!')
+ else:
+ return getattr(self._delegate, attr)
def __enter__(self):
- return self
+ with self._lock:
+ if self._in_context:
+ raise ValueError('Already in context!')
+ elif self._delegate is None:
+ self._delegate = self._up()
+ self._in_context = True
+ return self
def __exit__(self, exc_type, exc_val, exc_tb):
- return False
+ with self._lock:
+ if not self._in_context:
+ raise ValueError('Not in context!')
+ self._down()
+ self._in_context = False
+ self._delegate = None
+ return False
def __del__(self):
- self._on_deletion()
+ with self._lock:
+ if self._delegate is not None:
+ self._down()
+ self._delegate = None
+
+
+class _StubAssemblyManager(object):
+
+ def __init__(
+ self, thread_pool, thread_pool_size, end_link, grpc_link, stub_creator):
+ self._thread_pool = thread_pool
+ self._pool_size = thread_pool_size
+ self._end_link = end_link
+ self._grpc_link = grpc_link
+ self._stub_creator = stub_creator
+ self._own_pool = None
+
+ def up(self):
+ if self._thread_pool is None:
+ self._own_pool = logging_pool.pool(
+ _DEFAULT_POOL_SIZE if self._pool_size is None else self._pool_size)
+ assembly_pool = self._own_pool
+ else:
+ assembly_pool = self._thread_pool
+ self._end_link.join_link(self._grpc_link)
+ self._grpc_link.join_link(self._end_link)
+ self._end_link.start()
+ self._grpc_link.start()
+ return self._stub_creator(self._end_link, assembly_pool)
+
+ def down(self):
+ self._end_link.stop(0).wait()
+ self._grpc_link.stop()
+ self._end_link.join_link(utilities.NULL_LINK)
+ self._grpc_link.join_link(utilities.NULL_LINK)
+ if self._own_pool is not None:
+ self._own_pool.shutdown(wait=True)
+ self._own_pool = None
def _assemble(
channel, host, metadata_transformer, request_serializers,
- response_deserializers, thread_pool, thread_pool_size):
+ response_deserializers, thread_pool, thread_pool_size, stub_creator):
end_link = _core_implementations.invocation_end_link()
grpc_link = invocation.invocation_link(
channel, host, metadata_transformer, request_serializers,
response_deserializers)
- if thread_pool is None:
- invocation_pool = logging_pool.pool(
- _DEFAULT_POOL_SIZE if thread_pool_size is None else thread_pool_size)
- assembly_pool = invocation_pool
- else:
- invocation_pool = thread_pool
- assembly_pool = None
- end_link.join_link(grpc_link)
- grpc_link.join_link(end_link)
- end_link.start()
- grpc_link.start()
- return end_link, grpc_link, invocation_pool, assembly_pool
+ stub_assembly_manager = _StubAssemblyManager(
+ thread_pool, thread_pool_size, end_link, grpc_link, stub_creator)
+ stub = stub_assembly_manager.up()
+ return _AutoIntermediary(
+ stub_assembly_manager.up, stub_assembly_manager.down, stub)
-def _disassemble(end_link, grpc_link, pool):
- end_link.stop(24 * 60 * 60).wait()
- grpc_link.stop()
- end_link.join_link(utilities.NULL_LINK)
- grpc_link.join_link(utilities.NULL_LINK)
- if pool is not None:
- pool.shutdown(wait=True)
-
-
-def _wrap_assembly(stub, end_link, grpc_link, assembly_pool):
- disassembly_thread = threading.Thread(
- target=_disassemble, args=(end_link, grpc_link, assembly_pool))
- return _AutoIntermediary(stub, disassembly_thread.start)
+def _dynamic_stub_creator(service, cardinalities):
+ def create_dynamic_stub(end_link, invocation_pool):
+ return _crust_implementations.dynamic_stub(
+ end_link, service, cardinalities, invocation_pool)
+ return create_dynamic_stub
def generic_stub(
channel, host, metadata_transformer, request_serializers,
response_deserializers, thread_pool, thread_pool_size):
- end_link, grpc_link, invocation_pool, assembly_pool = _assemble(
+ return _assemble(
channel, host, metadata_transformer, request_serializers,
- response_deserializers, thread_pool, thread_pool_size)
- stub = _crust_implementations.generic_stub(end_link, invocation_pool)
- return _wrap_assembly(stub, end_link, grpc_link, assembly_pool)
+ response_deserializers, thread_pool, thread_pool_size,
+ _crust_implementations.generic_stub)
def dynamic_stub(
channel, host, service, cardinalities, metadata_transformer,
request_serializers, response_deserializers, thread_pool,
thread_pool_size):
- end_link, grpc_link, invocation_pool, assembly_pool = _assemble(
+ return _assemble(
channel, host, metadata_transformer, request_serializers,
- response_deserializers, thread_pool, thread_pool_size)
- stub = _crust_implementations.dynamic_stub(
- end_link, service, cardinalities, invocation_pool)
- return _wrap_assembly(stub, end_link, grpc_link, assembly_pool)
+ response_deserializers, thread_pool, thread_pool_size,
+ _dynamic_stub_creator(service, cardinalities))
diff --git a/src/python/grpcio_test/grpc_test/beta/_beta_features_test.py b/src/python/grpcio_test/grpc_test/beta/_beta_features_test.py
index 5916a9e..6b5090f 100644
--- a/src/python/grpcio_test/grpc_test/beta/_beta_features_test.py
+++ b/src/python/grpcio_test/grpc_test/beta/_beta_features_test.py
@@ -224,5 +224,78 @@
self.assertEqual(_RESPONSE, response)
+class ContextManagementAndLifecycleTest(unittest.TestCase):
+
+ def setUp(self):
+ self._servicer = _Servicer()
+ self._method_implementations = {
+ (_GROUP, _UNARY_UNARY):
+ utilities.unary_unary_inline(self._servicer.unary_unary),
+ (_GROUP, _UNARY_STREAM):
+ utilities.unary_stream_inline(self._servicer.unary_stream),
+ (_GROUP, _STREAM_UNARY):
+ utilities.stream_unary_inline(self._servicer.stream_unary),
+ (_GROUP, _STREAM_STREAM):
+ utilities.stream_stream_inline(self._servicer.stream_stream),
+ }
+
+ self._cardinalities = {
+ _UNARY_UNARY: cardinality.Cardinality.UNARY_UNARY,
+ _UNARY_STREAM: cardinality.Cardinality.UNARY_STREAM,
+ _STREAM_UNARY: cardinality.Cardinality.STREAM_UNARY,
+ _STREAM_STREAM: cardinality.Cardinality.STREAM_STREAM,
+ }
+
+ self._server_options = implementations.server_options(
+ thread_pool_size=test_constants.POOL_SIZE)
+ self._server_credentials = implementations.ssl_server_credentials(
+ [(resources.private_key(), resources.certificate_chain(),),])
+ self._client_credentials = implementations.ssl_client_credentials(
+ resources.test_root_certificates(), None, None)
+ self._stub_options = implementations.stub_options(
+ thread_pool_size=test_constants.POOL_SIZE)
+
+ def test_stub_context(self):
+ server = implementations.server(
+ self._method_implementations, options=self._server_options)
+ port = server.add_secure_port('[::]:0', self._server_credentials)
+ server.start()
+
+ channel = test_utilities.not_really_secure_channel(
+ 'localhost', port, self._client_credentials, _SERVER_HOST_OVERRIDE)
+ dynamic_stub = implementations.dynamic_stub(
+ channel, _GROUP, self._cardinalities, options=self._stub_options)
+ for _ in range(100):
+ with dynamic_stub:
+ pass
+ for _ in range(10):
+ with dynamic_stub:
+ call_options = interfaces.grpc_call_options(
+ disable_compression=True)
+ response = getattr(dynamic_stub, _UNARY_UNARY)(
+ _REQUEST, test_constants.LONG_TIMEOUT,
+ protocol_options=call_options)
+ self.assertEqual(_RESPONSE, response)
+ self.assertIsNotNone(self._servicer.peer())
+
+ server.stop(test_constants.SHORT_TIMEOUT).wait()
+
+ def test_server_lifecycle(self):
+ for _ in range(100):
+ server = implementations.server(
+ self._method_implementations, options=self._server_options)
+ port = server.add_secure_port('[::]:0', self._server_credentials)
+ server.start()
+ server.stop(test_constants.SHORT_TIMEOUT).wait()
+ for _ in range(100):
+ server = implementations.server(
+ self._method_implementations, options=self._server_options)
+ server.add_secure_port('[::]:0', self._server_credentials)
+ server.add_insecure_port('[::]:0')
+ with server:
+ server.stop(test_constants.SHORT_TIMEOUT)
+ server.stop(test_constants.SHORT_TIMEOUT)
+
+
if __name__ == '__main__':
unittest.main(verbosity=2)