closes bpo-31650: PEP 552 (Deterministic pycs) implementation (#4575)

Python now supports checking bytecode cache up-to-dateness with a hash of the
source contents rather than volatile source metadata. See the PEP for details.

While a fairly straightforward idea, quite a lot of code had to be modified due
to the pervasiveness of pyc implementation details in the codebase. Changes in
this commit include:

- The core changes to importlib to understand how to read, validate, and
  regenerate hash-based pycs.

- Support for generating hash-based pycs in py_compile and compileall.

- Modifications to our siphash implementation to support passing a custom
  key. We then expose it to importlib through _imp.

- Updates to all places in the interpreter, standard library, and tests that
  manually generate or parse pyc files to grok the new format.

- Support in the interpreter command line code for long options like
  --check-hash-based-pycs.

- Tests and documentation for all of the above.
diff --git a/Lib/test/test_compileall.py b/Lib/test/test_compileall.py
index 2356efc..38d7b99 100644
--- a/Lib/test/test_compileall.py
+++ b/Lib/test/test_compileall.py
@@ -48,9 +48,9 @@
 
     def data(self):
         with open(self.bc_path, 'rb') as file:
-            data = file.read(8)
+            data = file.read(12)
         mtime = int(os.stat(self.source_path).st_mtime)
-        compare = struct.pack('<4sl', importlib.util.MAGIC_NUMBER, mtime)
+        compare = struct.pack('<4sll', importlib.util.MAGIC_NUMBER, 0, mtime)
         return data, compare
 
     @unittest.skipUnless(hasattr(os, 'stat'), 'test needs os.stat()')
@@ -70,8 +70,8 @@
 
     def test_mtime(self):
         # Test a change in mtime leads to a new .pyc.
-        self.recreation_check(struct.pack('<4sl', importlib.util.MAGIC_NUMBER,
-                                          1))
+        self.recreation_check(struct.pack('<4sll', importlib.util.MAGIC_NUMBER,
+                                          0, 1))
 
     def test_magic_number(self):
         # Test a change in mtime leads to a new .pyc.
@@ -519,6 +519,19 @@
         out = self.assertRunOK('badfilename')
         self.assertRegex(out, b"Can't list 'badfilename'")
 
+    def test_pyc_invalidation_mode(self):
+        script_helper.make_script(self.pkgdir, 'f1', '')
+        pyc = importlib.util.cache_from_source(
+            os.path.join(self.pkgdir, 'f1.py'))
+        self.assertRunOK('--invalidation-mode=checked-hash', self.pkgdir)
+        with open(pyc, 'rb') as fp:
+            data = fp.read()
+        self.assertEqual(int.from_bytes(data[4:8], 'little'), 0b11)
+        self.assertRunOK('--invalidation-mode=unchecked-hash', self.pkgdir)
+        with open(pyc, 'rb') as fp:
+            data = fp.read()
+        self.assertEqual(int.from_bytes(data[4:8], 'little'), 0b01)
+
     @skipUnless(_have_multiprocessing, "requires multiprocessing")
     def test_workers(self):
         bar2fn = script_helper.make_script(self.directory, 'bar2', '')
diff --git a/Lib/test/test_imp.py b/Lib/test/test_imp.py
index b70ec7c..a115e60 100644
--- a/Lib/test/test_imp.py
+++ b/Lib/test/test_imp.py
@@ -4,11 +4,13 @@
 import os.path
 import sys
 from test import support
+from test.support import script_helper
 import unittest
 import warnings
 with warnings.catch_warnings():
     warnings.simplefilter('ignore', DeprecationWarning)
     import imp
+import _imp
 
 
 def requires_load_dynamic(meth):
@@ -329,6 +331,25 @@
         with self.assertRaises(TypeError):
             create_dynamic(BadSpec())
 
+    def test_source_hash(self):
+        self.assertEqual(_imp.source_hash(42, b'hi'), b'\xc6\xe7Z\r\x03:}\xab')
+        self.assertEqual(_imp.source_hash(43, b'hi'), b'\x85\x9765\xf8\x9a\x8b9')
+
+    def test_pyc_invalidation_mode_from_cmdline(self):
+        cases = [
+            ([], "default"),
+            (["--check-hash-based-pycs", "default"], "default"),
+            (["--check-hash-based-pycs", "always"], "always"),
+            (["--check-hash-based-pycs", "never"], "never"),
+        ]
+        for interp_args, expected in cases:
+            args = interp_args + [
+                "-c",
+                "import _imp; print(_imp.check_hash_based_pycs)",
+            ]
+            res = script_helper.assert_python_ok(*args)
+            self.assertEqual(res.out.strip().decode('utf-8'), expected)
+
 
 class ReloadTests(unittest.TestCase):
 
diff --git a/Lib/test/test_import/__init__.py b/Lib/test/test_import/__init__.py
index 5a610ba..ceea79f 100644
--- a/Lib/test/test_import/__init__.py
+++ b/Lib/test/test_import/__init__.py
@@ -598,7 +598,7 @@
     def test_foreign_code(self):
         py_compile.compile(self.file_name)
         with open(self.compiled_name, "rb") as f:
-            header = f.read(12)
+            header = f.read(16)
             code = marshal.load(f)
         constants = list(code.co_consts)
         foreign_code = importlib.import_module.__code__
diff --git a/Lib/test/test_importlib/source/test_file_loader.py b/Lib/test/test_importlib/source/test_file_loader.py
index a151149..643a02c 100644
--- a/Lib/test/test_importlib/source/test_file_loader.py
+++ b/Lib/test/test_importlib/source/test_file_loader.py
@@ -235,6 +235,123 @@
                 warnings.simplefilter('ignore', DeprecationWarning)
                 loader.load_module('bad name')
 
+    @util.writes_bytecode_files
+    def test_checked_hash_based_pyc(self):
+        with util.create_modules('_temp') as mapping:
+            source = mapping['_temp']
+            pyc = self.util.cache_from_source(source)
+            with open(source, 'wb') as fp:
+                fp.write(b'state = "old"')
+            os.utime(source, (50, 50))
+            py_compile.compile(
+                source,
+                invalidation_mode=py_compile.PycInvalidationMode.CHECKED_HASH,
+            )
+            loader = self.machinery.SourceFileLoader('_temp', source)
+            mod = types.ModuleType('_temp')
+            mod.__spec__ = self.util.spec_from_loader('_temp', loader)
+            loader.exec_module(mod)
+            self.assertEqual(mod.state, 'old')
+            # Write a new source with the same mtime and size as before.
+            with open(source, 'wb') as fp:
+                fp.write(b'state = "new"')
+            os.utime(source, (50, 50))
+            loader.exec_module(mod)
+            self.assertEqual(mod.state, 'new')
+            with open(pyc, 'rb') as fp:
+                data = fp.read()
+            self.assertEqual(int.from_bytes(data[4:8], 'little'), 0b11)
+            self.assertEqual(
+                self.util.source_hash(b'state = "new"'),
+                data[8:16],
+            )
+
+    @util.writes_bytecode_files
+    def test_overriden_checked_hash_based_pyc(self):
+        with util.create_modules('_temp') as mapping, \
+             unittest.mock.patch('_imp.check_hash_based_pycs', 'never'):
+            source = mapping['_temp']
+            pyc = self.util.cache_from_source(source)
+            with open(source, 'wb') as fp:
+                fp.write(b'state = "old"')
+            os.utime(source, (50, 50))
+            py_compile.compile(
+                source,
+                invalidation_mode=py_compile.PycInvalidationMode.CHECKED_HASH,
+            )
+            loader = self.machinery.SourceFileLoader('_temp', source)
+            mod = types.ModuleType('_temp')
+            mod.__spec__ = self.util.spec_from_loader('_temp', loader)
+            loader.exec_module(mod)
+            self.assertEqual(mod.state, 'old')
+            # Write a new source with the same mtime and size as before.
+            with open(source, 'wb') as fp:
+                fp.write(b'state = "new"')
+            os.utime(source, (50, 50))
+            loader.exec_module(mod)
+            self.assertEqual(mod.state, 'old')
+
+    @util.writes_bytecode_files
+    def test_unchecked_hash_based_pyc(self):
+        with util.create_modules('_temp') as mapping:
+            source = mapping['_temp']
+            pyc = self.util.cache_from_source(source)
+            with open(source, 'wb') as fp:
+                fp.write(b'state = "old"')
+            os.utime(source, (50, 50))
+            py_compile.compile(
+                source,
+                invalidation_mode=py_compile.PycInvalidationMode.UNCHECKED_HASH,
+            )
+            loader = self.machinery.SourceFileLoader('_temp', source)
+            mod = types.ModuleType('_temp')
+            mod.__spec__ = self.util.spec_from_loader('_temp', loader)
+            loader.exec_module(mod)
+            self.assertEqual(mod.state, 'old')
+            # Update the source file, which should be ignored.
+            with open(source, 'wb') as fp:
+                fp.write(b'state = "new"')
+            loader.exec_module(mod)
+            self.assertEqual(mod.state, 'old')
+            with open(pyc, 'rb') as fp:
+                data = fp.read()
+            self.assertEqual(int.from_bytes(data[4:8], 'little'), 0b1)
+            self.assertEqual(
+                self.util.source_hash(b'state = "old"'),
+                data[8:16],
+            )
+
+    @util.writes_bytecode_files
+    def test_overiden_unchecked_hash_based_pyc(self):
+        with util.create_modules('_temp') as mapping, \
+             unittest.mock.patch('_imp.check_hash_based_pycs', 'always'):
+            source = mapping['_temp']
+            pyc = self.util.cache_from_source(source)
+            with open(source, 'wb') as fp:
+                fp.write(b'state = "old"')
+            os.utime(source, (50, 50))
+            py_compile.compile(
+                source,
+                invalidation_mode=py_compile.PycInvalidationMode.UNCHECKED_HASH,
+            )
+            loader = self.machinery.SourceFileLoader('_temp', source)
+            mod = types.ModuleType('_temp')
+            mod.__spec__ = self.util.spec_from_loader('_temp', loader)
+            loader.exec_module(mod)
+            self.assertEqual(mod.state, 'old')
+            # Update the source file, which should be ignored.
+            with open(source, 'wb') as fp:
+                fp.write(b'state = "new"')
+            loader.exec_module(mod)
+            self.assertEqual(mod.state, 'new')
+            with open(pyc, 'rb') as fp:
+                data = fp.read()
+            self.assertEqual(int.from_bytes(data[4:8], 'little'), 0b1)
+            self.assertEqual(
+                self.util.source_hash(b'state = "new"'),
+                data[8:16],
+            )
+
 
 (Frozen_SimpleTest,
  Source_SimpleTest
@@ -247,15 +364,17 @@
     def import_(self, file, module_name):
         raise NotImplementedError
 
-    def manipulate_bytecode(self, name, mapping, manipulator, *,
-                            del_source=False):
+    def manipulate_bytecode(self,
+                            name, mapping, manipulator, *,
+                            del_source=False,
+                            invalidation_mode=py_compile.PycInvalidationMode.TIMESTAMP):
         """Manipulate the bytecode of a module by passing it into a callable
         that returns what to use as the new bytecode."""
         try:
             del sys.modules['_temp']
         except KeyError:
             pass
-        py_compile.compile(mapping[name])
+        py_compile.compile(mapping[name], invalidation_mode=invalidation_mode)
         if not del_source:
             bytecode_path = self.util.cache_from_source(mapping[name])
         else:
@@ -294,24 +413,51 @@
                                                 del_source=del_source)
             test('_temp', mapping, bc_path)
 
-    def _test_partial_timestamp(self, test, *, del_source=False):
+    def _test_partial_flags(self, test, *, del_source=False):
         with util.create_modules('_temp') as mapping:
             bc_path = self.manipulate_bytecode('_temp', mapping,
-                                                lambda bc: bc[:7],
-                                                del_source=del_source)
+                                               lambda bc: bc[:7],
+                                               del_source=del_source)
             test('_temp', mapping, bc_path)
 
