| import asyncio |
| import unittest |
| |
| from test.test_asyncio import functional as func_tests |
| |
| |
| def tearDownModule(): |
| asyncio.set_event_loop_policy(None) |
| |
| |
| class ReceiveStuffProto(asyncio.BufferedProtocol): |
| def __init__(self, cb, con_lost_fut): |
| self.cb = cb |
| self.con_lost_fut = con_lost_fut |
| |
| def get_buffer(self, sizehint): |
| self.buffer = bytearray(100) |
| return self.buffer |
| |
| def buffer_updated(self, nbytes): |
| self.cb(self.buffer[:nbytes]) |
| |
| def connection_lost(self, exc): |
| if exc is None: |
| self.con_lost_fut.set_result(None) |
| else: |
| self.con_lost_fut.set_exception(exc) |
| |
| |
| class BaseTestBufferedProtocol(func_tests.FunctionalTestCaseMixin): |
| |
| def new_loop(self): |
| raise NotImplementedError |
| |
| def test_buffered_proto_create_connection(self): |
| |
| NOISE = b'12345678+' * 1024 |
| |
| async def client(addr): |
| data = b'' |
| |
| def on_buf(buf): |
| nonlocal data |
| data += buf |
| if data == NOISE: |
| tr.write(b'1') |
| |
| conn_lost_fut = self.loop.create_future() |
| |
| tr, pr = await self.loop.create_connection( |
| lambda: ReceiveStuffProto(on_buf, conn_lost_fut), *addr) |
| |
| await conn_lost_fut |
| |
| async def on_server_client(reader, writer): |
| writer.write(NOISE) |
| await reader.readexactly(1) |
| writer.close() |
| await writer.wait_closed() |
| |
| srv = self.loop.run_until_complete( |
| asyncio.start_server( |
| on_server_client, '127.0.0.1', 0)) |
| |
| addr = srv.sockets[0].getsockname() |
| self.loop.run_until_complete( |
| asyncio.wait_for(client(addr), 5)) |
| |
| srv.close() |
| self.loop.run_until_complete(srv.wait_closed()) |
| |
| |
| class BufferedProtocolSelectorTests(BaseTestBufferedProtocol, |
| unittest.TestCase): |
| |
| def new_loop(self): |
| return asyncio.SelectorEventLoop() |
| |
| |
| @unittest.skipUnless(hasattr(asyncio, 'ProactorEventLoop'), 'Windows only') |
| class BufferedProtocolProactorTests(BaseTestBufferedProtocol, |
| unittest.TestCase): |
| |
| def new_loop(self): |
| return asyncio.ProactorEventLoop() |
| |
| |
| if __name__ == '__main__': |
| unittest.main() |