blob: a84e02a79a4a477114688292bdf78bcd06ed1aa5 [file] [log] [blame]
# Copyright 2015 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Examples of Python implementations of the stock.proto Stock service."""
from grpc.framework.common import cardinality
from grpc.framework.foundation import abandonment
from grpc.framework.foundation import stream
from tests.unit.framework.common import test_constants
from tests.unit.framework.interfaces.face import _service
from tests.unit._junkdrawer import stock_pb2
_STOCK_GROUP_NAME = 'Stock'
_SYMBOL_FORMAT = 'test symbol:%03d'
# A test-appropriate security-pricing function. :-P
_price = lambda symbol_name: float(hash(symbol_name) % 4096)
def _get_last_trade_price(stock_request, stock_reply_callback, control, active):
"""A unary-request, unary-response test method."""
control.control()
if active():
stock_reply_callback(
stock_pb2.StockReply(
symbol=stock_request.symbol,
price=_price(stock_request.symbol)))
else:
raise abandonment.Abandoned()
def _get_last_trade_price_multiple(stock_reply_consumer, control, active):
"""A stream-request, stream-response test method."""
def stock_reply_for_stock_request(stock_request):
control.control()
if active():
return stock_pb2.StockReply(
symbol=stock_request.symbol, price=_price(stock_request.symbol))
else:
raise abandonment.Abandoned()
class StockRequestConsumer(stream.Consumer):
def consume(self, stock_request):
stock_reply_consumer.consume(
stock_reply_for_stock_request(stock_request))
def terminate(self):
control.control()
stock_reply_consumer.terminate()
def consume_and_terminate(self, stock_request):
stock_reply_consumer.consume_and_terminate(
stock_reply_for_stock_request(stock_request))
return StockRequestConsumer()
def _watch_future_trades(stock_request, stock_reply_consumer, control, active):
"""A unary-request, stream-response test method."""
base_price = _price(stock_request.symbol)
for index in range(stock_request.num_trades_to_watch):
control.control()
if active():
stock_reply_consumer.consume(
stock_pb2.StockReply(
symbol=stock_request.symbol, price=base_price + index))
else:
raise abandonment.Abandoned()
stock_reply_consumer.terminate()
def _get_highest_trade_price(stock_reply_callback, control, active):
"""A stream-request, unary-response test method."""
class StockRequestConsumer(stream.Consumer):
"""Keeps an ongoing record of the most valuable symbol yet consumed."""
def __init__(self):
self._symbol = None
self._price = None
def consume(self, stock_request):
control.control()
if active():
if self._price is None:
self._symbol = stock_request.symbol
self._price = _price(stock_request.symbol)
else:
candidate_price = _price(stock_request.symbol)
if self._price < candidate_price:
self._symbol = stock_request.symbol
self._price = candidate_price
def terminate(self):
control.control()
if active():
if self._symbol is None:
raise ValueError()
else:
stock_reply_callback(
stock_pb2.StockReply(
symbol=self._symbol, price=self._price))
self._symbol = None
self._price = None
def consume_and_terminate(self, stock_request):
control.control()
if active():
if self._price is None:
stock_reply_callback(
stock_pb2.StockReply(
symbol=stock_request.symbol,
price=_price(stock_request.symbol)))
else:
candidate_price = _price(stock_request.symbol)
if self._price < candidate_price:
stock_reply_callback(
stock_pb2.StockReply(
symbol=stock_request.symbol,
price=candidate_price))
else:
stock_reply_callback(
stock_pb2.StockReply(
symbol=self._symbol, price=self._price))
self._symbol = None
self._price = None
return StockRequestConsumer()
class GetLastTradePrice(_service.UnaryUnaryTestMethodImplementation):
"""GetLastTradePrice for use in tests."""
def group(self):
return _STOCK_GROUP_NAME
def name(self):
return 'GetLastTradePrice'
def cardinality(self):
return cardinality.Cardinality.UNARY_UNARY
def request_class(self):
return stock_pb2.StockRequest
def response_class(self):
return stock_pb2.StockReply
def serialize_request(self, request):
return request.SerializeToString()
def deserialize_request(self, serialized_request):
return stock_pb2.StockRequest.FromString(serialized_request)
def serialize_response(self, response):
return response.SerializeToString()
def deserialize_response(self, serialized_response):
return stock_pb2.StockReply.FromString(serialized_response)
def service(self, request, response_callback, context, control):
_get_last_trade_price(request, response_callback, control,
context.is_active)
class GetLastTradePriceMessages(_service.UnaryUnaryTestMessages):
def __init__(self):
self._index = 0
def request(self):
symbol = _SYMBOL_FORMAT % self._index
self._index += 1
return stock_pb2.StockRequest(symbol=symbol)
def verify(self, request, response, test_case):
test_case.assertEqual(request.symbol, response.symbol)
test_case.assertEqual(_price(request.symbol), response.price)
class GetLastTradePriceMultiple(_service.StreamStreamTestMethodImplementation):
"""GetLastTradePriceMultiple for use in tests."""
def group(self):
return _STOCK_GROUP_NAME
def name(self):
return 'GetLastTradePriceMultiple'
def cardinality(self):
return cardinality.Cardinality.STREAM_STREAM
def request_class(self):
return stock_pb2.StockRequest
def response_class(self):
return stock_pb2.StockReply
def serialize_request(self, request):
return request.SerializeToString()
def deserialize_request(self, serialized_request):
return stock_pb2.StockRequest.FromString(serialized_request)
def serialize_response(self, response):
return response.SerializeToString()
def deserialize_response(self, serialized_response):
return stock_pb2.StockReply.FromString(serialized_response)
def service(self, response_consumer, context, control):
return _get_last_trade_price_multiple(response_consumer, control,
context.is_active)
class GetLastTradePriceMultipleMessages(_service.StreamStreamTestMessages):
"""Pairs of message streams for use with GetLastTradePriceMultiple."""
def __init__(self):
self._index = 0
def requests(self):
base_index = self._index
self._index += 1
return [
stock_pb2.StockRequest(symbol=_SYMBOL_FORMAT % (base_index + index))
for index in range(test_constants.STREAM_LENGTH)
]
def verify(self, requests, responses, test_case):
test_case.assertEqual(len(requests), len(responses))
for stock_request, stock_reply in zip(requests, responses):
test_case.assertEqual(stock_request.symbol, stock_reply.symbol)
test_case.assertEqual(
_price(stock_request.symbol), stock_reply.price)
class WatchFutureTrades(_service.UnaryStreamTestMethodImplementation):
"""WatchFutureTrades for use in tests."""
def group(self):
return _STOCK_GROUP_NAME
def name(self):
return 'WatchFutureTrades'
def cardinality(self):
return cardinality.Cardinality.UNARY_STREAM
def request_class(self):
return stock_pb2.StockRequest
def response_class(self):
return stock_pb2.StockReply
def serialize_request(self, request):
return request.SerializeToString()
def deserialize_request(self, serialized_request):
return stock_pb2.StockRequest.FromString(serialized_request)
def serialize_response(self, response):
return response.SerializeToString()
def deserialize_response(self, serialized_response):
return stock_pb2.StockReply.FromString(serialized_response)
def service(self, request, response_consumer, context, control):
_watch_future_trades(request, response_consumer, control,
context.is_active)
class WatchFutureTradesMessages(_service.UnaryStreamTestMessages):
"""Pairs of a single request message and a sequence of response messages."""
def __init__(self):
self._index = 0
def request(self):
symbol = _SYMBOL_FORMAT % self._index
self._index += 1
return stock_pb2.StockRequest(
symbol=symbol, num_trades_to_watch=test_constants.STREAM_LENGTH)
def verify(self, request, responses, test_case):
test_case.assertEqual(test_constants.STREAM_LENGTH, len(responses))
base_price = _price(request.symbol)
for index, response in enumerate(responses):
test_case.assertEqual(base_price + index, response.price)
class GetHighestTradePrice(_service.StreamUnaryTestMethodImplementation):
"""GetHighestTradePrice for use in tests."""
def group(self):
return _STOCK_GROUP_NAME
def name(self):
return 'GetHighestTradePrice'
def cardinality(self):
return cardinality.Cardinality.STREAM_UNARY
def request_class(self):
return stock_pb2.StockRequest
def response_class(self):
return stock_pb2.StockReply
def serialize_request(self, request):
return request.SerializeToString()
def deserialize_request(self, serialized_request):
return stock_pb2.StockRequest.FromString(serialized_request)
def serialize_response(self, response):
return response.SerializeToString()
def deserialize_response(self, serialized_response):
return stock_pb2.StockReply.FromString(serialized_response)
def service(self, response_callback, context, control):
return _get_highest_trade_price(response_callback, control,
context.is_active)
class GetHighestTradePriceMessages(_service.StreamUnaryTestMessages):
def requests(self):
return [
stock_pb2.StockRequest(symbol=_SYMBOL_FORMAT % index)
for index in range(test_constants.STREAM_LENGTH)
]
def verify(self, requests, response, test_case):
price = None
symbol = None
for stock_request in requests:
current_symbol = stock_request.symbol
current_price = _price(current_symbol)
if price is None or price < current_price:
price = current_price
symbol = current_symbol
test_case.assertEqual(price, response.price)
test_case.assertEqual(symbol, response.symbol)
class StockTestService(_service.TestService):
"""A corpus of test data with one method of each RPC cardinality."""
def unary_unary_scenarios(self):
return {
(_STOCK_GROUP_NAME, 'GetLastTradePrice'):
(GetLastTradePrice(), [GetLastTradePriceMessages()]),
}
def unary_stream_scenarios(self):
return {
(_STOCK_GROUP_NAME, 'WatchFutureTrades'):
(WatchFutureTrades(), [WatchFutureTradesMessages()]),
}
def stream_unary_scenarios(self):
return {
(_STOCK_GROUP_NAME, 'GetHighestTradePrice'):
(GetHighestTradePrice(), [GetHighestTradePriceMessages()])
}
def stream_stream_scenarios(self):
return {
(_STOCK_GROUP_NAME, 'GetLastTradePriceMultiple'):
(GetLastTradePriceMultiple(),
[GetLastTradePriceMultipleMessages()]),
}
STOCK_TEST_SERVICE = StockTestService()