blob: 5ade8ede5699915715beb0723bcda67139cb2858 [file] [log] [blame]
Guido van Rossumc4a6e8b2001-01-18 15:48:05 +00001# Tests for rich comparisons
2
Walter Dörwald721adf92003-04-29 21:31:19 +00003import unittest
4from test import test_support
5
6import operator
Guido van Rossumc4a6e8b2001-01-18 15:48:05 +00007
8class Number:
9
10 def __init__(self, x):
11 self.x = x
12
13 def __lt__(self, other):
14 return self.x < other
15
16 def __le__(self, other):
17 return self.x <= other
18
19 def __eq__(self, other):
20 return self.x == other
21
22 def __ne__(self, other):
23 return self.x != other
24
25 def __gt__(self, other):
26 return self.x > other
27
28 def __ge__(self, other):
29 return self.x >= other
30
31 def __cmp__(self, other):
Walter Dörwald721adf92003-04-29 21:31:19 +000032 raise test_support.TestFailed, "Number.__cmp__() should not be called"
Guido van Rossumc4a6e8b2001-01-18 15:48:05 +000033
34 def __repr__(self):
Walter Dörwald721adf92003-04-29 21:31:19 +000035 return "Number(%r)" % (self.x, )
Guido van Rossumc4a6e8b2001-01-18 15:48:05 +000036
37class Vector:
38
39 def __init__(self, data):
40 self.data = data
41
42 def __len__(self):
43 return len(self.data)
44
45 def __getitem__(self, i):
46 return self.data[i]
47
48 def __setitem__(self, i, v):
49 self.data[i] = v
50
51 def __hash__(self):
52 raise TypeError, "Vectors cannot be hashed"
53
54 def __nonzero__(self):
55 raise TypeError, "Vectors cannot be used in Boolean contexts"
56
57 def __cmp__(self, other):
Walter Dörwald721adf92003-04-29 21:31:19 +000058 raise test_support.TestFailed, "Vector.__cmp__() should not be called"
Guido van Rossumc4a6e8b2001-01-18 15:48:05 +000059
60 def __repr__(self):
Walter Dörwald721adf92003-04-29 21:31:19 +000061 return "Vector(%r)" % (self.data, )
Guido van Rossumc4a6e8b2001-01-18 15:48:05 +000062
63 def __lt__(self, other):
64 return Vector([a < b for a, b in zip(self.data, self.__cast(other))])
65
66 def __le__(self, other):
67 return Vector([a <= b for a, b in zip(self.data, self.__cast(other))])
68
69 def __eq__(self, other):
70 return Vector([a == b for a, b in zip(self.data, self.__cast(other))])
71
72 def __ne__(self, other):
73 return Vector([a != b for a, b in zip(self.data, self.__cast(other))])
74
75 def __gt__(self, other):
76 return Vector([a > b for a, b in zip(self.data, self.__cast(other))])
77
78 def __ge__(self, other):
79 return Vector([a >= b for a, b in zip(self.data, self.__cast(other))])
80
81 def __cast(self, other):
82 if isinstance(other, Vector):
83 other = other.data
84 if len(self.data) != len(other):
85 raise ValueError, "Cannot compare vectors of different length"
86 return other
87
Walter Dörwald721adf92003-04-29 21:31:19 +000088opmap = {
89 "lt": (lambda a,b: a< b, operator.lt, operator.__lt__),
90 "le": (lambda a,b: a<=b, operator.le, operator.__le__),
91 "eq": (lambda a,b: a==b, operator.eq, operator.__eq__),
92 "ne": (lambda a,b: a!=b, operator.ne, operator.__ne__),
93 "gt": (lambda a,b: a> b, operator.gt, operator.__gt__),
94 "ge": (lambda a,b: a>=b, operator.ge, operator.__ge__)
95}
Guido van Rossumc4a6e8b2001-01-18 15:48:05 +000096
Walter Dörwald721adf92003-04-29 21:31:19 +000097class VectorTest(unittest.TestCase):
98
99 def checkfail(self, error, opname, *args):
100 for op in opmap[opname]:
101 self.assertRaises(error, op, *args)
102
103 def checkequal(self, opname, a, b, expres):
104 for op in opmap[opname]:
105 realres = op(a, b)
106 # can't use assertEqual(realres, expres) here
107 self.assertEqual(len(realres), len(expres))
108 for i in xrange(len(realres)):
109 # results are bool, so we can use "is" here
110 self.assert_(realres[i] is expres[i])
111
112 def test_mixed(self):
113 # check that comparisons involving Vector objects
114 # which return rich results (i.e. Vectors with itemwise
115 # comparison results) work
116 a = Vector(range(2))
117 b = Vector(range(3))
118 # all comparisons should fail for different length
119 for opname in opmap:
120 self.checkfail(ValueError, opname, a, b)
121
122 a = range(5)
123 b = 5 * [2]
124 # try mixed arguments (but not (a, b) as that won't return a bool vector)
125 args = [(a, Vector(b)), (Vector(a), b), (Vector(a), Vector(b))]
126 for (a, b) in args:
127 self.checkequal("lt", a, b, [True, True, False, False, False])
128 self.checkequal("le", a, b, [True, True, True, False, False])
129 self.checkequal("eq", a, b, [False, False, True, False, False])
130 self.checkequal("ne", a, b, [True, True, False, True, True ])
131 self.checkequal("gt", a, b, [False, False, False, True, True ])
132 self.checkequal("ge", a, b, [False, False, True, True, True ])
133
134 for ops in opmap.itervalues():
135 for op in ops:
136 # calls __nonzero__, which should fail
137 self.assertRaises(TypeError, bool, op(a, b))
138
139class NumberTest(unittest.TestCase):
140
141 def test_basic(self):
142 # Check that comparisons involving Number objects
143 # give the same results give as comparing the
144 # corresponding ints
145 for a in xrange(3):
146 for b in xrange(3):
147 for typea in (int, Number):
148 for typeb in (int, Number):
149 if typea==typeb==int:
150 continue # the combination int, int is useless
151 ta = typea(a)
152 tb = typeb(b)
153 for ops in opmap.itervalues():
154 for op in ops:
155 realoutcome = op(a, b)
156 testoutcome = op(ta, tb)
157 self.assertEqual(realoutcome, testoutcome)
158
159 def checkvalue(self, opname, a, b, expres):
160 for typea in (int, Number):
161 for typeb in (int, Number):
162 ta = typea(a)
163 tb = typeb(b)
164 for op in opmap[opname]:
165 realres = op(ta, tb)
166 realres = getattr(realres, "x", realres)
167 self.assert_(realres is expres)
168
169 def test_values(self):
170 # check all operators and all comparison results
171 self.checkvalue("lt", 0, 0, False)
172 self.checkvalue("le", 0, 0, True )
173 self.checkvalue("eq", 0, 0, True )
174 self.checkvalue("ne", 0, 0, False)
175 self.checkvalue("gt", 0, 0, False)
176 self.checkvalue("ge", 0, 0, True )
177
178 self.checkvalue("lt", 0, 1, True )
179 self.checkvalue("le", 0, 1, True )
180 self.checkvalue("eq", 0, 1, False)
181 self.checkvalue("ne", 0, 1, True )
182 self.checkvalue("gt", 0, 1, False)
183 self.checkvalue("ge", 0, 1, False)
184
185 self.checkvalue("lt", 1, 0, False)
186 self.checkvalue("le", 1, 0, False)
187 self.checkvalue("eq", 1, 0, False)
188 self.checkvalue("ne", 1, 0, True )
189 self.checkvalue("gt", 1, 0, True )
190 self.checkvalue("ge", 1, 0, True )
191
192class MiscTest(unittest.TestCase):
193
194 def test_misbehavin(self):
195 class Misb:
196 def __lt__(self, other): return 0
197 def __gt__(self, other): return 0
198 def __eq__(self, other): return 0
199 def __le__(self, other): raise TestFailed, "This shouldn't happen"
200 def __ge__(self, other): raise TestFailed, "This shouldn't happen"
201 def __ne__(self, other): raise TestFailed, "This shouldn't happen"
202 def __cmp__(self, other): raise RuntimeError, "expected"
203 a = Misb()
204 b = Misb()
205 self.assertEqual(a<b, 0)
206 self.assertEqual(a==b, 0)
207 self.assertEqual(a>b, 0)
208 self.assertRaises(RuntimeError, cmp, a, b)
209
210 def test_not(self):
211 # Check that exceptions in __nonzero__ are properly
212 # propagated by the not operator
213 import operator
214 class Exc:
Guido van Rossumc4a6e8b2001-01-18 15:48:05 +0000215 pass
Walter Dörwald721adf92003-04-29 21:31:19 +0000216 class Bad:
217 def __nonzero__(self):
218 raise Exc
219
220 def do(bad):
221 not bad
222
223 for func in (do, operator.not_):
224 self.assertRaises(Exc, func, Bad())
225
226 def test_recursion(self):
227 # Check comparison for recursive objects
228 from UserList import UserList
229 a = UserList(); a.append(a)
230 b = UserList(); b.append(b)
231
232 self.assert_(a == b)
233 self.assert_(not a != b)
234 a.append(1)
235 self.assert_(a == a[0])
236 self.assert_(not a != a[0])
237 self.assert_(a != b)
238 self.assert_(not a == b)
239 b.append(0)
240 self.assert_(a != b)
241 self.assert_(not a == b)
242 a[1] = -1
243 self.assert_(a != b)
244 self.assert_(not a == b)
245
246 a = UserList()
247 b = UserList()
248 a.append(b)
249 b.append(a)
250 self.assert_(a == b)
251 self.assert_(not a != b)
252
253 b.append(17)
254 self.assert_(a != b)
255 self.assert_(not a == b)
256 a.append(17)
257 self.assert_(a == b)
258 self.assert_(not a != b)
259
260 def test_recursion2(self):
261 # This test exercises the circular structure handling code
262 # in PyObject_RichCompare()
263 class Weird(object):
264 def __eq__(self, other):
265 return self != other
266 def __ne__(self, other):
267 return self == other
268 def __lt__(self, other):
269 return self > other
270 def __gt__(self, other):
271 return self < other
272
273 self.assert_(Weird() == Weird())
274 self.assert_(not (Weird() != Weird()))
275
276 for op in opmap["lt"]:
277 self.assertRaises(ValueError, op, Weird(), Weird())
278
279class DictTest(unittest.TestCase):
280
281 def test_dicts(self):
282 # Verify that __eq__ and __ne__ work for dicts even if the keys and
283 # values don't support anything other than __eq__ and __ne__. Complex
284 # numbers are a fine example of that.
285 import random
286 imag1a = {}
287 for i in range(50):
288 imag1a[random.randrange(100)*1j] = random.randrange(100)*1j
289 items = imag1a.items()
290 random.shuffle(items)
291 imag1b = {}
292 for k, v in items:
293 imag1b[k] = v
294 imag2 = imag1b.copy()
295 imag2[k] = v + 1.0
296 self.assert_(imag1a == imag1a)
297 self.assert_(imag1a == imag1b)
298 self.assert_(imag2 == imag2)
299 self.assert_(imag1a != imag2)
300 for opname in ("lt", "le", "gt", "ge"):
301 for op in opmap[opname]:
302 self.assertRaises(TypeError, op, imag1a, imag2)
303
304class ListTest(unittest.TestCase):
305
306 def assertIs(self, a, b):
307 self.assert_(a is b)
308
309 def test_coverage(self):
310 # exercise all comparisons for lists
311 x = [42]
312 self.assertIs(x<x, False)
313 self.assertIs(x<=x, True)
314 self.assertIs(x==x, True)
315 self.assertIs(x!=x, False)
316 self.assertIs(x>x, False)
317 self.assertIs(x>=x, True)
318 y = [42, 42]
319 self.assertIs(x<y, True)
320 self.assertIs(x<=y, True)
321 self.assertIs(x==y, False)
322 self.assertIs(x!=y, True)
323 self.assertIs(x>y, False)
324 self.assertIs(x>=y, False)
325
326 def test_badentry(self):
327 # make sure that exceptions for item comparison are properly
328 # propagated in list comparisons
329 class Exc:
Guido van Rossumc4a6e8b2001-01-18 15:48:05 +0000330 pass
Walter Dörwald721adf92003-04-29 21:31:19 +0000331 class Bad:
332 def __eq__(self, other):
333 raise Exc
Guido van Rossumc4a6e8b2001-01-18 15:48:05 +0000334
Walter Dörwald721adf92003-04-29 21:31:19 +0000335 x = [Bad()]
336 y = [Bad()]
Guido van Rossumc4a6e8b2001-01-18 15:48:05 +0000337
Walter Dörwald721adf92003-04-29 21:31:19 +0000338 for op in opmap["eq"]:
339 self.assertRaises(Exc, op, x, y)
Guido van Rossumc4a6e8b2001-01-18 15:48:05 +0000340
Walter Dörwald721adf92003-04-29 21:31:19 +0000341 def test_goodentry(self):
342 # This test exercises the final call to PyObject_RichCompare()
343 # in Objects/listobject.c::list_richcompare()
344 class Good:
345 def __lt__(self, other):
346 return True
Guido van Rossumc4a6e8b2001-01-18 15:48:05 +0000347
Walter Dörwald721adf92003-04-29 21:31:19 +0000348 x = [Good()]
349 y = [Good()]
Tim Peters8880f6d2001-01-19 06:12:17 +0000350
Walter Dörwald721adf92003-04-29 21:31:19 +0000351 for op in opmap["lt"]:
352 self.assertIs(op(x, y), True)
Guido van Rossum9710bd52001-01-18 15:55:59 +0000353
Walter Dörwald721adf92003-04-29 21:31:19 +0000354def test_main():
Walter Dörwald21d3a322003-05-01 17:45:56 +0000355 test_support.run_unittest(VectorTest, NumberTest, MiscTest, DictTest, ListTest)
Guido van Rossum890f2092001-01-18 16:21:57 +0000356
Walter Dörwald721adf92003-04-29 21:31:19 +0000357if __name__ == "__main__":
358 test_main()