Port #7347 to py3k.

Add CreateKeyEx and DeleteKeyEx, along with test improvements.
diff --git a/Lib/test/test_winreg.py b/Lib/test/test_winreg.py
index fd28852..e29a727 100644
--- a/Lib/test/test_winreg.py
+++ b/Lib/test/test_winreg.py
@@ -5,13 +5,32 @@
 import os, sys
 import unittest
 from test import support
+from platform import machine
 
 # Do this first so test will be skipped if module doesn't exist
 support.import_module('winreg')
 # Now import everything
 from winreg import *
 
+try:
+    REMOTE_NAME = sys.argv[sys.argv.index("--remote")+1]
+except (IndexError, ValueError):
+    REMOTE_NAME = None
+
+# tuple of (major, minor)
+WIN_VER = sys.getwindowsversion()[:2]
+# Some tests should only run on 64-bit architectures where WOW64 will be.
+WIN64_MACHINE = True if machine() == "AMD64" else False
+
+# Starting with Windows 7 and Windows Server 2008 R2, WOW64 no longer uses
+# registry reflection and formerly reflected keys are shared instead.
+# Windows 7 and Windows Server 2008 R2 are version 6.1. Due to this, some
+# tests are only valid up until 6.1
+HAS_REFLECTION = True if WIN_VER < (6, 1) else False
+
 test_key_name = "SOFTWARE\\Python Registry Test Key - Delete Me"
+# On OS'es that support reflection we should test with a reflected key
+test_reflect_key_name = "SOFTWARE\\Classes\\Python Test Key - Delete Me"
 
 test_data = [
     ("Int Value",     45,                                      REG_DWORD),
@@ -25,8 +44,7 @@
     ("Japanese 日本", "日本語", REG_SZ),
 ]
 
-class WinregTests(unittest.TestCase):
-    remote_name = None
+class BaseWinregTests(unittest.TestCase):
 
     def setUp(self):
         # Make sure that the test key is absent when the test
@@ -49,7 +67,8 @@
         CloseKey(hkey)
         DeleteKey(root, subkey)
 
-    def WriteTestData(self, root_key, subkeystr="sub_key"):
+    def _write_test_data(self, root_key, subkeystr="sub_key",
+                         CreateKey=CreateKey):
         # Set the default value for this key.
         SetValue(root_key, test_key_name, REG_SZ, "Default value")
         key = CreateKey(root_key, test_key_name)
@@ -90,7 +109,7 @@
         except EnvironmentError:
             pass
 
