blob: d31d07e01f98ba3c1d9c9d0b948a57dfbc21f8ed [file] [log] [blame]
Steven D'Aprano95702722016-04-15 01:51:31 +10001"""Test the secrets module.
2
3As most of the functions in secrets are thin wrappers around functions
4defined elsewhere, we don't need to test them exhaustively.
5"""
6
7
8import secrets
9import unittest
10import string
11
12
13# === Unit tests ===
14
15class Compare_Digest_Tests(unittest.TestCase):
16 """Test secrets.compare_digest function."""
17
18 def test_equal(self):
19 # Test compare_digest functionality with equal (byte/text) strings.
20 for s in ("a", "bcd", "xyz123"):
21 a = s*100
22 b = s*100
23 self.assertTrue(secrets.compare_digest(a, b))
24 self.assertTrue(secrets.compare_digest(a.encode('utf-8'), b.encode('utf-8')))
25
26 def test_unequal(self):
27 # Test compare_digest functionality with unequal (byte/text) strings.
28 self.assertFalse(secrets.compare_digest("abc", "abcd"))
29 self.assertFalse(secrets.compare_digest(b"abc", b"abcd"))
30 for s in ("x", "mn", "a1b2c3"):
31 a = s*100 + "q"
32 b = s*100 + "k"
33 self.assertFalse(secrets.compare_digest(a, b))
34 self.assertFalse(secrets.compare_digest(a.encode('utf-8'), b.encode('utf-8')))
35
36 def test_bad_types(self):
37 # Test that compare_digest raises with mixed types.
38 a = 'abcde'
39 b = a.encode('utf-8')
40 assert isinstance(a, str)
41 assert isinstance(b, bytes)
42 self.assertRaises(TypeError, secrets.compare_digest, a, b)
43 self.assertRaises(TypeError, secrets.compare_digest, b, a)
44
45 def test_bool(self):
46 # Test that compare_digest returns a bool.
Steven D'Aprano08fbef02016-04-15 10:04:24 +100047 self.assertIsInstance(secrets.compare_digest("abc", "abc"), bool)
48 self.assertIsInstance(secrets.compare_digest("abc", "xyz"), bool)
Steven D'Aprano95702722016-04-15 01:51:31 +100049
50
51class Random_Tests(unittest.TestCase):
52 """Test wrappers around SystemRandom methods."""
53
54 def test_randbits(self):
55 # Test randbits.
56 errmsg = "randbits(%d) returned %d"
57 for numbits in (3, 12, 30):
58 for i in range(6):
59 n = secrets.randbits(numbits)
60 self.assertTrue(0 <= n < 2**numbits, errmsg % (numbits, n))
61
62 def test_choice(self):
63 # Test choice.
64 items = [1, 2, 4, 8, 16, 32, 64]
65 for i in range(10):
66 self.assertTrue(secrets.choice(items) in items)
67
68 def test_randbelow(self):
69 # Test randbelow.
Steven D'Aprano95702722016-04-15 01:51:31 +100070 for i in range(2, 10):
Steven D'Aprano08fbef02016-04-15 10:04:24 +100071 self.assertIn(secrets.randbelow(i), range(i))
Steven D'Aprano95702722016-04-15 01:51:31 +100072 self.assertRaises(ValueError, secrets.randbelow, 0)
Raymond Hettingere9ee2072016-12-29 22:54:25 -070073 self.assertRaises(ValueError, secrets.randbelow, -1)
Steven D'Aprano95702722016-04-15 01:51:31 +100074
75
76class Token_Tests(unittest.TestCase):
77 """Test token functions."""
78
79 def test_token_defaults(self):
80 # Test that token_* functions handle default size correctly.
81 for func in (secrets.token_bytes, secrets.token_hex,
82 secrets.token_urlsafe):
Steven D'Aprano08fbef02016-04-15 10:04:24 +100083 with self.subTest(func=func):
84 name = func.__name__
85 try:
86 func()
87 except TypeError:
88 self.fail("%s cannot be called with no argument" % name)
89 try:
90 func(None)
91 except TypeError:
92 self.fail("%s cannot be called with None" % name)
Steven D'Aprano95702722016-04-15 01:51:31 +100093 size = secrets.DEFAULT_ENTROPY
94 self.assertEqual(len(secrets.token_bytes(None)), size)
95 self.assertEqual(len(secrets.token_hex(None)), 2*size)
96
97 def test_token_bytes(self):
98 # Test token_bytes.
Steven D'Aprano95702722016-04-15 01:51:31 +100099 for n in (1, 8, 17, 100):
Steven D'Aprano08fbef02016-04-15 10:04:24 +1000100 with self.subTest(n=n):
101 self.assertIsInstance(secrets.token_bytes(n), bytes)
102 self.assertEqual(len(secrets.token_bytes(n)), n)
Steven D'Aprano95702722016-04-15 01:51:31 +1000103
104 def test_token_hex(self):
105 # Test token_hex.
Steven D'Aprano95702722016-04-15 01:51:31 +1000106 for n in (1, 12, 25, 90):
Steven D'Aprano08fbef02016-04-15 10:04:24 +1000107 with self.subTest(n=n):
108 s = secrets.token_hex(n)
Steven D'Aprano8ca020e2016-04-15 10:06:18 +1000109 self.assertIsInstance(s, str)
Steven D'Aprano08fbef02016-04-15 10:04:24 +1000110 self.assertEqual(len(s), 2*n)
111 self.assertTrue(all(c in string.hexdigits for c in s))
Steven D'Aprano95702722016-04-15 01:51:31 +1000112
113 def test_token_urlsafe(self):
114 # Test token_urlsafe.
Steven D'Aprano95702722016-04-15 01:51:31 +1000115 legal = string.ascii_letters + string.digits + '-_'
116 for n in (1, 11, 28, 76):
Steven D'Aprano08fbef02016-04-15 10:04:24 +1000117 with self.subTest(n=n):
118 s = secrets.token_urlsafe(n)
Steven D'Aprano8ca020e2016-04-15 10:06:18 +1000119 self.assertIsInstance(s, str)
Steven D'Aprano08fbef02016-04-15 10:04:24 +1000120 self.assertTrue(all(c in legal for c in s))
Steven D'Aprano95702722016-04-15 01:51:31 +1000121
122
123if __name__ == '__main__':
124 unittest.main()