blob: 9045ff58a0f675dcbbbb24575769950946bfe674 [file] [log] [blame]
# Copyright 2015 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import time
import threading
import unittest
import platform
from grpc._cython import cygrpc
from tests.unit._cython import test_utilities
from tests.unit import test_common
from tests.unit import resources
_SSL_HOST_OVERRIDE = b'foo.test.google.fr'
_CALL_CREDENTIALS_METADATA_KEY = 'call-creds-key'
_CALL_CREDENTIALS_METADATA_VALUE = 'call-creds-value'
_EMPTY_FLAGS = 0
def _metadata_plugin(context, callback):
callback(((
_CALL_CREDENTIALS_METADATA_KEY,
_CALL_CREDENTIALS_METADATA_VALUE,
),), cygrpc.StatusCode.ok, b'')
class TypeSmokeTest(unittest.TestCase):
def testCompletionQueueUpDown(self):
completion_queue = cygrpc.CompletionQueue()
del completion_queue
def testServerUpDown(self):
server = cygrpc.Server(set([
(
b'grpc.so_reuseport',
0,
),
]))
del server
def testChannelUpDown(self):
channel = cygrpc.Channel(b'[::]:0', None)
del channel
def test_metadata_plugin_call_credentials_up_down(self):
cygrpc.MetadataPluginCallCredentials(_metadata_plugin,
b'test plugin name!')
def testServerStartNoExplicitShutdown(self):
server = cygrpc.Server([
(
b'grpc.so_reuseport',
0,
),
])
completion_queue = cygrpc.CompletionQueue()
server.register_completion_queue(completion_queue)
port = server.add_http2_port(b'[::]:0')
self.assertIsInstance(port, int)
server.start()
del server
def testServerStartShutdown(self):
completion_queue = cygrpc.CompletionQueue()
server = cygrpc.Server([
(
b'grpc.so_reuseport',
0,
),
])
server.add_http2_port(b'[::]:0')
server.register_completion_queue(completion_queue)
server.start()
shutdown_tag = object()
server.shutdown(completion_queue, shutdown_tag)
event = completion_queue.poll()
self.assertEqual(cygrpc.CompletionType.operation_complete,
event.completion_type)
self.assertIs(shutdown_tag, event.tag)
del server
del completion_queue
class ServerClientMixin(object):
def setUpMixin(self, server_credentials, client_credentials, host_override):
self.server_completion_queue = cygrpc.CompletionQueue()
self.server = cygrpc.Server([
(
b'grpc.so_reuseport',
0,
),
])
self.server.register_completion_queue(self.server_completion_queue)
if server_credentials:
self.port = self.server.add_http2_port(b'[::]:0',
server_credentials)
else:
self.port = self.server.add_http2_port(b'[::]:0')
self.server.start()
self.client_completion_queue = cygrpc.CompletionQueue()
if client_credentials:
client_channel_arguments = ((
cygrpc.ChannelArgKey.ssl_target_name_override,
host_override,
),)
self.client_channel = cygrpc.Channel('localhost:{}'.format(
self.port).encode(), client_channel_arguments,
client_credentials)
else:
self.client_channel = cygrpc.Channel('localhost:{}'.format(
self.port).encode(), set())
if host_override:
self.host_argument = None # default host
self.expected_host = host_override
else:
# arbitrary host name necessitating no further identification
self.host_argument = b'hostess'
self.expected_host = self.host_argument
def tearDownMixin(self):
del self.server
del self.client_completion_queue
del self.server_completion_queue
def _perform_operations(self, operations, call, queue, deadline,
description):
"""Perform the list of operations with given call, queue, and deadline.
Invocation errors are reported with as an exception with `description` in
the message. Performs the operations asynchronously, returning a future.
"""
def performer():
tag = object()
try:
call_result = call.start_client_batch(operations, tag)
self.assertEqual(cygrpc.CallError.ok, call_result)
event = queue.poll(deadline=deadline)
self.assertEqual(cygrpc.CompletionType.operation_complete,
event.completion_type)
self.assertTrue(event.success)
self.assertIs(tag, event.tag)
except Exception as error:
raise Exception("Error in '{}': {}".format(
description, error.message))
return event
return test_utilities.SimpleFuture(performer)
def test_echo(self):
DEADLINE = time.time() + 5
DEADLINE_TOLERANCE = 0.25
CLIENT_METADATA_ASCII_KEY = 'key'
CLIENT_METADATA_ASCII_VALUE = 'val'
CLIENT_METADATA_BIN_KEY = 'key-bin'
CLIENT_METADATA_BIN_VALUE = b'\0' * 1000
SERVER_INITIAL_METADATA_KEY = 'init_me_me_me'
SERVER_INITIAL_METADATA_VALUE = 'whodawha?'
SERVER_TRAILING_METADATA_KEY = 'california_is_in_a_drought'
SERVER_TRAILING_METADATA_VALUE = 'zomg it is'
SERVER_STATUS_CODE = cygrpc.StatusCode.ok
SERVER_STATUS_DETAILS = 'our work is never over'
REQUEST = b'in death a member of project mayhem has a name'
RESPONSE = b'his name is robert paulson'
METHOD = b'twinkies'
server_request_tag = object()
request_call_result = self.server.request_call(
self.server_completion_queue, self.server_completion_queue,
server_request_tag)
self.assertEqual(cygrpc.CallError.ok, request_call_result)
client_call_tag = object()
client_call = self.client_channel.create_call(
None, 0, self.client_completion_queue, METHOD, self.host_argument,
DEADLINE)
client_initial_metadata = (
(
CLIENT_METADATA_ASCII_KEY,
CLIENT_METADATA_ASCII_VALUE,
),
(
CLIENT_METADATA_BIN_KEY,
CLIENT_METADATA_BIN_VALUE,
),
)
client_start_batch_result = client_call.start_client_batch([
cygrpc.SendInitialMetadataOperation(client_initial_metadata,
_EMPTY_FLAGS),
cygrpc.SendMessageOperation(REQUEST, _EMPTY_FLAGS),
cygrpc.SendCloseFromClientOperation(_EMPTY_FLAGS),
cygrpc.ReceiveInitialMetadataOperation(_EMPTY_FLAGS),
cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS),
cygrpc.ReceiveStatusOnClientOperation(_EMPTY_FLAGS),
], client_call_tag)
self.assertEqual(cygrpc.CallError.ok, client_start_batch_result)
client_event_future = test_utilities.CompletionQueuePollFuture(
self.client_completion_queue, DEADLINE)
request_event = self.server_completion_queue.poll(deadline=DEADLINE)
self.assertEqual(cygrpc.CompletionType.operation_complete,
request_event.completion_type)
self.assertIsInstance(request_event.call, cygrpc.Call)
self.assertIs(server_request_tag, request_event.tag)
self.assertTrue(
test_common.metadata_transmitted(client_initial_metadata,
request_event.invocation_metadata))
self.assertEqual(METHOD, request_event.call_details.method)
self.assertEqual(self.expected_host, request_event.call_details.host)
self.assertLess(
abs(DEADLINE - request_event.call_details.deadline),
DEADLINE_TOLERANCE)
server_call_tag = object()
server_call = request_event.call
server_initial_metadata = ((
SERVER_INITIAL_METADATA_KEY,
SERVER_INITIAL_METADATA_VALUE,
),)
server_trailing_metadata = ((
SERVER_TRAILING_METADATA_KEY,
SERVER_TRAILING_METADATA_VALUE,
),)
server_start_batch_result = server_call.start_server_batch([
cygrpc.SendInitialMetadataOperation(server_initial_metadata,
_EMPTY_FLAGS),
cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS),
cygrpc.SendMessageOperation(RESPONSE, _EMPTY_FLAGS),
cygrpc.ReceiveCloseOnServerOperation(_EMPTY_FLAGS),
cygrpc.SendStatusFromServerOperation(
server_trailing_metadata, SERVER_STATUS_CODE,
SERVER_STATUS_DETAILS, _EMPTY_FLAGS)
], server_call_tag)
self.assertEqual(cygrpc.CallError.ok, server_start_batch_result)
server_event = self.server_completion_queue.poll(deadline=DEADLINE)
client_event = client_event_future.result()
self.assertEqual(6, len(client_event.batch_operations))
found_client_op_types = set()
for client_result in client_event.batch_operations:
# we expect each op type to be unique
self.assertNotIn(client_result.type(), found_client_op_types)
found_client_op_types.add(client_result.type())
if client_result.type(
) == cygrpc.OperationType.receive_initial_metadata:
self.assertTrue(
test_common.metadata_transmitted(
server_initial_metadata,
client_result.initial_metadata()))
elif client_result.type() == cygrpc.OperationType.receive_message:
self.assertEqual(RESPONSE, client_result.message())
elif client_result.type(
) == cygrpc.OperationType.receive_status_on_client:
self.assertTrue(
test_common.metadata_transmitted(
server_trailing_metadata,
client_result.trailing_metadata()))
self.assertEqual(SERVER_STATUS_DETAILS, client_result.details())
self.assertEqual(SERVER_STATUS_CODE, client_result.code())
self.assertEqual(
set([
cygrpc.OperationType.send_initial_metadata,
cygrpc.OperationType.send_message,
cygrpc.OperationType.send_close_from_client,
cygrpc.OperationType.receive_initial_metadata,
cygrpc.OperationType.receive_message,
cygrpc.OperationType.receive_status_on_client
]), found_client_op_types)
self.assertEqual(5, len(server_event.batch_operations))
found_server_op_types = set()
for server_result in server_event.batch_operations:
self.assertNotIn(client_result.type(), found_server_op_types)
found_server_op_types.add(server_result.type())
if server_result.type() == cygrpc.OperationType.receive_message:
self.assertEqual(REQUEST, server_result.message())
elif server_result.type(
) == cygrpc.OperationType.receive_close_on_server:
self.assertFalse(server_result.cancelled())
self.assertEqual(
set([
cygrpc.OperationType.send_initial_metadata,
cygrpc.OperationType.receive_message,
cygrpc.OperationType.send_message,
cygrpc.OperationType.receive_close_on_server,
cygrpc.OperationType.send_status_from_server
]), found_server_op_types)
del client_call
del server_call
def test6522(self):
DEADLINE = time.time() + 5
DEADLINE_TOLERANCE = 0.25
METHOD = b'twinkies'
empty_metadata = ()
server_request_tag = object()
self.server.request_call(self.server_completion_queue,
self.server_completion_queue,
server_request_tag)
client_call = self.client_channel.create_call(
None, 0, self.client_completion_queue, METHOD, self.host_argument,
DEADLINE)
# Prologue
def perform_client_operations(operations, description):
return self._perform_operations(operations, client_call,
self.client_completion_queue,
DEADLINE, description)
client_event_future = perform_client_operations([
cygrpc.SendInitialMetadataOperation(empty_metadata, _EMPTY_FLAGS),
cygrpc.ReceiveInitialMetadataOperation(_EMPTY_FLAGS),
], "Client prologue")
request_event = self.server_completion_queue.poll(deadline=DEADLINE)
server_call = request_event.call
def perform_server_operations(operations, description):
return self._perform_operations(operations, server_call,
self.server_completion_queue,
DEADLINE, description)
server_event_future = perform_server_operations([
cygrpc.SendInitialMetadataOperation(empty_metadata, _EMPTY_FLAGS),
], "Server prologue")
client_event_future.result() # force completion
server_event_future.result()
# Messaging
for _ in range(10):
client_event_future = perform_client_operations([
cygrpc.SendMessageOperation(b'', _EMPTY_FLAGS),
cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS),
], "Client message")
server_event_future = perform_server_operations([
cygrpc.SendMessageOperation(b'', _EMPTY_FLAGS),
cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS),
], "Server receive")
client_event_future.result() # force completion
server_event_future.result()
# Epilogue
client_event_future = perform_client_operations([
cygrpc.SendCloseFromClientOperation(_EMPTY_FLAGS),
cygrpc.ReceiveStatusOnClientOperation(_EMPTY_FLAGS)
], "Client epilogue")
server_event_future = perform_server_operations([
cygrpc.ReceiveCloseOnServerOperation(_EMPTY_FLAGS),
cygrpc.SendStatusFromServerOperation(
empty_metadata, cygrpc.StatusCode.ok, b'', _EMPTY_FLAGS)
], "Server epilogue")
client_event_future.result() # force completion
server_event_future.result()
class InsecureServerInsecureClient(unittest.TestCase, ServerClientMixin):
def setUp(self):
self.setUpMixin(None, None, None)
def tearDown(self):
self.tearDownMixin()
class SecureServerSecureClient(unittest.TestCase, ServerClientMixin):
def setUp(self):
server_credentials = cygrpc.server_credentials_ssl(
None, [
cygrpc.SslPemKeyCertPair(resources.private_key(),
resources.certificate_chain())
], False)
client_credentials = cygrpc.SSLChannelCredentials(
resources.test_root_certificates(), None, None)
self.setUpMixin(server_credentials, client_credentials,
_SSL_HOST_OVERRIDE)
def tearDown(self):
self.tearDownMixin()
if __name__ == '__main__':
unittest.main(verbosity=2)