blob: f412a89055cdb117b19f2215632e9d2c1bb7c3d3 [file] [log] [blame]
Guido van Rossumc4a6e8b2001-01-18 15:48:05 +00001# Tests for rich comparisons
2
Walter Dörwald721adf92003-04-29 21:31:19 +00003import unittest
4from test import test_support
5
6import operator
Guido van Rossumc4a6e8b2001-01-18 15:48:05 +00007
8class Number:
9
10 def __init__(self, x):
11 self.x = x
12
13 def __lt__(self, other):
14 return self.x < other
15
16 def __le__(self, other):
17 return self.x <= other
18
19 def __eq__(self, other):
20 return self.x == other
21
22 def __ne__(self, other):
23 return self.x != other
24
25 def __gt__(self, other):
26 return self.x > other
27
28 def __ge__(self, other):
29 return self.x >= other
30
31 def __cmp__(self, other):
Walter Dörwald721adf92003-04-29 21:31:19 +000032 raise test_support.TestFailed, "Number.__cmp__() should not be called"
Guido van Rossumc4a6e8b2001-01-18 15:48:05 +000033
34 def __repr__(self):
Walter Dörwald721adf92003-04-29 21:31:19 +000035 return "Number(%r)" % (self.x, )
Guido van Rossumc4a6e8b2001-01-18 15:48:05 +000036
37class Vector:
38
39 def __init__(self, data):
40 self.data = data
41
42 def __len__(self):
43 return len(self.data)
44
45 def __getitem__(self, i):
46 return self.data[i]
47
48 def __setitem__(self, i, v):
49 self.data[i] = v
50
51 def __hash__(self):
52 raise TypeError, "Vectors cannot be hashed"
53
54 def __nonzero__(self):
55 raise TypeError, "Vectors cannot be used in Boolean contexts"
56
57 def __cmp__(self, other):
Walter Dörwald721adf92003-04-29 21:31:19 +000058 raise test_support.TestFailed, "Vector.__cmp__() should not be called"
Guido van Rossumc4a6e8b2001-01-18 15:48:05 +000059
60 def __repr__(self):
Walter Dörwald721adf92003-04-29 21:31:19 +000061 return "Vector(%r)" % (self.data, )
Guido van Rossumc4a6e8b2001-01-18 15:48:05 +000062
63 def __lt__(self, other):
64 return Vector([a < b for a, b in zip(self.data, self.__cast(other))])
65
66 def __le__(self, other):
67 return Vector([a <= b for a, b in zip(self.data, self.__cast(other))])
68
69 def __eq__(self, other):
70 return Vector([a == b for a, b in zip(self.data, self.__cast(other))])
71
72 def __ne__(self, other):
73 return Vector([a != b for a, b in zip(self.data, self.__cast(other))])
74
75 def __gt__(self, other):
76 return Vector([a > b for a, b in zip(self.data, self.__cast(other))])
77
78 def __ge__(self, other):
79 return Vector([a >= b for a, b in zip(self.data, self.__cast(other))])
80
81 def __cast(self, other):
82 if isinstance(other, Vector):
83 other = other.data
84 if len(self.data) != len(other):
85 raise ValueError, "Cannot compare vectors of different length"
86 return other
87
Walter Dörwald721adf92003-04-29 21:31:19 +000088opmap = {
89 "lt": (lambda a,b: a< b, operator.lt, operator.__lt__),
90 "le": (lambda a,b: a<=b, operator.le, operator.__le__),
91 "eq": (lambda a,b: a==b, operator.eq, operator.__eq__),
92 "ne": (lambda a,b: a!=b, operator.ne, operator.__ne__),
93 "gt": (lambda a,b: a> b, operator.gt, operator.__gt__),
94 "ge": (lambda a,b: a>=b, operator.ge, operator.__ge__)
95}
Guido van Rossumc4a6e8b2001-01-18 15:48:05 +000096
Walter Dörwald721adf92003-04-29 21:31:19 +000097class VectorTest(unittest.TestCase):
98
99 def checkfail(self, error, opname, *args):
100 for op in opmap[opname]:
101 self.assertRaises(error, op, *args)
102
103 def checkequal(self, opname, a, b, expres):
104 for op in opmap[opname]:
105 realres = op(a, b)
106 # can't use assertEqual(realres, expres) here
107 self.assertEqual(len(realres), len(expres))
108 for i in xrange(len(realres)):
109 # results are bool, so we can use "is" here
110 self.assert_(realres[i] is expres[i])
111
112 def test_mixed(self):
113 # check that comparisons involving Vector objects
114 # which return rich results (i.e. Vectors with itemwise
115 # comparison results) work
116 a = Vector(range(2))
117 b = Vector(range(3))
118 # all comparisons should fail for different length
119 for opname in opmap:
120 self.checkfail(ValueError, opname, a, b)
121
122 a = range(5)
123 b = 5 * [2]
124 # try mixed arguments (but not (a, b) as that won't return a bool vector)
125 args = [(a, Vector(b)), (Vector(a), b), (Vector(a), Vector(b))]
126 for (a, b) in args:
127 self.checkequal("lt", a, b, [True, True, False, False, False])
128 self.checkequal("le", a, b, [True, True, True, False, False])
129 self.checkequal("eq", a, b, [False, False, True, False, False])
130 self.checkequal("ne", a, b, [True, True, False, True, True ])
131 self.checkequal("gt", a, b, [False, False, False, True, True ])
132 self.checkequal("ge", a, b, [False, False, True, True, True ])
133
134 for ops in opmap.itervalues():
135 for op in ops:
136 # calls __nonzero__, which should fail
137 self.assertRaises(TypeError, bool, op(a, b))
138
139class NumberTest(unittest.TestCase):
140
141 def test_basic(self):
142 # Check that comparisons involving Number objects
143 # give the same results give as comparing the
144 # corresponding ints
145 for a in xrange(3):
146 for b in xrange(3):
147 for typea in (int, Number):
148 for typeb in (int, Number):
149 if typea==typeb==int:
150 continue # the combination int, int is useless
151 ta = typea(a)
152 tb = typeb(b)
153 for ops in opmap.itervalues():
154 for op in ops:
155 realoutcome = op(a, b)
156 testoutcome = op(ta, tb)
157 self.assertEqual(realoutcome, testoutcome)
158
159 def checkvalue(self, opname, a, b, expres):
160 for typea in (int, Number):
161 for typeb in (int, Number):
162 ta = typea(a)
163 tb = typeb(b)
164 for op in opmap[opname]:
165 realres = op(ta, tb)
166 realres = getattr(realres, "x", realres)
167 self.assert_(realres is expres)
168
169 def test_values(self):
170 # check all operators and all comparison results
171 self.checkvalue("lt", 0, 0, False)
172 self.checkvalue("le", 0, 0, True )
173 self.checkvalue("eq", 0, 0, True )
174 self.checkvalue("ne", 0, 0, False)
175 self.checkvalue("gt", 0, 0, False)
176 self.checkvalue("ge", 0, 0, True )
177
178 self.checkvalue("lt", 0, 1, True )
179 self.checkvalue("le", 0, 1, True )
180 self.checkvalue("eq", 0, 1, False)
181 self.checkvalue("ne", 0, 1, True )
182 self.checkvalue("gt", 0, 1, False)
183 self.checkvalue("ge", 0, 1, False)
184
185 self.checkvalue("lt", 1, 0, False)
186 self.checkvalue("le", 1, 0, False)
187 self.checkvalue("eq", 1, 0, False)
188 self.checkvalue("ne", 1, 0, True )
189 self.checkvalue("gt", 1, 0, True )
190 self.checkvalue("ge", 1, 0, True )
191
192class MiscTest(unittest.TestCase):
193
194 def test_misbehavin(self):
195 class Misb:
196 def __lt__(self, other): return 0
197 def __gt__(self, other): return 0
198 def __eq__(self, other): return 0
199 def __le__(self, other): raise TestFailed, "This shouldn't happen"
200 def __ge__(self, other): raise TestFailed, "This shouldn't happen"
201 def __ne__(self, other): raise TestFailed, "This shouldn't happen"
202 def __cmp__(self, other): raise RuntimeError, "expected"
203 a = Misb()
204 b = Misb()
205 self.assertEqual(a<b, 0)
206 self.assertEqual(a==b, 0)
207 self.assertEqual(a>b, 0)
208 self.assertRaises(RuntimeError, cmp, a, b)
209
210 def test_not(self):
211 # Check that exceptions in __nonzero__ are properly
212 # propagated by the not operator
213 import operator
Neal Norwitz5a822fb2006-03-24 07:03:44 +0000214 class Exc(Exception):
Guido van Rossumc4a6e8b2001-01-18 15:48:05 +0000215 pass
Walter Dörwald721adf92003-04-29 21:31:19 +0000216 class Bad:
217 def __nonzero__(self):
218 raise Exc
219
220 def do(bad):
221 not bad
222
223 for func in (do, operator.not_):
224 self.assertRaises(Exc, func, Bad())
225
226 def test_recursion(self):
Armin Rigo2b3eb402003-10-28 12:05:48 +0000227 # Check that comparison for recursive objects fails gracefully
Walter Dörwald721adf92003-04-29 21:31:19 +0000228 from UserList import UserList
Walter Dörwald721adf92003-04-29 21:31:19 +0000229 a = UserList()
230 b = UserList()
231 a.append(b)
232 b.append(a)
Armin Rigo2b3eb402003-10-28 12:05:48 +0000233 self.assertRaises(RuntimeError, operator.eq, a, b)
234 self.assertRaises(RuntimeError, operator.ne, a, b)
235 self.assertRaises(RuntimeError, operator.lt, a, b)
236 self.assertRaises(RuntimeError, operator.le, a, b)
237 self.assertRaises(RuntimeError, operator.gt, a, b)
238 self.assertRaises(RuntimeError, operator.ge, a, b)
Walter Dörwald721adf92003-04-29 21:31:19 +0000239
240 b.append(17)
Armin Rigo2b3eb402003-10-28 12:05:48 +0000241 # Even recursive lists of different lengths are different,
242 # but they cannot be ordered
243 self.assert_(not (a == b))
Walter Dörwald721adf92003-04-29 21:31:19 +0000244 self.assert_(a != b)
Armin Rigo2b3eb402003-10-28 12:05:48 +0000245 self.assertRaises(RuntimeError, operator.lt, a, b)
246 self.assertRaises(RuntimeError, operator.le, a, b)
247 self.assertRaises(RuntimeError, operator.gt, a, b)
248 self.assertRaises(RuntimeError, operator.ge, a, b)
Walter Dörwald721adf92003-04-29 21:31:19 +0000249 a.append(17)
Armin Rigo2b3eb402003-10-28 12:05:48 +0000250 self.assertRaises(RuntimeError, operator.eq, a, b)
251 self.assertRaises(RuntimeError, operator.ne, a, b)
252 a.insert(0, 11)
253 b.insert(0, 12)
254 self.assert_(not (a == b))
255 self.assert_(a != b)
256 self.assert_(a < b)
Walter Dörwald721adf92003-04-29 21:31:19 +0000257
258class DictTest(unittest.TestCase):
259
260 def test_dicts(self):
261 # Verify that __eq__ and __ne__ work for dicts even if the keys and
Georg Brandlbe3856d2005-08-24 09:08:57 +0000262 # values don't support anything other than __eq__ and __ne__ (and
263 # __hash__). Complex numbers are a fine example of that.
Walter Dörwald721adf92003-04-29 21:31:19 +0000264 import random
265 imag1a = {}
266 for i in range(50):
267 imag1a[random.randrange(100)*1j] = random.randrange(100)*1j
268 items = imag1a.items()
269 random.shuffle(items)
270 imag1b = {}
271 for k, v in items:
272 imag1b[k] = v
273 imag2 = imag1b.copy()
274 imag2[k] = v + 1.0
275 self.assert_(imag1a == imag1a)
276 self.assert_(imag1a == imag1b)
277 self.assert_(imag2 == imag2)
278 self.assert_(imag1a != imag2)
279 for opname in ("lt", "le", "gt", "ge"):
280 for op in opmap[opname]:
281 self.assertRaises(TypeError, op, imag1a, imag2)
282
283class ListTest(unittest.TestCase):
284
285 def assertIs(self, a, b):
286 self.assert_(a is b)
287
288 def test_coverage(self):
289 # exercise all comparisons for lists
290 x = [42]
291 self.assertIs(x<x, False)
292 self.assertIs(x<=x, True)
293 self.assertIs(x==x, True)
294 self.assertIs(x!=x, False)
295 self.assertIs(x>x, False)
296 self.assertIs(x>=x, True)
297 y = [42, 42]
298 self.assertIs(x<y, True)
299 self.assertIs(x<=y, True)
300 self.assertIs(x==y, False)
301 self.assertIs(x!=y, True)
302 self.assertIs(x>y, False)
303 self.assertIs(x>=y, False)
304
305 def test_badentry(self):
306 # make sure that exceptions for item comparison are properly
307 # propagated in list comparisons
Neal Norwitz5a822fb2006-03-24 07:03:44 +0000308 class Exc(Exception):
Guido van Rossumc4a6e8b2001-01-18 15:48:05 +0000309 pass
Walter Dörwald721adf92003-04-29 21:31:19 +0000310 class Bad:
311 def __eq__(self, other):
312 raise Exc
Guido van Rossumc4a6e8b2001-01-18 15:48:05 +0000313
Walter Dörwald721adf92003-04-29 21:31:19 +0000314 x = [Bad()]
315 y = [Bad()]
Guido van Rossumc4a6e8b2001-01-18 15:48:05 +0000316
Walter Dörwald721adf92003-04-29 21:31:19 +0000317 for op in opmap["eq"]:
318 self.assertRaises(Exc, op, x, y)
Guido van Rossumc4a6e8b2001-01-18 15:48:05 +0000319
Walter Dörwald721adf92003-04-29 21:31:19 +0000320 def test_goodentry(self):
321 # This test exercises the final call to PyObject_RichCompare()
322 # in Objects/listobject.c::list_richcompare()
323 class Good:
324 def __lt__(self, other):
325 return True
Guido van Rossumc4a6e8b2001-01-18 15:48:05 +0000326
Walter Dörwald721adf92003-04-29 21:31:19 +0000327 x = [Good()]
328 y = [Good()]
Tim Peters8880f6d2001-01-19 06:12:17 +0000329
Walter Dörwald721adf92003-04-29 21:31:19 +0000330 for op in opmap["lt"]:
331 self.assertIs(op(x, y), True)
Guido van Rossum9710bd52001-01-18 15:55:59 +0000332
Walter Dörwald721adf92003-04-29 21:31:19 +0000333def test_main():
Walter Dörwald21d3a322003-05-01 17:45:56 +0000334 test_support.run_unittest(VectorTest, NumberTest, MiscTest, DictTest, ListTest)
Guido van Rossum890f2092001-01-18 16:21:57 +0000335
Walter Dörwald721adf92003-04-29 21:31:19 +0000336if __name__ == "__main__":
337 test_main()