Covert pickle tests to use unittest.

Extend tests to cover a few more cases.  For cPickle, test several of
the undocumented features.
diff --git a/Lib/test/pickletester.py b/Lib/test/pickletester.py
index fa3ddf4..0f36d66 100644
--- a/Lib/test/pickletester.py
+++ b/Lib/test/pickletester.py
@@ -1,9 +1,27 @@
-# test_pickle and test_cpickle both use this.
-
+import unittest
 from test_support import TestFailed, have_unicode
-import sys
 
-# break into multiple strings to please font-lock-mode
+class C:
+    def __cmp__(self, other):
+        return cmp(self.__dict__, other.__dict__)
+
+import __main__
+__main__.C = C
+C.__module__ = "__main__"
+
+class myint(int):
+    def __init__(self, x):
+        self.str = str(x)
+
+class initarg(C):
+    def __init__(self, a, b):
+        self.a = a
+        self.b = b
+
+    def __getinitargs__(self):
+        return self.a, self.b
+
+# break into multiple strings to avoid confusing font-lock-mode
 DATA = """(lp1
 I0
 aL1L
@@ -57,18 +75,8 @@
           '\x00\x80J\x00\x00\x00\x80(U\x03abcq\x04h\x04(c__main__\n' + \
           'C\nq\x05oq\x06}q\x07(U\x03fooq\x08K\x01U\x03barq\tK\x02ubh' + \
           '\x06tq\nh\nK\x05e.'
-
-class C:
-    def __cmp__(self, other):
-        return cmp(self.__dict__, other.__dict__)
-
-import __main__
-__main__.C = C
-C.__module__ = "__main__"
-
-# Call this with the module to be tested (pickle or cPickle).
-
-def dotest(pickle):
+    
+def create_data():
     c = C()
     c.foo = 1
     c.bar = 2
@@ -86,153 +94,159 @@
     x.append(y)
     x.append(y)
     x.append(5)
-    r = []
-    r.append(r)
+    return x
 
-    print "dumps()"
-    s = pickle.dumps(x)
+class AbstractPickleTests(unittest.TestCase):
 
-    print "loads()"
-    x2 = pickle.loads(s)
-    if x2 == x:
-        print "ok"
-    else:
-        print "bad"
+    _testdata = create_data()
 
-    print "loads() DATA"
-    x2 = pickle.loads(DATA)
-    if x2 == x:
-        print "ok"
-    else:
-        print "bad"
-
-    print "dumps() binary"
-    s = pickle.dumps(x, 1)
-
-    print "loads() binary"
-    x2 = pickle.loads(s)
-    if x2 == x:
-        print "ok"
-    else:
-        print "bad"
-
-    print "loads() BINDATA"
-    x2 = pickle.loads(BINDATA)
-    if x2 == x:
-        print "ok"
-    else:
-        print "bad"
-
-    print "dumps() RECURSIVE"
-    s = pickle.dumps(r)
-    x2 = pickle.loads(s)
-    if x2 == r:
-        print "ok"
-    else:
-        print "bad"
-
-    # don't create cyclic garbage
-    del x2[0]
-    del r[0]
-
-    # Test protection against closed files
-    import tempfile, os
-    fn = tempfile.mktemp()
-    f = open(fn, "w")
-    f.close()
-    try:
-        pickle.dump(123, f)
-    except ValueError:
+    def setUp(self):
+        # subclass must define self.dumps, self.loads, self.error
         pass
-    else:
-        print "dump to closed file should raise ValueError"
 
-    f = open(fn, "r")
-    f.close()
-    try:
-        pickle.load(f)
-    except ValueError:
-        pass
-    else:
-        print "load from closed file should raise ValueError"
-    os.remove(fn)
+    def test_misc(self):
+        # test various datatypes not tested by testdata
+        x = myint(4)
+        s = self.dumps(x)
+        y = self.loads(s)
+        self.assertEqual(x, y)
 
-    # Test specific bad cases
-    for i in range(10):
-        try:
-            x = pickle.loads('garyp')
-        except KeyError, y:
-            # pickle
-            del y
-        except pickle.BadPickleGet, y:
-            # cPickle
-            del y
-        else:
-            print "unexpected success!"
-            break
+        x = (1, ())
+        s = self.dumps(x)
+        y = self.loads(s)
+        self.assertEqual(x, y)
 
-    # Test insecure strings
-    insecure = ["abc", "2 + 2", # not quoted
-                "'abc' + 'def'", # not a single quoted string
-                "'abc", # quote is not closed
-                "'abc\"", # open quote and close quote don't match
-                "'abc'   ?", # junk after close quote
-                # some tests of the quoting rules
-                "'abc\"\''",
-                "'\\\\a\'\'\'\\\'\\\\\''",
-                ]
-    for s in insecure:
-        buf = "S" + s + "\012p0\012."
-        try:
-            x = pickle.loads(buf)
-        except ValueError:
-            pass
-        else:
-            print "accepted insecure string: %s" % repr(buf)
+        x = initarg(1, x)
+        s = self.dumps(x)
+        y = self.loads(s)
+        self.assertEqual(x, y)
 
-    # Test some Unicode end cases
+        # XXX test __reduce__ protocol?
+
+    def test_identity(self):
+        s = self.dumps(self._testdata)
+        x = self.loads(s)
+        self.assertEqual(x, self._testdata)
+
+    def test_constant(self):
+        x = self.loads(DATA)
+        self.assertEqual(x, self._testdata)
+        x = self.loads(BINDATA)
+        self.assertEqual(x, self._testdata)
+
+    def test_binary(self):
+        s = self.dumps(self._testdata, 1)
+        x = self.loads(s)
+        self.assertEqual(x, self._testdata)
+
+    def test_recursive_list(self):
+        l = []
+        l.append(l)
+        s = self.dumps(l)
+        x = self.loads(s)
+        self.assertEqual(x, l)
+        self.assertEqual(x, x[0])
+        self.assertEqual(id(x), id(x[0]))
+
+    def test_recursive_dict(self):
+        d = {}
+        d[1] = d
+        s = self.dumps(d)
+        x = self.loads(s)
+        self.assertEqual(x, d)
+        self.assertEqual(x[1], x)
+        self.assertEqual(id(x[1]), id(x))
+
+    def test_recursive_inst(self):
+        i = C()
+        i.attr = i
+        s = self.dumps(i)
+        x = self.loads(s)
+        self.assertEqual(x, i)
+        self.assertEqual(x.attr, x)
+        self.assertEqual(id(x.attr), id(x))
+
+    def test_recursive_multi(self):
+        l = []
+        d = {1:l}
+        i = C()
+        i.attr = d
+        l.append(i)
+        s = self.dumps(l)
+        x = self.loads(s)
+        self.assertEqual(x, l)
+        self.assertEqual(x[0], i)
+        self.assertEqual(x[0].attr, d)
+        self.assertEqual(x[0].attr[1], x)
+        self.assertEqual(x[0].attr[1][0], i)
+        self.assertEqual(x[0].attr[1][0].attr, d)
+
+    def test_garyp(self):
+        self.assertRaises(self.error, self.loads, 'garyp')
+
+    def test_insecure_strings(self):
+        insecure = ["abc", "2 + 2", # not quoted
+                    "'abc' + 'def'", # not a single quoted string
+                    "'abc", # quote is not closed
+                    "'abc\"", # open quote and close quote don't match
+                    "'abc'   ?", # junk after close quote
+                    # some tests of the quoting rules
+                    "'abc\"\''",
+                    "'\\\\a\'\'\'\\\'\\\\\''",
+                    ]
+        for s in insecure:
+            buf = "S" + s + "\012p0\012."
+            self.assertRaises(ValueError, self.loads, buf)
+
     if have_unicode:
-        endcases = [unicode(''), unicode('<\\u>'), unicode('<\\\u1234>'),
-                    unicode('<\n>'),  unicode('<\\>')]
-    else:
-        endcases = []
-    for u in endcases:
-        try:
-            u2 = pickle.loads(pickle.dumps(u))
-        except Exception, msg:
-            print "Endcase exception: %s => %s(%s)" % \
-                  (`u`, msg.__class__.__name__, str(msg))
-        else:
-            if u2 != u:
-                print "Endcase failure: %s => %s" % (`u`, `u2`)
+        def test_unicode(self):
+            endcases = [unicode(''), unicode('<\\u>'), unicode('<\\\u1234>'),
+                        unicode('<\n>'),  unicode('<\\>')]
+            for u in endcases:
+                p = self.dumps(u)
+                u2 = self.loads(p)
+                self.assertEqual(u2, u)
 
-    # Test the full range of Python ints.
-    n = sys.maxint
-    while n:
-        for expected in (-n, n):
-            for binary_mode in (0, 1):
-                s = pickle.dumps(expected, binary_mode)
-                got = pickle.loads(s)
-                if expected != got:
-                    raise TestFailed("for %s-mode pickle of %d, pickle "
-                                     "string is %s, loaded back as %s" % (
-                                     binary_mode and "binary" or "text",
-                                     expected,
-                                     repr(s),
-                                     got))
-        n = n >> 1
+    def test_ints(self):
+        import sys
+        n = sys.maxint
+        while n:
+            for expected in (-n, n):
+                s = self.dumps(expected)
+                n2 = self.loads(s)
+                self.assertEqual(expected, n2)
+            n = n >> 1
 
-    # Fake a pickle from a sizeof(long)==8 box.
-    maxint64 = (1L << 63) - 1
-    data = 'I' + str(maxint64) + '\n.'
-    got = pickle.loads(data)
-    if maxint64 != got:
-        raise TestFailed("maxint64 test failed %r %r" % (maxint64, got))
-    # Try too with a bogus literal.
-    data = 'I' + str(maxint64) + 'JUNK\n.'
-    try:
-        got = pickle.loads(data)
-    except ValueError:
+    def test_maxint64(self):
+        maxint64 = (1L << 63) - 1
+        data = 'I' + str(maxint64) + '\n.'
+        got = self.loads(data)
+        self.assertEqual(got, maxint64)
+
+        # Try too with a bogus literal.
+        data = 'I' + str(maxint64) + 'JUNK\n.'
+        self.assertRaises(ValueError, self.loads, data)
+
+    def test_reduce(self):
         pass
-    else:
-        raise TestFailed("should have raised error on bogus INT literal")
+
+    def test_getinitargs(self):
+        pass
+
+class AbstractPickleModuleTests(unittest.TestCase):
+
+    def test_dump_closed_file(self):
+        import tempfile, os
+        fn = tempfile.mktemp()
+        f = open(fn, "w")
+        f.close()
+        self.assertRaises(ValueError, self.module.dump, 123, f)
+        os.remove(fn)
+
+    def test_load_closed_file(self):
+        import tempfile, os
+        fn = tempfile.mktemp()
+        f = open(fn, "w")
+        f.close()
+        self.assertRaises(ValueError, self.module.dump, 123, f)
+        os.remove(fn)