Guido van Rossum | c4a6e8b | 2001-01-18 15:48:05 +0000 | [diff] [blame] | 1 | # Tests for rich comparisons |
| 2 | |
Barry Warsaw | 04f357c | 2002-07-23 19:04:11 +0000 | [diff] [blame] | 3 | from test.test_support import TestFailed, verify, verbose |
Guido van Rossum | c4a6e8b | 2001-01-18 15:48:05 +0000 | [diff] [blame] | 4 | |
| 5 | class Number: |
| 6 | |
| 7 | def __init__(self, x): |
| 8 | self.x = x |
| 9 | |
| 10 | def __lt__(self, other): |
| 11 | return self.x < other |
| 12 | |
| 13 | def __le__(self, other): |
| 14 | return self.x <= other |
| 15 | |
| 16 | def __eq__(self, other): |
| 17 | return self.x == other |
| 18 | |
| 19 | def __ne__(self, other): |
| 20 | return self.x != other |
| 21 | |
| 22 | def __gt__(self, other): |
| 23 | return self.x > other |
| 24 | |
| 25 | def __ge__(self, other): |
| 26 | return self.x >= other |
| 27 | |
| 28 | def __cmp__(self, other): |
| 29 | raise TestFailed, "Number.__cmp__() should not be called" |
| 30 | |
| 31 | def __repr__(self): |
| 32 | return "Number(%s)" % repr(self.x) |
| 33 | |
| 34 | class Vector: |
| 35 | |
| 36 | def __init__(self, data): |
| 37 | self.data = data |
| 38 | |
| 39 | def __len__(self): |
| 40 | return len(self.data) |
| 41 | |
| 42 | def __getitem__(self, i): |
| 43 | return self.data[i] |
| 44 | |
| 45 | def __setitem__(self, i, v): |
| 46 | self.data[i] = v |
| 47 | |
| 48 | def __hash__(self): |
| 49 | raise TypeError, "Vectors cannot be hashed" |
| 50 | |
| 51 | def __nonzero__(self): |
| 52 | raise TypeError, "Vectors cannot be used in Boolean contexts" |
| 53 | |
| 54 | def __cmp__(self, other): |
| 55 | raise TestFailed, "Vector.__cmp__() should not be called" |
| 56 | |
| 57 | def __repr__(self): |
| 58 | return "Vector(%s)" % repr(self.data) |
| 59 | |
| 60 | def __lt__(self, other): |
| 61 | return Vector([a < b for a, b in zip(self.data, self.__cast(other))]) |
| 62 | |
| 63 | def __le__(self, other): |
| 64 | return Vector([a <= b for a, b in zip(self.data, self.__cast(other))]) |
| 65 | |
| 66 | def __eq__(self, other): |
| 67 | return Vector([a == b for a, b in zip(self.data, self.__cast(other))]) |
| 68 | |
| 69 | def __ne__(self, other): |
| 70 | return Vector([a != b for a, b in zip(self.data, self.__cast(other))]) |
| 71 | |
| 72 | def __gt__(self, other): |
| 73 | return Vector([a > b for a, b in zip(self.data, self.__cast(other))]) |
| 74 | |
| 75 | def __ge__(self, other): |
| 76 | return Vector([a >= b for a, b in zip(self.data, self.__cast(other))]) |
| 77 | |
| 78 | def __cast(self, other): |
| 79 | if isinstance(other, Vector): |
| 80 | other = other.data |
| 81 | if len(self.data) != len(other): |
| 82 | raise ValueError, "Cannot compare vectors of different length" |
| 83 | return other |
| 84 | |
| 85 | operators = "<", "<=", "==", "!=", ">", ">=" |
| 86 | opmap = {} |
| 87 | for op in operators: |
| 88 | opmap[op] = eval("lambda a, b: a %s b" % op) |
| 89 | |
| 90 | def testvector(): |
| 91 | a = Vector(range(2)) |
| 92 | b = Vector(range(3)) |
| 93 | for op in operators: |
| 94 | try: |
| 95 | opmap[op](a, b) |
| 96 | except ValueError: |
| 97 | pass |
| 98 | else: |
| 99 | raise TestFailed, "a %s b for different length should fail" % op |
| 100 | a = Vector(range(5)) |
| 101 | b = Vector(5 * [2]) |
| 102 | for op in operators: |
| 103 | print "%23s %-2s %-23s -> %s" % (a, op, b, opmap[op](a, b)) |
| 104 | print "%23s %-2s %-23s -> %s" % (a, op, b.data, opmap[op](a, b.data)) |
| 105 | print "%23s %-2s %-23s -> %s" % (a.data, op, b, opmap[op](a.data, b)) |
| 106 | try: |
| 107 | if opmap[op](a, b): |
| 108 | raise TestFailed, "a %s b shouldn't be true" % op |
| 109 | else: |
| 110 | raise TestFailed, "a %s b shouldn't be false" % op |
| 111 | except TypeError: |
| 112 | pass |
| 113 | |
| 114 | def testop(a, b, op): |
| 115 | try: |
| 116 | ax = a.x |
| 117 | except AttributeError: |
| 118 | ax = a |
| 119 | try: |
| 120 | bx = b.x |
| 121 | except AttributeError: |
| 122 | bx = b |
| 123 | opfunc = opmap[op] |
| 124 | realoutcome = opfunc(ax, bx) |
| 125 | testoutcome = opfunc(a, b) |
| 126 | if realoutcome != testoutcome: |
| 127 | print "Error for", a, op, b, ": expected", realoutcome, |
| 128 | print "but got", testoutcome |
| 129 | ## else: |
| 130 | ## print a, op, b, "-->", testoutcome # and "true" or "false" |
| 131 | |
| 132 | def testit(a, b): |
| 133 | testop(a, b, "<") |
| 134 | testop(a, b, "<=") |
| 135 | testop(a, b, "==") |
| 136 | testop(a, b, "!=") |
| 137 | testop(a, b, ">") |
| 138 | testop(a, b, ">=") |
| 139 | |
| 140 | def basic(): |
| 141 | for a in range(3): |
| 142 | for b in range(3): |
| 143 | testit(Number(a), Number(b)) |
| 144 | testit(a, Number(b)) |
| 145 | testit(Number(a), b) |
| 146 | |
| 147 | def tabulate(c1=Number, c2=Number): |
| 148 | for op in operators: |
| 149 | opfunc = opmap[op] |
| 150 | print |
| 151 | print "operator:", op |
| 152 | print |
| 153 | print "%9s" % "", |
| 154 | for b in range(3): |
| 155 | b = c2(b) |
| 156 | print "| %9s" % b, |
| 157 | print "|" |
| 158 | print '----------+-' * 4 |
| 159 | for a in range(3): |
| 160 | a = c1(a) |
| 161 | print "%9s" % a, |
| 162 | for b in range(3): |
| 163 | b = c2(b) |
| 164 | print "| %9s" % opfunc(a, b), |
| 165 | print "|" |
| 166 | print '----------+-' * 4 |
| 167 | print |
| 168 | print '*' * 50 |
Tim Peters | 8880f6d | 2001-01-19 06:12:17 +0000 | [diff] [blame] | 169 | |
Guido van Rossum | 9710bd5 | 2001-01-18 15:55:59 +0000 | [diff] [blame] | 170 | def misbehavin(): |
| 171 | class Misb: |
| 172 | def __lt__(self, other): return 0 |
| 173 | def __gt__(self, other): return 0 |
| 174 | def __eq__(self, other): return 0 |
| 175 | def __le__(self, other): raise TestFailed, "This shouldn't happen" |
| 176 | def __ge__(self, other): raise TestFailed, "This shouldn't happen" |
| 177 | def __ne__(self, other): raise TestFailed, "This shouldn't happen" |
| 178 | def __cmp__(self, other): raise RuntimeError, "expected" |
| 179 | a = Misb() |
| 180 | b = Misb() |
| 181 | verify((a<b) == 0) |
| 182 | verify((a==b) == 0) |
| 183 | verify((a>b) == 0) |
| 184 | try: |
| 185 | print cmp(a, b) |
| 186 | except RuntimeError: |
| 187 | pass |
| 188 | else: |
| 189 | raise TestFailed, "cmp(Misb(), Misb()) didn't raise RuntimeError" |
| 190 | |
Guido van Rossum | 890f209 | 2001-01-18 16:21:57 +0000 | [diff] [blame] | 191 | def recursion(): |
| 192 | from UserList import UserList |
| 193 | a = UserList(); a.append(a) |
| 194 | b = UserList(); b.append(b) |
| 195 | def check(s, a=a, b=b): |
| 196 | if verbose: |
Guido van Rossum | 4e8db2e | 2001-01-18 21:52:26 +0000 | [diff] [blame] | 197 | print "check", s |
| 198 | try: |
| 199 | if not eval(s): |
| 200 | raise TestFailed, s + " was false but expected to be true" |
| 201 | except RuntimeError, msg: |
| 202 | raise TestFailed, str(msg) |
Guido van Rossum | 890f209 | 2001-01-18 16:21:57 +0000 | [diff] [blame] | 203 | if verbose: |
| 204 | print "recursion tests: a=%s, b=%s" % (a, b) |
| 205 | check('a==b') |
Guido van Rossum | 890f209 | 2001-01-18 16:21:57 +0000 | [diff] [blame] | 206 | check('not a!=b') |
Guido van Rossum | 890f209 | 2001-01-18 16:21:57 +0000 | [diff] [blame] | 207 | a.append(1) |
Guido van Rossum | 4e8db2e | 2001-01-18 21:52:26 +0000 | [diff] [blame] | 208 | if verbose: |
| 209 | print "recursion tests: a=%s, b=%s" % (a, b) |
| 210 | check('a!=b') |
| 211 | check('not a==b') |
Guido van Rossum | 890f209 | 2001-01-18 16:21:57 +0000 | [diff] [blame] | 212 | b.append(0) |
| 213 | if verbose: |
| 214 | print "recursion tests: a=%s, b=%s" % (a, b) |
Guido van Rossum | 890f209 | 2001-01-18 16:21:57 +0000 | [diff] [blame] | 215 | check('a!=b') |
Guido van Rossum | 890f209 | 2001-01-18 16:21:57 +0000 | [diff] [blame] | 216 | check('not a==b') |
Guido van Rossum | 890f209 | 2001-01-18 16:21:57 +0000 | [diff] [blame] | 217 | a[1] = -1 |
| 218 | if verbose: |
| 219 | print "recursion tests: a=%s, b=%s" % (a, b) |
Guido van Rossum | 890f209 | 2001-01-18 16:21:57 +0000 | [diff] [blame] | 220 | check('a!=b') |
Guido van Rossum | 890f209 | 2001-01-18 16:21:57 +0000 | [diff] [blame] | 221 | check('not a==b') |
Guido van Rossum | 890f209 | 2001-01-18 16:21:57 +0000 | [diff] [blame] | 222 | if verbose: print "recursion tests ok" |
| 223 | |
Tim Peters | e63415e | 2001-05-08 04:38:29 +0000 | [diff] [blame] | 224 | def dicts(): |
| 225 | # Verify that __eq__ and __ne__ work for dicts even if the keys and |
| 226 | # values don't support anything other than __eq__ and __ne__. Complex |
| 227 | # numbers are a fine example of that. |
| 228 | import random |
| 229 | imag1a = {} |
| 230 | for i in range(50): |
| 231 | imag1a[random.randrange(100)*1j] = random.randrange(100)*1j |
| 232 | items = imag1a.items() |
| 233 | random.shuffle(items) |
| 234 | imag1b = {} |
| 235 | for k, v in items: |
| 236 | imag1b[k] = v |
| 237 | imag2 = imag1b.copy() |
| 238 | imag2[k] = v + 1.0 |
| 239 | verify(imag1a == imag1a, "imag1a == imag1a should have worked") |
| 240 | verify(imag1a == imag1b, "imag1a == imag1b should have worked") |
| 241 | verify(imag2 == imag2, "imag2 == imag2 should have worked") |
| 242 | verify(imag1a != imag2, "imag1a != imag2 should have worked") |
| 243 | for op in "<", "<=", ">", ">=": |
| 244 | try: |
| 245 | eval("imag1a %s imag2" % op) |
| 246 | except TypeError: |
| 247 | pass |
| 248 | else: |
| 249 | raise TestFailed("expected TypeError from imag1a %s imag2" % op) |
| 250 | |
Guido van Rossum | c4a6e8b | 2001-01-18 15:48:05 +0000 | [diff] [blame] | 251 | def main(): |
| 252 | basic() |
| 253 | tabulate() |
| 254 | tabulate(c1=int) |
| 255 | tabulate(c2=int) |
| 256 | testvector() |
Guido van Rossum | 9710bd5 | 2001-01-18 15:55:59 +0000 | [diff] [blame] | 257 | misbehavin() |
Guido van Rossum | 890f209 | 2001-01-18 16:21:57 +0000 | [diff] [blame] | 258 | recursion() |
Tim Peters | e63415e | 2001-05-08 04:38:29 +0000 | [diff] [blame] | 259 | dicts() |
Guido van Rossum | c4a6e8b | 2001-01-18 15:48:05 +0000 | [diff] [blame] | 260 | |
| 261 | main() |