-    def _test_partial_size(self, test, *, del_source=False):
+    def _test_partial_hash(self, test, *, del_source=False):
+        with util.create_modules('_temp') as mapping:
+            bc_path = self.manipulate_bytecode(
+                '_temp',
+                mapping,
+                lambda bc: bc[:13],
+                del_source=del_source,
+                invalidation_mode=py_compile.PycInvalidationMode.CHECKED_HASH,
+            )
+            test('_temp', mapping, bc_path)
+        with util.create_modules('_temp') as mapping:
+            bc_path = self.manipulate_bytecode(
+                '_temp',
+                mapping,
+                lambda bc: bc[:13],
+                del_source=del_source,
+                invalidation_mode=py_compile.PycInvalidationMode.UNCHECKED_HASH,
+            )
+            test('_temp', mapping, bc_path)
+
+    def _test_partial_timestamp(self, test, *, del_source=False):
         with util.create_modules('_temp') as mapping:
             bc_path = self.manipulate_bytecode('_temp', mapping,
                                                 lambda bc: bc[:11],
                                                 del_source=del_source)
             test('_temp', mapping, bc_path)
 
+    def _test_partial_size(self, test, *, del_source=False):
+        with util.create_modules('_temp') as mapping:
+            bc_path = self.manipulate_bytecode('_temp', mapping,
+                                                lambda bc: bc[:15],
+                                                del_source=del_source)
+            test('_temp', mapping, bc_path)
+
     def _test_no_marshal(self, *, del_source=False):
         with util.create_modules('_temp') as mapping:
             bc_path = self.manipulate_bytecode('_temp', mapping,
-                                                lambda bc: bc[:12],
+                                                lambda bc: bc[:16],
                                                 del_source=del_source)
             file_path = mapping['_temp'] if not del_source else bc_path
             with self.assertRaises(EOFError):
