| """ |
| Test suite for SocketServer.py. |
| """ |
| |
| import os |
| import socket |
| import errno |
| import imp |
| import select |
| import time |
| import threading |
| from functools import wraps |
| import unittest |
| import SocketServer |
| |
| import test.test_support |
| from test.test_support import reap_children, verbose, TestSkipped |
| from test.test_support import TESTFN as TEST_FILE |
| |
| test.test_support.requires("network") |
| |
| NREQ = 3 |
| DELAY = 0.5 |
| TEST_STR = b"hello world\n" |
| HOST = "localhost" |
| |
| HAVE_UNIX_SOCKETS = hasattr(socket, "AF_UNIX") |
| HAVE_FORKING = hasattr(os, "fork") and os.name != "os2" |
| |
| |
| class MyMixinHandler: |
| def handle(self): |
| time.sleep(DELAY) |
| line = self.rfile.readline() |
| time.sleep(DELAY) |
| self.wfile.write(line) |
| |
| |
| def receive(sock, n, timeout=20): |
| r, w, x = select.select([sock], [], [], timeout) |
| if sock in r: |
| return sock.recv(n) |
| else: |
| raise RuntimeError("timed out on %r" % (sock,)) |
| |
| |
| class MyStreamHandler(MyMixinHandler, SocketServer.StreamRequestHandler): |
| pass |
| |
| class MyDatagramHandler(MyMixinHandler, |
| SocketServer.DatagramRequestHandler): |
| pass |
| |
| if HAVE_UNIX_SOCKETS: |
| class ForkingUnixStreamServer(SocketServer.ForkingMixIn, |
| SocketServer.UnixStreamServer): |
| pass |
| |
| class ForkingUnixDatagramServer(SocketServer.ForkingMixIn, |
| SocketServer.UnixDatagramServer): |
| pass |
| |
| |
| class MyMixinServer: |
| def serve_a_few(self): |
| for i in range(NREQ): |
| self.handle_request() |
| |
| def handle_error(self, request, client_address): |
| self.close_request(request) |
| self.server_close() |
| raise |
| |
| def receive(sock, n, timeout=20): |
| r, w, x = select.select([sock], [], [], timeout) |
| if sock in r: |
| return sock.recv(n) |
| else: |
| raise RuntimeError("timed out on %r" % (sock,)) |
| |
| def testdgram(proto, addr): |
| s = socket.socket(proto, socket.SOCK_DGRAM) |
| s.sendto(teststring, addr) |
| buf = data = receive(s, 100) |
| while data and b'\n' not in buf: |
| data = receive(s, 100) |
| buf += data |
| verify(buf == teststring) |
| s.close() |
| |
| def teststream(proto, addr): |
| s = socket.socket(proto, socket.SOCK_STREAM) |
| s.connect(addr) |
| s.sendall(teststring) |
| buf = data = receive(s, 100) |
| while data and b'\n' not in buf: |
| data = receive(s, 100) |
| buf += data |
| verify(buf == teststring) |
| s.close() |
| |
| class ServerThread(threading.Thread): |
| def __init__(self, addr, svrcls, hdlrcls): |
| threading.Thread.__init__(self) |
| self.__addr = addr |
| self.__svrcls = svrcls |
| self.__hdlrcls = hdlrcls |
| self.ready = threading.Event() |
| |
| def run(self): |
| class svrcls(MyMixinServer, self.__svrcls): |
| pass |
| if verbose: print("thread: creating server") |
| svr = svrcls(self.__addr, self.__hdlrcls) |
| # pull the address out of the server in case it changed |
| # this can happen if another process is using the port |
| addr = svr.server_address |
| if addr: |
| self.__addr = addr |
| if self.__addr != svr.socket.getsockname(): |
| raise RuntimeError('server_address was %s, expected %s' % |
| (self.__addr, svr.socket.getsockname())) |
| self.ready.set() |
| if verbose: print("thread: serving three times") |
| svr.serve_a_few() |
| if verbose: print("thread: done") |
| |
| |
| class ForgivingTCPServer(SocketServer.TCPServer): |
| # prevent errors if another process is using the port we want |
| def server_bind(self): |
| host, default_port = self.server_address |
| # this code shamelessly stolen from test.test_support |
| # the ports were changed to protect the innocent |
| import sys |
| for port in [default_port, 3434, 8798, 23833]: |
| try: |
| self.server_address = host, port |
| SocketServer.TCPServer.server_bind(self) |
| break |
| except socket.error as e: |
| (err, msg) = e |
| if err != errno.EADDRINUSE: |
| raise |
| print(' WARNING: failed to listen on port %d, trying another' % port, file=sys.__stderr__) |
| |
| class SocketServerTest(unittest.TestCase): |
| """Test all socket servers.""" |
| |
| def setUp(self): |
| self.port_seed = 0 |
| self.test_files = [] |
| |
| def tearDown(self): |
| time.sleep(DELAY) |
| reap_children() |
| |
| for fn in self.test_files: |
| try: |
| os.remove(fn) |
| except os.error: |
| pass |
| self.test_files[:] = [] |
| |
| def pickport(self): |
| self.port_seed += 1 |
| return 10000 + (os.getpid() % 1000)*10 + self.port_seed |
| |
| def pickaddr(self, proto): |
| if proto == socket.AF_INET: |
| return (HOST, self.pickport()) |
| else: |
| fn = TEST_FILE + str(self.pickport()) |
| if os.name == 'os2': |
| # AF_UNIX socket names on OS/2 require a specific prefix |
| # which can't include a drive letter and must also use |
| # backslashes as directory separators |
| if fn[1] == ':': |
| fn = fn[2:] |
| if fn[0] in (os.sep, os.altsep): |
| fn = fn[1:] |
| fn = os.path.join('\socket', fn) |
| if os.sep == '/': |
| fn = fn.replace(os.sep, os.altsep) |
| else: |
| fn = fn.replace(os.altsep, os.sep) |
| self.test_files.append(fn) |
| return fn |
| |
| def run_servers(self, proto, servers, hdlrcls, testfunc): |
| for svrcls in servers: |
| addr = self.pickaddr(proto) |
| if verbose: |
| print("ADDR =", addr) |
| print("CLASS =", svrcls) |
| t = ServerThread(addr, svrcls, hdlrcls) |
| if verbose: print("server created") |
| t.start() |
| if verbose: print("server running") |
| for i in range(NREQ): |
| t.ready.wait(10*DELAY) |
| self.assert_(t.ready.isSet(), |
| "Server not ready within a reasonable time") |
| if verbose: print("test client", i) |
| testfunc(proto, addr) |
| if verbose: print("waiting for server") |
| t.join() |
| if verbose: print("done") |
| |
| def stream_examine(self, proto, addr): |
| s = socket.socket(proto, socket.SOCK_STREAM) |
| s.connect(addr) |
| s.sendall(TEST_STR) |
| buf = data = receive(s, 100) |
| while data and b'\n' not in buf: |
| data = receive(s, 100) |
| buf += data |
| self.assertEquals(buf, TEST_STR) |
| s.close() |
| |
| def dgram_examine(self, proto, addr): |
| s = socket.socket(proto, socket.SOCK_DGRAM) |
| s.sendto(TEST_STR, addr) |
| buf = data = receive(s, 100) |
| while data and b'\n' not in buf: |
| data = receive(s, 100) |
| buf += data |
| self.assertEquals(buf, TEST_STR) |
| s.close() |
| |
| def test_TCPServers(self): |
| # Test SocketServer.TCPServer |
| servers = [ForgivingTCPServer, SocketServer.ThreadingTCPServer] |
| if HAVE_FORKING: |
| servers.append(SocketServer.ForkingTCPServer) |
| self.run_servers(socket.AF_INET, servers, |
| MyStreamHandler, self.stream_examine) |
| |
| def test_UDPServers(self): |
| # Test SocketServer.UDPServer |
| servers = [SocketServer.UDPServer, |
| SocketServer.ThreadingUDPServer] |
| if HAVE_FORKING: |
| servers.append(SocketServer.ForkingUDPServer) |
| self.run_servers(socket.AF_INET, servers, MyDatagramHandler, |
| self.dgram_examine) |
| |
| def test_stream_servers(self): |
| # Test SocketServer's stream servers |
| if not HAVE_UNIX_SOCKETS: |
| return |
| servers = [SocketServer.UnixStreamServer, |
| SocketServer.ThreadingUnixStreamServer] |
| if HAVE_FORKING: |
| servers.append(ForkingUnixStreamServer) |
| self.run_servers(socket.AF_UNIX, servers, MyStreamHandler, |
| self.stream_examine) |
| |
| # Alas, on Linux (at least) recvfrom() doesn't return a meaningful |
| # client address so this cannot work: |
| |
| # def test_dgram_servers(self): |
| # # Test SocketServer.UnixDatagramServer |
| # if not HAVE_UNIX_SOCKETS: |
| # return |
| # servers = [SocketServer.UnixDatagramServer, |
| # SocketServer.ThreadingUnixDatagramServer] |
| # if HAVE_FORKING: |
| # servers.append(ForkingUnixDatagramServer) |
| # self.run_servers(socket.AF_UNIX, servers, MyDatagramHandler, |
| # self.dgram_examine) |
| |
| |
| def test_main(): |
| if imp.lock_held(): |
| # If the import lock is held, the threads will hang |
| raise TestSkipped("can't run when import lock is held") |
| |
| test.test_support.run_unittest(SocketServerTest) |
| |
| if __name__ == "__main__": |
| test_main() |