Fixes to stub and server lifecycle

Context management is implemented.

Stub deletion now cancels all RPCs immediately.
diff --git a/src/python/grpcio/grpc/beta/_server.py b/src/python/grpcio/grpc/beta/_server.py
index 05b954d..4f45443 100644
--- a/src/python/grpcio/grpc/beta/_server.py
+++ b/src/python/grpcio/grpc/beta/_server.py
@@ -44,6 +44,12 @@
 _MAXIMUM_TIMEOUT = 24 * 60 * 60
 
 
+def _set_event():
+  event = threading.Event()
+  event.set()
+  return event
+
+
 class _GRPCServicer(base.Servicer):
 
   def __init__(self, delegate):
@@ -61,86 +67,143 @@
         raise
 
 
-def _disassemble(grpc_link, end_link, pool, event, grace):
-  grpc_link.begin_stop()
-  end_link.stop(grace).wait()
-  grpc_link.end_stop()
-  grpc_link.join_link(utilities.NULL_LINK)
-  end_link.join_link(utilities.NULL_LINK)
-  if pool is not None:
-    pool.shutdown(wait=True)
-  event.set()
+class _Server(interfaces.Server):
 
-
-class Server(interfaces.Server):
-
-  def __init__(self, grpc_link, end_link, pool):
+  def __init__(
+      self, implementations, multi_implementation, pool, pool_size,
+      default_timeout, maximum_timeout, grpc_link):
+    self._lock = threading.Lock()
+    self._implementations = implementations
+    self._multi_implementation = multi_implementation
+    self._customer_pool = pool
+    self._pool_size = pool_size
+    self._default_timeout = default_timeout
+    self._maximum_timeout = maximum_timeout
     self._grpc_link = grpc_link
-    self._end_link = end_link
-    self._pool = pool
 
-  def add_insecure_port(self, address):
-    return self._grpc_link.add_port(address, None)
-
-  def add_secure_port(self, address, server_credentials):
-    return self._grpc_link.add_port(
-        address, server_credentials._intermediary_low_credentials)  # pylint: disable=protected-access
+    self._end_link = None
+    self._stop_events = None
+    self._pool = None
 
   def _start(self):
-    self._grpc_link.join_link(self._end_link)
-    self._end_link.join_link(self._grpc_link)
-    self._grpc_link.start()
-    self._end_link.start()
+    with self._lock:
+      if self._end_link is not None:
+        raise ValueError('Cannot start already-started server!')
 
-  def _stop(self, grace):
-    stop_event = threading.Event()
-    if 0 < grace:
-      disassembly_thread = threading.Thread(
-          target=_disassemble,
-          args=(
-              self._grpc_link, self._end_link, self._pool, stop_event, grace,))
-      disassembly_thread.start()
-      return stop_event
-    else:
-      _disassemble(self._grpc_link, self._end_link, self._pool, stop_event, 0)
-      return stop_event
+      if self._customer_pool is None:
+        self._pool = logging_pool.pool(self._pool_size)
+        assembly_pool = self._pool
+      else:
+        assembly_pool = self._customer_pool
+
+      servicer = _GRPCServicer(
+          _crust_implementations.servicer(
+              self._implementations, self._multi_implementation, assembly_pool))
+
+      self._end_link = _core_implementations.service_end_link(
+          servicer, self._default_timeout, self._maximum_timeout)
+
+      self._grpc_link.join_link(self._end_link)
+      self._end_link.join_link(self._grpc_link)
+      self._grpc_link.start()
+      self._end_link.start()
+
+  def _dissociate_links_and_shut_down_pool(self):
+    self._grpc_link.end_stop()
+    self._grpc_link.join_link(utilities.NULL_LINK)
+    self._end_link.join_link(utilities.NULL_LINK)
+    self._end_link = None
+    if self._pool is not None:
+      self._pool.shutdown(wait=True)
+    self._pool = None
+
+  def _stop_stopping(self):
+    self._dissociate_links_and_shut_down_pool()
+    for stop_event in self._stop_events:
+      stop_event.set()
+    self._stop_events = None
+
+  def _stop_started(self):
+    self._grpc_link.begin_stop()
+    self._end_link.stop(0).wait()
+    self._dissociate_links_and_shut_down_pool()
+
+  def _foreign_thread_stop(self, end_stop_event, stop_events):
+    end_stop_event.wait()
+    with self._lock:
+      if self._stop_events is stop_events:
+        self._stop_stopping()
+
+  def _schedule_stop(self, grace):
+    with self._lock:
+      if self._end_link is None:
+        return _set_event()
+      server_stop_event = threading.Event()
+      if self._stop_events is None:
+        self._stop_events = [server_stop_event]
+        self._grpc_link.begin_stop()
+      else:
+        self._stop_events.append(server_stop_event)
+      end_stop_event = self._end_link.stop(grace)
+      end_stop_thread = threading.Thread(
+          target=self._foreign_thread_stop,
+          args=(end_stop_event, self._stop_events))
+      end_stop_thread.start()
+      return server_stop_event
+
+  def _stop_now(self):
+    with self._lock:
+      if self._end_link is not None:
+        if self._stop_events is None:
+          self._stop_started()
+        else:
+          self._stop_stopping()
+
+  def add_insecure_port(self, address):
+    with self._lock:
+      if self._end_link is None:
+        return self._grpc_link.add_port(address, None)
+      else:
+        raise ValueError('Can\'t add port to serving server!')
+
+  def add_secure_port(self, address, server_credentials):
+    with self._lock:
+      if self._end_link is None:
+        return self._grpc_link.add_port(
+            address, server_credentials._intermediary_low_credentials)  # pylint: disable=protected-access
+      else:
+        raise ValueError('Can\'t add port to serving server!')
 
   def start(self):
     self._start()
 
   def stop(self, grace):
