Removed cython client-side call tracking
This ensures sync calls get cancelled after
a keyboard interrupt, as well as all calls
getting destroyed before grpc_shutdown()
diff --git a/src/python/grpcio/grpc/_channel.py b/src/python/grpcio/grpc/_channel.py
index a89b501..29dbc3a 100644
--- a/src/python/grpcio/grpc/_channel.py
+++ b/src/python/grpcio/grpc/_channel.py
@@ -195,7 +195,8 @@
cygrpc.operation_send_message(
serialized_request, _EMPTY_FLAGS),
)
- call.start_batch(cygrpc.Operations(operations), event_handler)
+ call.start_client_batch(cygrpc.Operations(operations),
+ event_handler)
state.due.add(cygrpc.OperationType.send_message)
while True:
state.condition.wait()
@@ -211,7 +212,7 @@
operations = (
cygrpc.operation_send_close_from_client(_EMPTY_FLAGS),
)
- call.start_batch(cygrpc.Operations(operations), event_handler)
+ call.start_client_batch(cygrpc.Operations(operations), event_handler)
state.due.add(cygrpc.OperationType.send_close_from_client)
def stop_consumption_thread(timeout):
@@ -312,7 +313,7 @@
if self._state.code is None:
event_handler = _event_handler(
self._state, self._call, self._response_deserializer)
- self._call.start_batch(
+ self._call.start_client_batch(
cygrpc.Operations(
(cygrpc.operation_receive_message(_EMPTY_FLAGS),)),
event_handler)
@@ -471,7 +472,7 @@
None, 0, completion_queue, self._method, None, deadline_timespec)
if credentials is not None:
call.set_credentials(credentials._credentials)
- call.start_batch(cygrpc.Operations(operations), None)
+ call.start_client_batch(cygrpc.Operations(operations), None)
_handle_event(completion_queue.poll(), state, self._response_deserializer)
return state, deadline
@@ -495,7 +496,7 @@
call.set_credentials(credentials._credentials)
event_handler = _event_handler(state, call, self._response_deserializer)
with state.condition:
- call.start_batch(cygrpc.Operations(operations), event_handler)
+ call.start_client_batch(cygrpc.Operations(operations), event_handler)
return _Rendezvous(state, call, self._response_deserializer, deadline)
@@ -523,7 +524,7 @@
call.set_credentials(credentials._credentials)
event_handler = _event_handler(state, call, self._response_deserializer)
with state.condition:
- call.start_batch(
+ call.start_client_batch(
cygrpc.Operations(
(cygrpc.operation_receive_initial_metadata(_EMPTY_FLAGS),)),
event_handler)
@@ -534,7 +535,7 @@
cygrpc.operation_send_close_from_client(_EMPTY_FLAGS),
cygrpc.operation_receive_status_on_client(_EMPTY_FLAGS),
)
- call.start_batch(cygrpc.Operations(operations), event_handler)
+ call.start_client_batch(cygrpc.Operations(operations), event_handler)
return _Rendezvous(state, call, self._response_deserializer, deadline)
@@ -558,7 +559,7 @@
if credentials is not None:
call.set_credentials(credentials._credentials)
with state.condition:
- call.start_batch(
+ call.start_client_batch(
cygrpc.Operations(
(cygrpc.operation_receive_initial_metadata(_EMPTY_FLAGS),)),
None)
@@ -568,7 +569,7 @@
cygrpc.operation_receive_message(_EMPTY_FLAGS),
cygrpc.operation_receive_status_on_client(_EMPTY_FLAGS),
)
- call.start_batch(cygrpc.Operations(operations), None)
+ call.start_client_batch(cygrpc.Operations(operations), None)
_consume_request_iterator(
request_iterator, state, call, self._request_serializer)
while True:
@@ -602,7 +603,7 @@
call.set_credentials(credentials._credentials)
event_handler = _event_handler(state, call, self._response_deserializer)
with state.condition:
- call.start_batch(
+ call.start_client_batch(
cygrpc.Operations(
(cygrpc.operation_receive_initial_metadata(_EMPTY_FLAGS),)),
event_handler)
@@ -612,7 +613,7 @@
cygrpc.operation_receive_message(_EMPTY_FLAGS),
cygrpc.operation_receive_status_on_client(_EMPTY_FLAGS),
)
- call.start_batch(cygrpc.Operations(operations), event_handler)
+ call.start_client_batch(cygrpc.Operations(operations), event_handler)
_consume_request_iterator(
request_iterator, state, call, self._request_serializer)
return _Rendezvous(state, call, self._response_deserializer, deadline)
@@ -639,7 +640,7 @@
call.set_credentials(credentials._credentials)
event_handler = _event_handler(state, call, self._response_deserializer)
with state.condition:
- call.start_batch(
+ call.start_client_batch(
cygrpc.Operations(
(cygrpc.operation_receive_initial_metadata(_EMPTY_FLAGS),)),
event_handler)
@@ -648,7 +649,7 @@
_common.cygrpc_metadata(metadata), _EMPTY_FLAGS),
cygrpc.operation_receive_status_on_client(_EMPTY_FLAGS),
)
- call.start_batch(cygrpc.Operations(operations), event_handler)
+ call.start_client_batch(cygrpc.Operations(operations), event_handler)
_consume_request_iterator(
request_iterator, state, call, self._request_serializer)
return _Rendezvous(state, call, self._response_deserializer, deadline)
diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/call.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/call.pyx.pxi
index a09bbc7..6570dcd 100644
--- a/src/python/grpcio/grpc/_cython/_cygrpc/call.pyx.pxi
+++ b/src/python/grpcio/grpc/_cython/_cygrpc/call.pyx.pxi
@@ -37,13 +37,16 @@
self.c_call = NULL
self.references = []
- def start_batch(self, operations, tag):
+ def _start_batch(self, operations, tag, retain_self):
if not self.is_valid:
raise ValueError("invalid call object cannot be used from Python")
cdef grpc_call_error result
cdef Operations cy_operations = Operations(operations)
cdef OperationTag operation_tag = OperationTag(tag)
- operation_tag.operation_call = self
+ if retain_self:
+ operation_tag.operation_call = self
+ else:
+ operation_tag.operation_call = None
operation_tag.batch_operations = cy_operations
cpython.Py_INCREF(operation_tag)
with nogil:
@@ -52,6 +55,14 @@
<cpython.PyObject *>operation_tag, NULL)
return result
+ def start_client_batch(self, operations, tag):
+ # We don't reference this call in the operations tag because
+ # it should be cancelled when it goes out of scope
+ return self._start_batch(operations, tag, False)
+
+ def start_server_batch(self, operations, tag):
+ return self._start_batch(operations, tag, True)
+
def cancel(
self, grpc_status_code error_code=GRPC_STATUS__DO_NOT_USE,
details=None):
diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/records.pxd.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/records.pxd.pxi
index 0474697..96c5b02 100644
--- a/src/python/grpcio/grpc/_cython/_cygrpc/records.pxd.pxi
+++ b/src/python/grpcio/grpc/_cython/_cygrpc/records.pxd.pxi
@@ -58,14 +58,14 @@
cdef readonly bint success
cdef readonly object tag
- # For operations with calls
- cdef readonly Call operation_call
-
# For Server.request_call
cdef readonly bint is_new_request
cdef readonly CallDetails request_call_details
cdef readonly Metadata request_metadata
+ # For server calls
+ cdef readonly Call operation_call
+
# For Call.start_batch
cdef readonly Operations batch_operations
diff --git a/src/python/grpcio/grpc/_server.py b/src/python/grpcio/grpc/_server.py
index f4c1140..5d805b5 100644
--- a/src/python/grpcio/grpc/_server.py
+++ b/src/python/grpcio/grpc/_server.py
@@ -157,7 +157,7 @@
effective_details, _EMPTY_FLAGS),
)
token = _SEND_STATUS_FROM_SERVER_TOKEN
- call.start_batch(
+ call.start_server_batch(
cygrpc.Operations(operations),
_send_status_from_server(state, token))
state.statused = True
@@ -257,7 +257,7 @@
if self._state.initial_metadata_allowed:
operation = cygrpc.operation_send_initial_metadata(
_common.cygrpc_metadata(initial_metadata), _EMPTY_FLAGS)
- self._rpc_event.operation_call.start_batch(
+ self._rpc_event.operation_call.start_server_batch(
cygrpc.Operations((operation,)),
_send_initial_metadata(self._state))
self._state.initial_metadata_allowed = False
@@ -292,7 +292,7 @@
elif self._state.client is _CLOSED or self._state.statused:
raise StopIteration()
else:
- self._call.start_batch(
+ self._call.start_server_batch(
cygrpc.Operations((cygrpc.operation_receive_message(_EMPTY_FLAGS),)),
_receive_message(self._state, self._call, self._request_deserializer))
self._state.due.add(_RECEIVE_MESSAGE_TOKEN)
@@ -333,7 +333,7 @@
if state.client is _CANCELLED or state.statused:
return None
else:
- start_batch_result = rpc_event.operation_call.start_batch(
+ start_server_batch_result = rpc_event.operation_call.start_server_batch(
cygrpc.Operations(
(cygrpc.operation_receive_message(_EMPTY_FLAGS),)),
_receive_message(
@@ -417,7 +417,7 @@
cygrpc.operation_send_message(serialized_response, _EMPTY_FLAGS),
)
token = _SEND_MESSAGE_TOKEN
- rpc_event.operation_call.start_batch(
+ rpc_event.operation_call.start_server_batch(
cygrpc.Operations(operations), _send_message(state, token))
state.due.add(token)
while True:
@@ -443,7 +443,7 @@
if serialized_response is not None:
operations.append(cygrpc.operation_send_message(
serialized_response, _EMPTY_FLAGS))
- rpc_event.operation_call.start_batch(
+ rpc_event.operation_call.start_server_batch(
cygrpc.Operations(operations),
_send_status_from_server(state, _SEND_STATUS_FROM_SERVER_TOKEN))
state.statused = True
@@ -550,7 +550,7 @@
b'Method not found!', _EMPTY_FLAGS),
)
rpc_state = _RPCState()
- rpc_event.operation_call.start_batch(
+ rpc_event.operation_call.start_server_batch(
operations, lambda ignored_event: (rpc_state, (),))
return rpc_state
@@ -558,7 +558,7 @@
def _handle_with_method_handler(rpc_event, method_handler, thread_pool):
state = _RPCState()
with state.condition:
- rpc_event.operation_call.start_batch(
+ rpc_event.operation_call.start_server_batch(
cygrpc.Operations(
(cygrpc.operation_receive_close_on_server(_EMPTY_FLAGS),)),
_receive_close_on_server(state))
diff --git a/src/python/grpcio_tests/tests/unit/_cython/_cancel_many_calls_test.py b/src/python/grpcio_tests/tests/unit/_cython/_cancel_many_calls_test.py
index cac0c8b..cf212c5 100644
--- a/src/python/grpcio_tests/tests/unit/_cython/_cancel_many_calls_test.py
+++ b/src/python/grpcio_tests/tests/unit/_cython/_cancel_many_calls_test.py
@@ -81,11 +81,11 @@
self._state.condition.wait()
with self._lock:
- self._call.start_batch(
+ self._call.start_server_batch(
cygrpc.Operations(
(cygrpc.operation_receive_close_on_server(_EMPTY_FLAGS),)),
_RECEIVE_CLOSE_ON_SERVER_TAG)
- self._call.start_batch(
+ self._call.start_server_batch(
cygrpc.Operations((cygrpc.operation_receive_message(_EMPTY_FLAGS),)),
_RECEIVE_MESSAGE_TAG)
first_event = self._completion_queue.poll()
@@ -101,7 +101,7 @@
_EMPTY_METADATA, cygrpc.StatusCode.ok, b'test details!',
_EMPTY_FLAGS),
)
- self._call.start_batch(
+ self._call.start_server_batch(
cygrpc.Operations(operations), _SERVER_COMPLETE_CALL_TAG)
self._completion_queue.poll()
self._completion_queue.poll()
@@ -193,7 +193,7 @@
cygrpc.operation_receive_status_on_client(_EMPTY_FLAGS),
)
tag = 'client_complete_call_{0:04d}_tag'.format(index)
- client_call.start_batch(cygrpc.Operations(operations), tag)
+ client_call.start_client_batch(cygrpc.Operations(operations), tag)
client_due.add(tag)
client_calls.append(client_call)
diff --git a/src/python/grpcio_tests/tests/unit/_cython/_read_some_but_not_all_responses_test.py b/src/python/grpcio_tests/tests/unit/_cython/_read_some_but_not_all_responses_test.py
index 27fcee0..152d8edde 100644
--- a/src/python/grpcio_tests/tests/unit/_cython/_read_some_but_not_all_responses_test.py
+++ b/src/python/grpcio_tests/tests/unit/_cython/_read_some_but_not_all_responses_test.py
@@ -168,12 +168,12 @@
client_complete_rpc_tag = 'client_complete_rpc_tag'
with client_condition:
client_receive_initial_metadata_start_batch_result = (
- client_call.start_batch(cygrpc.Operations([
+ client_call.start_client_batch(cygrpc.Operations([
cygrpc.operation_receive_initial_metadata(_EMPTY_FLAGS),
]), client_receive_initial_metadata_tag))
client_due.add(client_receive_initial_metadata_tag)
client_complete_rpc_start_batch_result = (
- client_call.start_batch(cygrpc.Operations([
+ client_call.start_client_batch(cygrpc.Operations([
cygrpc.operation_send_initial_metadata(
_EMPTY_METADATA, _EMPTY_FLAGS),
cygrpc.operation_send_close_from_client(_EMPTY_FLAGS),
@@ -185,30 +185,30 @@
with server_call_condition:
server_send_initial_metadata_start_batch_result = (
- server_rpc_event.operation_call.start_batch(cygrpc.Operations([
+ server_rpc_event.operation_call.start_server_batch([
cygrpc.operation_send_initial_metadata(
_EMPTY_METADATA, _EMPTY_FLAGS),
- ]), server_send_initial_metadata_tag))
+ ], server_send_initial_metadata_tag))
server_send_first_message_start_batch_result = (
- server_rpc_event.operation_call.start_batch(cygrpc.Operations([
+ server_rpc_event.operation_call.start_server_batch([
cygrpc.operation_send_message(b'\x07', _EMPTY_FLAGS),
- ]), server_send_first_message_tag))
+ ], server_send_first_message_tag))
server_send_initial_metadata_event = server_call_driver.event_with_tag(
server_send_initial_metadata_tag)
server_send_first_message_event = server_call_driver.event_with_tag(
server_send_first_message_tag)
with server_call_condition:
server_send_second_message_start_batch_result = (
- server_rpc_event.operation_call.start_batch(cygrpc.Operations([
+ server_rpc_event.operation_call.start_server_batch([
cygrpc.operation_send_message(b'\x07', _EMPTY_FLAGS),
- ]), server_send_second_message_tag))
+ ], server_send_second_message_tag))
server_complete_rpc_start_batch_result = (
- server_rpc_event.operation_call.start_batch(cygrpc.Operations([
+ server_rpc_event.operation_call.start_server_batch([
cygrpc.operation_receive_close_on_server(_EMPTY_FLAGS),
cygrpc.operation_send_status_from_server(
cygrpc.Metadata(()), cygrpc.StatusCode.ok, b'test details',
_EMPTY_FLAGS),
- ]), server_complete_rpc_tag))
+ ], server_complete_rpc_tag))
server_send_second_message_event = server_call_driver.event_with_tag(
server_send_second_message_tag)
server_complete_rpc_event = server_call_driver.event_with_tag(
@@ -218,7 +218,7 @@
with client_condition:
client_receive_first_message_tag = 'client_receive_first_message_tag'
client_receive_first_message_start_batch_result = (
- client_call.start_batch(cygrpc.Operations([
+ client_call.start_client_batch(cygrpc.Operations([
cygrpc.operation_receive_message(_EMPTY_FLAGS),
]), client_receive_first_message_tag))
client_due.add(client_receive_first_message_tag)
diff --git a/src/python/grpcio_tests/tests/unit/_cython/cygrpc_test.py b/src/python/grpcio_tests/tests/unit/_cython/cygrpc_test.py
index b740695..9d1dbc1 100644
--- a/src/python/grpcio_tests/tests/unit/_cython/cygrpc_test.py
+++ b/src/python/grpcio_tests/tests/unit/_cython/cygrpc_test.py
@@ -186,7 +186,8 @@
def performer():
tag = object()
try:
- call_result = call.start_batch(cygrpc.Operations(operations), tag)
+ call_result = call.start_client_batch(
+ cygrpc.Operations(operations), tag)
self.assertEqual(cygrpc.CallError.ok, call_result)
event = queue.poll(deadline)
self.assertEqual(cygrpc.CompletionType.operation_complete, event.type)
@@ -231,7 +232,7 @@
cygrpc.Metadatum(CLIENT_METADATA_ASCII_KEY,
CLIENT_METADATA_ASCII_VALUE),
cygrpc.Metadatum(CLIENT_METADATA_BIN_KEY, CLIENT_METADATA_BIN_VALUE)])
- client_start_batch_result = client_call.start_batch(cygrpc.Operations([
+ client_start_batch_result = client_call.start_client_batch([
cygrpc.operation_send_initial_metadata(client_initial_metadata,
_EMPTY_FLAGS),
cygrpc.operation_send_message(REQUEST, _EMPTY_FLAGS),
@@ -239,7 +240,7 @@
cygrpc.operation_receive_initial_metadata(_EMPTY_FLAGS),
cygrpc.operation_receive_message(_EMPTY_FLAGS),
cygrpc.operation_receive_status_on_client(_EMPTY_FLAGS)
- ]), client_call_tag)
+ ], client_call_tag)
self.assertEqual(cygrpc.CallError.ok, client_start_batch_result)
client_event_future = test_utilities.CompletionQueuePollFuture(
self.client_completion_queue, cygrpc_deadline)
@@ -268,7 +269,7 @@
server_trailing_metadata = cygrpc.Metadata([
cygrpc.Metadatum(SERVER_TRAILING_METADATA_KEY,
SERVER_TRAILING_METADATA_VALUE)])
- server_start_batch_result = server_call.start_batch([
+ server_start_batch_result = server_call.start_server_batch([
cygrpc.operation_send_initial_metadata(server_initial_metadata,
_EMPTY_FLAGS),
cygrpc.operation_receive_message(_EMPTY_FLAGS),