blob: 30dac8f91755d0e0a1ebfbd3495c5e8d5a42932a [file] [log] [blame]
Guido van Rossumc4a6e8b2001-01-18 15:48:05 +00001# Tests for rich comparisons
2
Walter Dörwald721adf92003-04-29 21:31:19 +00003import unittest
Benjamin Petersonee8712c2008-05-20 21:35:26 +00004from test import support
Walter Dörwald721adf92003-04-29 21:31:19 +00005
6import operator
Guido van Rossumc4a6e8b2001-01-18 15:48:05 +00007
8class Number:
9
10 def __init__(self, x):
11 self.x = x
12
13 def __lt__(self, other):
14 return self.x < other
15
16 def __le__(self, other):
17 return self.x <= other
18
19 def __eq__(self, other):
20 return self.x == other
21
22 def __ne__(self, other):
23 return self.x != other
24
25 def __gt__(self, other):
26 return self.x > other
27
28 def __ge__(self, other):
29 return self.x >= other
30
31 def __cmp__(self, other):
Benjamin Petersonee8712c2008-05-20 21:35:26 +000032 raise support.TestFailed("Number.__cmp__() should not be called")
Guido van Rossumc4a6e8b2001-01-18 15:48:05 +000033
34 def __repr__(self):
Walter Dörwald721adf92003-04-29 21:31:19 +000035 return "Number(%r)" % (self.x, )
Guido van Rossumc4a6e8b2001-01-18 15:48:05 +000036
37class Vector:
38
39 def __init__(self, data):
40 self.data = data
41
42 def __len__(self):
43 return len(self.data)
44
45 def __getitem__(self, i):
46 return self.data[i]
47
48 def __setitem__(self, i, v):
49 self.data[i] = v
50
51 def __hash__(self):
Collin Winter3add4d72007-08-29 23:37:32 +000052 raise TypeError("Vectors cannot be hashed")
Guido van Rossumc4a6e8b2001-01-18 15:48:05 +000053
Jack Diederich4dafcc42006-11-28 19:15:13 +000054 def __bool__(self):
Collin Winter3add4d72007-08-29 23:37:32 +000055 raise TypeError("Vectors cannot be used in Boolean contexts")
Guido van Rossumc4a6e8b2001-01-18 15:48:05 +000056
57 def __cmp__(self, other):
Benjamin Petersonee8712c2008-05-20 21:35:26 +000058 raise support.TestFailed("Vector.__cmp__() should not be called")
Guido van Rossumc4a6e8b2001-01-18 15:48:05 +000059
60 def __repr__(self):
Walter Dörwald721adf92003-04-29 21:31:19 +000061 return "Vector(%r)" % (self.data, )
Guido van Rossumc4a6e8b2001-01-18 15:48:05 +000062
63 def __lt__(self, other):
64 return Vector([a < b for a, b in zip(self.data, self.__cast(other))])
65
66 def __le__(self, other):
67 return Vector([a <= b for a, b in zip(self.data, self.__cast(other))])
68
69 def __eq__(self, other):
70 return Vector([a == b for a, b in zip(self.data, self.__cast(other))])
71
72 def __ne__(self, other):
73 return Vector([a != b for a, b in zip(self.data, self.__cast(other))])
74
75 def __gt__(self, other):
76 return Vector([a > b for a, b in zip(self.data, self.__cast(other))])
77
78 def __ge__(self, other):
79 return Vector([a >= b for a, b in zip(self.data, self.__cast(other))])
80
81 def __cast(self, other):
82 if isinstance(other, Vector):
83 other = other.data
84 if len(self.data) != len(other):
Collin Winter3add4d72007-08-29 23:37:32 +000085 raise ValueError("Cannot compare vectors of different length")
Guido van Rossumc4a6e8b2001-01-18 15:48:05 +000086 return other
87
Christian Heimes5fb7c2a2007-12-24 08:52:31 +000088
89class SimpleOrder(object):
90 """
91 A simple class that defines order but not full comparison.
92 """
93
94 def __init__(self, value):
95 self.value = value
96
97 def __lt__(self, other):
98 if not isinstance(other, SimpleOrder):
99 return True
100 return self.value < other.value
101
102 def __gt__(self, other):
103 if not isinstance(other, SimpleOrder):
104 return False
105 return self.value > other.value
106
107
108class DumbEqualityWithoutHash(object):
109 """
110 A class that define __eq__, but no __hash__: it shouldn't be hashable.
111 """
112
113 def __eq__(self, other):
114 return False
115
116
Walter Dörwald721adf92003-04-29 21:31:19 +0000117opmap = {
118 "lt": (lambda a,b: a< b, operator.lt, operator.__lt__),
119 "le": (lambda a,b: a<=b, operator.le, operator.__le__),
120 "eq": (lambda a,b: a==b, operator.eq, operator.__eq__),
121 "ne": (lambda a,b: a!=b, operator.ne, operator.__ne__),
122 "gt": (lambda a,b: a> b, operator.gt, operator.__gt__),
123 "ge": (lambda a,b: a>=b, operator.ge, operator.__ge__)
124}
Guido van Rossumc4a6e8b2001-01-18 15:48:05 +0000125
Walter Dörwald721adf92003-04-29 21:31:19 +0000126class VectorTest(unittest.TestCase):
127
128 def checkfail(self, error, opname, *args):
129 for op in opmap[opname]:
130 self.assertRaises(error, op, *args)
131
132 def checkequal(self, opname, a, b, expres):
133 for op in opmap[opname]:
134 realres = op(a, b)
135 # can't use assertEqual(realres, expres) here
136 self.assertEqual(len(realres), len(expres))
Guido van Rossum805365e2007-05-07 22:24:25 +0000137 for i in range(len(realres)):
Walter Dörwald721adf92003-04-29 21:31:19 +0000138 # results are bool, so we can use "is" here
139 self.assert_(realres[i] is expres[i])
140
141 def test_mixed(self):
142 # check that comparisons involving Vector objects
143 # which return rich results (i.e. Vectors with itemwise
144 # comparison results) work
145 a = Vector(range(2))
146 b = Vector(range(3))
147 # all comparisons should fail for different length
148 for opname in opmap:
149 self.checkfail(ValueError, opname, a, b)
150
Guido van Rossum805365e2007-05-07 22:24:25 +0000151 a = list(range(5))
Walter Dörwald721adf92003-04-29 21:31:19 +0000152 b = 5 * [2]
153 # try mixed arguments (but not (a, b) as that won't return a bool vector)
154 args = [(a, Vector(b)), (Vector(a), b), (Vector(a), Vector(b))]
155 for (a, b) in args:
156 self.checkequal("lt", a, b, [True, True, False, False, False])
157 self.checkequal("le", a, b, [True, True, True, False, False])
158 self.checkequal("eq", a, b, [False, False, True, False, False])
159 self.checkequal("ne", a, b, [True, True, False, True, True ])
160 self.checkequal("gt", a, b, [False, False, False, True, True ])
161 self.checkequal("ge", a, b, [False, False, True, True, True ])
162
Guido van Rossumcc2b0162007-02-11 06:12:03 +0000163 for ops in opmap.values():
Walter Dörwald721adf92003-04-29 21:31:19 +0000164 for op in ops:
Jack Diederich4dafcc42006-11-28 19:15:13 +0000165 # calls __bool__, which should fail
Walter Dörwald721adf92003-04-29 21:31:19 +0000166 self.assertRaises(TypeError, bool, op(a, b))
167
168class NumberTest(unittest.TestCase):
169
170 def test_basic(self):
171 # Check that comparisons involving Number objects
172 # give the same results give as comparing the
173 # corresponding ints
Guido van Rossum805365e2007-05-07 22:24:25 +0000174 for a in range(3):
175 for b in range(3):
Walter Dörwald721adf92003-04-29 21:31:19 +0000176 for typea in (int, Number):
177 for typeb in (int, Number):
178 if typea==typeb==int:
179 continue # the combination int, int is useless
180 ta = typea(a)
181 tb = typeb(b)
Guido van Rossumcc2b0162007-02-11 06:12:03 +0000182 for ops in opmap.values():
Walter Dörwald721adf92003-04-29 21:31:19 +0000183 for op in ops:
184 realoutcome = op(a, b)
185 testoutcome = op(ta, tb)
186 self.assertEqual(realoutcome, testoutcome)
187
188 def checkvalue(self, opname, a, b, expres):
189 for typea in (int, Number):
190 for typeb in (int, Number):
191 ta = typea(a)
192 tb = typeb(b)
193 for op in opmap[opname]:
194 realres = op(ta, tb)
195 realres = getattr(realres, "x", realres)
196 self.assert_(realres is expres)
197
198 def test_values(self):
199 # check all operators and all comparison results
200 self.checkvalue("lt", 0, 0, False)
201 self.checkvalue("le", 0, 0, True )
202 self.checkvalue("eq", 0, 0, True )
203 self.checkvalue("ne", 0, 0, False)
204 self.checkvalue("gt", 0, 0, False)
205 self.checkvalue("ge", 0, 0, True )
206
207 self.checkvalue("lt", 0, 1, True )
208 self.checkvalue("le", 0, 1, True )
209 self.checkvalue("eq", 0, 1, False)
210 self.checkvalue("ne", 0, 1, True )
211 self.checkvalue("gt", 0, 1, False)
212 self.checkvalue("ge", 0, 1, False)
213
214 self.checkvalue("lt", 1, 0, False)
215 self.checkvalue("le", 1, 0, False)
216 self.checkvalue("eq", 1, 0, False)
217 self.checkvalue("ne", 1, 0, True )
218 self.checkvalue("gt", 1, 0, True )
219 self.checkvalue("ge", 1, 0, True )
220
221class MiscTest(unittest.TestCase):
222
223 def test_misbehavin(self):
224 class Misb:
225 def __lt__(self, other): return 0
226 def __gt__(self, other): return 0
227 def __eq__(self, other): return 0
Collin Winter3add4d72007-08-29 23:37:32 +0000228 def __le__(self, other): raise TestFailed("This shouldn't happen")
229 def __ge__(self, other): raise TestFailed("This shouldn't happen")
230 def __ne__(self, other): raise TestFailed("This shouldn't happen")
231 def __cmp__(self, other): raise RuntimeError("expected")
Walter Dörwald721adf92003-04-29 21:31:19 +0000232 a = Misb()
233 b = Misb()
234 self.assertEqual(a<b, 0)
235 self.assertEqual(a==b, 0)
236 self.assertEqual(a>b, 0)
237 self.assertRaises(RuntimeError, cmp, a, b)
238
239 def test_not(self):
Jack Diederich4dafcc42006-11-28 19:15:13 +0000240 # Check that exceptions in __bool__ are properly
Walter Dörwald721adf92003-04-29 21:31:19 +0000241 # propagated by the not operator
242 import operator
Neal Norwitz0fb43762006-03-24 07:02:16 +0000243 class Exc(Exception):
Guido van Rossumc4a6e8b2001-01-18 15:48:05 +0000244 pass
Walter Dörwald721adf92003-04-29 21:31:19 +0000245 class Bad:
Jack Diederich4dafcc42006-11-28 19:15:13 +0000246 def __bool__(self):
Walter Dörwald721adf92003-04-29 21:31:19 +0000247 raise Exc
248
249 def do(bad):
250 not bad
251
252 for func in (do, operator.not_):
253 self.assertRaises(Exc, func, Bad())
254
255 def test_recursion(self):
Armin Rigo2b3eb402003-10-28 12:05:48 +0000256 # Check that comparison for recursive objects fails gracefully
Raymond Hettinger53dbe392008-02-12 20:03:09 +0000257 from collections import UserList
Walter Dörwald721adf92003-04-29 21:31:19 +0000258 a = UserList()
259 b = UserList()
260 a.append(b)
261 b.append(a)
Armin Rigo2b3eb402003-10-28 12:05:48 +0000262 self.assertRaises(RuntimeError, operator.eq, a, b)
263 self.assertRaises(RuntimeError, operator.ne, a, b)
264 self.assertRaises(RuntimeError, operator.lt, a, b)
265 self.assertRaises(RuntimeError, operator.le, a, b)
266 self.assertRaises(RuntimeError, operator.gt, a, b)
267 self.assertRaises(RuntimeError, operator.ge, a, b)
Walter Dörwald721adf92003-04-29 21:31:19 +0000268
269 b.append(17)
Armin Rigo2b3eb402003-10-28 12:05:48 +0000270 # Even recursive lists of different lengths are different,
271 # but they cannot be ordered
272 self.assert_(not (a == b))
Walter Dörwald721adf92003-04-29 21:31:19 +0000273 self.assert_(a != b)
Armin Rigo2b3eb402003-10-28 12:05:48 +0000274 self.assertRaises(RuntimeError, operator.lt, a, b)
275 self.assertRaises(RuntimeError, operator.le, a, b)
276 self.assertRaises(RuntimeError, operator.gt, a, b)
277 self.assertRaises(RuntimeError, operator.ge, a, b)
Walter Dörwald721adf92003-04-29 21:31:19 +0000278 a.append(17)
Armin Rigo2b3eb402003-10-28 12:05:48 +0000279 self.assertRaises(RuntimeError, operator.eq, a, b)
280 self.assertRaises(RuntimeError, operator.ne, a, b)
281 a.insert(0, 11)
282 b.insert(0, 12)
283 self.assert_(not (a == b))
284 self.assert_(a != b)
285 self.assert_(a < b)
Walter Dörwald721adf92003-04-29 21:31:19 +0000286
287class DictTest(unittest.TestCase):
288
289 def test_dicts(self):
290 # Verify that __eq__ and __ne__ work for dicts even if the keys and
Georg Brandlbe3856d2005-08-24 09:08:57 +0000291 # values don't support anything other than __eq__ and __ne__ (and
292 # __hash__). Complex numbers are a fine example of that.
Walter Dörwald721adf92003-04-29 21:31:19 +0000293 import random
294 imag1a = {}
295 for i in range(50):
296 imag1a[random.randrange(100)*1j] = random.randrange(100)*1j
Guido van Rossum75d26cc2007-02-15 04:01:01 +0000297 items = list(imag1a.items())
Walter Dörwald721adf92003-04-29 21:31:19 +0000298 random.shuffle(items)
299 imag1b = {}
300 for k, v in items:
301 imag1b[k] = v
302 imag2 = imag1b.copy()
303 imag2[k] = v + 1.0
Guido van Rossume61fd5b2007-07-11 12:20:59 +0000304 self.assertEqual(imag1a, imag1a)
305 self.assertEqual(imag1a, imag1b)
306 self.assertEqual(imag2, imag2)
Walter Dörwald721adf92003-04-29 21:31:19 +0000307 self.assert_(imag1a != imag2)
308 for opname in ("lt", "le", "gt", "ge"):
309 for op in opmap[opname]:
310 self.assertRaises(TypeError, op, imag1a, imag2)
311
312class ListTest(unittest.TestCase):
313
314 def assertIs(self, a, b):
315 self.assert_(a is b)
316
317 def test_coverage(self):
318 # exercise all comparisons for lists
319 x = [42]
320 self.assertIs(x<x, False)
321 self.assertIs(x<=x, True)
322 self.assertIs(x==x, True)
323 self.assertIs(x!=x, False)
324 self.assertIs(x>x, False)
325 self.assertIs(x>=x, True)
326 y = [42, 42]
327 self.assertIs(x<y, True)
328 self.assertIs(x<=y, True)
329 self.assertIs(x==y, False)
330 self.assertIs(x!=y, True)
331 self.assertIs(x>y, False)
332 self.assertIs(x>=y, False)
333
334 def test_badentry(self):
335 # make sure that exceptions for item comparison are properly
336 # propagated in list comparisons
Neal Norwitz0fb43762006-03-24 07:02:16 +0000337 class Exc(Exception):
Guido van Rossumc4a6e8b2001-01-18 15:48:05 +0000338 pass
Walter Dörwald721adf92003-04-29 21:31:19 +0000339 class Bad:
340 def __eq__(self, other):
341 raise Exc
Guido van Rossumc4a6e8b2001-01-18 15:48:05 +0000342
Walter Dörwald721adf92003-04-29 21:31:19 +0000343 x = [Bad()]
344 y = [Bad()]
Guido van Rossumc4a6e8b2001-01-18 15:48:05 +0000345
Walter Dörwald721adf92003-04-29 21:31:19 +0000346 for op in opmap["eq"]:
347 self.assertRaises(Exc, op, x, y)
Guido van Rossumc4a6e8b2001-01-18 15:48:05 +0000348
Walter Dörwald721adf92003-04-29 21:31:19 +0000349 def test_goodentry(self):
350 # This test exercises the final call to PyObject_RichCompare()
351 # in Objects/listobject.c::list_richcompare()
352 class Good:
353 def __lt__(self, other):
354 return True
Guido van Rossumc4a6e8b2001-01-18 15:48:05 +0000355
Walter Dörwald721adf92003-04-29 21:31:19 +0000356 x = [Good()]
357 y = [Good()]
Tim Peters8880f6d2001-01-19 06:12:17 +0000358
Walter Dörwald721adf92003-04-29 21:31:19 +0000359 for op in opmap["lt"]:
360 self.assertIs(op(x, y), True)
Guido van Rossum9710bd52001-01-18 15:55:59 +0000361
Christian Heimes5fb7c2a2007-12-24 08:52:31 +0000362
363class HashableTest(unittest.TestCase):
364 """
365 Test hashability of classes with rich operators defined.
366 """
367
368 def test_simpleOrderHashable(self):
369 """
370 A class that only defines __gt__ and/or __lt__ should be hashable.
371 """
372 a = SimpleOrder(1)
373 b = SimpleOrder(2)
374 self.assert_(a < b)
375 self.assert_(b > a)
376 self.assert_(a.__hash__ is not None)
377
378 def test_notHashableException(self):
379 """
380 If a class is not hashable, it should raise a TypeError with an
381 understandable message.
382 """
383 a = DumbEqualityWithoutHash()
384 try:
385 hash(a)
386 except TypeError as e:
387 self.assertEquals(str(e),
388 "unhashable type: 'DumbEqualityWithoutHash'")
389 else:
Benjamin Petersonee8712c2008-05-20 21:35:26 +0000390 raise support.TestFailed("Should not be here")
Christian Heimes5fb7c2a2007-12-24 08:52:31 +0000391
392
Walter Dörwald721adf92003-04-29 21:31:19 +0000393def test_main():
Benjamin Petersonee8712c2008-05-20 21:35:26 +0000394 support.run_unittest(VectorTest, NumberTest, MiscTest, DictTest, ListTest, HashableTest)
Guido van Rossum890f2092001-01-18 16:21:57 +0000395
Walter Dörwald721adf92003-04-29 21:31:19 +0000396if __name__ == "__main__":
397 test_main()