-    return self._stop(grace)
+    if 0 < grace:
+      return self._schedule_stop(grace)
+    else:
+      self._stop_now()
+      return _set_event()
 
   def __enter__(self):
     self._start()
     return self
 
   def __exit__(self, exc_type, exc_val, exc_tb):
-    self._stop(0).wait()
+    self._stop_now()
     return False
 
+  def __del__(self):
+    self._stop_now()
+
 
 def server(
     implementations, multi_implementation, request_deserializers,
     response_serializers, thread_pool, thread_pool_size, default_timeout,
     maximum_timeout):
-  if thread_pool is None:
-    service_thread_pool = logging_pool.pool(
-        _DEFAULT_POOL_SIZE if thread_pool_size is None else thread_pool_size)
-    assembly_thread_pool = service_thread_pool
-  else:
-    service_thread_pool = thread_pool
-    assembly_thread_pool = None
-
-  servicer = _GRPCServicer(
-      _crust_implementations.servicer(
-          implementations, multi_implementation, service_thread_pool))
-
   grpc_link = service.service_link(request_deserializers, response_serializers)
-
-  end_link = _core_implementations.service_end_link(
-      servicer,
+  return _Server(
+      implementations, multi_implementation, thread_pool,
+      _DEFAULT_POOL_SIZE if thread_pool_size is None else thread_pool_size,
       _DEFAULT_TIMEOUT if default_timeout is None else default_timeout,
-      _MAXIMUM_TIMEOUT if maximum_timeout is None else maximum_timeout)
-
-  return Server(grpc_link, end_link, assembly_thread_pool)
+      _MAXIMUM_TIMEOUT if maximum_timeout is None else maximum_timeout,
+      grpc_link)
diff --git a/src/python/grpcio/grpc/beta/_stub.py b/src/python/grpcio/grpc/beta/_stub.py
index 11dab88..2af0193 100644
--- a/src/python/grpcio/grpc/beta/_stub.py
+++ b/src/python/grpcio/grpc/beta/_stub.py
@@ -42,76 +42,114 @@
 
 class _AutoIntermediary(object):
 
-  def __init__(self, delegate, on_deletion):
+  def __init__(self, up, down, delegate):
+    self._lock = threading.Lock()
+    self._up = up
+    self._down = down
+    self._in_context = False
     self._delegate = delegate
-    self._on_deletion = on_deletion
 
   def __getattr__(self, attr):
-    return getattr(self._delegate, attr)
+    with self._lock:
+      if self._delegate is None:
+        raise AttributeError('No useful attributes out of context!')
+      else:
+        return getattr(self._delegate, attr)
 
   def __enter__(self):
-    return self
+    with self._lock:
+      if self._in_context:
+        raise ValueError('Already in context!')
+      elif self._delegate is None:
+        self._delegate = self._up()
+      self._in_context = True
+      return self
 
   def __exit__(self, exc_type, exc_val, exc_tb):
