Merge pull request #2969 from soltanmm/gravity-well

Add cancel_all_calls to Python server.
diff --git a/src/python/grpcio/grpc/_adapter/_c/types.h b/src/python/grpcio/grpc/_adapter/_c/types.h
index f646465..f6ff957 100644
--- a/src/python/grpcio/grpc/_adapter/_c/types.h
+++ b/src/python/grpcio/grpc/_adapter/_c/types.h
@@ -146,6 +146,7 @@
   PyObject_HEAD
   grpc_server *c_serv;
   CompletionQueue *cq;
+  int shutdown_called;
 } Server;
 Server *pygrpc_Server_new(PyTypeObject *type, PyObject *args, PyObject *kwargs);
 void pygrpc_Server_dealloc(Server *self);
@@ -156,6 +157,7 @@
 PyObject *pygrpc_Server_start(Server *self, PyObject *ignored);
 PyObject *pygrpc_Server_shutdown(
     Server *self, PyObject *args, PyObject *kwargs);
+PyObject *pygrpc_Server_cancel_all_calls(Server *self, PyObject *unused);
 extern PyTypeObject pygrpc_Server_type;
 
 /*=========*/
diff --git a/src/python/grpcio/grpc/_adapter/_c/types/server.c b/src/python/grpcio/grpc/_adapter/_c/types/server.c
index 15c98f2..8feab8a 100644
--- a/src/python/grpcio/grpc/_adapter/_c/types/server.c
+++ b/src/python/grpcio/grpc/_adapter/_c/types/server.c
@@ -45,6 +45,8 @@
      METH_KEYWORDS, ""},
     {"start", (PyCFunction)pygrpc_Server_start, METH_NOARGS, ""},
     {"shutdown", (PyCFunction)pygrpc_Server_shutdown, METH_KEYWORDS, ""},
+    {"cancel_all_calls", (PyCFunction)pygrpc_Server_cancel_all_calls,
+     METH_NOARGS, ""},
     {NULL}
 };
 const char pygrpc_Server_doc[] = "See grpc._adapter._types.Server.";
@@ -109,6 +111,7 @@
   pygrpc_discard_channel_args(c_args);
   self->cq = cq;
   Py_INCREF(self->cq);
+  self->shutdown_called = 0;
   return self;
 }
 
@@ -163,6 +166,7 @@
 
 PyObject *pygrpc_Server_start(Server *self, PyObject *ignored) {
   grpc_server_start(self->c_serv);
+  self->shutdown_called = 0;
   Py_RETURN_NONE;
 }
 
@@ -176,5 +180,17 @@
   }
   tag = pygrpc_produce_server_shutdown_tag(user_tag);
   grpc_server_shutdown_and_notify(self->c_serv, self->cq->c_cq, tag);
+  self->shutdown_called = 1;
+  Py_RETURN_NONE;
+}
+
+PyObject *pygrpc_Server_cancel_all_calls(Server *self, PyObject *unused) {
+  if (!self->shutdown_called) {
+    PyErr_SetString(
+        PyExc_RuntimeError,
+        "shutdown must have been called prior to calling cancel_all_calls!");
+    return NULL;
+  }
+  grpc_server_cancel_all_calls(self->c_serv);
   Py_RETURN_NONE;
 }
diff --git a/src/python/grpcio/grpc/_adapter/_low.py b/src/python/grpcio/grpc/_adapter/_low.py
index 147086e..3859ebb 100644
--- a/src/python/grpcio/grpc/_adapter/_low.py
+++ b/src/python/grpcio/grpc/_adapter/_low.py
@@ -124,3 +124,6 @@
 
   def request_call(self, completion_queue, tag):
     return self.server.request_call(completion_queue.completion_queue, tag)
+
+  def cancel_all_calls(self):
+    return self.server.cancel_all_calls()
diff --git a/src/python/grpcio_test/grpc_test/_adapter/_low_test.py b/src/python/grpcio_test/grpc_test/_adapter/_low_test.py
index 44fe760..7014912 100644
--- a/src/python/grpcio_test/grpc_test/_adapter/_low_test.py
+++ b/src/python/grpcio_test/grpc_test/_adapter/_low_test.py
@@ -52,7 +52,6 @@
   def set_ith_result(i, completion_queue):
     result = completion_queue.next(deadline)
     with lock:
