- Issue #13703: oCERT-2011-003: add -R command-line option and PYTHONHASHSEED
environment variable, to provide an opt-in way to protect against denial of
service attacks due to hash collisions within the dict and set types. Patch
by David Malcolm, based on work by Victor Stinner.
diff --git a/Lib/test/test_cmd_line.py b/Lib/test/test_cmd_line.py
index efef74f..28362df 100644
--- a/Lib/test/test_cmd_line.py
+++ b/Lib/test/test_cmd_line.py
@@ -103,6 +103,20 @@
self.exit_code('-c', 'pass'),
0)
+ def test_hash_randomization(self):
+ # Verify that -R enables hash randomization:
+ self.verify_valid_flag('-R')
+ hashes = []
+ for i in range(2):
+ code = 'print(hash("spam"))'
+ data = self.start_python('-R', '-c', code)
+ hashes.append(data)
+ self.assertNotEqual(hashes[0], hashes[1])
+
+ # Verify that sys.flags contains hash_randomization
+ code = 'import sys; print sys.flags'
+ data = self.start_python('-R', '-c', code)
+ self.assertTrue('hash_randomization=1' in data)
def test_main():
test.test_support.run_unittest(CmdLineTest)
diff --git a/Lib/test/test_hash.py b/Lib/test/test_hash.py
index 7ce40b9..1a982c4 100644
--- a/Lib/test/test_hash.py
+++ b/Lib/test/test_hash.py
@@ -3,10 +3,18 @@
#
# Also test that hash implementations are inherited as expected
+import os
+import sys
+import struct
+import datetime
import unittest
+import subprocess
+
from test import test_support
from collections import Hashable
+IS_64BIT = (struct.calcsize('l') == 8)
+
class HashEqualityTestCase(unittest.TestCase):
@@ -133,10 +141,100 @@
for obj in self.hashes_to_check:
self.assertEqual(hash(obj), _default_hash(obj))
+class HashRandomizationTests(unittest.TestCase):
+
+ # Each subclass should define a field "repr_", containing the repr() of
+ # an object to be tested
+
+ def get_hash_command(self, repr_):
+ return 'print(hash(%s))' % repr_
+
+ def get_hash(self, repr_, seed=None):
+ env = os.environ.copy()
+ if seed is not None:
+ env['PYTHONHASHSEED'] = str(seed)
+ else:
+ env.pop('PYTHONHASHSEED', None)
+ cmd_line = [sys.executable, '-c', self.get_hash_command(repr_)]
+ p = subprocess.Popen(cmd_line, stdin=subprocess.PIPE,
+ stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
+ env=env)
+ out, err = p.communicate()
+ out = test_support.strip_python_stderr(out)
+ return int(out.strip())
+
+ def test_randomized_hash(self):
+ # two runs should return different hashes
+ run1 = self.get_hash(self.repr_, seed='random')
+ run2 = self.get_hash(self.repr_, seed='random')
+ self.assertNotEqual(run1, run2)
+
+class StringlikeHashRandomizationTests(HashRandomizationTests):
+ def test_null_hash(self):
+ # PYTHONHASHSEED=0 disables the randomized hash
+ if IS_64BIT:
+ known_hash_of_obj = 1453079729188098211
+ else:
+ known_hash_of_obj = -1600925533
+
+ # Randomization is disabled by default:
+ self.assertEqual(self.get_hash(self.repr_), known_hash_of_obj)
+
+ # It can also be disabled by setting the seed to 0:
+ self.assertEqual(self.get_hash(self.repr_, seed=0), known_hash_of_obj)
+
+ def test_fixed_hash(self):
+ # test a fixed seed for the randomized hash
+ # Note that all types share the same values:
+ if IS_64BIT:
+ h = -4410911502303878509
+ else:
+ h = -206076799
+ self.assertEqual(self.get_hash(self.repr_, seed=42), h)
+
+class StrHashRandomizationTests(StringlikeHashRandomizationTests):
+ repr_ = repr('abc')
+
+ def test_empty_string(self):
+ self.assertEqual(hash(""), 0)
+
+class UnicodeHashRandomizationTests(StringlikeHashRandomizationTests):
+ repr_ = repr(u'abc')
+
+ def test_empty_string(self):
+ self.assertEqual(hash(u""), 0)
+
+class BufferHashRandomizationTests(StringlikeHashRandomizationTests):
+ repr_ = 'buffer("abc")'
+
+ def test_empty_string(self):
+ self.assertEqual(hash(buffer("")), 0)
+
+class DatetimeTests(HashRandomizationTests):
+ def get_hash_command(self, repr_):
+ return 'import datetime; print(hash(%s))' % repr_
+
+class DatetimeDateTests(DatetimeTests):
+ repr_ = repr(datetime.date(1066, 10, 14))
+
+class DatetimeDatetimeTests(DatetimeTests):
+ repr_ = repr(datetime.datetime(1, 2, 3, 4, 5, 6, 7))
+
+class DatetimeTimeTests(DatetimeTests):
+ repr_ = repr(datetime.time(0))
+
+
def test_main():
test_support.run_unittest(HashEqualityTestCase,
HashInheritanceTestCase,
- HashBuiltinsTestCase)
+ HashBuiltinsTestCase,
+ StrHashRandomizationTests,
+ UnicodeHashRandomizationTests,
+ BufferHashRandomizationTests,
+ DatetimeDateTests,
+ DatetimeDatetimeTests,
+ DatetimeTimeTests)
+
if __name__ == "__main__":
diff --git a/Lib/test/test_os.py b/Lib/test/test_os.py
index db7e9b4..0561499 100644
--- a/Lib/test/test_os.py
+++ b/Lib/test/test_os.py
@@ -6,6 +6,8 @@
import unittest
import warnings
import sys
+import subprocess
+
from test import test_support
warnings.filterwarnings("ignore", "tempnam", RuntimeWarning, __name__)
@@ -499,18 +501,46 @@
class URandomTests (unittest.TestCase):
def test_urandom(self):
- try:
- with test_support.check_warnings():
- self.assertEqual(len(os.urandom(1)), 1)
- self.assertEqual(len(os.urandom(10)), 10)
- self.assertEqual(len(os.urandom(100)), 100)
- self.assertEqual(len(os.urandom(1000)), 1000)
- # see http://bugs.python.org/issue3708
- self.assertEqual(len(os.urandom(0.9)), 0)
- self.assertEqual(len(os.urandom(1.1)), 1)
- self.assertEqual(len(os.urandom(2.0)), 2)
- except NotImplementedError:
- pass
+ with test_support.check_warnings():
+ self.assertEqual(len(os.urandom(1)), 1)
+ self.assertEqual(len(os.urandom(10)), 10)
+ self.assertEqual(len(os.urandom(100)), 100)
+ self.assertEqual(len(os.urandom(1000)), 1000)
+ # see http://bugs.python.org/issue3708
+ self.assertEqual(len(os.urandom(0.9)), 0)
+ self.assertEqual(len(os.urandom(1.1)), 1)
+ self.assertEqual(len(os.urandom(2.0)), 2)
+
+ def test_urandom_length(self):
+ self.assertEqual(len(os.urandom(0)), 0)
+ self.assertEqual(len(os.urandom(1)), 1)
+ self.assertEqual(len(os.urandom(10)), 10)
+ self.assertEqual(len(os.urandom(100)), 100)
+ self.assertEqual(len(os.urandom(1000)), 1000)
+
+ def test_urandom_value(self):
+ data1 = os.urandom(16)
+ data2 = os.urandom(16)
+ self.assertNotEqual(data1, data2)
+
+ def get_urandom_subprocess(self, count):
+ code = '\n'.join((
+ 'import os, sys',
+ 'data = os.urandom(%s)' % count,
+ 'sys.stdout.write(data)',
+ 'sys.stdout.flush()'))
+ cmd_line = [sys.executable, '-c', code]
+ p = subprocess.Popen(cmd_line, stdin=subprocess.PIPE,
+ stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
+ out, err = p.communicate()
+ out = test_support.strip_python_stderr(out)
+ self.assertEqual(len(out), count)
+ return out
+
+ def test_urandom_subprocess(self):
+ data1 = self.get_urandom_subprocess(16)
+ data2 = self.get_urandom_subprocess(16)
+ self.assertNotEqual(data1, data2)
class Win32ErrorTests(unittest.TestCase):
def test_rename(self):
diff --git a/Lib/test/test_set.py b/Lib/test/test_set.py
index 3539a14..18822ca 100644
--- a/Lib/test/test_set.py
+++ b/Lib/test/test_set.py
@@ -6,7 +6,6 @@
import operator
import copy
import pickle
-import os
from random import randrange, shuffle
import sys
import collections
@@ -688,6 +687,17 @@
if self.repr is not None:
self.assertEqual(repr(self.set), self.repr)
+ def check_repr_against_values(self):
+ text = repr(self.set)
+ self.assertTrue(text.startswith('{'))
+ self.assertTrue(text.endswith('}'))
+
+ result = text[1:-1].split(', ')
+ result.sort()
+ sorted_repr_values = [repr(value) for value in self.values]
+ sorted_repr_values.sort()
+ self.assertEqual(result, sorted_repr_values)
+
def test_print(self):
fo = open(test_support.TESTFN, "wb")
try:
@@ -837,6 +847,46 @@
self.length = 3
self.repr = None
+#------------------------------------------------------------------------------
+
+class TestBasicOpsString(TestBasicOps):
+ def setUp(self):
+ self.case = "string set"
+ self.values = ["a", "b", "c"]
+ self.set = set(self.values)
+ self.dup = set(self.values)
+ self.length = 3
+
+ def test_repr(self):
+ self.check_repr_against_values()
+
+#------------------------------------------------------------------------------
+
+class TestBasicOpsUnicode(TestBasicOps):
+ def setUp(self):
+ self.case = "unicode set"
+ self.values = [u"a", u"b", u"c"]
+ self.set = set(self.values)
+ self.dup = set(self.values)
+ self.length = 3
+
+ def test_repr(self):
+ self.check_repr_against_values()
+
+#------------------------------------------------------------------------------
+
+class TestBasicOpsMixedStringUnicode(TestBasicOps):
+ def setUp(self):
+ self.case = "string and bytes set"
+ self.values = ["a", "b", u"a", u"b"]
+ self.set = set(self.values)
+ self.dup = set(self.values)
+ self.length = 4
+
+ def test_repr(self):
+ with test_support.check_warnings():
+ self.check_repr_against_values()
+
#==============================================================================
def baditer():
diff --git a/Lib/test/test_support.py b/Lib/test/test_support.py
index 2212fce..b572f9a 100644
--- a/Lib/test/test_support.py
+++ b/Lib/test/test_support.py
@@ -24,7 +24,7 @@
"captured_stdout", "TransientResource", "transient_internet",
"run_with_locale", "set_memlimit", "bigmemtest", "bigaddrspacetest",
"BasicTestRunner", "run_unittest", "run_doctest", "threading_setup",
- "threading_cleanup", "reap_children"]
+ "threading_cleanup", "reap_children", "strip_python_stderr"]
class Error(Exception):
"""Base class for regression test exceptions."""
@@ -893,3 +893,13 @@
break
except:
break
+
+def strip_python_stderr(stderr):
+ """Strip the stderr of a Python process from potential debug output
+ emitted by the interpreter.
+
+ This will typically be run on the result of the communicate() method
+ of a subprocess.Popen object.
+ """
+ stderr = re.sub(br"\[\d+ refs\]\r?\n?$", b"", stderr).strip()
+ return stderr
diff --git a/Lib/test/test_symtable.py b/Lib/test/test_symtable.py
index 0b4190d..4b54f55 100644
--- a/Lib/test/test_symtable.py
+++ b/Lib/test/test_symtable.py
@@ -105,10 +105,11 @@
def test_function_info(self):
func = self.spam
- self.assertEqual(func.get_parameters(), ("a", "b", "kw", "var"))
- self.assertEqual(func.get_locals(),
+ self.assertEqual(
+ tuple(sorted(func.get_parameters())), ("a", "b", "kw", "var"))
+ self.assertEqual(tuple(sorted(func.get_locals())),
("a", "b", "bar", "internal", "kw", "var", "x"))
- self.assertEqual(func.get_globals(), ("bar", "glob"))
+ self.assertEqual(tuple(sorted(func.get_globals())), ("bar", "glob"))
self.assertEqual(self.internal.get_frees(), ("x",))
def test_globals(self):
diff --git a/Lib/test/test_sys.py b/Lib/test/test_sys.py
index fd6fb2b..e82569a 100644
--- a/Lib/test/test_sys.py
+++ b/Lib/test/test_sys.py
@@ -384,7 +384,7 @@
attrs = ("debug", "py3k_warning", "division_warning", "division_new",
"inspect", "interactive", "optimize", "dont_write_bytecode",
"no_site", "ignore_environment", "tabcheck", "verbose",
- "unicode", "bytes_warning")
+ "unicode", "bytes_warning", "hash_randomization")
for attr in attrs:
self.assert_(hasattr(sys.flags, attr), attr)
self.assertEqual(type(getattr(sys.flags, attr)), int, attr)