-    return False
+    with self._lock:
+      if not self._in_context:
+        raise ValueError('Not in context!')
+      self._down()
+      self._in_context = False
+      self._delegate = None
+      return False
 
   def __del__(self):
-    self._on_deletion()
+    with self._lock:
+      if self._delegate is not None:
+        self._down()
+        self._delegate = None
+
+
+class _StubAssemblyManager(object):
+
+  def __init__(
+      self, thread_pool, thread_pool_size, end_link, grpc_link, stub_creator):
+    self._thread_pool = thread_pool
+    self._pool_size = thread_pool_size
+    self._end_link = end_link
+    self._grpc_link = grpc_link
+    self._stub_creator = stub_creator
+    self._own_pool = None
+
+  def up(self):
+    if self._thread_pool is None:
+      self._own_pool = logging_pool.pool(
+          _DEFAULT_POOL_SIZE if self._pool_size is None else self._pool_size)
+      assembly_pool = self._own_pool
+    else:
+      assembly_pool = self._thread_pool
+    self._end_link.join_link(self._grpc_link)
+    self._grpc_link.join_link(self._end_link)
+    self._end_link.start()
+    self._grpc_link.start()
+    return self._stub_creator(self._end_link, assembly_pool)
+
+  def down(self):
+    self._end_link.stop(0).wait()
+    self._grpc_link.stop()
+    self._end_link.join_link(utilities.NULL_LINK)
+    self._grpc_link.join_link(utilities.NULL_LINK)
+    if self._own_pool is not None:
+      self._own_pool.shutdown(wait=True)
+      self._own_pool = None
 
 
 def _assemble(
     channel, host, metadata_transformer, request_serializers,
-    response_deserializers, thread_pool, thread_pool_size):
+    response_deserializers, thread_pool, thread_pool_size, stub_creator):
   end_link = _core_implementations.invocation_end_link()
   grpc_link = invocation.invocation_link(
       channel, host, metadata_transformer, request_serializers,
       response_deserializers)
-  if thread_pool is None:
-    invocation_pool = logging_pool.pool(
-        _DEFAULT_POOL_SIZE if thread_pool_size is None else thread_pool_size)
-    assembly_pool = invocation_pool
-  else:
-    invocation_pool = thread_pool
-    assembly_pool = None
-  end_link.join_link(grpc_link)
-  grpc_link.join_link(end_link)
-  end_link.start()
-  grpc_link.start()
-  return end_link, grpc_link, invocation_pool, assembly_pool
+  stub_assembly_manager = _StubAssemblyManager(
+      thread_pool, thread_pool_size, end_link, grpc_link, stub_creator)
+  stub = stub_assembly_manager.up()
+  return _AutoIntermediary(
+      stub_assembly_manager.up, stub_assembly_manager.down, stub)
 
 
-def _disassemble(end_link, grpc_link, pool):
-  end_link.stop(24 * 60 * 60).wait()
-  grpc_link.stop()
-  end_link.join_link(utilities.NULL_LINK)
-  grpc_link.join_link(utilities.NULL_LINK)
-  if pool is not None:
-    pool.shutdown(wait=True)
-
-
-def _wrap_assembly(stub, end_link, grpc_link, assembly_pool):
-  disassembly_thread = threading.Thread(
-      target=_disassemble, args=(end_link, grpc_link, assembly_pool))
-  return _AutoIntermediary(stub, disassembly_thread.start)
+def _dynamic_stub_creator(service, cardinalities):
+  def create_dynamic_stub(end_link, invocation_pool):
+    return _crust_implementations.dynamic_stub(
+        end_link, service, cardinalities, invocation_pool)
+  return create_dynamic_stub
 
 
 def generic_stub(
     channel, host, metadata_transformer, request_serializers,
     response_deserializers, thread_pool, thread_pool_size):
-  end_link, grpc_link, invocation_pool, assembly_pool = _assemble(
+  return _assemble(
       channel, host, metadata_transformer, request_serializers,
-      response_deserializers, thread_pool, thread_pool_size)
-  stub = _crust_implementations.generic_stub(end_link, invocation_pool)
-  return _wrap_assembly(stub, end_link, grpc_link, assembly_pool)
+      response_deserializers, thread_pool, thread_pool_size,
+      _crust_implementations.generic_stub)
 
 
 def dynamic_stub(
     channel, host, service, cardinalities, metadata_transformer,
     request_serializers, response_deserializers, thread_pool,
     thread_pool_size):
