blob: db6d31ff9527bec6f1cbd0f1dd605e9d392ba28a [file] [log] [blame]
Guido van Rossumc4a6e8b2001-01-18 15:48:05 +00001# Tests for rich comparisons
2
Walter Dörwald721adf92003-04-29 21:31:19 +00003import unittest
4from test import test_support
5
6import operator
Guido van Rossumc4a6e8b2001-01-18 15:48:05 +00007
8class Number:
9
10 def __init__(self, x):
11 self.x = x
12
13 def __lt__(self, other):
14 return self.x < other
15
16 def __le__(self, other):
17 return self.x <= other
18
19 def __eq__(self, other):
20 return self.x == other
21
22 def __ne__(self, other):
23 return self.x != other
24
25 def __gt__(self, other):
26 return self.x > other
27
28 def __ge__(self, other):
29 return self.x >= other
30
31 def __cmp__(self, other):
Walter Dörwald721adf92003-04-29 21:31:19 +000032 raise test_support.TestFailed, "Number.__cmp__() should not be called"
Guido van Rossumc4a6e8b2001-01-18 15:48:05 +000033
34 def __repr__(self):
Walter Dörwald721adf92003-04-29 21:31:19 +000035 return "Number(%r)" % (self.x, )
Guido van Rossumc4a6e8b2001-01-18 15:48:05 +000036
37class Vector:
38
39 def __init__(self, data):
40 self.data = data
41
42 def __len__(self):
43 return len(self.data)
44
45 def __getitem__(self, i):
46 return self.data[i]
47
48 def __setitem__(self, i, v):
49 self.data[i] = v
50
51 def __hash__(self):
52 raise TypeError, "Vectors cannot be hashed"
53
54 def __nonzero__(self):
55 raise TypeError, "Vectors cannot be used in Boolean contexts"
56
57 def __cmp__(self, other):
Walter Dörwald721adf92003-04-29 21:31:19 +000058 raise test_support.TestFailed, "Vector.__cmp__() should not be called"
Guido van Rossumc4a6e8b2001-01-18 15:48:05 +000059
60 def __repr__(self):
Walter Dörwald721adf92003-04-29 21:31:19 +000061 return "Vector(%r)" % (self.data, )
Guido van Rossumc4a6e8b2001-01-18 15:48:05 +000062
63 def __lt__(self, other):
64 return Vector([a < b for a, b in zip(self.data, self.__cast(other))])
65
66 def __le__(self, other):
67 return Vector([a <= b for a, b in zip(self.data, self.__cast(other))])
68
69 def __eq__(self, other):
70 return Vector([a == b for a, b in zip(self.data, self.__cast(other))])
71
72 def __ne__(self, other):
73 return Vector([a != b for a, b in zip(self.data, self.__cast(other))])
74
75 def __gt__(self, other):
76 return Vector([a > b for a, b in zip(self.data, self.__cast(other))])
77
78 def __ge__(self, other):
79 return Vector([a >= b for a, b in zip(self.data, self.__cast(other))])
80
81 def __cast(self, other):
82 if isinstance(other, Vector):
83 other = other.data
84 if len(self.data) != len(other):
85 raise ValueError, "Cannot compare vectors of different length"
86 return other
87
Guido van Rossum0b7b6fd2007-12-19 22:51:13 +000088
89class SimpleOrder(object):
90 """
91 A simple class that defines order but not full comparison.
92 """
93
94 def __init__(self, value):
95 self.value = value
96
97 def __lt__(self, other):
98 if not isinstance(other, SimpleOrder):
99 return True
100 return self.value < other.value
101
102 def __gt__(self, other):
103 if not isinstance(other, SimpleOrder):
104 return False
105 return self.value > other.value
106
107
108class DumbEqualityWithoutHash(object):
109 """
110 A class that define __eq__, but no __hash__: it shouldn't be hashable.
111 """
112
113 def __eq__(self, other):
114 return False
115
116
Walter Dörwald721adf92003-04-29 21:31:19 +0000117opmap = {
118 "lt": (lambda a,b: a< b, operator.lt, operator.__lt__),
119 "le": (lambda a,b: a<=b, operator.le, operator.__le__),
120 "eq": (lambda a,b: a==b, operator.eq, operator.__eq__),
121 "ne": (lambda a,b: a!=b, operator.ne, operator.__ne__),
122 "gt": (lambda a,b: a> b, operator.gt, operator.__gt__),
123 "ge": (lambda a,b: a>=b, operator.ge, operator.__ge__)
124}
Guido van Rossumc4a6e8b2001-01-18 15:48:05 +0000125
Walter Dörwald721adf92003-04-29 21:31:19 +0000126class VectorTest(unittest.TestCase):
127
128 def checkfail(self, error, opname, *args):
129 for op in opmap[opname]:
130 self.assertRaises(error, op, *args)
131
132 def checkequal(self, opname, a, b, expres):
133 for op in opmap[opname]:
134 realres = op(a, b)
135 # can't use assertEqual(realres, expres) here
136 self.assertEqual(len(realres), len(expres))
137 for i in xrange(len(realres)):
138 # results are bool, so we can use "is" here
139 self.assert_(realres[i] is expres[i])
140
141 def test_mixed(self):
142 # check that comparisons involving Vector objects
143 # which return rich results (i.e. Vectors with itemwise
144 # comparison results) work
145 a = Vector(range(2))
146 b = Vector(range(3))
147 # all comparisons should fail for different length
148 for opname in opmap:
149 self.checkfail(ValueError, opname, a, b)
150
151 a = range(5)
152 b = 5 * [2]
153 # try mixed arguments (but not (a, b) as that won't return a bool vector)
154 args = [(a, Vector(b)), (Vector(a), b), (Vector(a), Vector(b))]
155 for (a, b) in args:
156 self.checkequal("lt", a, b, [True, True, False, False, False])
157 self.checkequal("le", a, b, [True, True, True, False, False])
158 self.checkequal("eq", a, b, [False, False, True, False, False])
159 self.checkequal("ne", a, b, [True, True, False, True, True ])
160 self.checkequal("gt", a, b, [False, False, False, True, True ])
161 self.checkequal("ge", a, b, [False, False, True, True, True ])
162
163 for ops in opmap.itervalues():
164 for op in ops:
165 # calls __nonzero__, which should fail
166 self.assertRaises(TypeError, bool, op(a, b))
167
168class NumberTest(unittest.TestCase):
169
170 def test_basic(self):
171 # Check that comparisons involving Number objects
172 # give the same results give as comparing the
173 # corresponding ints
174 for a in xrange(3):
175 for b in xrange(3):
176 for typea in (int, Number):
177 for typeb in (int, Number):
178 if typea==typeb==int:
179 continue # the combination int, int is useless
180 ta = typea(a)
181 tb = typeb(b)
182 for ops in opmap.itervalues():
183 for op in ops:
184 realoutcome = op(a, b)
185 testoutcome = op(ta, tb)
186 self.assertEqual(realoutcome, testoutcome)
187
188 def checkvalue(self, opname, a, b, expres):
189 for typea in (int, Number):
190 for typeb in (int, Number):
191 ta = typea(a)
192 tb = typeb(b)
193 for op in opmap[opname]:
194 realres = op(ta, tb)
195 realres = getattr(realres, "x", realres)
196 self.assert_(realres is expres)
197
198 def test_values(self):
199 # check all operators and all comparison results
200 self.checkvalue("lt", 0, 0, False)
201 self.checkvalue("le", 0, 0, True )
202 self.checkvalue("eq", 0, 0, True )
203 self.checkvalue("ne", 0, 0, False)
204 self.checkvalue("gt", 0, 0, False)
205 self.checkvalue("ge", 0, 0, True )
206
207 self.checkvalue("lt", 0, 1, True )
208 self.checkvalue("le", 0, 1, True )
209 self.checkvalue("eq", 0, 1, False)
210 self.checkvalue("ne", 0, 1, True )
211 self.checkvalue("gt", 0, 1, False)
212 self.checkvalue("ge", 0, 1, False)
213
214 self.checkvalue("lt", 1, 0, False)
215 self.checkvalue("le", 1, 0, False)
216 self.checkvalue("eq", 1, 0, False)
217 self.checkvalue("ne", 1, 0, True )
218 self.checkvalue("gt", 1, 0, True )
219 self.checkvalue("ge", 1, 0, True )
220
221class MiscTest(unittest.TestCase):
222
223 def test_misbehavin(self):
224 class Misb:
225 def __lt__(self, other): return 0
226 def __gt__(self, other): return 0
227 def __eq__(self, other): return 0
228 def __le__(self, other): raise TestFailed, "This shouldn't happen"
229 def __ge__(self, other): raise TestFailed, "This shouldn't happen"
230 def __ne__(self, other): raise TestFailed, "This shouldn't happen"
231 def __cmp__(self, other): raise RuntimeError, "expected"
232 a = Misb()
233 b = Misb()
234 self.assertEqual(a<b, 0)
235 self.assertEqual(a==b, 0)
236 self.assertEqual(a>b, 0)
237 self.assertRaises(RuntimeError, cmp, a, b)
238
239 def test_not(self):
240 # Check that exceptions in __nonzero__ are properly
241 # propagated by the not operator
242 import operator
Neal Norwitz5a822fb2006-03-24 07:03:44 +0000243 class Exc(Exception):
Guido van Rossumc4a6e8b2001-01-18 15:48:05 +0000244 pass
Walter Dörwald721adf92003-04-29 21:31:19 +0000245 class Bad:
246 def __nonzero__(self):
247 raise Exc
248
249 def do(bad):
250 not bad
251
252 for func in (do, operator.not_):
253 self.assertRaises(Exc, func, Bad())
254
255 def test_recursion(self):
Armin Rigo2b3eb402003-10-28 12:05:48 +0000256 # Check that comparison for recursive objects fails gracefully
Walter Dörwald721adf92003-04-29 21:31:19 +0000257 from UserList import UserList
Walter Dörwald721adf92003-04-29 21:31:19 +0000258 a = UserList()
259 b = UserList()
260 a.append(b)
261 b.append(a)
Armin Rigo2b3eb402003-10-28 12:05:48 +0000262 self.assertRaises(RuntimeError, operator.eq, a, b)
263 self.assertRaises(RuntimeError, operator.ne, a, b)
264 self.assertRaises(RuntimeError, operator.lt, a, b)
265 self.assertRaises(RuntimeError, operator.le, a, b)
266 self.assertRaises(RuntimeError, operator.gt, a, b)
267 self.assertRaises(RuntimeError, operator.ge, a, b)
Walter Dörwald721adf92003-04-29 21:31:19 +0000268
269 b.append(17)
Armin Rigo2b3eb402003-10-28 12:05:48 +0000270 # Even recursive lists of different lengths are different,
271 # but they cannot be ordered
272 self.assert_(not (a == b))
Walter Dörwald721adf92003-04-29 21:31:19 +0000273 self.assert_(a != b)
Armin Rigo2b3eb402003-10-28 12:05:48 +0000274 self.assertRaises(RuntimeError, operator.lt, a, b)
275 self.assertRaises(RuntimeError, operator.le, a, b)
276 self.assertRaises(RuntimeError, operator.gt, a, b)
277 self.assertRaises(RuntimeError, operator.ge, a, b)
Walter Dörwald721adf92003-04-29 21:31:19 +0000278 a.append(17)
Armin Rigo2b3eb402003-10-28 12:05:48 +0000279 self.assertRaises(RuntimeError, operator.eq, a, b)
280 self.assertRaises(RuntimeError, operator.ne, a, b)
281 a.insert(0, 11)
282 b.insert(0, 12)
283 self.assert_(not (a == b))
284 self.assert_(a != b)
285 self.assert_(a < b)
Walter Dörwald721adf92003-04-29 21:31:19 +0000286
287class DictTest(unittest.TestCase):
288
289 def test_dicts(self):
290 # Verify that __eq__ and __ne__ work for dicts even if the keys and
Georg Brandlbe3856d2005-08-24 09:08:57 +0000291 # values don't support anything other than __eq__ and __ne__ (and
292 # __hash__). Complex numbers are a fine example of that.
Walter Dörwald721adf92003-04-29 21:31:19 +0000293 import random
294 imag1a = {}
295 for i in range(50):
296 imag1a[random.randrange(100)*1j] = random.randrange(100)*1j
297 items = imag1a.items()
298 random.shuffle(items)
299 imag1b = {}
300 for k, v in items:
301 imag1b[k] = v
302 imag2 = imag1b.copy()
303 imag2[k] = v + 1.0
304 self.assert_(imag1a == imag1a)
305 self.assert_(imag1a == imag1b)
306 self.assert_(imag2 == imag2)
307 self.assert_(imag1a != imag2)
308 for opname in ("lt", "le", "gt", "ge"):
309 for op in opmap[opname]:
310 self.assertRaises(TypeError, op, imag1a, imag2)
311
312class ListTest(unittest.TestCase):
313
314 def assertIs(self, a, b):
315 self.assert_(a is b)
316
317 def test_coverage(self):
318 # exercise all comparisons for lists
319 x = [42]
320 self.assertIs(x<x, False)
321 self.assertIs(x<=x, True)
322 self.assertIs(x==x, True)
323 self.assertIs(x!=x, False)
324 self.assertIs(x>x, False)
325 self.assertIs(x>=x, True)
326 y = [42, 42]
327 self.assertIs(x<y, True)
328 self.assertIs(x<=y, True)
329 self.assertIs(x==y, False)
330 self.assertIs(x!=y, True)
331 self.assertIs(x>y, False)
332 self.assertIs(x>=y, False)
333
334 def test_badentry(self):
335 # make sure that exceptions for item comparison are properly
336 # propagated in list comparisons
Neal Norwitz5a822fb2006-03-24 07:03:44 +0000337 class Exc(Exception):
Guido van Rossumc4a6e8b2001-01-18 15:48:05 +0000338 pass
Walter Dörwald721adf92003-04-29 21:31:19 +0000339 class Bad:
340 def __eq__(self, other):
341 raise Exc
Guido van Rossumc4a6e8b2001-01-18 15:48:05 +0000342
Walter Dörwald721adf92003-04-29 21:31:19 +0000343 x = [Bad()]
344 y = [Bad()]
Guido van Rossumc4a6e8b2001-01-18 15:48:05 +0000345
Walter Dörwald721adf92003-04-29 21:31:19 +0000346 for op in opmap["eq"]:
347 self.assertRaises(Exc, op, x, y)
Guido van Rossumc4a6e8b2001-01-18 15:48:05 +0000348
Walter Dörwald721adf92003-04-29 21:31:19 +0000349 def test_goodentry(self):
350 # This test exercises the final call to PyObject_RichCompare()
351 # in Objects/listobject.c::list_richcompare()
352 class Good:
353 def __lt__(self, other):
354 return True
Guido van Rossumc4a6e8b2001-01-18 15:48:05 +0000355
Walter Dörwald721adf92003-04-29 21:31:19 +0000356 x = [Good()]
357 y = [Good()]
Tim Peters8880f6d2001-01-19 06:12:17 +0000358
Walter Dörwald721adf92003-04-29 21:31:19 +0000359 for op in opmap["lt"]:
360 self.assertIs(op(x, y), True)
Guido van Rossum9710bd52001-01-18 15:55:59 +0000361
Guido van Rossum0b7b6fd2007-12-19 22:51:13 +0000362
363class HashableTest(unittest.TestCase):
364 """
365 Test hashability of classes with rich operators defined.
366 """
367
368 def test_simpleOrderHashable(self):
369 """
370 A class that only defines __gt__ and/or __lt__ should be hashable.
371 """
372 a = SimpleOrder(1)
373 b = SimpleOrder(2)
374 self.assert_(a < b)
375 self.assert_(b > a)
376 self.assert_(a.__hash__ is not None)
377
378 def test_notHashableException(self):
379 """
380 If a class is not hashable, it should raise a TypeError with an
381 understandable message.
382 """
383 a = DumbEqualityWithoutHash()
384 try:
385 hash(a)
386 except TypeError, e:
387 self.assertEquals(str(e),
388 "unhashable type: 'DumbEqualityWithoutHash'")
389 else:
390 raise test_support.TestFailed("Should not be here")
391
392
Walter Dörwald721adf92003-04-29 21:31:19 +0000393def test_main():
Guido van Rossum0b7b6fd2007-12-19 22:51:13 +0000394 test_support.run_unittest(VectorTest, NumberTest, MiscTest, DictTest, ListTest, HashableTest)
Guido van Rossum890f2092001-01-18 16:21:57 +0000395
Walter Dörwald721adf92003-04-29 21:31:19 +0000396if __name__ == "__main__":
397 test_main()