@@ -320,7 +466,7 @@
     def _test_non_code_marshal(self, *, del_source=False):
         with util.create_modules('_temp') as mapping:
             bytecode_path = self.manipulate_bytecode('_temp', mapping,
-                                    lambda bc: bc[:12] + marshal.dumps(b'abcd'),
+                                    lambda bc: bc[:16] + marshal.dumps(b'abcd'),
                                     del_source=del_source)
             file_path = mapping['_temp'] if not del_source else bytecode_path
             with self.assertRaises(ImportError) as cm:
@@ -331,7 +477,7 @@
     def _test_bad_marshal(self, *, del_source=False):
         with util.create_modules('_temp') as mapping:
             bytecode_path = self.manipulate_bytecode('_temp', mapping,
-                                                lambda bc: bc[:12] + b'<test>',
+                                                lambda bc: bc[:16] + b'<test>',
                                                 del_source=del_source)
             file_path = mapping['_temp'] if not del_source else bytecode_path
             with self.assertRaises(EOFError):
@@ -376,7 +522,7 @@
         def test(name, mapping, bytecode_path):
             self.import_(mapping[name], name)
             with open(bytecode_path, 'rb') as file:
-                self.assertGreater(len(file.read()), 12)
+                self.assertGreater(len(file.read()), 16)
 
         self._test_empty_file(test)
 
