blob: 914a943da56a0edbbc300160384ce2fdebb73787 [file] [log] [blame]
Antoine Pitrou803e6d62010-10-13 10:36:15 +00001import os
2import sys
3import ssl
Antoine Pitrouf26f87e2010-10-13 11:27:09 +00004import pprint
Antoine Pitrou84fa4312010-10-13 11:51:05 +00005import socket
Antoine Pitrou803e6d62010-10-13 10:36:15 +00006import threading
7import urllib.parse
8# Rename HTTPServer to _HTTPServer so as to avoid confusion with HTTPSServer.
Antoine Pitrouf26f87e2010-10-13 11:27:09 +00009from http.server import (HTTPServer as _HTTPServer,
10 SimpleHTTPRequestHandler, BaseHTTPRequestHandler)
Antoine Pitrou803e6d62010-10-13 10:36:15 +000011
12from test import support
13
14here = os.path.dirname(__file__)
15
16HOST = support.HOST
17CERTFILE = os.path.join(here, 'keycert.pem')
18
19# This one's based on HTTPServer, which is based on SocketServer
20
21class HTTPSServer(_HTTPServer):
22
23 def __init__(self, server_address, handler_class, context):
24 _HTTPServer.__init__(self, server_address, handler_class)
25 self.context = context
26
27 def __str__(self):
28 return ('<%s %s:%s>' %
29 (self.__class__.__name__,
30 self.server_name,
31 self.server_port))
32
33 def get_request(self):
34 # override this to wrap socket with SSL
Antoine Pitrou84fa4312010-10-13 11:51:05 +000035 try:
36 sock, addr = self.socket.accept()
37 sslconn = self.context.wrap_socket(sock, server_side=True)
38 except socket.error as e:
39 # socket errors are silenced by the caller, print them here
40 if support.verbose:
41 sys.stderr.write("Got an error:\n%s\n" % e)
42 raise
Antoine Pitrou803e6d62010-10-13 10:36:15 +000043 return sslconn, addr
44
45class RootedHTTPRequestHandler(SimpleHTTPRequestHandler):
46 # need to override translate_path to get a known root,
47 # instead of using os.curdir, since the test could be
48 # run from anywhere
49
50 server_version = "TestHTTPS/1.0"
51 root = here
52 # Avoid hanging when a request gets interrupted by the client
53 timeout = 5
54
55 def translate_path(self, path):
56 """Translate a /-separated PATH to the local filename syntax.
57
58 Components that mean special things to the local file system
59 (e.g. drive or directory names) are ignored. (XXX They should
60 probably be diagnosed.)
61
62 """
63 # abandon query parameters
64 path = urllib.parse.urlparse(path)[2]
65 path = os.path.normpath(urllib.parse.unquote(path))
66 words = path.split('/')
67 words = filter(None, words)
68 path = self.root
69 for word in words:
70 drive, word = os.path.splitdrive(word)
71 head, word = os.path.split(word)
72 path = os.path.join(path, word)
73 return path
74
75 def log_message(self, format, *args):
76 # we override this to suppress logging unless "verbose"
77 if support.verbose:
78 sys.stdout.write(" server (%s:%d %s):\n [%s] %s\n" %
79 (self.server.server_address,
80 self.server.server_port,
81 self.request.cipher(),
82 self.log_date_time_string(),
83 format%args))
84
Antoine Pitrouf26f87e2010-10-13 11:27:09 +000085
86class StatsRequestHandler(BaseHTTPRequestHandler):
87 """Example HTTP request handler which returns SSL statistics on GET
88 requests.
89 """
90
91 server_version = "StatsHTTPS/1.0"
92
93 def do_GET(self, send_body=True):
94 """Serve a GET request."""
95 sock = self.rfile.raw._sock
96 context = sock.context
97 body = pprint.pformat(context.session_stats())
98 body = body.encode('utf-8')
99 self.send_response(200)
100 self.send_header("Content-type", "text/plain; charset=utf-8")
101 self.send_header("Content-Length", str(len(body)))
102 self.end_headers()
103 if send_body:
104 self.wfile.write(body)
105
106 def do_HEAD(self):
107 """Serve a HEAD request."""
108 self.do_GET(send_body=False)
109
110 def log_request(self, format, *args):
111 if support.verbose:
112 BaseHTTPRequestHandler.log_request(self, format, *args)
113
114
Antoine Pitrou803e6d62010-10-13 10:36:15 +0000115class HTTPSServerThread(threading.Thread):
116
117 def __init__(self, context, host=HOST, handler_class=None):
118 self.flag = None
119 self.server = HTTPSServer((host, 0),
120 handler_class or RootedHTTPRequestHandler,
121 context)
122 self.port = self.server.server_port
123 threading.Thread.__init__(self)
124 self.daemon = True
125
126 def __str__(self):
127 return "<%s %s>" % (self.__class__.__name__, self.server)
128
129 def start(self, flag=None):
130 self.flag = flag
131 threading.Thread.start(self)
132
133 def run(self):
134 if self.flag:
135 self.flag.set()
136 self.server.serve_forever(0.05)
137
138 def stop(self):
139 self.server.shutdown()
140
141
142def make_https_server(case, certfile=CERTFILE, host=HOST, handler_class=None):
143 # we assume the certfile contains both private key and certificate
144 context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
145 context.load_cert_chain(certfile)
146 server = HTTPSServerThread(context, host, handler_class)
147 flag = threading.Event()
148 server.start(flag)
149 flag.wait()
150 def cleanup():
151 if support.verbose:
152 sys.stdout.write('stopping HTTPS server\n')
153 server.stop()
154 if support.verbose:
155 sys.stdout.write('joining HTTPS thread\n')
156 server.join()
157 case.addCleanup(cleanup)
158 return server
Antoine Pitrouf26f87e2010-10-13 11:27:09 +0000159
160
161if __name__ == "__main__":
162 import argparse
163 parser = argparse.ArgumentParser(
164 description='Run a test HTTPS server. '
165 'By default, the current directory is served.')
166 parser.add_argument('-p', '--port', type=int, default=4433,
167 help='port to listen on (default: %(default)s)')
168 parser.add_argument('-q', '--quiet', dest='verbose', default=True,
169 action='store_false', help='be less verbose')
170 parser.add_argument('-s', '--stats', dest='use_stats_handler', default=False,
171 action='store_true', help='always return stats page')
172 args = parser.parse_args()
173
174 support.verbose = args.verbose
175 if args.use_stats_handler:
176 handler_class = StatsRequestHandler
177 else:
178 handler_class = RootedHTTPRequestHandler
179 handler_class.root = os.getcwd()
180 context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
181 context.load_cert_chain(CERTFILE)
182
183 server = HTTPSServer(("", args.port), handler_class, context)
184 server.serve_forever(0.1)