-      print i, completion_queue, result, time.time() - deadline
       results[i] = result
   for i, completion_queue in enumerate(completion_queues):
     thread = threading.Thread(target=set_ith_result,
@@ -80,10 +79,12 @@
     del self.client_channel
 
     self.client_completion_queue.shutdown()
-    while self.client_completion_queue.next().type != _types.EventType.QUEUE_SHUTDOWN:
+    while (self.client_completion_queue.next().type !=
+           _types.EventType.QUEUE_SHUTDOWN):
       pass
     self.server_completion_queue.shutdown()
-    while self.server_completion_queue.next().type != _types.EventType.QUEUE_SHUTDOWN:
+    while (self.server_completion_queue.next().type !=
+           _types.EventType.QUEUE_SHUTDOWN):
       pass
 
     del self.client_completion_queue
@@ -91,58 +92,68 @@
     del self.server
 
   def testEcho(self):
-    DEADLINE = time.time()+5
-    DEADLINE_TOLERANCE = 0.25
-    CLIENT_METADATA_ASCII_KEY = 'key'
-    CLIENT_METADATA_ASCII_VALUE = 'val'
-    CLIENT_METADATA_BIN_KEY = 'key-bin'
-    CLIENT_METADATA_BIN_VALUE = b'\0'*1000
-    SERVER_INITIAL_METADATA_KEY = 'init_me_me_me'
-    SERVER_INITIAL_METADATA_VALUE = 'whodawha?'
-    SERVER_TRAILING_METADATA_KEY = 'california_is_in_a_drought'
-    SERVER_TRAILING_METADATA_VALUE = 'zomg it is'
-    SERVER_STATUS_CODE = _types.StatusCode.OK
-    SERVER_STATUS_DETAILS = 'our work is never over'
-    REQUEST = 'in death a member of project mayhem has a name'
-    RESPONSE = 'his name is robert paulson'
-    METHOD = 'twinkies'
-    HOST = 'hostess'
+    deadline = time.time() + 5
+    event_time_tolerance = 2
+    deadline_tolerance = 0.25
+    client_metadata_ascii_key = 'key'
+    client_metadata_ascii_value = 'val'
+    client_metadata_bin_key = 'key-bin'
+    client_metadata_bin_value = b'\0'*1000
+    server_initial_metadata_key = 'init_me_me_me'
+    server_initial_metadata_value = 'whodawha?'
+    server_trailing_metadata_key = 'california_is_in_a_drought'
+    server_trailing_metadata_value = 'zomg it is'
+    server_status_code = _types.StatusCode.OK
+    server_status_details = 'our work is never over'
+    request = 'blarghaflargh'
+    response = 'his name is robert paulson'
+    method = 'twinkies'
+    host = 'hostess'
     server_request_tag = object()
-    request_call_result = self.server.request_call(self.server_completion_queue, server_request_tag)
+    request_call_result = self.server.request_call(self.server_completion_queue,
+                                                   server_request_tag)
 
-    self.assertEquals(_types.CallError.OK, request_call_result)
+    self.assertEqual(_types.CallError.OK, request_call_result)
 
     client_call_tag = object()
-    client_call = self.client_channel.create_call(self.client_completion_queue, METHOD, HOST, DEADLINE)
-    client_initial_metadata = [(CLIENT_METADATA_ASCII_KEY, CLIENT_METADATA_ASCII_VALUE), (CLIENT_METADATA_BIN_KEY, CLIENT_METADATA_BIN_VALUE)]
+    client_call = self.client_channel.create_call(
+        self.client_completion_queue, method, host, deadline)
+    client_initial_metadata = [
+        (client_metadata_ascii_key, client_metadata_ascii_value),
+        (client_metadata_bin_key, client_metadata_bin_value)
+    ]
     client_start_batch_result = client_call.start_batch([
         _types.OpArgs.send_initial_metadata(client_initial_metadata),
-        _types.OpArgs.send_message(REQUEST, 0),
+        _types.OpArgs.send_message(request, 0),
         _types.OpArgs.send_close_from_client(),
         _types.OpArgs.recv_initial_metadata(),
         _types.OpArgs.recv_message(),
         _types.OpArgs.recv_status_on_client()
     ], client_call_tag)
-    self.assertEquals(_types.CallError.OK, client_start_batch_result)
+    self.assertEqual(_types.CallError.OK, client_start_batch_result)
 
-    client_no_event, request_event, = wait_for_events([self.client_completion_queue, self.server_completion_queue], time.time() + 2)
-    self.assertEquals(client_no_event, None)
-    self.assertEquals(_types.EventType.OP_COMPLETE, request_event.type)
+    client_no_event, request_event, = wait_for_events(
+        [self.client_completion_queue, self.server_completion_queue],
+        time.time() + event_time_tolerance)
+    self.assertEqual(client_no_event, None)
+    self.assertEqual(_types.EventType.OP_COMPLETE, request_event.type)
     self.assertIsInstance(request_event.call, _low.Call)
     self.assertIs(server_request_tag, request_event.tag)
-    self.assertEquals(1, len(request_event.results))
+    self.assertEqual(1, len(request_event.results))
     received_initial_metadata = dict(request_event.results[0].initial_metadata)
     # Check that our metadata were transmitted
-    self.assertEquals(
+    self.assertEqual(
         dict(client_initial_metadata),
-        dict((x, received_initial_metadata[x]) for x in zip(*client_initial_metadata)[0]))
+        dict((x, received_initial_metadata[x])
+             for x in zip(*client_initial_metadata)[0]))
     # Check that Python's user agent string is a part of the full user agent
     # string
     self.assertIn('Python-gRPC-{}'.format(_grpcio_metadata.__version__),
                   received_initial_metadata['user-agent'])
-    self.assertEquals(METHOD, request_event.call_details.method)
-    self.assertEquals(HOST, request_event.call_details.host)
-    self.assertLess(abs(DEADLINE - request_event.call_details.deadline), DEADLINE_TOLERANCE)
+    self.assertEqual(method, request_event.call_details.method)
+    self.assertEqual(host, request_event.call_details.host)
+    self.assertLess(abs(deadline - request_event.call_details.deadline),
+                    deadline_tolerance)
 
     # Check that the channel is connected, and that both it and the call have
     # the proper target and peer; do this after the first flurry of messages to
@@ -155,33 +166,43 @@
 
     server_call_tag = object()
     server_call = request_event.call
-    server_initial_metadata = [(SERVER_INITIAL_METADATA_KEY, SERVER_INITIAL_METADATA_VALUE)]
-    server_trailing_metadata = [(SERVER_TRAILING_METADATA_KEY, SERVER_TRAILING_METADATA_VALUE)]
+    server_initial_metadata = [
+        (server_initial_metadata_key, server_initial_metadata_value)
+    ]
+    server_trailing_metadata = [
+        (server_trailing_metadata_key, server_trailing_metadata_value)
+    ]
     server_start_batch_result = server_call.start_batch([
         _types.OpArgs.send_initial_metadata(server_initial_metadata),
         _types.OpArgs.recv_message(),
-        _types.OpArgs.send_message(RESPONSE, 0),
+        _types.OpArgs.send_message(response, 0),
         _types.OpArgs.recv_close_on_server(),
-        _types.OpArgs.send_status_from_server(server_trailing_metadata, SERVER_STATUS_CODE, SERVER_STATUS_DETAILS)
+        _types.OpArgs.send_status_from_server(
+            server_trailing_metadata, server_status_code, server_status_details)
     ], server_call_tag)
-    self.assertEquals(_types.CallError.OK, server_start_batch_result)
+    self.assertEqual(_types.CallError.OK, server_start_batch_result)
 
-    client_event, server_event, = wait_for_events([self.client_completion_queue, self.server_completion_queue], time.time() + 1)
+    client_event, server_event, = wait_for_events(
+        [self.client_completion_queue, self.server_completion_queue],
+        time.time() + event_time_tolerance)
 
-    self.assertEquals(6, len(client_event.results))
+    self.assertEqual(6, len(client_event.results))
     found_client_op_types = set()
     for client_result in client_event.results:
-      self.assertNotIn(client_result.type, found_client_op_types)  # we expect each op type to be unique
+      # we expect each op type to be unique
+      self.assertNotIn(client_result.type, found_client_op_types)
       found_client_op_types.add(client_result.type)
       if client_result.type == _types.OpType.RECV_INITIAL_METADATA:
-        self.assertEquals(dict(server_initial_metadata), dict(client_result.initial_metadata))
+        self.assertEqual(dict(server_initial_metadata),
+                         dict(client_result.initial_metadata))
       elif client_result.type == _types.OpType.RECV_MESSAGE:
-        self.assertEquals(RESPONSE, client_result.message)
+        self.assertEqual(response, client_result.message)
       elif client_result.type == _types.OpType.RECV_STATUS_ON_CLIENT:
-        self.assertEquals(dict(server_trailing_metadata), dict(client_result.trailing_metadata))
-        self.assertEquals(SERVER_STATUS_DETAILS, client_result.status.details)
-        self.assertEquals(SERVER_STATUS_CODE, client_result.status.code)
-    self.assertEquals(set([
+        self.assertEqual(dict(server_trailing_metadata),
+                         dict(client_result.trailing_metadata))
+        self.assertEqual(server_status_details, client_result.status.details)
+        self.assertEqual(server_status_code, client_result.status.code)
+    self.assertEqual(set([
           _types.OpType.SEND_INITIAL_METADATA,
           _types.OpType.SEND_MESSAGE,
           _types.OpType.SEND_CLOSE_FROM_CLIENT,
@@ -190,16 +211,16 @@
           _types.OpType.RECV_STATUS_ON_CLIENT
       ]), found_client_op_types)
 
-    self.assertEquals(5, len(server_event.results))
+    self.assertEqual(5, len(server_event.results))
     found_server_op_types = set()
     for server_result in server_event.results:
       self.assertNotIn(client_result.type, found_server_op_types)
       found_server_op_types.add(server_result.type)
       if server_result.type == _types.OpType.RECV_MESSAGE:
-        self.assertEquals(REQUEST, server_result.message)
+        self.assertEqual(request, server_result.message)
       elif server_result.type == _types.OpType.RECV_CLOSE_ON_SERVER:
         self.assertFalse(server_result.cancelled)
-    self.assertEquals(set([
+    self.assertEqual(set([
           _types.OpType.SEND_INITIAL_METADATA,
           _types.OpType.RECV_MESSAGE,
           _types.OpType.SEND_MESSAGE,
@@ -211,5 +232,81 @@
     del server_call
 
 
+class HangingServerShutdown(unittest.TestCase):
+
+  def setUp(self):
+    self.server_completion_queue = _low.CompletionQueue()
+    self.server = _low.Server(self.server_completion_queue, [])
+    self.port = self.server.add_http2_port('[::]:0')
+    self.client_completion_queue = _low.CompletionQueue()
+    self.client_channel = _low.Channel('localhost:%d'%self.port, [])
+
+    self.server.start()
+
+  def tearDown(self):
+    self.server.shutdown()
+    del self.client_channel
+
+    self.client_completion_queue.shutdown()
+    self.server_completion_queue.shutdown()
+    while True:
+      client_event, server_event = wait_for_events(
+          [self.client_completion_queue, self.server_completion_queue],
+          float("+inf"))
+      if (client_event.type == _types.EventType.QUEUE_SHUTDOWN and
+          server_event.type == _types.EventType.QUEUE_SHUTDOWN):
+        break
+
+    del self.client_completion_queue
+    del self.server_completion_queue
+    del self.server
+
+  def testHangingServerCall(self):
+    deadline = time.time() + 5
+    deadline_tolerance = 0.25
+    event_time_tolerance = 2
+    cancel_all_calls_time_tolerance = 0.5
+    request = 'blarghaflargh'
+    method = 'twinkies'
+    host = 'hostess'
+    server_request_tag = object()
+    request_call_result = self.server.request_call(self.server_completion_queue,
+                                                   server_request_tag)
+
+    client_call_tag = object()
+    client_call = self.client_channel.create_call(self.client_completion_queue,
+                                                  method, host, deadline)
+    client_start_batch_result = client_call.start_batch([
+        _types.OpArgs.send_initial_metadata([]),
+        _types.OpArgs.send_message(request, 0),
+        _types.OpArgs.send_close_from_client(),
+        _types.OpArgs.recv_initial_metadata(),
+        _types.OpArgs.recv_message(),
+        _types.OpArgs.recv_status_on_client()
+    ], client_call_tag)
+
+    client_no_event, request_event, = wait_for_events(
+        [self.client_completion_queue, self.server_completion_queue],
+        time.time() + event_time_tolerance)
+
+    # Now try to shutdown the server and expect that we see server shutdown
+    # almost immediately after calling cancel_all_calls.
+    with self.assertRaises(RuntimeError):
+      self.server.cancel_all_calls()
+    shutdown_tag = object()
+    self.server.shutdown(shutdown_tag)
+    pre_cancel_timestamp = time.time()
+    self.server.cancel_all_calls()
+    finish_shutdown_timestamp = None
+    client_call_event, server_shutdown_event = wait_for_events(
+        [self.client_completion_queue, self.server_completion_queue],
+        time.time() + event_time_tolerance)
+    self.assertIs(shutdown_tag, server_shutdown_event.tag)
+    self.assertGreater(pre_cancel_timestamp + cancel_all_calls_time_tolerance,
+                       time.time())
+
+    del client_call
+
+
 if __name__ == '__main__':
   unittest.main(verbosity=2)