-    def ReadTestData(self, root_key, subkeystr="sub_key"):
+    def _read_test_data(self, root_key, subkeystr="sub_key", OpenKey=OpenKey):
         # Check we can get default value for this key.
         val = QueryValue(root_key, test_key_name)
         self.assertEquals(val, "Default value",
@@ -130,7 +149,7 @@
 
         key.Close()
 
-    def DeleteTestData(self, root_key, subkeystr="sub_key"):
+    def _delete_test_data(self, root_key, subkeystr="sub_key"):
         key = OpenKey(root_key, test_key_name, 0, KEY_ALL_ACCESS)
         sub_key = OpenKey(key, subkeystr, 0, KEY_ALL_ACCESS)
         # It is not necessary to delete the values before deleting
@@ -160,39 +179,179 @@
         except WindowsError: # Use this error name this time
             pass
 
-    def TestAll(self, root_key, subkeystr="sub_key"):
-        self.WriteTestData(root_key, subkeystr)
-        self.ReadTestData(root_key, subkeystr)
-        self.DeleteTestData(root_key, subkeystr)
+    def _test_all(self, root_key, subkeystr="sub_key"):
+        self._write_test_data(root_key, subkeystr)
+        self._read_test_data(root_key, subkeystr)
+        self._delete_test_data(root_key, subkeystr)
 
-    def testLocalMachineRegistryWorks(self):
-        self.TestAll(HKEY_CURRENT_USER)
-        self.TestAll(HKEY_CURRENT_USER, "日本-subkey")
+class LocalWinregTests(BaseWinregTests):
 
-    def testConnectRegistryToLocalMachineWorks(self):
+    def test_registry_works(self):
+        self._test_all(HKEY_CURRENT_USER)
+        self._test_all(HKEY_CURRENT_USER, "日本-subkey")
+
+    def test_registry_works_extended_functions(self):
+        # Substitute the regular CreateKey and OpenKey calls with their
+        # extended counterparts.
+        # Note: DeleteKeyEx is not used here because it is platform dependent
+        cke = lambda key, sub_key: CreateKeyEx(key, sub_key, 0, KEY_ALL_ACCESS)
+        self._write_test_data(HKEY_CURRENT_USER, CreateKey=cke)
+
+        oke = lambda key, sub_key: OpenKeyEx(key, sub_key, 0, KEY_READ)
+        self._read_test_data(HKEY_CURRENT_USER, OpenKey=oke)
+
+        self._delete_test_data(HKEY_CURRENT_USER)
+
+    def test_connect_registry_to_local_machine_works(self):
         # perform minimal ConnectRegistry test which just invokes it
         h = ConnectRegistry(None, HKEY_LOCAL_MACHINE)
+        self.assertNotEqual(h.handle, 0)
         h.Close()
+        self.assertEqual(h.handle, 0)
 
-    def testRemoteMachineRegistryWorks(self):
-        if not self.remote_name:
-            return # remote machine name not specified
-        remote_key = ConnectRegistry(self.remote_name, HKEY_CURRENT_USER)
-        self.TestAll(remote_key)
+    def test_inexistant_remote_registry(self):
+        connect = lambda: ConnectRegistry("abcdefghijkl", HKEY_CURRENT_USER)
+        self.assertRaises(WindowsError, connect)
 
     def testExpandEnvironmentStrings(self):
         r = ExpandEnvironmentStrings("%windir%\\test")
         self.assertEqual(type(r), str)
         self.assertEqual(r, os.environ["windir"] + "\\test")
 
+    def test_context_manager(self):
+        # ensure that the handle is closed if an exception occurs
+        try:
+            with ConnectRegistry(None, HKEY_LOCAL_MACHINE) as h:
+                self.assertNotEqual(h.handle, 0)
+                raise WindowsError
+        except WindowsError:
+            self.assertEqual(h.handle, 0)
+
+    # Reflection requires XP x64/Vista at a minimum. XP doesn't have this stuff
+    # or DeleteKeyEx so make sure their use raises NotImplementedError
+    @unittest.skipUnless(WIN_VER < (5, 2), "Requires Windows XP")
+    def test_reflection_unsupported(self):
+        try:
+            with CreateKey(HKEY_CURRENT_USER, test_key_name) as ck:
+                self.assertNotEqual(ck.handle, 0)
+
+            key = OpenKey(HKEY_CURRENT_USER, test_key_name)
+            self.assertNotEqual(key.handle, 0)
+
+            with self.assertRaises(NotImplementedError):
+                DisableReflectionKey(key)
+            with self.assertRaises(NotImplementedError):
+                EnableReflectionKey(key)
+            with self.assertRaises(NotImplementedError):
+                QueryReflectionKey(key)
+            with self.assertRaises(NotImplementedError):
+                DeleteKeyEx(HKEY_CURRENT_USER, test_key_name)
+        finally:
+            DeleteKey(HKEY_CURRENT_USER, test_key_name)
+
+
+@unittest.skipUnless(REMOTE_NAME, "Skipping remote registry tests")
+class RemoteWinregTests(BaseWinregTests):
+
+    def test_remote_registry_works(self):
+        remote_key = ConnectRegistry(REMOTE_NAME, HKEY_CURRENT_USER)
+        self._test_all(remote_key)
+
+
+@unittest.skipUnless(WIN64_MACHINE, "x64 specific registry tests")
+class Win64WinregTests(BaseWinregTests):
+
+    def test_reflection_functions(self):
+        # Test that we can call the query, enable, and disable functions
+        # on a key which isn't on the reflection list with no consequences.
+        with OpenKey(HKEY_LOCAL_MACHINE, "Software") as key:
+            # HKLM\Software is redirected but not reflected in all OSes
+            self.assertTrue(QueryReflectionKey(key))
+            self.assertEquals(None, EnableReflectionKey(key))
+            self.assertEquals(None, DisableReflectionKey(key))
+            self.assertTrue(QueryReflectionKey(key))
+
+    @unittest.skipUnless(HAS_REFLECTION, "OS doesn't support reflection")
+    def test_reflection(self):
+        # Test that we can create, open, and delete keys in the 32-bit
+        # area. Because we are doing this in a key which gets reflected,
+        # test the differences of 32 and 64-bit keys before and after the
+        # reflection occurs (ie. when the created key is closed).
+        try:
+            with CreateKeyEx(HKEY_CURRENT_USER, test_reflect_key_name, 0,
+                             KEY_ALL_ACCESS | KEY_WOW64_32KEY) as created_key:
+                self.assertNotEqual(created_key.handle, 0)
+
+                # The key should now be available in the 32-bit area
+                with OpenKey(HKEY_CURRENT_USER, test_reflect_key_name, 0,
+                             KEY_ALL_ACCESS | KEY_WOW64_32KEY) as key:
+                    self.assertNotEqual(key.handle, 0)
+
+                # Write a value to what currently is only in the 32-bit area
+                SetValueEx(created_key, "", 0, REG_SZ, "32KEY")
+
+                # The key is not reflected until created_key is closed.
+                # The 64-bit version of the key should not be available yet.
+                open_fail = lambda: OpenKey(HKEY_CURRENT_USER,
+                                            test_reflect_key_name, 0,
+                                            KEY_READ | KEY_WOW64_64KEY)
+                self.assertRaises(WindowsError, open_fail)
+
+            # Now explicitly open the 64-bit version of the key
+            with OpenKey(HKEY_CURRENT_USER, test_reflect_key_name, 0,
+                         KEY_ALL_ACCESS | KEY_WOW64_64KEY) as key:
+                self.assertNotEqual(key.handle, 0)
+                # Make sure the original value we set is there
+                self.assertEqual("32KEY", QueryValue(key, ""))
+                # Set a new value, which will get reflected to 32-bit
+                SetValueEx(key, "", 0, REG_SZ, "64KEY")
+
+            # Reflection uses a "last-writer wins policy, so the value we set
+            # on the 64-bit key should be the same on 32-bit
+            with OpenKey(HKEY_CURRENT_USER, test_reflect_key_name, 0,
+                         KEY_READ | KEY_WOW64_32KEY) as key:
+                self.assertEqual("64KEY", QueryValue(key, ""))
+        finally:
+            DeleteKeyEx(HKEY_CURRENT_USER, test_reflect_key_name,
+                        KEY_WOW64_32KEY, 0)
+
+    @unittest.skipUnless(HAS_REFLECTION, "OS doesn't support reflection")
+    def test_disable_reflection(self):
+        # Make use of a key which gets redirected and reflected
+        try:
+            with CreateKeyEx(HKEY_CURRENT_USER, test_reflect_key_name, 0,
+                             KEY_ALL_ACCESS | KEY_WOW64_32KEY) as created_key:
+                # QueryReflectionKey returns whether or not the key is disabled
+                disabled = QueryReflectionKey(created_key)
+                self.assertEqual(type(disabled), bool)
+                # HKCU\Software\Classes is reflected by default
+                self.assertFalse(disabled)
+
+                DisableReflectionKey(created_key)
+                self.assertTrue(QueryReflectionKey(created_key))
+
+            # The key is now closed and would normally be reflected to the
+            # 64-bit area, but let's make sure that didn't happen.
+            open_fail = lambda: OpenKeyEx(HKEY_CURRENT_USER,
+                                          test_reflect_key_name, 0,
+                                          KEY_READ | KEY_WOW64_64KEY)
+            self.assertRaises(WindowsError, open_fail)
+
+            # Make sure the 32-bit key is actually there
+            with OpenKeyEx(HKEY_CURRENT_USER, test_reflect_key_name, 0,
+                           KEY_READ | KEY_WOW64_32KEY) as key:
+                self.assertNotEqual(key.handle, 0)
+        finally:
+            DeleteKeyEx(HKEY_CURRENT_USER, test_reflect_key_name,
+                        KEY_WOW64_32KEY, 0)
+
+
 def test_main():
-    support.run_unittest(WinregTests)
+    support.run_unittest(LocalWinregTests, RemoteWinregTests,
+                         Win64WinregTests)
 
 if __name__ == "__main__":
-    try:
-        WinregTests.remote_name = sys.argv[sys.argv.index("--remote")+1]
-    except (IndexError, ValueError):
+    if not REMOTE_NAME:
         print("Remote registry calls can be tested using",
               "'test_winreg.py --remote \\\\machine_name'")
-        WinregTests.remote_name = None
     test_main()