Close #13062: Add inspect.getclosurevars to simplify testing stateful closures
diff --git a/Lib/test/test_inspect.py b/Lib/test/test_inspect.py
index 53c947f..8327721 100644
--- a/Lib/test/test_inspect.py
+++ b/Lib/test/test_inspect.py
@@ -665,6 +665,105 @@
         self.assertIn(('f', b.f), inspect.getmembers(b, inspect.ismethod))
 
 
+_global_ref = object()
+class TestGetClosureVars(unittest.TestCase):
+
+    def test_name_resolution(self):
+        # Basic test of the 4 different resolution mechanisms
+        def f(nonlocal_ref):
+            def g(local_ref):
+                print(local_ref, nonlocal_ref, _global_ref, unbound_ref)
+            return g
+        _arg = object()
+        nonlocal_vars = {"nonlocal_ref": _arg}
+        global_vars = {"_global_ref": _global_ref}
+        builtin_vars = {"print": print}
+        unbound_names = {"unbound_ref"}
+        expected = inspect.ClosureVars(nonlocal_vars, global_vars,
+                                       builtin_vars, unbound_names)
+        self.assertEqual(inspect.getclosurevars(f(_arg)), expected)
+
+    def test_generator_closure(self):
+        def f(nonlocal_ref):
+            def g(local_ref):
+                print(local_ref, nonlocal_ref, _global_ref, unbound_ref)
+                yield
+            return g
+        _arg = object()
+        nonlocal_vars = {"nonlocal_ref": _arg}
+        global_vars = {"_global_ref": _global_ref}
+        builtin_vars = {"print": print}
+        unbound_names = {"unbound_ref"}
+        expected = inspect.ClosureVars(nonlocal_vars, global_vars,
+                                       builtin_vars, unbound_names)
+        self.assertEqual(inspect.getclosurevars(f(_arg)), expected)
+
+    def test_method_closure(self):
+        class C:
+            def f(self, nonlocal_ref):
+                def g(local_ref):
+                    print(local_ref, nonlocal_ref, _global_ref, unbound_ref)
+                return g
+        _arg = object()
+        nonlocal_vars = {"nonlocal_ref": _arg}
+        global_vars = {"_global_ref": _global_ref}
+        builtin_vars = {"print": print}
+        unbound_names = {"unbound_ref"}
+        expected = inspect.ClosureVars(nonlocal_vars, global_vars,
+                                       builtin_vars, unbound_names)
+        self.assertEqual(inspect.getclosurevars(C().f(_arg)), expected)
+
+    def test_nonlocal_vars(self):
+        # More complex tests of nonlocal resolution
+        def _nonlocal_vars(f):
+            return inspect.getclosurevars(f).nonlocals
+
+        def make_adder(x):
+            def add(y):
+                return x + y
+            return add
+
+        def curry(func, arg1):
+            return lambda arg2: func(arg1, arg2)
+
+        def less_than(a, b):
+            return a < b
+
+        # The infamous Y combinator.
+        def Y(le):
+            def g(f):
+                return le(lambda x: f(f)(x))
+            Y.g_ref = g
+            return g(g)
+
+        def check_y_combinator(func):
+            self.assertEqual(_nonlocal_vars(func), {'f': Y.g_ref})
+
+        inc = make_adder(1)
+        add_two = make_adder(2)
+        greater_than_five = curry(less_than, 5)
+
+        self.assertEqual(_nonlocal_vars(inc), {'x': 1})
+        self.assertEqual(_nonlocal_vars(add_two), {'x': 2})
+        self.assertEqual(_nonlocal_vars(greater_than_five),
+                         {'arg1': 5, 'func': less_than})
+        self.assertEqual(_nonlocal_vars((lambda x: lambda y: x + y)(3)),
+                         {'x': 3})
+        Y(check_y_combinator)
+
+    def test_getclosurevars_empty(self):
+        def foo(): pass
+        _empty = inspect.ClosureVars({}, {}, {}, set())
+        self.assertEqual(inspect.getclosurevars(lambda: True), _empty)
+        self.assertEqual(inspect.getclosurevars(foo), _empty)
+
+    def test_getclosurevars_error(self):
+        class T: pass
+        self.assertRaises(TypeError, inspect.getclosurevars, 1)
+        self.assertRaises(TypeError, inspect.getclosurevars, list)
+        self.assertRaises(TypeError, inspect.getclosurevars, {})
+
+
 class TestGetcallargsFunctions(unittest.TestCase):
 
     def assertEqualCallArgs(self, func, call_params_string, locs=None):
@@ -2100,7 +2199,7 @@
         TestGetcallargsFunctions, TestGetcallargsMethods,
         TestGetcallargsUnboundMethods, TestGetattrStatic, TestGetGeneratorState,
         TestNoEOL, TestSignatureObject, TestSignatureBind, TestParameterObject,
-        TestBoundArguments
+        TestBoundArguments, TestGetClosureVars
     )
 
 if __name__ == "__main__":