-  end_link, grpc_link, invocation_pool, assembly_pool = _assemble(
+  return _assemble(
       channel, host, metadata_transformer, request_serializers,
-      response_deserializers, thread_pool, thread_pool_size)
-  stub = _crust_implementations.dynamic_stub(
-      end_link, service, cardinalities, invocation_pool)
-  return _wrap_assembly(stub, end_link, grpc_link, assembly_pool)
+      response_deserializers, thread_pool, thread_pool_size,
+      _dynamic_stub_creator(service, cardinalities))
diff --git a/src/python/grpcio_test/grpc_test/beta/_beta_features_test.py b/src/python/grpcio_test/grpc_test/beta/_beta_features_test.py
index 5916a9e..6b5090f 100644
--- a/src/python/grpcio_test/grpc_test/beta/_beta_features_test.py
+++ b/src/python/grpcio_test/grpc_test/beta/_beta_features_test.py
@@ -224,5 +224,78 @@
     self.assertEqual(_RESPONSE, response)
 
 
+class ContextManagementAndLifecycleTest(unittest.TestCase):
+
+  def setUp(self):
+    self._servicer = _Servicer()
+    self._method_implementations = {
+        (_GROUP, _UNARY_UNARY):
+            utilities.unary_unary_inline(self._servicer.unary_unary),
+        (_GROUP, _UNARY_STREAM):
+            utilities.unary_stream_inline(self._servicer.unary_stream),
+        (_GROUP, _STREAM_UNARY):
+            utilities.stream_unary_inline(self._servicer.stream_unary),
+        (_GROUP, _STREAM_STREAM):
+            utilities.stream_stream_inline(self._servicer.stream_stream),
+    }
+
+    self._cardinalities = {
+        _UNARY_UNARY: cardinality.Cardinality.UNARY_UNARY,
+        _UNARY_STREAM: cardinality.Cardinality.UNARY_STREAM,
+        _STREAM_UNARY: cardinality.Cardinality.STREAM_UNARY,
+        _STREAM_STREAM: cardinality.Cardinality.STREAM_STREAM,
+    }
+
+    self._server_options = implementations.server_options(
+        thread_pool_size=test_constants.POOL_SIZE)
+    self._server_credentials = implementations.ssl_server_credentials(
+        [(resources.private_key(), resources.certificate_chain(),),])
+    self._client_credentials = implementations.ssl_client_credentials(
+        resources.test_root_certificates(), None, None)
+    self._stub_options = implementations.stub_options(
+        thread_pool_size=test_constants.POOL_SIZE)
+
+  def test_stub_context(self):
+    server = implementations.server(
+        self._method_implementations, options=self._server_options)
+    port = server.add_secure_port('[::]:0', self._server_credentials)
+    server.start()
+
+    channel = test_utilities.not_really_secure_channel(
+        'localhost', port, self._client_credentials, _SERVER_HOST_OVERRIDE)
+    dynamic_stub = implementations.dynamic_stub(
+        channel, _GROUP, self._cardinalities, options=self._stub_options)
+    for _ in range(100):
+      with dynamic_stub:
+        pass
+    for _ in range(10):
+      with dynamic_stub:
+        call_options = interfaces.grpc_call_options(
+            disable_compression=True)
+        response = getattr(dynamic_stub, _UNARY_UNARY)(
+            _REQUEST, test_constants.LONG_TIMEOUT,
+            protocol_options=call_options)
+        self.assertEqual(_RESPONSE, response)
+        self.assertIsNotNone(self._servicer.peer())
+
+    server.stop(test_constants.SHORT_TIMEOUT).wait()
+
+  def test_server_lifecycle(self):
+    for _ in range(100):
+      server = implementations.server(
+          self._method_implementations, options=self._server_options)
+      port = server.add_secure_port('[::]:0', self._server_credentials)
+      server.start()
+      server.stop(test_constants.SHORT_TIMEOUT).wait()
+    for _ in range(100):
+      server = implementations.server(
+          self._method_implementations, options=self._server_options)
+      server.add_secure_port('[::]:0', self._server_credentials)
+      server.add_insecure_port('[::]:0')
+      with server:
+        server.stop(test_constants.SHORT_TIMEOUT)
+      server.stop(test_constants.SHORT_TIMEOUT)
+
+
 if __name__ == '__main__':
   unittest.main(verbosity=2)