blob: acf37a038faa1358a292d9346b30bbe9557d7c14 [file] [log] [blame]
Jim Fultond15dc062004-07-14 19:11:50 +00001import unittest
2from doctest import DocTestSuite
Benjamin Petersonee8712c2008-05-20 21:35:26 +00003from test import support
Christian Heimes587c2bf2008-01-19 16:21:02 +00004import weakref
5import gc
6
Antoine Pitrouda991da2010-08-03 18:32:26 +00007# Modules under test
8_thread = support.import_module('_thread')
9threading = support.import_module('threading')
10import _threading_local
11
12
Christian Heimes587c2bf2008-01-19 16:21:02 +000013class Weak(object):
14 pass
15
16def target(local, weaklist):
17 weak = Weak()
18 local.weak = weak
19 weaklist.append(weakref.ref(weak))
20
Antoine Pitrouda991da2010-08-03 18:32:26 +000021
22class BaseLocalTest:
Christian Heimes587c2bf2008-01-19 16:21:02 +000023
24 def test_local_refs(self):
25 self._local_refs(20)
26 self._local_refs(50)
27 self._local_refs(100)
28
29 def _local_refs(self, n):
Antoine Pitrouda991da2010-08-03 18:32:26 +000030 local = self._local()
Christian Heimes587c2bf2008-01-19 16:21:02 +000031 weaklist = []
32 for i in range(n):
33 t = threading.Thread(target=target, args=(local, weaklist))
34 t.start()
35 t.join()
36 del t
37
38 gc.collect()
39 self.assertEqual(len(weaklist), n)
40
Antoine Pitrou5af4f4b2010-08-09 22:38:19 +000041 # XXX _threading_local keeps the local of the last stopped thread alive.
Christian Heimes587c2bf2008-01-19 16:21:02 +000042 deadlist = [weak for weak in weaklist if weak() is None]
Antoine Pitrou5af4f4b2010-08-09 22:38:19 +000043 self.assertIn(len(deadlist), (n-1, n))
Christian Heimes587c2bf2008-01-19 16:21:02 +000044
45 # Assignment to the same thread local frees it sometimes (!)
46 local.someothervar = None
47 gc.collect()
48 deadlist = [weak for weak in weaklist if weak() is None]
Benjamin Peterson577473f2010-01-19 00:09:57 +000049 self.assertIn(len(deadlist), (n-1, n), (n, len(deadlist)))
Jim Fultond15dc062004-07-14 19:11:50 +000050
Benjamin Peterson8a250ae2008-06-30 23:30:24 +000051 def test_derived(self):
52 # Issue 3088: if there is a threads switch inside the __init__
53 # of a threading.local derived class, the per-thread dictionary
54 # is created but not correctly set on the object.
55 # The first member set may be bogus.
56 import time
Antoine Pitrouda991da2010-08-03 18:32:26 +000057 class Local(self._local):
Benjamin Peterson8a250ae2008-06-30 23:30:24 +000058 def __init__(self):
59 time.sleep(0.01)
60 local = Local()
61
62 def f(i):
63 local.x = i
64 # Simply check that the variable is correctly set
65 self.assertEqual(local.x, i)
66
67 threads= []
68 for i in range(10):
69 t = threading.Thread(target=f, args=(i,))
70 t.start()
71 threads.append(t)
72
73 for t in threads:
74 t.join()
75
Philip Jenvey26713ca2009-09-29 04:57:18 +000076 def test_derived_cycle_dealloc(self):
77 # http://bugs.python.org/issue6990
Antoine Pitrouda991da2010-08-03 18:32:26 +000078 class Local(self._local):
Philip Jenvey26713ca2009-09-29 04:57:18 +000079 pass
80 locals = None
81 passed = False
82 e1 = threading.Event()
83 e2 = threading.Event()
84
85 def f():
86 nonlocal passed
87 # 1) Involve Local in a cycle
88 cycle = [Local()]
89 cycle.append(cycle)
90 cycle[0].foo = 'bar'
91
92 # 2) GC the cycle (triggers threadmodule.c::local_clear
93 # before local_dealloc)
94 del cycle
95 gc.collect()
96 e1.set()
97 e2.wait()
98
99 # 4) New Locals should be empty
100 passed = all(not hasattr(local, 'foo') for local in locals)
101
102 t = threading.Thread(target=f)
103 t.start()
104 e1.wait()
105
106 # 3) New Locals should recycle the original's address. Creating
107 # them in the thread overwrites the thread state and avoids the
108 # bug
109 locals = [Local() for i in range(10)]
110 e2.set()
111 t.join()
112
113 self.assertTrue(passed)
114
Jack Diederich561d5aa2010-02-22 19:55:46 +0000115 def test_arguments(self):
116 # Issue 1522237
Antoine Pitrouda991da2010-08-03 18:32:26 +0000117 class MyLocal(self._local):
118 def __init__(self, *args, **kwargs):
119 pass
Jack Diederich561d5aa2010-02-22 19:55:46 +0000120
Antoine Pitrouda991da2010-08-03 18:32:26 +0000121 MyLocal(a=1)
122 MyLocal(1)
123 self.assertRaises(TypeError, self._local, a=1)
124 self.assertRaises(TypeError, self._local, 1)
Jack Diederich561d5aa2010-02-22 19:55:46 +0000125
Antoine Pitrou1a9a9d52010-08-28 18:17:03 +0000126 def _test_one_class(self, c):
127 self._failed = "No error message set or cleared."
128 obj = c()
129 e1 = threading.Event()
130 e2 = threading.Event()
131
132 def f1():
133 obj.x = 'foo'
134 obj.y = 'bar'
135 del obj.y
136 e1.set()
137 e2.wait()
138
139 def f2():
140 try:
141 foo = obj.x
142 except AttributeError:
143 # This is expected -- we haven't set obj.x in this thread yet!
144 self._failed = "" # passed
145 else:
146 self._failed = ('Incorrectly got value %r from class %r\n' %
147 (foo, c))
148 sys.stderr.write(self._failed)
149
150 t1 = threading.Thread(target=f1)
151 t1.start()
152 e1.wait()
153 t2 = threading.Thread(target=f2)
154 t2.start()
155 t2.join()
156 # The test is done; just let t1 know it can exit, and wait for it.
157 e2.set()
158 t1.join()
159
160 self.assertFalse(self._failed, self._failed)
161
162 def test_threading_local(self):
163 self._test_one_class(self._local)
164
165 def test_threading_local_subclass(self):
166 class LocalSubclass(self._local):
167 """To test that subclasses behave properly."""
168 self._test_one_class(LocalSubclass)
169
170 def _test_dict_attribute(self, cls):
171 obj = cls()
172 obj.x = 5
173 self.assertEqual(obj.__dict__, {'x': 5})
174 with self.assertRaises(AttributeError):
175 obj.__dict__ = {}
176 with self.assertRaises(AttributeError):
177 del obj.__dict__
178
179 def test_dict_attribute(self):
180 self._test_dict_attribute(self._local)
181
182 def test_dict_attribute_subclass(self):
183 class LocalSubclass(self._local):
184 """To test that subclasses behave properly."""
185 self._test_dict_attribute(LocalSubclass)
186
Antoine Pitrouda991da2010-08-03 18:32:26 +0000187
188class ThreadLocalTest(unittest.TestCase, BaseLocalTest):
189 _local = _thread._local
190
Antoine Pitrou5af4f4b2010-08-09 22:38:19 +0000191 # Fails for the pure Python implementation
192 def test_cycle_collection(self):
193 class X:
194 pass
195
196 x = X()
197 x.local = self._local()
198 x.local.x = x
199 wr = weakref.ref(x)
200 del x
201 gc.collect()
202 self.assertIs(wr(), None)
203
Antoine Pitrouda991da2010-08-03 18:32:26 +0000204class PyThreadingLocalTest(unittest.TestCase, BaseLocalTest):
205 _local = _threading_local.local
Jack Diederich561d5aa2010-02-22 19:55:46 +0000206
Benjamin Peterson8a250ae2008-06-30 23:30:24 +0000207
Jim Fultond15dc062004-07-14 19:11:50 +0000208def test_main():
Christian Heimes587c2bf2008-01-19 16:21:02 +0000209 suite = unittest.TestSuite()
210 suite.addTest(DocTestSuite('_threading_local'))
Antoine Pitrouda991da2010-08-03 18:32:26 +0000211 suite.addTest(unittest.makeSuite(ThreadLocalTest))
212 suite.addTest(unittest.makeSuite(PyThreadingLocalTest))
Jim Fultond15dc062004-07-14 19:11:50 +0000213
Antoine Pitrou577ba7d2010-08-04 00:18:49 +0000214 local_orig = _threading_local.local
215 def setUp(test):
216 _threading_local.local = _thread._local
217 def tearDown(test):
218 _threading_local.local = local_orig
219 suite.addTest(DocTestSuite('_threading_local',
220 setUp=setUp, tearDown=tearDown)
221 )
Jim Fultond15dc062004-07-14 19:11:50 +0000222
Benjamin Petersonee8712c2008-05-20 21:35:26 +0000223 support.run_unittest(suite)
Jim Fultond15dc062004-07-14 19:11:50 +0000224
225if __name__ == '__main__':
226 test_main()