@@ -384,7 +530,7 @@
         def test(name, mapping, bytecode_path):
             self.import_(mapping[name], name)
             with open(bytecode_path, 'rb') as file:
-                self.assertGreater(len(file.read()), 12)
+                self.assertGreater(len(file.read()), 16)
 
         self._test_partial_magic(test)
 
@@ -395,7 +541,7 @@
         def test(name, mapping, bytecode_path):
             self.import_(mapping[name], name)
             with open(bytecode_path, 'rb') as file:
-                self.assertGreater(len(file.read()), 12)
+                self.assertGreater(len(file.read()), 16)
 
         self._test_magic_only(test)
 
@@ -418,18 +564,38 @@
         def test(name, mapping, bc_path):
             self.import_(mapping[name], name)
             with open(bc_path, 'rb') as file:
-                self.assertGreater(len(file.read()), 12)
+                self.assertGreater(len(file.read()), 16)
 
         self._test_partial_timestamp(test)
 
     @util.writes_bytecode_files
+    def test_partial_flags(self):
+        # When the flags is partial, regenerate the .pyc, else raise EOFError.
+        def test(name, mapping, bc_path):
+            self.import_(mapping[name], name)
+            with open(bc_path, 'rb') as file:
+                self.assertGreater(len(file.read()), 16)
+
+        self._test_partial_flags(test)
+
+    @util.writes_bytecode_files
+    def test_partial_hash(self):
+        # When the hash is partial, regenerate the .pyc, else raise EOFError.
+        def test(name, mapping, bc_path):
+            self.import_(mapping[name], name)
+            with open(bc_path, 'rb') as file:
+                self.assertGreater(len(file.read()), 16)
+
+        self._test_partial_hash(test)
+
+    @util.writes_bytecode_files
     def test_partial_size(self):
         # When the size is partial, regenerate the .pyc, else
         # raise EOFError.
         def test(name, mapping, bc_path):
             self.import_(mapping[name], name)
             with open(bc_path, 'rb') as file:
