blob: 7a7de637c38823c086f1ad4a0c50de5752ff64bf [file] [log] [blame]
Steve Dower9ddc4162019-05-29 08:20:35 -07001"""This script contains the actual auditing tests.
2
3It should not be imported directly, but should be run by the test_audit
4module with arguments identifying each test.
5
6"""
7
8import contextlib
9import sys
10
11
12class TestHook:
13 """Used in standard hook tests to collect any logged events.
14
15 Should be used in a with block to ensure that it has no impact
16 after the test completes.
17 """
18
19 def __init__(self, raise_on_events=None, exc_type=RuntimeError):
20 self.raise_on_events = raise_on_events or ()
21 self.exc_type = exc_type
22 self.seen = []
23 self.closed = False
24
25 def __enter__(self, *a):
26 sys.addaudithook(self)
27 return self
28
29 def __exit__(self, *a):
30 self.close()
31
32 def close(self):
33 self.closed = True
34
35 @property
36 def seen_events(self):
37 return [i[0] for i in self.seen]
38
39 def __call__(self, event, args):
40 if self.closed:
41 return
42 self.seen.append((event, args))
43 if event in self.raise_on_events:
44 raise self.exc_type("saw event " + event)
45
46
Steve Dower9ddc4162019-05-29 08:20:35 -070047# Simple helpers, since we are not in unittest here
48def assertEqual(x, y):
49 if x != y:
50 raise AssertionError(f"{x!r} should equal {y!r}")
51
52
53def assertIn(el, series):
54 if el not in series:
55 raise AssertionError(f"{el!r} should be in {series!r}")
56
57
58def assertNotIn(el, series):
59 if el in series:
60 raise AssertionError(f"{el!r} should not be in {series!r}")
61
62
63def assertSequenceEqual(x, y):
64 if len(x) != len(y):
65 raise AssertionError(f"{x!r} should equal {y!r}")
66 if any(ix != iy for ix, iy in zip(x, y)):
67 raise AssertionError(f"{x!r} should equal {y!r}")
68
69
70@contextlib.contextmanager
71def assertRaises(ex_type):
72 try:
73 yield
74 assert False, f"expected {ex_type}"
75 except BaseException as ex:
76 if isinstance(ex, AssertionError):
77 raise
78 assert type(ex) is ex_type, f"{ex} should be {ex_type}"
79
80
81def test_basic():
82 with TestHook() as hook:
83 sys.audit("test_event", 1, 2, 3)
84 assertEqual(hook.seen[0][0], "test_event")
85 assertEqual(hook.seen[0][1], (1, 2, 3))
86
87
88def test_block_add_hook():
89 # Raising an exception should prevent a new hook from being added,
90 # but will not propagate out.
91 with TestHook(raise_on_events="sys.addaudithook") as hook1:
92 with TestHook() as hook2:
93 sys.audit("test_event")
94 assertIn("test_event", hook1.seen_events)
95 assertNotIn("test_event", hook2.seen_events)
96
97
98def test_block_add_hook_baseexception():
99 # Raising BaseException will propagate out when adding a hook
100 with assertRaises(BaseException):
101 with TestHook(
102 raise_on_events="sys.addaudithook", exc_type=BaseException
103 ) as hook1:
104 # Adding this next hook should raise BaseException
105 with TestHook() as hook2:
106 pass
107
108
Steve Dower9ddc4162019-05-29 08:20:35 -0700109def test_pickle():
110 import pickle
111
112 class PicklePrint:
113 def __reduce_ex__(self, p):
114 return str, ("Pwned!",)
115
116 payload_1 = pickle.dumps(PicklePrint())
117 payload_2 = pickle.dumps(("a", "b", "c", 1, 2, 3))
118
119 # Before we add the hook, ensure our malicious pickle loads
120 assertEqual("Pwned!", pickle.loads(payload_1))
121
122 with TestHook(raise_on_events="pickle.find_class") as hook:
123 with assertRaises(RuntimeError):
124 # With the hook enabled, loading globals is not allowed
125 pickle.loads(payload_1)
126 # pickles with no globals are okay
127 pickle.loads(payload_2)
128
129
130def test_monkeypatch():
131 class A:
132 pass
133
134 class B:
135 pass
136
137 class C(A):
138 pass
139
140 a = A()
141
142 with TestHook() as hook:
143 # Catch name changes
144 C.__name__ = "X"
145 # Catch type changes
146 C.__bases__ = (B,)
147 # Ensure bypassing __setattr__ is still caught
148 type.__dict__["__bases__"].__set__(C, (B,))
149 # Catch attribute replacement
150 C.__init__ = B.__init__
151 # Catch attribute addition
152 C.new_attr = 123
153 # Catch class changes
154 a.__class__ = B
155
156 actual = [(a[0], a[1]) for e, a in hook.seen if e == "object.__setattr__"]
157 assertSequenceEqual(
158 [(C, "__name__"), (C, "__bases__"), (C, "__bases__"), (a, "__class__")], actual
159 )
160
161
162def test_open():
163 # SSLContext.load_dh_params uses _Py_fopen_obj rather than normal open()
164 try:
165 import ssl
166
167 load_dh_params = ssl.create_default_context().load_dh_params
168 except ImportError:
169 load_dh_params = None
170
171 # Try a range of "open" functions.
172 # All of them should fail
173 with TestHook(raise_on_events={"open"}) as hook:
174 for fn, *args in [
175 (open, sys.argv[2], "r"),
176 (open, sys.executable, "rb"),
177 (open, 3, "wb"),
178 (open, sys.argv[2], "w", -1, None, None, None, False, lambda *a: 1),
179 (load_dh_params, sys.argv[2]),
180 ]:
181 if not fn:
182 continue
183 with assertRaises(RuntimeError):
184 fn(*args)
185
186 actual_mode = [(a[0], a[1]) for e, a in hook.seen if e == "open" and a[1]]
187 actual_flag = [(a[0], a[2]) for e, a in hook.seen if e == "open" and not a[1]]
188 assertSequenceEqual(
189 [
190 i
191 for i in [
192 (sys.argv[2], "r"),
193 (sys.executable, "r"),
194 (3, "w"),
195 (sys.argv[2], "w"),
196 (sys.argv[2], "rb") if load_dh_params else None,
197 ]
198 if i is not None
199 ],
200 actual_mode,
201 )
202 assertSequenceEqual([], actual_flag)
203
204
205def test_cantrace():
206 traced = []
207
208 def trace(frame, event, *args):
209 if frame.f_code == TestHook.__call__.__code__:
210 traced.append(event)
211
212 old = sys.settrace(trace)
213 try:
214 with TestHook() as hook:
215 # No traced call
216 eval("1")
217
218 # No traced call
219 hook.__cantrace__ = False
220 eval("2")
221
222 # One traced call
223 hook.__cantrace__ = True
224 eval("3")
225
226 # Two traced calls (writing to private member, eval)
227 hook.__cantrace__ = 1
228 eval("4")
229
230 # One traced call (writing to private member)
231 hook.__cantrace__ = 0
232 finally:
233 sys.settrace(old)
234
235 assertSequenceEqual(["call"] * 4, traced)
236
237
Zackery Spytz08286d52019-06-21 09:31:59 -0600238def test_mmap():
239 import mmap
Steve Dowerbea33f52019-11-28 08:46:11 -0800240
Zackery Spytz08286d52019-06-21 09:31:59 -0600241 with TestHook() as hook:
242 mmap.mmap(-1, 8)
243 assertEqual(hook.seen[0][1][:2], (-1, 8))
244
245
Steve Dowerbea33f52019-11-28 08:46:11 -0800246def test_excepthook():
247 def excepthook(exc_type, exc_value, exc_tb):
248 if exc_type is not RuntimeError:
249 sys.__excepthook__(exc_type, exc_value, exc_tb)
250
251 def hook(event, args):
252 if event == "sys.excepthook":
253 if not isinstance(args[2], args[1]):
254 raise TypeError(f"Expected isinstance({args[2]!r}, " f"{args[1]!r})")
255 if args[0] != excepthook:
256 raise ValueError(f"Expected {args[0]} == {excepthook}")
257 print(event, repr(args[2]))
258
259 sys.addaudithook(hook)
260 sys.excepthook = excepthook
261 raise RuntimeError("fatal-error")
262
263
264def test_unraisablehook():
265 from _testcapi import write_unraisable_exc
266
267 def unraisablehook(hookargs):
268 pass
269
270 def hook(event, args):
271 if event == "sys.unraisablehook":
272 if args[0] != unraisablehook:
273 raise ValueError(f"Expected {args[0]} == {unraisablehook}")
274 print(event, repr(args[1].exc_value), args[1].err_msg)
275
276 sys.addaudithook(hook)
277 sys.unraisablehook = unraisablehook
278 write_unraisable_exc(RuntimeError("nonfatal-error"), "for audit hook test", None)
279
280
Steve Doweree17e372019-12-09 11:18:12 -0800281def test_winreg():
282 from winreg import OpenKey, EnumKey, CloseKey, HKEY_LOCAL_MACHINE
283
284 def hook(event, args):
285 if not event.startswith("winreg."):
286 return
287 print(event, *args)
288
289 sys.addaudithook(hook)
290
291 k = OpenKey(HKEY_LOCAL_MACHINE, "Software")
292 EnumKey(k, 0)
293 try:
294 EnumKey(k, 10000)
295 except OSError:
296 pass
297 else:
298 raise RuntimeError("Expected EnumKey(HKLM, 10000) to fail")
299
300 kv = k.Detach()
301 CloseKey(kv)
302
303
Steve Dower63ba5cc2020-03-31 12:38:53 +0100304def test_socket():
305 import socket
306
307 def hook(event, args):
308 if event.startswith("socket."):
309 print(event, *args)
310
311 sys.addaudithook(hook)
312
313 socket.gethostname()
314
315 # Don't care if this fails, we just want the audit message
316 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
317 try:
318 # Don't care if this fails, we just want the audit message
319 sock.bind(('127.0.0.1', 8080))
Steve Dower3ef4a7e2020-04-01 09:38:26 +0100320 except Exception:
Steve Dower63ba5cc2020-03-31 12:38:53 +0100321 pass
322 finally:
323 sock.close()
324
325
Pablo Galindob4f90892021-03-10 00:53:57 +0000326def test_gc():
327 import gc
328
329 def hook(event, args):
330 if event.startswith("gc."):
331 print(event, *args)
332
333 sys.addaudithook(hook)
334
335 gc.get_objects(generation=1)
336
337 x = object()
338 y = [x]
339
340 gc.get_referrers(x)
341 gc.get_referents(y)
342
343
Saiyang Gou927b8412021-04-23 03:19:08 -0700344def test_http_client():
345 import http.client
346
347 def hook(event, args):
348 if event.startswith("http.client."):
349 print(event, *args[1:])
350
351 sys.addaudithook(hook)
352
353 conn = http.client.HTTPConnection('www.python.org')
354 try:
355 conn.request('GET', '/')
356 except OSError:
357 print('http.client.send', '[cannot send]')
358 finally:
359 conn.close()
360
361
Erlend Egeberg Aasland7244c002021-04-27 01:16:46 +0200362def test_sqlite3():
363 import sqlite3
364
365 def hook(event, *args):
366 if event.startswith("sqlite3."):
367 print(event, *args)
368
369 sys.addaudithook(hook)
Erlend Egeberg Aaslandc96cc082021-05-02 23:25:17 +0200370 cx1 = sqlite3.connect(":memory:")
371 cx2 = sqlite3.Connection(":memory:")
Erlend Egeberg Aasland7244c002021-04-27 01:16:46 +0200372
373 # Configured without --enable-loadable-sqlite-extensions
374 if hasattr(sqlite3.Connection, "enable_load_extension"):
Erlend Egeberg Aaslandc96cc082021-05-02 23:25:17 +0200375 cx1.enable_load_extension(False)
Erlend Egeberg Aasland7244c002021-04-27 01:16:46 +0200376 try:
Erlend Egeberg Aaslandc96cc082021-05-02 23:25:17 +0200377 cx1.load_extension("test")
Erlend Egeberg Aasland7244c002021-04-27 01:16:46 +0200378 except sqlite3.OperationalError:
379 pass
380 else:
381 raise RuntimeError("Expected sqlite3.load_extension to fail")
382
383
Steve Dower9ddc4162019-05-29 08:20:35 -0700384if __name__ == "__main__":
Victor Stinnerf6e58ae2020-06-10 18:49:23 +0200385 from test.support import suppress_msvcrt_asserts
Steve Dowerbea33f52019-11-28 08:46:11 -0800386
Victor Stinnerf6e58ae2020-06-10 18:49:23 +0200387 suppress_msvcrt_asserts()
Steve Dower9ddc4162019-05-29 08:20:35 -0700388
389 test = sys.argv[1]
390 globals()[test]()