blob: ad6838628e606de10b3e509293e010934e83c1c4 [file] [log] [blame]
Guido van Rossumc4a6e8b2001-01-18 15:48:05 +00001# Tests for rich comparisons
2
Walter Dörwald721adf92003-04-29 21:31:19 +00003import unittest
4from test import test_support
5
6import operator
Guido van Rossumc4a6e8b2001-01-18 15:48:05 +00007
8class Number:
9
10 def __init__(self, x):
11 self.x = x
12
13 def __lt__(self, other):
14 return self.x < other
15
16 def __le__(self, other):
17 return self.x <= other
18
19 def __eq__(self, other):
20 return self.x == other
21
22 def __ne__(self, other):
23 return self.x != other
24
25 def __gt__(self, other):
26 return self.x > other
27
28 def __ge__(self, other):
29 return self.x >= other
30
31 def __cmp__(self, other):
Walter Dörwald721adf92003-04-29 21:31:19 +000032 raise test_support.TestFailed, "Number.__cmp__() should not be called"
Guido van Rossumc4a6e8b2001-01-18 15:48:05 +000033
34 def __repr__(self):
Walter Dörwald721adf92003-04-29 21:31:19 +000035 return "Number(%r)" % (self.x, )
Guido van Rossumc4a6e8b2001-01-18 15:48:05 +000036
37class Vector:
38
39 def __init__(self, data):
40 self.data = data
41
42 def __len__(self):
43 return len(self.data)
44
45 def __getitem__(self, i):
46 return self.data[i]
47
48 def __setitem__(self, i, v):
49 self.data[i] = v
50
Nick Coghlan53663a62008-07-15 14:27:37 +000051 __hash__ = None # Vectors cannot be hashed
Guido van Rossumc4a6e8b2001-01-18 15:48:05 +000052
53 def __nonzero__(self):
54 raise TypeError, "Vectors cannot be used in Boolean contexts"
55
56 def __cmp__(self, other):
Walter Dörwald721adf92003-04-29 21:31:19 +000057 raise test_support.TestFailed, "Vector.__cmp__() should not be called"
Guido van Rossumc4a6e8b2001-01-18 15:48:05 +000058
59 def __repr__(self):
Walter Dörwald721adf92003-04-29 21:31:19 +000060 return "Vector(%r)" % (self.data, )
Guido van Rossumc4a6e8b2001-01-18 15:48:05 +000061
62 def __lt__(self, other):
63 return Vector([a < b for a, b in zip(self.data, self.__cast(other))])
64
65 def __le__(self, other):
66 return Vector([a <= b for a, b in zip(self.data, self.__cast(other))])
67
68 def __eq__(self, other):
69 return Vector([a == b for a, b in zip(self.data, self.__cast(other))])
70
71 def __ne__(self, other):
72 return Vector([a != b for a, b in zip(self.data, self.__cast(other))])
73
74 def __gt__(self, other):
75 return Vector([a > b for a, b in zip(self.data, self.__cast(other))])
76
77 def __ge__(self, other):
78 return Vector([a >= b for a, b in zip(self.data, self.__cast(other))])
79
80 def __cast(self, other):
81 if isinstance(other, Vector):
82 other = other.data
83 if len(self.data) != len(other):
84 raise ValueError, "Cannot compare vectors of different length"
85 return other
86
Walter Dörwald721adf92003-04-29 21:31:19 +000087opmap = {
88 "lt": (lambda a,b: a< b, operator.lt, operator.__lt__),
89 "le": (lambda a,b: a<=b, operator.le, operator.__le__),
90 "eq": (lambda a,b: a==b, operator.eq, operator.__eq__),
91 "ne": (lambda a,b: a!=b, operator.ne, operator.__ne__),
92 "gt": (lambda a,b: a> b, operator.gt, operator.__gt__),
93 "ge": (lambda a,b: a>=b, operator.ge, operator.__ge__)
94}
Guido van Rossumc4a6e8b2001-01-18 15:48:05 +000095
Walter Dörwald721adf92003-04-29 21:31:19 +000096class VectorTest(unittest.TestCase):
97
98 def checkfail(self, error, opname, *args):
99 for op in opmap[opname]:
100 self.assertRaises(error, op, *args)
101
102 def checkequal(self, opname, a, b, expres):
103 for op in opmap[opname]:
104 realres = op(a, b)
105 # can't use assertEqual(realres, expres) here
106 self.assertEqual(len(realres), len(expres))
107 for i in xrange(len(realres)):
108 # results are bool, so we can use "is" here
109 self.assert_(realres[i] is expres[i])
110
111 def test_mixed(self):
112 # check that comparisons involving Vector objects
113 # which return rich results (i.e. Vectors with itemwise
114 # comparison results) work
115 a = Vector(range(2))
116 b = Vector(range(3))
117 # all comparisons should fail for different length
118 for opname in opmap:
119 self.checkfail(ValueError, opname, a, b)
120
121 a = range(5)
122 b = 5 * [2]
123 # try mixed arguments (but not (a, b) as that won't return a bool vector)
124 args = [(a, Vector(b)), (Vector(a), b), (Vector(a), Vector(b))]
125 for (a, b) in args:
126 self.checkequal("lt", a, b, [True, True, False, False, False])
127 self.checkequal("le", a, b, [True, True, True, False, False])
128 self.checkequal("eq", a, b, [False, False, True, False, False])
129 self.checkequal("ne", a, b, [True, True, False, True, True ])
130 self.checkequal("gt", a, b, [False, False, False, True, True ])
131 self.checkequal("ge", a, b, [False, False, True, True, True ])
132
133 for ops in opmap.itervalues():
134 for op in ops:
135 # calls __nonzero__, which should fail
136 self.assertRaises(TypeError, bool, op(a, b))
137
138class NumberTest(unittest.TestCase):
139
140 def test_basic(self):
141 # Check that comparisons involving Number objects
142 # give the same results give as comparing the
143 # corresponding ints
144 for a in xrange(3):
145 for b in xrange(3):
146 for typea in (int, Number):
147 for typeb in (int, Number):
148 if typea==typeb==int:
149 continue # the combination int, int is useless
150 ta = typea(a)
151 tb = typeb(b)
152 for ops in opmap.itervalues():
153 for op in ops:
154 realoutcome = op(a, b)
155 testoutcome = op(ta, tb)
156 self.assertEqual(realoutcome, testoutcome)
157
158 def checkvalue(self, opname, a, b, expres):
159 for typea in (int, Number):
160 for typeb in (int, Number):
161 ta = typea(a)
162 tb = typeb(b)
163 for op in opmap[opname]:
164 realres = op(ta, tb)
165 realres = getattr(realres, "x", realres)
166 self.assert_(realres is expres)
167
168 def test_values(self):
169 # check all operators and all comparison results
170 self.checkvalue("lt", 0, 0, False)
171 self.checkvalue("le", 0, 0, True )
172 self.checkvalue("eq", 0, 0, True )
173 self.checkvalue("ne", 0, 0, False)
174 self.checkvalue("gt", 0, 0, False)
175 self.checkvalue("ge", 0, 0, True )
176
177 self.checkvalue("lt", 0, 1, True )
178 self.checkvalue("le", 0, 1, True )
179 self.checkvalue("eq", 0, 1, False)
180 self.checkvalue("ne", 0, 1, True )
181 self.checkvalue("gt", 0, 1, False)
182 self.checkvalue("ge", 0, 1, False)
183
184 self.checkvalue("lt", 1, 0, False)
185 self.checkvalue("le", 1, 0, False)
186 self.checkvalue("eq", 1, 0, False)
187 self.checkvalue("ne", 1, 0, True )
188 self.checkvalue("gt", 1, 0, True )
189 self.checkvalue("ge", 1, 0, True )
190
191class MiscTest(unittest.TestCase):
192
193 def test_misbehavin(self):
194 class Misb:
195 def __lt__(self, other): return 0
196 def __gt__(self, other): return 0
197 def __eq__(self, other): return 0
198 def __le__(self, other): raise TestFailed, "This shouldn't happen"
199 def __ge__(self, other): raise TestFailed, "This shouldn't happen"
200 def __ne__(self, other): raise TestFailed, "This shouldn't happen"
201 def __cmp__(self, other): raise RuntimeError, "expected"
202 a = Misb()
203 b = Misb()
204 self.assertEqual(a<b, 0)
205 self.assertEqual(a==b, 0)
206 self.assertEqual(a>b, 0)
207 self.assertRaises(RuntimeError, cmp, a, b)
208
209 def test_not(self):
210 # Check that exceptions in __nonzero__ are properly
211 # propagated by the not operator
212 import operator
Neal Norwitz5a822fb2006-03-24 07:03:44 +0000213 class Exc(Exception):
Guido van Rossumc4a6e8b2001-01-18 15:48:05 +0000214 pass
Walter Dörwald721adf92003-04-29 21:31:19 +0000215 class Bad:
216 def __nonzero__(self):
217 raise Exc
218
219 def do(bad):
220 not bad
221
222 for func in (do, operator.not_):
223 self.assertRaises(Exc, func, Bad())
224
225 def test_recursion(self):
Armin Rigo2b3eb402003-10-28 12:05:48 +0000226 # Check that comparison for recursive objects fails gracefully
Walter Dörwald721adf92003-04-29 21:31:19 +0000227 from UserList import UserList
Walter Dörwald721adf92003-04-29 21:31:19 +0000228 a = UserList()
229 b = UserList()
230 a.append(b)
231 b.append(a)
Armin Rigo2b3eb402003-10-28 12:05:48 +0000232 self.assertRaises(RuntimeError, operator.eq, a, b)
233 self.assertRaises(RuntimeError, operator.ne, a, b)
234 self.assertRaises(RuntimeError, operator.lt, a, b)
235 self.assertRaises(RuntimeError, operator.le, a, b)
236 self.assertRaises(RuntimeError, operator.gt, a, b)
237 self.assertRaises(RuntimeError, operator.ge, a, b)
Walter Dörwald721adf92003-04-29 21:31:19 +0000238
239 b.append(17)
Armin Rigo2b3eb402003-10-28 12:05:48 +0000240 # Even recursive lists of different lengths are different,
241 # but they cannot be ordered
242 self.assert_(not (a == b))
Walter Dörwald721adf92003-04-29 21:31:19 +0000243 self.assert_(a != b)
Armin Rigo2b3eb402003-10-28 12:05:48 +0000244 self.assertRaises(RuntimeError, operator.lt, a, b)
245 self.assertRaises(RuntimeError, operator.le, a, b)
246 self.assertRaises(RuntimeError, operator.gt, a, b)
247 self.assertRaises(RuntimeError, operator.ge, a, b)
Walter Dörwald721adf92003-04-29 21:31:19 +0000248 a.append(17)
Armin Rigo2b3eb402003-10-28 12:05:48 +0000249 self.assertRaises(RuntimeError, operator.eq, a, b)
250 self.assertRaises(RuntimeError, operator.ne, a, b)
251 a.insert(0, 11)
252 b.insert(0, 12)
253 self.assert_(not (a == b))
254 self.assert_(a != b)
255 self.assert_(a < b)
Walter Dörwald721adf92003-04-29 21:31:19 +0000256
257class DictTest(unittest.TestCase):
258
259 def test_dicts(self):
260 # Verify that __eq__ and __ne__ work for dicts even if the keys and
Georg Brandlbe3856d2005-08-24 09:08:57 +0000261 # values don't support anything other than __eq__ and __ne__ (and
262 # __hash__). Complex numbers are a fine example of that.
Walter Dörwald721adf92003-04-29 21:31:19 +0000263 import random
264 imag1a = {}
265 for i in range(50):
266 imag1a[random.randrange(100)*1j] = random.randrange(100)*1j
267 items = imag1a.items()
268 random.shuffle(items)
269 imag1b = {}
270 for k, v in items:
271 imag1b[k] = v
272 imag2 = imag1b.copy()
273 imag2[k] = v + 1.0
274 self.assert_(imag1a == imag1a)
275 self.assert_(imag1a == imag1b)
276 self.assert_(imag2 == imag2)
277 self.assert_(imag1a != imag2)
278 for opname in ("lt", "le", "gt", "ge"):
279 for op in opmap[opname]:
280 self.assertRaises(TypeError, op, imag1a, imag2)
281
282class ListTest(unittest.TestCase):
283
284 def assertIs(self, a, b):
285 self.assert_(a is b)
286
287 def test_coverage(self):
288 # exercise all comparisons for lists
289 x = [42]
290 self.assertIs(x<x, False)
291 self.assertIs(x<=x, True)
292 self.assertIs(x==x, True)
293 self.assertIs(x!=x, False)
294 self.assertIs(x>x, False)
295 self.assertIs(x>=x, True)
296 y = [42, 42]
297 self.assertIs(x<y, True)
298 self.assertIs(x<=y, True)
299 self.assertIs(x==y, False)
300 self.assertIs(x!=y, True)
301 self.assertIs(x>y, False)
302 self.assertIs(x>=y, False)
303
304 def test_badentry(self):
305 # make sure that exceptions for item comparison are properly
306 # propagated in list comparisons
Neal Norwitz5a822fb2006-03-24 07:03:44 +0000307 class Exc(Exception):
Guido van Rossumc4a6e8b2001-01-18 15:48:05 +0000308 pass
Walter Dörwald721adf92003-04-29 21:31:19 +0000309 class Bad:
310 def __eq__(self, other):
311 raise Exc
Guido van Rossumc4a6e8b2001-01-18 15:48:05 +0000312
Walter Dörwald721adf92003-04-29 21:31:19 +0000313 x = [Bad()]
314 y = [Bad()]
Guido van Rossumc4a6e8b2001-01-18 15:48:05 +0000315
Walter Dörwald721adf92003-04-29 21:31:19 +0000316 for op in opmap["eq"]:
317 self.assertRaises(Exc, op, x, y)
Guido van Rossumc4a6e8b2001-01-18 15:48:05 +0000318
Walter Dörwald721adf92003-04-29 21:31:19 +0000319 def test_goodentry(self):
320 # This test exercises the final call to PyObject_RichCompare()
321 # in Objects/listobject.c::list_richcompare()
322 class Good:
323 def __lt__(self, other):
324 return True
Guido van Rossumc4a6e8b2001-01-18 15:48:05 +0000325
Walter Dörwald721adf92003-04-29 21:31:19 +0000326 x = [Good()]
327 y = [Good()]
Tim Peters8880f6d2001-01-19 06:12:17 +0000328
Walter Dörwald721adf92003-04-29 21:31:19 +0000329 for op in opmap["lt"]:
330 self.assertIs(op(x, y), True)
Guido van Rossum9710bd52001-01-18 15:55:59 +0000331
Walter Dörwald721adf92003-04-29 21:31:19 +0000332def test_main():
Nick Coghlan53663a62008-07-15 14:27:37 +0000333 test_support.run_unittest(VectorTest, NumberTest, MiscTest, DictTest, ListTest)
Guido van Rossum890f2092001-01-18 16:21:57 +0000334
Walter Dörwald721adf92003-04-29 21:31:19 +0000335if __name__ == "__main__":
336 test_main()