-                self.assertGreater(len(file.read()), 12)
+                self.assertGreater(len(file.read()), 16)
 
         self._test_partial_size(test)
 
@@ -459,13 +625,13 @@
             py_compile.compile(mapping['_temp'])
             bytecode_path = self.util.cache_from_source(mapping['_temp'])
             with open(bytecode_path, 'r+b') as bytecode_file:
-                bytecode_file.seek(4)
+                bytecode_file.seek(8)
                 bytecode_file.write(zeros)
             self.import_(mapping['_temp'], '_temp')
             source_mtime = os.path.getmtime(mapping['_temp'])
             source_timestamp = self.importlib._w_long(source_mtime)
             with open(bytecode_path, 'rb') as bytecode_file:
-                bytecode_file.seek(4)
+                bytecode_file.seek(8)
                 self.assertEqual(bytecode_file.read(4), source_timestamp)
 
     # [bytecode read-only]
@@ -560,6 +726,20 @@
 
         self._test_partial_timestamp(test, del_source=True)
 
+    def test_partial_flags(self):
+        def test(name, mapping, bytecode_path):
+            with self.assertRaises(EOFError):
+                self.import_(bytecode_path, name)
+
+        self._test_partial_flags(test, del_source=True)
+
+    def test_partial_hash(self):
+        def test(name, mapping, bytecode_path):
+            with self.assertRaises(EOFError):
+                self.import_(bytecode_path, name)
+
+        self._test_partial_hash(test, del_source=True)
+
     def test_partial_size(self):
         def test(name, mapping, bytecode_path):
             with self.assertRaises(EOFError):
diff --git a/Lib/test/test_importlib/test_abc.py b/Lib/test/test_importlib/test_abc.py
index 54b2da6..4ba28c6 100644
--- a/Lib/test/test_importlib/test_abc.py
+++ b/Lib/test/test_importlib/test_abc.py
@@ -673,6 +673,7 @@
         if magic is None:
             magic = self.util.MAGIC_NUMBER
         data = bytearray(magic)
