blob: ecd23afda71d2c6d04c239a35cdf8f4f1ca62353 [file] [log] [blame]
# Copyright 2016 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test a corner-case at the level of the Cython API."""
import threading
import unittest
from grpc._cython import cygrpc
_EMPTY_FLAGS = 0
_EMPTY_METADATA = ()
class _ServerDriver(object):
def __init__(self, completion_queue, shutdown_tag):
self._condition = threading.Condition()
self._completion_queue = completion_queue
self._shutdown_tag = shutdown_tag
self._events = []
self._saw_shutdown_tag = False
def start(self):
def in_thread():
while True:
event = self._completion_queue.poll()
with self._condition:
self._events.append(event)
self._condition.notify()
if event.tag is self._shutdown_tag:
self._saw_shutdown_tag = True
break
thread = threading.Thread(target=in_thread)
thread.start()
def done(self):
with self._condition:
return self._saw_shutdown_tag
def first_event(self):
with self._condition:
while not self._events:
self._condition.wait()
return self._events[0]
def events(self):
with self._condition:
while not self._saw_shutdown_tag:
self._condition.wait()
return tuple(self._events)
class _QueueDriver(object):
def __init__(self, condition, completion_queue, due):
self._condition = condition
self._completion_queue = completion_queue
self._due = due
self._events = []
self._returned = False
def start(self):
def in_thread():
while True:
event = self._completion_queue.poll()
with self._condition:
self._events.append(event)
self._due.remove(event.tag)
self._condition.notify_all()
if not self._due:
self._returned = True
return
thread = threading.Thread(target=in_thread)
thread.start()
def done(self):
with self._condition:
return self._returned
def event_with_tag(self, tag):
with self._condition:
while True:
for event in self._events:
if event.tag is tag:
return event
self._condition.wait()
def events(self):
with self._condition:
while not self._returned:
self._condition.wait()
return tuple(self._events)
class ReadSomeButNotAllResponsesTest(unittest.TestCase):
def testReadSomeButNotAllResponses(self):
server_completion_queue = cygrpc.CompletionQueue()
server = cygrpc.Server(
cygrpc.ChannelArgs([cygrpc.ChannelArg(b'grpc.so_reuseport', 0)]))
server.register_completion_queue(server_completion_queue)
port = server.add_http2_port(b'[::]:0')
server.start()
channel = cygrpc.Channel('localhost:{}'.format(port).encode(),
cygrpc.ChannelArgs([]))
server_shutdown_tag = 'server_shutdown_tag'
server_driver = _ServerDriver(server_completion_queue,
server_shutdown_tag)
server_driver.start()
client_condition = threading.Condition()
client_due = set()
client_completion_queue = cygrpc.CompletionQueue()
client_driver = _QueueDriver(client_condition, client_completion_queue,
client_due)
client_driver.start()
server_call_condition = threading.Condition()
server_send_initial_metadata_tag = 'server_send_initial_metadata_tag'
server_send_first_message_tag = 'server_send_first_message_tag'
server_send_second_message_tag = 'server_send_second_message_tag'
server_complete_rpc_tag = 'server_complete_rpc_tag'
server_call_due = set((
server_send_initial_metadata_tag,
server_send_first_message_tag,
server_send_second_message_tag,
server_complete_rpc_tag,
))
server_call_completion_queue = cygrpc.CompletionQueue()
server_call_driver = _QueueDriver(server_call_condition,
server_call_completion_queue,
server_call_due)
server_call_driver.start()
server_rpc_tag = 'server_rpc_tag'
request_call_result = server.request_call(server_call_completion_queue,
server_completion_queue,
server_rpc_tag)
client_call = channel.create_call(None, _EMPTY_FLAGS,
client_completion_queue, b'/twinkies',
None, None)
client_receive_initial_metadata_tag = 'client_receive_initial_metadata_tag'
client_complete_rpc_tag = 'client_complete_rpc_tag'
with client_condition:
client_receive_initial_metadata_start_batch_result = (
client_call.start_client_batch([
cygrpc.ReceiveInitialMetadataOperation(_EMPTY_FLAGS),
], client_receive_initial_metadata_tag))
client_due.add(client_receive_initial_metadata_tag)
client_complete_rpc_start_batch_result = (
client_call.start_client_batch([
cygrpc.SendInitialMetadataOperation(_EMPTY_METADATA,
_EMPTY_FLAGS),
cygrpc.SendCloseFromClientOperation(_EMPTY_FLAGS),
cygrpc.ReceiveStatusOnClientOperation(_EMPTY_FLAGS),
], client_complete_rpc_tag))
client_due.add(client_complete_rpc_tag)
server_rpc_event = server_driver.first_event()
with server_call_condition:
server_send_initial_metadata_start_batch_result = (
server_rpc_event.call.start_server_batch([
cygrpc.SendInitialMetadataOperation(_EMPTY_METADATA,
_EMPTY_FLAGS),
], server_send_initial_metadata_tag))
server_send_first_message_start_batch_result = (
server_rpc_event.call.start_server_batch([
cygrpc.SendMessageOperation(b'\x07', _EMPTY_FLAGS),
], server_send_first_message_tag))
server_send_initial_metadata_event = server_call_driver.event_with_tag(
server_send_initial_metadata_tag)
server_send_first_message_event = server_call_driver.event_with_tag(
server_send_first_message_tag)
with server_call_condition:
server_send_second_message_start_batch_result = (
server_rpc_event.call.start_server_batch([
cygrpc.SendMessageOperation(b'\x07', _EMPTY_FLAGS),
], server_send_second_message_tag))
server_complete_rpc_start_batch_result = (
server_rpc_event.call.start_server_batch([
cygrpc.ReceiveCloseOnServerOperation(_EMPTY_FLAGS),
cygrpc.SendStatusFromServerOperation(
(), cygrpc.StatusCode.ok, b'test details',
_EMPTY_FLAGS),
], server_complete_rpc_tag))
server_send_second_message_event = server_call_driver.event_with_tag(
server_send_second_message_tag)
server_complete_rpc_event = server_call_driver.event_with_tag(
server_complete_rpc_tag)
server_call_driver.events()
with client_condition:
client_receive_first_message_tag = 'client_receive_first_message_tag'
client_receive_first_message_start_batch_result = (
client_call.start_client_batch([
cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS),
], client_receive_first_message_tag))
client_due.add(client_receive_first_message_tag)
client_receive_first_message_event = client_driver.event_with_tag(
client_receive_first_message_tag)
client_call_cancel_result = client_call.cancel()
client_driver.events()
server.shutdown(server_completion_queue, server_shutdown_tag)
server.cancel_all_calls()
server_driver.events()
self.assertEqual(cygrpc.CallError.ok, request_call_result)
self.assertEqual(cygrpc.CallError.ok,
server_send_initial_metadata_start_batch_result)
self.assertEqual(cygrpc.CallError.ok,
client_receive_initial_metadata_start_batch_result)
self.assertEqual(cygrpc.CallError.ok,
client_complete_rpc_start_batch_result)
self.assertEqual(cygrpc.CallError.ok, client_call_cancel_result)
self.assertIs(server_rpc_tag, server_rpc_event.tag)
self.assertEqual(cygrpc.CompletionType.operation_complete,
server_rpc_event.completion_type)
self.assertIsInstance(server_rpc_event.call, cygrpc.Call)
if __name__ == '__main__':
unittest.main(verbosity=2)