| # Copyright 2017 gRPC authors. |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| """Common utilities for tests of the Cython layer of gRPC Python.""" |
| |
| import collections |
| import threading |
| |
| from grpc._cython import cygrpc |
| |
| RPC_COUNT = 4000 |
| |
| EMPTY_FLAGS = 0 |
| |
| INVOCATION_METADATA = ( |
| ('client-md-key', 'client-md-key'), |
| ('client-md-key-bin', b'\x00\x01' * 3000), |
| ) |
| |
| INITIAL_METADATA = ( |
| ('server-initial-md-key', 'server-initial-md-value'), |
| ('server-initial-md-key-bin', b'\x00\x02' * 3000), |
| ) |
| |
| TRAILING_METADATA = ( |
| ('server-trailing-md-key', 'server-trailing-md-value'), |
| ('server-trailing-md-key-bin', b'\x00\x03' * 3000), |
| ) |
| |
| |
| class QueueDriver(object): |
| |
| def __init__(self, condition, completion_queue): |
| self._condition = condition |
| self._completion_queue = completion_queue |
| self._due = collections.defaultdict(int) |
| self._events = collections.defaultdict(list) |
| |
| def add_due(self, tags): |
| if not self._due: |
| |
| def in_thread(): |
| while True: |
| event = self._completion_queue.poll() |
| with self._condition: |
| self._events[event.tag].append(event) |
| self._due[event.tag] -= 1 |
| self._condition.notify_all() |
| if self._due[event.tag] <= 0: |
| self._due.pop(event.tag) |
| if not self._due: |
| return |
| |
| thread = threading.Thread(target=in_thread) |
| thread.start() |
| for tag in tags: |
| self._due[tag] += 1 |
| |
| def event_with_tag(self, tag): |
| with self._condition: |
| while True: |
| if self._events[tag]: |
| return self._events[tag].pop(0) |
| else: |
| self._condition.wait() |
| |
| |
| def execute_many_times(behavior): |
| return tuple(behavior() for _ in range(RPC_COUNT)) |
| |
| |
| class OperationResult( |
| collections.namedtuple('OperationResult', ( |
| 'start_batch_result', |
| 'completion_type', |
| 'success', |
| ))): |
| pass |
| |
| |
| SUCCESSFUL_OPERATION_RESULT = OperationResult( |
| cygrpc.CallError.ok, cygrpc.CompletionType.operation_complete, True) |
| |
| |
| class RpcTest(object): |
| |
| def setUp(self): |
| self.server_completion_queue = cygrpc.CompletionQueue() |
| self.server = cygrpc.Server([(b'grpc.so_reuseport', 0)]) |
| self.server.register_completion_queue(self.server_completion_queue) |
| port = self.server.add_http2_port(b'[::]:0') |
| self.server.start() |
| self.channel = cygrpc.Channel('localhost:{}'.format(port).encode(), [], |
| None) |
| |
| self._server_shutdown_tag = 'server_shutdown_tag' |
| self.server_condition = threading.Condition() |
| self.server_driver = QueueDriver(self.server_condition, |
| self.server_completion_queue) |
| with self.server_condition: |
| self.server_driver.add_due({ |
| self._server_shutdown_tag, |
| }) |
| |
| self.client_condition = threading.Condition() |
| self.client_completion_queue = cygrpc.CompletionQueue() |
| self.client_driver = QueueDriver(self.client_condition, |
| self.client_completion_queue) |
| |
| def tearDown(self): |
| self.server.shutdown(self.server_completion_queue, |
| self._server_shutdown_tag) |
| self.server.cancel_all_calls() |