Issue #10242: backport of more fixes to unittest.TestCase.assertItemsEqual
diff --git a/Lib/unittest/case.py b/Lib/unittest/case.py
index cd8f4fa..ecb6a3e 100644
--- a/Lib/unittest/case.py
+++ b/Lib/unittest/case.py
@@ -10,9 +10,11 @@
from . import result
from .util import (
- strclass, safe_repr, sorted_list_difference, unorderable_list_difference
+ strclass, safe_repr, unorderable_list_difference,
+ _count_diff_all_purpose, _count_diff_hashable
)
+
__unittest = True
@@ -863,6 +865,7 @@
- [0, 1, 1] and [1, 0, 1] compare equal.
- [0, 0, 1] and [0, 1] compare unequal.
"""
+ first_seq, second_seq = list(actual_seq), list(expected_seq)
with warnings.catch_warnings():
if sys.py3kwarning:
# Silence Py3k warning raised during the sorting
@@ -871,29 +874,23 @@
"comparing unequal types"]:
warnings.filterwarnings("ignore", _msg, DeprecationWarning)
try:
- actual = collections.Counter(iter(actual_seq))
- expected = collections.Counter(iter(expected_seq))
+ first = collections.Counter(first_seq)
+ second = collections.Counter(second_seq)
except TypeError:
- # Unsortable items (example: set(), complex(), ...)
- actual = list(actual_seq)
- expected = list(expected_seq)
- missing, unexpected = unorderable_list_difference(expected, actual)
+ # Handle case with unhashable elements
+ differences = _count_diff_all_purpose(first_seq, second_seq)
else:
- if actual == expected:
+ if first == second:
return
- missing = list(expected - actual)
- unexpected = list(actual - expected)
+ differences = _count_diff_hashable(first_seq, second_seq)
- errors = []
- if missing:
- errors.append('Expected, but missing:\n %s' %
- safe_repr(missing))
- if unexpected:
- errors.append('Unexpected, but present:\n %s' %
- safe_repr(unexpected))
- if errors:
- standardMsg = '\n'.join(errors)
- self.fail(self._formatMessage(msg, standardMsg))
+ if differences:
+ standardMsg = 'Element counts were not equal:\n'
+ lines = ['First has %d, Second has %d: %r' % diff for diff in differences]
+ diffMsg = '\n'.join(lines)
+ standardMsg = self._truncateMessage(standardMsg, diffMsg)
+ msg = self._formatMessage(msg, standardMsg)
+ self.fail(msg)
def assertMultiLineEqual(self, first, second, msg=None):
"""Assert that two multi-line strings are equal."""
diff --git a/Lib/unittest/test/test_assertions.py b/Lib/unittest/test/test_assertions.py
index e85ca91..e1ba614 100644
--- a/Lib/unittest/test/test_assertions.py
+++ b/Lib/unittest/test/test_assertions.py
@@ -228,12 +228,6 @@
"^Missing: 'key'$",
"^Missing: 'key' : oops$"])
- def testAssertItemsEqual(self):
- self.assertMessages('assertItemsEqual', ([], [None]),
- [r"\[None\]$", "^oops$",
- r"\[None\]$",
- r"\[None\] : oops$"])
-
def testAssertMultiLineEqual(self):
self.assertMessages('assertMultiLineEqual', ("", "foo"),
[r"\+ foo$", "^oops$",
diff --git a/Lib/unittest/test/test_case.py b/Lib/unittest/test/test_case.py
index 250e905..06eeda1 100644
--- a/Lib/unittest/test/test_case.py
+++ b/Lib/unittest/test/test_case.py
@@ -686,20 +686,19 @@
# Test that sequences of unhashable objects can be tested for sameness:
self.assertItemsEqual([[1, 2], [3, 4], 0], [False, [3, 4], [1, 2]])
- with test_support.check_warnings(quiet=True) as w:
- # hashable types, but not orderable
- self.assertRaises(self.failureException, self.assertItemsEqual,
- [], [divmod, 'x', 1, 5j, 2j, frozenset()])
- # comparing dicts raises a py3k warning
- self.assertItemsEqual([{'a': 1}, {'b': 2}], [{'b': 2}, {'a': 1}])
- # comparing heterogenous non-hashable sequences raises a py3k warning
- self.assertItemsEqual([1, 'x', divmod, []], [divmod, [], 'x', 1])
- self.assertRaises(self.failureException, self.assertItemsEqual,
- [], [divmod, [], 'x', 1, 5j, 2j, set()])
- # fail the test if warnings are not silenced
- if w.warnings:
- self.fail('assertItemsEqual raised a warning: ' +
- str(w.warnings[0]))
+ # Test that iterator of unhashable objects can be tested for sameness:
+ self.assertItemsEqual(iter([1, 2, [], 3, 4]),
+ iter([1, 2, [], 3, 4]))
+
+ # hashable types, but not orderable
+ self.assertRaises(self.failureException, self.assertItemsEqual,
+ [], [divmod, 'x', 1, 5j, 2j, frozenset()])
+ # comparing dicts
+ self.assertItemsEqual([{'a': 1}, {'b': 2}], [{'b': 2}, {'a': 1}])
+ # comparing heterogenous non-hashable sequences
+ self.assertItemsEqual([1, 'x', divmod, []], [divmod, [], 'x', 1])
+ self.assertRaises(self.failureException, self.assertItemsEqual,
+ [], [divmod, [], 'x', 1, 5j, 2j, set()])
self.assertRaises(self.failureException, self.assertItemsEqual,
[[1]], [[2]])
@@ -717,6 +716,19 @@
b = a[::-1]
self.assertItemsEqual(a, b)
+ # test utility functions supporting assertItemsEqual()
+
+ diffs = set(unittest.util._count_diff_all_purpose('aaabccd', 'abbbcce'))
+ expected = {(3,1,'a'), (1,3,'b'), (1,0,'d'), (0,1,'e')}
+ self.assertEqual(diffs, expected)
+
+ diffs = unittest.util._count_diff_all_purpose([[]], [])
+ self.assertEqual(diffs, [(1, 0, [])])
+
+ diffs = set(unittest.util._count_diff_hashable('aaabccd', 'abbbcce'))
+ expected = {(3,1,'a'), (1,3,'b'), (1,0,'d'), (0,1,'e')}
+ self.assertEqual(diffs, expected)
+
def testAssertSetEqual(self):
set1 = set()
set2 = set()
diff --git a/Lib/unittest/util.py b/Lib/unittest/util.py
index d201657..220a024 100644
--- a/Lib/unittest/util.py
+++ b/Lib/unittest/util.py
@@ -1,4 +1,6 @@
"""Various utility functions."""
+from collections import namedtuple, OrderedDict
+
__unittest = True
@@ -92,3 +94,63 @@
# anything left in actual is unexpected
return missing, actual
+
+_Mismatch = namedtuple('Mismatch', 'actual expected value')
+
+def _count_diff_all_purpose(actual, expected):
+ 'Returns list of (cnt_act, cnt_exp, elem) triples where the counts differ'
+ # elements need not be hashable
+ s, t = list(actual), list(expected)
+ m, n = len(s), len(t)
+ NULL = object()
+ result = []
+ for i, elem in enumerate(s):
+ if elem is NULL:
+ continue
+ cnt_s = cnt_t = 0
+ for j in range(i, m):
+ if s[j] == elem:
+ cnt_s += 1
+ s[j] = NULL
+ for j, other_elem in enumerate(t):
+ if other_elem == elem:
+ cnt_t += 1
+ t[j] = NULL
+ if cnt_s != cnt_t:
+ diff = _Mismatch(cnt_s, cnt_t, elem)
+ result.append(diff)
+
+ for i, elem in enumerate(t):
+ if elem is NULL:
+ continue
+ cnt_t = 0
+ for j in range(i, n):
+ if t[j] == elem:
+ cnt_t += 1
+ t[j] = NULL
+ diff = _Mismatch(0, cnt_t, elem)
+ result.append(diff)
+ return result
+
+def _ordered_count(iterable):
+ 'Return dict of element counts, in the order they were first seen'
+ c = OrderedDict()
+ for elem in iterable:
+ c[elem] = c.get(elem, 0) + 1
+ return c
+
+def _count_diff_hashable(actual, expected):
+ 'Returns list of (cnt_act, cnt_exp, elem) triples where the counts differ'
+ # elements must be hashable
+ s, t = _ordered_count(actual), _ordered_count(expected)
+ result = []
+ for elem, cnt_s in s.items():
+ cnt_t = t.get(elem, 0)
+ if cnt_s != cnt_t:
+ diff = _Mismatch(cnt_s, cnt_t, elem)
+ result.append(diff)
+ for elem, cnt_t in t.items():
+ if elem not in s:
+ diff = _Mismatch(0, cnt_t, elem)
+ result.append(diff)
+ return result