blob: a312e28573ea1334a9e384956a688bce544e3543 [file] [log] [blame]
Benjamin Petersondaeb9252014-08-20 14:14:50 -05001import os
2import sys
3import ssl
4import pprint
5import urllib
6import urlparse
7# Rename HTTPServer to _HTTPServer so as to avoid confusion with HTTPSServer.
8from BaseHTTPServer import HTTPServer as _HTTPServer, BaseHTTPRequestHandler
9from SimpleHTTPServer import SimpleHTTPRequestHandler
10
11from test import test_support as support
12threading = support.import_module("threading")
13
14here = os.path.dirname(__file__)
15
16HOST = support.HOST
17CERTFILE = os.path.join(here, 'keycert.pem')
18
19# This one's based on HTTPServer, which is based on SocketServer
20
21class HTTPSServer(_HTTPServer):
22
23 def __init__(self, server_address, handler_class, context):
24 _HTTPServer.__init__(self, server_address, handler_class)
25 self.context = context
26
27 def __str__(self):
28 return ('<%s %s:%s>' %
29 (self.__class__.__name__,
30 self.server_name,
31 self.server_port))
32
33 def get_request(self):
34 # override this to wrap socket with SSL
35 try:
36 sock, addr = self.socket.accept()
37 sslconn = self.context.wrap_socket(sock, server_side=True)
38 except OSError as e:
39 # socket errors are silenced by the caller, print them here
40 if support.verbose:
41 sys.stderr.write("Got an error:\n%s\n" % e)
42 raise
43 return sslconn, addr
44
45class RootedHTTPRequestHandler(SimpleHTTPRequestHandler):
46 # need to override translate_path to get a known root,
47 # instead of using os.curdir, since the test could be
48 # run from anywhere
49
50 server_version = "TestHTTPS/1.0"
51 root = here
52 # Avoid hanging when a request gets interrupted by the client
53 timeout = 5
54
55 def translate_path(self, path):
56 """Translate a /-separated PATH to the local filename syntax.
57
58 Components that mean special things to the local file system
59 (e.g. drive or directory names) are ignored. (XXX They should
60 probably be diagnosed.)
61
62 """
63 # abandon query parameters
64 path = urlparse.urlparse(path)[2]
65 path = os.path.normpath(urllib.unquote(path))
66 words = path.split('/')
67 words = filter(None, words)
68 path = self.root
69 for word in words:
70 drive, word = os.path.splitdrive(word)
71 head, word = os.path.split(word)
72 path = os.path.join(path, word)
73 return path
74
75 def log_message(self, format, *args):
76 # we override this to suppress logging unless "verbose"
77 if support.verbose:
78 sys.stdout.write(" server (%s:%d %s):\n [%s] %s\n" %
79 (self.server.server_address,
80 self.server.server_port,
81 self.request.cipher(),
82 self.log_date_time_string(),
83 format%args))
84
85
86class StatsRequestHandler(BaseHTTPRequestHandler):
87 """Example HTTP request handler which returns SSL statistics on GET
88 requests.
89 """
90
91 server_version = "StatsHTTPS/1.0"
92
93 def do_GET(self, send_body=True):
94 """Serve a GET request."""
95 sock = self.rfile.raw._sock
96 context = sock.context
97 stats = {
98 'session_cache': context.session_stats(),
99 'cipher': sock.cipher(),
100 'compression': sock.compression(),
101 }
102 body = pprint.pformat(stats)
103 body = body.encode('utf-8')
104 self.send_response(200)
105 self.send_header("Content-type", "text/plain; charset=utf-8")
106 self.send_header("Content-Length", str(len(body)))
107 self.end_headers()
108 if send_body:
109 self.wfile.write(body)
110
111 def do_HEAD(self):
112 """Serve a HEAD request."""
113 self.do_GET(send_body=False)
114
115 def log_request(self, format, *args):
116 if support.verbose:
117 BaseHTTPRequestHandler.log_request(self, format, *args)
118
119
120class HTTPSServerThread(threading.Thread):
121
122 def __init__(self, context, host=HOST, handler_class=None):
123 self.flag = None
124 self.server = HTTPSServer((host, 0),
125 handler_class or RootedHTTPRequestHandler,
126 context)
127 self.port = self.server.server_port
128 threading.Thread.__init__(self)
129 self.daemon = True
130
131 def __str__(self):
132 return "<%s %s>" % (self.__class__.__name__, self.server)
133
134 def start(self, flag=None):
135 self.flag = flag
136 threading.Thread.start(self)
137
138 def run(self):
139 if self.flag:
140 self.flag.set()
141 try:
142 self.server.serve_forever(0.05)
143 finally:
144 self.server.server_close()
145
146 def stop(self):
147 self.server.shutdown()
148
149
150def make_https_server(case, context=None, certfile=CERTFILE,
151 host=HOST, handler_class=None):
152 if context is None:
153 context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
154 # We assume the certfile contains both private key and certificate
155 context.load_cert_chain(certfile)
156 server = HTTPSServerThread(context, host, handler_class)
157 flag = threading.Event()
158 server.start(flag)
159 flag.wait()
160 def cleanup():
161 if support.verbose:
162 sys.stdout.write('stopping HTTPS server\n')
163 server.stop()
164 if support.verbose:
165 sys.stdout.write('joining HTTPS thread\n')
166 server.join()
167 case.addCleanup(cleanup)
168 return server
169
170
171if __name__ == "__main__":
172 import argparse
173 parser = argparse.ArgumentParser(
174 description='Run a test HTTPS server. '
175 'By default, the current directory is served.')
176 parser.add_argument('-p', '--port', type=int, default=4433,
177 help='port to listen on (default: %(default)s)')
178 parser.add_argument('-q', '--quiet', dest='verbose', default=True,
179 action='store_false', help='be less verbose')
180 parser.add_argument('-s', '--stats', dest='use_stats_handler', default=False,
181 action='store_true', help='always return stats page')
182 parser.add_argument('--curve-name', dest='curve_name', type=str,
183 action='store',
184 help='curve name for EC-based Diffie-Hellman')
185 parser.add_argument('--ciphers', dest='ciphers', type=str,
186 help='allowed cipher list')
187 parser.add_argument('--dh', dest='dh_file', type=str, action='store',
188 help='PEM file containing DH parameters')
189 args = parser.parse_args()
190
191 support.verbose = args.verbose
192 if args.use_stats_handler:
193 handler_class = StatsRequestHandler
194 else:
195 handler_class = RootedHTTPRequestHandler
196 handler_class.root = os.getcwd()
197 context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
198 context.load_cert_chain(CERTFILE)
199 if args.curve_name:
200 context.set_ecdh_curve(args.curve_name)
201 if args.dh_file:
202 context.load_dh_params(args.dh_file)
203 if args.ciphers:
204 context.set_ciphers(args.ciphers)
205
206 server = HTTPSServer(("", args.port), handler_class, context)
207 if args.verbose:
208 print("Listening on https://localhost:{0.port}".format(args))
209 server.serve_forever(0.1)