+        data.extend(self.init._w_long(0))
         data.extend(self.init._w_long(self.source_mtime))
         data.extend(self.init._w_long(self.source_size))
         code_object = compile(self.source, self.path, 'exec',
@@ -836,6 +837,7 @@
         if bytecode_written:
             self.assertIn(self.cached, self.loader.written)
             data = bytearray(self.util.MAGIC_NUMBER)
+            data.extend(self.init._w_long(0))
             data.extend(self.init._w_long(self.loader.source_mtime))
             data.extend(self.init._w_long(self.loader.source_size))
             data.extend(marshal.dumps(code_object))
diff --git a/Lib/test/test_py_compile.py b/Lib/test/test_py_compile.py
index 4a6caa5..bcb686c 100644
--- a/Lib/test/test_py_compile.py
+++ b/Lib/test/test_py_compile.py
@@ -122,6 +122,24 @@
         # Specifying optimized bytecode should lead to a path reflecting that.
         self.assertIn('opt-2', py_compile.compile(self.source_path, optimize=2))
 
+    def test_invalidation_mode(self):
+        py_compile.compile(
+            self.source_path,
+            invalidation_mode=py_compile.PycInvalidationMode.CHECKED_HASH,
+        )
+        with open(self.cache_path, 'rb') as fp:
+            flags = importlib._bootstrap_external._classify_pyc(
+                fp.read(), 'test', {})
+        self.assertEqual(flags, 0b11)
+        py_compile.compile(
+            self.source_path,
+            invalidation_mode=py_compile.PycInvalidationMode.UNCHECKED_HASH,
+        )
+        with open(self.cache_path, 'rb') as fp:
+            flags = importlib._bootstrap_external._classify_pyc(
+                fp.read(), 'test', {})
+        self.assertEqual(flags, 0b1)
+
 
 if __name__ == "__main__":
     unittest.main()
diff --git a/Lib/test/test_zipimport.py b/Lib/test/test_zipimport.py
index 67ca39b..901bebd 100644
--- a/Lib/test/test_zipimport.py
+++ b/Lib/test/test_zipimport.py
@@ -40,7 +40,7 @@
         else:
             mtime = int(-0x100000000 + int(mtime))
     pyc = (importlib.util.MAGIC_NUMBER +
-        struct.pack("<ii", int(mtime), size & 0xFFFFFFFF) + data)
+        struct.pack("<iii", 0, int(mtime), size & 0xFFFFFFFF) + data)
     return pyc
 
 def module_path_to_dotted_name(path):
@@ -187,6 +187,20 @@
                  TESTMOD + pyc_ext: (NOW, test_pyc)}
         self.doTest(pyc_ext, files, TESTMOD)
 
+    def testUncheckedHashBasedPyc(self):
+        source = b"state = 'old'"
+        source_hash = importlib.util.source_hash(source)
+        bytecode = importlib._bootstrap_external._code_to_hash_pyc(
+            compile(source, "???", "exec"),
+            source_hash,
+            False, # unchecked
+        )
+        files = {TESTMOD + ".py": (NOW, "state = 'new'"),
+                 TESTMOD + ".pyc": (NOW - 20, bytecode)}
+        def check(mod):
+            self.assertEqual(mod.state, 'old')
+        self.doTest(None, files, TESTMOD, call=check)
+
     def testEmptyPy(self):
         files = {TESTMOD + ".py": (NOW, "")}
         self.doTest(None, files, TESTMOD)
@@ -215,7 +229,7 @@
         badtime_pyc = bytearray(test_pyc)
         # flip the second bit -- not the first as that one isn't stored in the
         # .py's mtime in the zip archive.
-        badtime_pyc[7] ^= 0x02
+        badtime_pyc[11] ^= 0x02
         files = {TESTMOD + ".py": (NOW, test_src),
                  TESTMOD + pyc_ext: (NOW, badtime_pyc)}
         self.doTest(".py", files, TESTMOD)