blob: 4e7d45925d220431dac67ae5204c5906751d3fbb [file] [log] [blame]
Guido van Rossumc4a6e8b2001-01-18 15:48:05 +00001# Tests for rich comparisons
2
Guido van Rossum890f2092001-01-18 16:21:57 +00003from test_support import TestFailed, verify, verbose
Guido van Rossumc4a6e8b2001-01-18 15:48:05 +00004
5class Number:
6
7 def __init__(self, x):
8 self.x = x
9
10 def __lt__(self, other):
11 return self.x < other
12
13 def __le__(self, other):
14 return self.x <= other
15
16 def __eq__(self, other):
17 return self.x == other
18
19 def __ne__(self, other):
20 return self.x != other
21
22 def __gt__(self, other):
23 return self.x > other
24
25 def __ge__(self, other):
26 return self.x >= other
27
28 def __cmp__(self, other):
29 raise TestFailed, "Number.__cmp__() should not be called"
30
31 def __repr__(self):
32 return "Number(%s)" % repr(self.x)
33
34class Vector:
35
36 def __init__(self, data):
37 self.data = data
38
39 def __len__(self):
40 return len(self.data)
41
42 def __getitem__(self, i):
43 return self.data[i]
44
45 def __setitem__(self, i, v):
46 self.data[i] = v
47
48 def __hash__(self):
49 raise TypeError, "Vectors cannot be hashed"
50
51 def __nonzero__(self):
52 raise TypeError, "Vectors cannot be used in Boolean contexts"
53
54 def __cmp__(self, other):
55 raise TestFailed, "Vector.__cmp__() should not be called"
56
57 def __repr__(self):
58 return "Vector(%s)" % repr(self.data)
59
60 def __lt__(self, other):
61 return Vector([a < b for a, b in zip(self.data, self.__cast(other))])
62
63 def __le__(self, other):
64 return Vector([a <= b for a, b in zip(self.data, self.__cast(other))])
65
66 def __eq__(self, other):
67 return Vector([a == b for a, b in zip(self.data, self.__cast(other))])
68
69 def __ne__(self, other):
70 return Vector([a != b for a, b in zip(self.data, self.__cast(other))])
71
72 def __gt__(self, other):
73 return Vector([a > b for a, b in zip(self.data, self.__cast(other))])
74
75 def __ge__(self, other):
76 return Vector([a >= b for a, b in zip(self.data, self.__cast(other))])
77
78 def __cast(self, other):
79 if isinstance(other, Vector):
80 other = other.data
81 if len(self.data) != len(other):
82 raise ValueError, "Cannot compare vectors of different length"
83 return other
84
85operators = "<", "<=", "==", "!=", ">", ">="
86opmap = {}
87for op in operators:
88 opmap[op] = eval("lambda a, b: a %s b" % op)
89
90def testvector():
91 a = Vector(range(2))
92 b = Vector(range(3))
93 for op in operators:
94 try:
95 opmap[op](a, b)
96 except ValueError:
97 pass
98 else:
99 raise TestFailed, "a %s b for different length should fail" % op
100 a = Vector(range(5))
101 b = Vector(5 * [2])
102 for op in operators:
103 print "%23s %-2s %-23s -> %s" % (a, op, b, opmap[op](a, b))
104 print "%23s %-2s %-23s -> %s" % (a, op, b.data, opmap[op](a, b.data))
105 print "%23s %-2s %-23s -> %s" % (a.data, op, b, opmap[op](a.data, b))
106 try:
107 if opmap[op](a, b):
108 raise TestFailed, "a %s b shouldn't be true" % op
109 else:
110 raise TestFailed, "a %s b shouldn't be false" % op
111 except TypeError:
112 pass
113
114def testop(a, b, op):
115 try:
116 ax = a.x
117 except AttributeError:
118 ax = a
119 try:
120 bx = b.x
121 except AttributeError:
122 bx = b
123 opfunc = opmap[op]
124 realoutcome = opfunc(ax, bx)
125 testoutcome = opfunc(a, b)
126 if realoutcome != testoutcome:
127 print "Error for", a, op, b, ": expected", realoutcome,
128 print "but got", testoutcome
129## else:
130## print a, op, b, "-->", testoutcome # and "true" or "false"
131
132def testit(a, b):
133 testop(a, b, "<")
134 testop(a, b, "<=")
135 testop(a, b, "==")
136 testop(a, b, "!=")
137 testop(a, b, ">")
138 testop(a, b, ">=")
139
140def basic():
141 for a in range(3):
142 for b in range(3):
143 testit(Number(a), Number(b))
144 testit(a, Number(b))
145 testit(Number(a), b)
146
147def tabulate(c1=Number, c2=Number):
148 for op in operators:
149 opfunc = opmap[op]
150 print
151 print "operator:", op
152 print
153 print "%9s" % "",
154 for b in range(3):
155 b = c2(b)
156 print "| %9s" % b,
157 print "|"
158 print '----------+-' * 4
159 for a in range(3):
160 a = c1(a)
161 print "%9s" % a,
162 for b in range(3):
163 b = c2(b)
164 print "| %9s" % opfunc(a, b),
165 print "|"
166 print '----------+-' * 4
167 print
168 print '*' * 50
Tim Peters8880f6d2001-01-19 06:12:17 +0000169
Guido van Rossum9710bd52001-01-18 15:55:59 +0000170def misbehavin():
171 class Misb:
172 def __lt__(self, other): return 0
173 def __gt__(self, other): return 0
174 def __eq__(self, other): return 0
175 def __le__(self, other): raise TestFailed, "This shouldn't happen"
176 def __ge__(self, other): raise TestFailed, "This shouldn't happen"
177 def __ne__(self, other): raise TestFailed, "This shouldn't happen"
178 def __cmp__(self, other): raise RuntimeError, "expected"
179 a = Misb()
180 b = Misb()
181 verify((a<b) == 0)
182 verify((a==b) == 0)
183 verify((a>b) == 0)
184 try:
185 print cmp(a, b)
186 except RuntimeError:
187 pass
188 else:
189 raise TestFailed, "cmp(Misb(), Misb()) didn't raise RuntimeError"
190
Guido van Rossum890f2092001-01-18 16:21:57 +0000191def recursion():
192 from UserList import UserList
193 a = UserList(); a.append(a)
194 b = UserList(); b.append(b)
195 def check(s, a=a, b=b):
196 if verbose:
Guido van Rossum4e8db2e2001-01-18 21:52:26 +0000197 print "check", s
198 try:
199 if not eval(s):
200 raise TestFailed, s + " was false but expected to be true"
201 except RuntimeError, msg:
202 raise TestFailed, str(msg)
Guido van Rossum890f2092001-01-18 16:21:57 +0000203 if verbose:
204 print "recursion tests: a=%s, b=%s" % (a, b)
205 check('a==b')
Guido van Rossum890f2092001-01-18 16:21:57 +0000206 check('not a!=b')
Guido van Rossum890f2092001-01-18 16:21:57 +0000207 a.append(1)
Guido van Rossum4e8db2e2001-01-18 21:52:26 +0000208 if verbose:
209 print "recursion tests: a=%s, b=%s" % (a, b)
210 check('a!=b')
211 check('not a==b')
Guido van Rossum890f2092001-01-18 16:21:57 +0000212 b.append(0)
213 if verbose:
214 print "recursion tests: a=%s, b=%s" % (a, b)
Guido van Rossum890f2092001-01-18 16:21:57 +0000215 check('a!=b')
Guido van Rossum890f2092001-01-18 16:21:57 +0000216 check('not a==b')
Guido van Rossum890f2092001-01-18 16:21:57 +0000217 a[1] = -1
218 if verbose:
219 print "recursion tests: a=%s, b=%s" % (a, b)
Guido van Rossum890f2092001-01-18 16:21:57 +0000220 check('a!=b')
Guido van Rossum890f2092001-01-18 16:21:57 +0000221 check('not a==b')
Guido van Rossum890f2092001-01-18 16:21:57 +0000222 if verbose: print "recursion tests ok"
223
Tim Peterse63415e2001-05-08 04:38:29 +0000224def dicts():
225 # Verify that __eq__ and __ne__ work for dicts even if the keys and
226 # values don't support anything other than __eq__ and __ne__. Complex
227 # numbers are a fine example of that.
228 import random
229 imag1a = {}
230 for i in range(50):
231 imag1a[random.randrange(100)*1j] = random.randrange(100)*1j
232 items = imag1a.items()
233 random.shuffle(items)
234 imag1b = {}
235 for k, v in items:
236 imag1b[k] = v
237 imag2 = imag1b.copy()
238 imag2[k] = v + 1.0
239 verify(imag1a == imag1a, "imag1a == imag1a should have worked")
240 verify(imag1a == imag1b, "imag1a == imag1b should have worked")
241 verify(imag2 == imag2, "imag2 == imag2 should have worked")
242 verify(imag1a != imag2, "imag1a != imag2 should have worked")
243 for op in "<", "<=", ">", ">=":
244 try:
245 eval("imag1a %s imag2" % op)
246 except TypeError:
247 pass
248 else:
249 raise TestFailed("expected TypeError from imag1a %s imag2" % op)
250
Guido van Rossumc4a6e8b2001-01-18 15:48:05 +0000251def main():
252 basic()
253 tabulate()
254 tabulate(c1=int)
255 tabulate(c2=int)
256 testvector()
Guido van Rossum9710bd52001-01-18 15:55:59 +0000257 misbehavin()
Guido van Rossum890f2092001-01-18 16:21:57 +0000258 recursion()
Tim Peterse63415e2001-05-08 04:38:29 +0000259 dicts()
Guido van Rossumc4a6e8b2001-01-18 15:48:05 +0000260
261main()