blob: c1de7790145bf4b36324716ce2f5d7a94b9fa6fd [file] [log] [blame]
# Copyright 2016, Google Inc.
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are
# met:
#
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above
# copyright notice, this list of conditions and the following disclaimer
# in the documentation and/or other materials provided with the
# distribution.
# * Neither the name of Google Inc. nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Test making many calls and immediately cancelling most of them."""
import threading
import unittest
from grpc._cython import cygrpc
from grpc.framework.foundation import logging_pool
from tests.unit.framework.common import test_constants
_INFINITE_FUTURE = cygrpc.Timespec(float('+inf'))
_EMPTY_FLAGS = 0
_EMPTY_METADATA = cygrpc.Metadata(())
_SERVER_SHUTDOWN_TAG = 'server_shutdown'
_REQUEST_CALL_TAG = 'request_call'
_RECEIVE_CLOSE_ON_SERVER_TAG = 'receive_close_on_server'
_RECEIVE_MESSAGE_TAG = 'receive_message'
_SERVER_COMPLETE_CALL_TAG = 'server_complete_call'
_SUCCESS_CALL_FRACTION = 1.0 / 8.0
class _State(object):
def __init__(self):
self.condition = threading.Condition()
self.handlers_released = False
self.parked_handlers = 0
self.handled_rpcs = 0
def _is_cancellation_event(event):
return (
event.tag is _RECEIVE_CLOSE_ON_SERVER_TAG and
event.batch_operations[0].received_cancelled)
class _Handler(object):
def __init__(self, state, completion_queue, rpc_event):
self._state = state
self._lock = threading.Lock()
self._completion_queue = completion_queue
self._call = rpc_event.operation_call
def __call__(self):
with self._state.condition:
self._state.parked_handlers += 1
if self._state.parked_handlers == test_constants.THREAD_CONCURRENCY:
self._state.condition.notify_all()
while not self._state.handlers_released:
self._state.condition.wait()
with self._lock:
self._call.start_batch(
cygrpc.Operations(
(cygrpc.operation_receive_close_on_server(_EMPTY_FLAGS),)),
_RECEIVE_CLOSE_ON_SERVER_TAG)
self._call.start_batch(
cygrpc.Operations((cygrpc.operation_receive_message(_EMPTY_FLAGS),)),
_RECEIVE_MESSAGE_TAG)
first_event = self._completion_queue.poll()
if _is_cancellation_event(first_event):
self._completion_queue.poll()
else:
with self._lock:
operations = (
cygrpc.operation_send_initial_metadata(
_EMPTY_METADATA, _EMPTY_FLAGS),
cygrpc.operation_send_message(b'\x79\x57', _EMPTY_FLAGS),
cygrpc.operation_send_status_from_server(
_EMPTY_METADATA, cygrpc.StatusCode.ok, b'test details!',
_EMPTY_FLAGS),
)
self._call.start_batch(
cygrpc.Operations(operations), _SERVER_COMPLETE_CALL_TAG)
self._completion_queue.poll()
self._completion_queue.poll()
def _serve(state, server, server_completion_queue, thread_pool):
for _ in range(test_constants.RPC_CONCURRENCY):
call_completion_queue = cygrpc.CompletionQueue()
server.request_call(
call_completion_queue, server_completion_queue, _REQUEST_CALL_TAG)
rpc_event = server_completion_queue.poll()
thread_pool.submit(_Handler(state, call_completion_queue, rpc_event))
with state.condition:
state.handled_rpcs += 1
if test_constants.RPC_CONCURRENCY <= state.handled_rpcs:
state.condition.notify_all()
server_completion_queue.poll()
class _QueueDriver(object):
def __init__(self, condition, completion_queue, due):
self._condition = condition
self._completion_queue = completion_queue
self._due = due
self._events = []
self._returned = False
def start(self):
def in_thread():
while True:
event = self._completion_queue.poll()
with self._condition:
self._events.append(event)
self._due.remove(event.tag)
self._condition.notify_all()
if not self._due:
self._returned = True
return
thread = threading.Thread(target=in_thread)
thread.start()
def events(self, at_least):
with self._condition:
while len(self._events) < at_least:
self._condition.wait()
return tuple(self._events)
class CancelManyCallsTest(unittest.TestCase):
def testCancelManyCalls(self):
server_thread_pool = logging_pool.pool(test_constants.THREAD_CONCURRENCY)
server_completion_queue = cygrpc.CompletionQueue()
server = cygrpc.Server()
server.register_completion_queue(server_completion_queue)
port = server.add_http2_port('[::]:0')
server.start()
channel = cygrpc.Channel('localhost:{}'.format(port))
state = _State()
server_thread_args = (
state, server, server_completion_queue, server_thread_pool,)
server_thread = threading.Thread(target=_serve, args=server_thread_args)
server_thread.start()
client_condition = threading.Condition()
client_due = set()
client_completion_queue = cygrpc.CompletionQueue()
client_driver = _QueueDriver(
client_condition, client_completion_queue, client_due)
client_driver.start()
with client_condition:
client_calls = []
for index in range(test_constants.RPC_CONCURRENCY):
client_call = channel.create_call(
None, _EMPTY_FLAGS, client_completion_queue, b'/twinkies', None,
_INFINITE_FUTURE)
operations = (
cygrpc.operation_send_initial_metadata(
_EMPTY_METADATA, _EMPTY_FLAGS),
cygrpc.operation_send_message(b'\x45\x56', _EMPTY_FLAGS),
cygrpc.operation_send_close_from_client(_EMPTY_FLAGS),
cygrpc.operation_receive_initial_metadata(_EMPTY_FLAGS),
cygrpc.operation_receive_message(_EMPTY_FLAGS),
cygrpc.operation_receive_status_on_client(_EMPTY_FLAGS),
)
tag = 'client_complete_call_{0:04d}_tag'.format(index)
client_call.start_batch(cygrpc.Operations(operations), tag)
client_due.add(tag)
client_calls.append(client_call)
with state.condition:
while True:
if state.parked_handlers < test_constants.THREAD_CONCURRENCY:
state.condition.wait()
elif state.handled_rpcs < test_constants.RPC_CONCURRENCY:
state.condition.wait()
else:
state.handlers_released = True
state.condition.notify_all()
break
client_driver.events(
test_constants.RPC_CONCURRENCY * _SUCCESS_CALL_FRACTION)
with client_condition:
for client_call in client_calls:
client_call.cancel()
with state.condition:
server.shutdown(server_completion_queue, _SERVER_SHUTDOWN_TAG)
if __name__ == '__main__':
unittest.main(verbosity=2)