Issue #14166: Pickler objects now have an optional `dispatch_table` attribute which allows to set custom per-pickler reduction functions.
Patch by sbt.
diff --git a/Lib/pickle.py b/Lib/pickle.py
index c01a6af..20b3646 100644
--- a/Lib/pickle.py
+++ b/Lib/pickle.py
@@ -297,8 +297,8 @@
f(self, obj) # Call unbound method with explicit self
return
- # Check copyreg.dispatch_table
- reduce = dispatch_table.get(t)
+ # Check private dispatch table if any, or else copyreg.dispatch_table
+ reduce = getattr(self, 'dispatch_table', dispatch_table).get(t)
if reduce:
rv = reduce(obj)
else:
diff --git a/Lib/test/pickletester.py b/Lib/test/pickletester.py
index 831306f..1a551c8 100644
--- a/Lib/test/pickletester.py
+++ b/Lib/test/pickletester.py
@@ -1605,6 +1605,105 @@
self.assertEqual(unpickler.load(), data)
+# Tests for dispatch_table attribute
+
+REDUCE_A = 'reduce_A'
+
+class AAA(object):
+ def __reduce__(self):
+ return str, (REDUCE_A,)
+
+class BBB(object):
+ pass
+
+class AbstractDispatchTableTests(unittest.TestCase):
+
+ def test_default_dispatch_table(self):
+ # No dispatch_table attribute by default
+ f = io.BytesIO()
+ p = self.pickler_class(f, 0)
+ with self.assertRaises(AttributeError):
+ p.dispatch_table
+ self.assertFalse(hasattr(p, 'dispatch_table'))
+
+ def test_class_dispatch_table(self):
+ # A dispatch_table attribute can be specified class-wide
+ dt = self.get_dispatch_table()
+
+ class MyPickler(self.pickler_class):
+ dispatch_table = dt
+
+ def dumps(obj, protocol=None):
+ f = io.BytesIO()
+ p = MyPickler(f, protocol)
+ self.assertEqual(p.dispatch_table, dt)
+ p.dump(obj)
+ return f.getvalue()
+
+ self._test_dispatch_table(dumps, dt)
+
+ def test_instance_dispatch_table(self):
+ # A dispatch_table attribute can also be specified instance-wide
+ dt = self.get_dispatch_table()
+
+ def dumps(obj, protocol=None):
+ f = io.BytesIO()
+ p = self.pickler_class(f, protocol)
+ p.dispatch_table = dt
+ self.assertEqual(p.dispatch_table, dt)
+ p.dump(obj)
+ return f.getvalue()
+
+ self._test_dispatch_table(dumps, dt)
+
+ def _test_dispatch_table(self, dumps, dispatch_table):
+ def custom_load_dump(obj):
+ return pickle.loads(dumps(obj, 0))
+
+ def default_load_dump(obj):
+ return pickle.loads(pickle.dumps(obj, 0))
+
+ # pickling complex numbers using protocol 0 relies on copyreg
+ # so check pickling a complex number still works
+ z = 1 + 2j
+ self.assertEqual(custom_load_dump(z), z)
+ self.assertEqual(default_load_dump(z), z)
+
+ # modify pickling of complex
+ REDUCE_1 = 'reduce_1'
+ def reduce_1(obj):
+ return str, (REDUCE_1,)
+ dispatch_table[complex] = reduce_1
+ self.assertEqual(custom_load_dump(z), REDUCE_1)
+ self.assertEqual(default_load_dump(z), z)
+
+ # check picklability of AAA and BBB
+ a = AAA()
+ b = BBB()
+ self.assertEqual(custom_load_dump(a), REDUCE_A)
+ self.assertIsInstance(custom_load_dump(b), BBB)
+ self.assertEqual(default_load_dump(a), REDUCE_A)
+ self.assertIsInstance(default_load_dump(b), BBB)
+
+ # modify pickling of BBB
+ dispatch_table[BBB] = reduce_1
+ self.assertEqual(custom_load_dump(a), REDUCE_A)
+ self.assertEqual(custom_load_dump(b), REDUCE_1)
+ self.assertEqual(default_load_dump(a), REDUCE_A)
+ self.assertIsInstance(default_load_dump(b), BBB)
+
+ # revert pickling of BBB and modify pickling of AAA
+ REDUCE_2 = 'reduce_2'
+ def reduce_2(obj):
+ return str, (REDUCE_2,)
+ dispatch_table[AAA] = reduce_2
+ del dispatch_table[BBB]
+ self.assertEqual(custom_load_dump(a), REDUCE_2)
+ self.assertIsInstance(custom_load_dump(b), BBB)
+ self.assertEqual(default_load_dump(a), REDUCE_A)
+ self.assertIsInstance(default_load_dump(b), BBB)
+
+
if __name__ == "__main__":
# Print some stuff that can be used to rewrite DATA{0,1,2}
from pickletools import dis
diff --git a/Lib/test/test_pickle.py b/Lib/test/test_pickle.py
index 9da2cae..f52d4bd 100644
--- a/Lib/test/test_pickle.py
+++ b/Lib/test/test_pickle.py
@@ -1,5 +1,6 @@
import pickle
import io
+import collections
from test import support
@@ -7,6 +8,7 @@
from test.pickletester import AbstractPickleModuleTests
from test.pickletester import AbstractPersistentPicklerTests
from test.pickletester import AbstractPicklerUnpicklerObjectTests
+from test.pickletester import AbstractDispatchTableTests
from test.pickletester import BigmemPickleTests
try:
@@ -80,6 +82,18 @@
unpickler_class = pickle._Unpickler
+class PyDispatchTableTests(AbstractDispatchTableTests):
+ pickler_class = pickle._Pickler
+ def get_dispatch_table(self):
+ return pickle.dispatch_table.copy()
+
+
+class PyChainDispatchTableTests(AbstractDispatchTableTests):
+ pickler_class = pickle._Pickler
+ def get_dispatch_table(self):
+ return collections.ChainMap({}, pickle.dispatch_table)
+
+
if has_c_implementation:
class CPicklerTests(PyPicklerTests):
pickler = _pickle.Pickler
@@ -101,14 +115,26 @@
pickler_class = _pickle.Pickler
unpickler_class = _pickle.Unpickler
+ class CDispatchTableTests(AbstractDispatchTableTests):
+ pickler_class = pickle.Pickler
+ def get_dispatch_table(self):
+ return pickle.dispatch_table.copy()
+
+ class CChainDispatchTableTests(AbstractDispatchTableTests):
+ pickler_class = pickle.Pickler
+ def get_dispatch_table(self):
+ return collections.ChainMap({}, pickle.dispatch_table)
+
def test_main():
- tests = [PickleTests, PyPicklerTests, PyPersPicklerTests]
+ tests = [PickleTests, PyPicklerTests, PyPersPicklerTests,
+ PyDispatchTableTests, PyChainDispatchTableTests]
if has_c_implementation:
tests.extend([CPicklerTests, CPersPicklerTests,
CDumpPickle_LoadPickle, DumpPickle_CLoadPickle,
PyPicklerUnpicklerObjectTests,
CPicklerUnpicklerObjectTests,
+ CDispatchTableTests, CChainDispatchTableTests,
InMemoryPickleTests])
support.run_unittest(*tests)
support.run_doctest(pickle)