blob: d0736b12e8ad5345a1be0915dd7b80e5c2bf2eba [file] [log] [blame]
Antoine Pitrou803e6d62010-10-13 10:36:15 +00001import os
2import sys
3import ssl
4import threading
5import urllib.parse
6# Rename HTTPServer to _HTTPServer so as to avoid confusion with HTTPSServer.
7from http.server import HTTPServer as _HTTPServer, SimpleHTTPRequestHandler
8
9from test import support
10
11here = os.path.dirname(__file__)
12
13HOST = support.HOST
14CERTFILE = os.path.join(here, 'keycert.pem')
15
16# This one's based on HTTPServer, which is based on SocketServer
17
18class HTTPSServer(_HTTPServer):
19
20 def __init__(self, server_address, handler_class, context):
21 _HTTPServer.__init__(self, server_address, handler_class)
22 self.context = context
23
24 def __str__(self):
25 return ('<%s %s:%s>' %
26 (self.__class__.__name__,
27 self.server_name,
28 self.server_port))
29
30 def get_request(self):
31 # override this to wrap socket with SSL
32 sock, addr = self.socket.accept()
33 sslconn = self.context.wrap_socket(sock, server_side=True)
34 return sslconn, addr
35
36class RootedHTTPRequestHandler(SimpleHTTPRequestHandler):
37 # need to override translate_path to get a known root,
38 # instead of using os.curdir, since the test could be
39 # run from anywhere
40
41 server_version = "TestHTTPS/1.0"
42 root = here
43 # Avoid hanging when a request gets interrupted by the client
44 timeout = 5
45
46 def translate_path(self, path):
47 """Translate a /-separated PATH to the local filename syntax.
48
49 Components that mean special things to the local file system
50 (e.g. drive or directory names) are ignored. (XXX They should
51 probably be diagnosed.)
52
53 """
54 # abandon query parameters
55 path = urllib.parse.urlparse(path)[2]
56 path = os.path.normpath(urllib.parse.unquote(path))
57 words = path.split('/')
58 words = filter(None, words)
59 path = self.root
60 for word in words:
61 drive, word = os.path.splitdrive(word)
62 head, word = os.path.split(word)
63 path = os.path.join(path, word)
64 return path
65
66 def log_message(self, format, *args):
67 # we override this to suppress logging unless "verbose"
68 if support.verbose:
69 sys.stdout.write(" server (%s:%d %s):\n [%s] %s\n" %
70 (self.server.server_address,
71 self.server.server_port,
72 self.request.cipher(),
73 self.log_date_time_string(),
74 format%args))
75
76class HTTPSServerThread(threading.Thread):
77
78 def __init__(self, context, host=HOST, handler_class=None):
79 self.flag = None
80 self.server = HTTPSServer((host, 0),
81 handler_class or RootedHTTPRequestHandler,
82 context)
83 self.port = self.server.server_port
84 threading.Thread.__init__(self)
85 self.daemon = True
86
87 def __str__(self):
88 return "<%s %s>" % (self.__class__.__name__, self.server)
89
90 def start(self, flag=None):
91 self.flag = flag
92 threading.Thread.start(self)
93
94 def run(self):
95 if self.flag:
96 self.flag.set()
97 self.server.serve_forever(0.05)
98
99 def stop(self):
100 self.server.shutdown()
101
102
103def make_https_server(case, certfile=CERTFILE, host=HOST, handler_class=None):
104 # we assume the certfile contains both private key and certificate
105 context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
106 context.load_cert_chain(certfile)
107 server = HTTPSServerThread(context, host, handler_class)
108 flag = threading.Event()
109 server.start(flag)
110 flag.wait()
111 def cleanup():
112 if support.verbose:
113 sys.stdout.write('stopping HTTPS server\n')
114 server.stop()
115 if support.verbose:
116 sys.stdout.write('joining HTTPS thread\n')
117 server.join()
118 case.addCleanup(cleanup)
119 return server