blob: 71ffff750fd41b6b643f8bba9499292a142b831f [file] [log] [blame]
Guido van Rossum7736b5b2008-01-15 21:44:53 +00001# Originally contributed by Sjoerd Mullender.
2# Significantly modified by Jeffrey Yasskin <jyasskin at gmail.com>.
3
4"""Rational, infinite-precision, real numbers."""
5
Guido van Rossum7736b5b2008-01-15 21:44:53 +00006import math
7import numbers
8import operator
Christian Heimes587c2bf2008-01-19 16:21:02 +00009import re
Guido van Rossum7736b5b2008-01-15 21:44:53 +000010
11__all__ = ["Rational"]
12
13RationalAbc = numbers.Rational
14
15
16def _gcd(a, b):
17 """Calculate the Greatest Common Divisor.
18
19 Unless b==0, the result will have the same sign as b (so that when
20 b is divided by it, the result comes out positive).
21 """
22 while b:
23 a, b = b, a%b
24 return a
25
26
27def _binary_float_to_ratio(x):
28 """x -> (top, bot), a pair of ints s.t. x = top/bot.
29
30 The conversion is done exactly, without rounding.
31 bot > 0 guaranteed.
32 Some form of binary fp is assumed.
33 Pass NaNs or infinities at your own risk.
34
35 >>> _binary_float_to_ratio(10.0)
36 (10, 1)
37 >>> _binary_float_to_ratio(0.0)
38 (0, 1)
39 >>> _binary_float_to_ratio(-.25)
40 (-1, 4)
41 """
42
43 if x == 0:
44 return 0, 1
45 f, e = math.frexp(x)
46 signbit = 1
47 if f < 0:
48 f = -f
49 signbit = -1
50 assert 0.5 <= f < 1.0
51 # x = signbit * f * 2**e exactly
52
53 # Suck up CHUNK bits at a time; 28 is enough so that we suck
54 # up all bits in 2 iterations for all known binary double-
55 # precision formats, and small enough to fit in an int.
56 CHUNK = 28
57 top = 0
58 # invariant: x = signbit * (top + f) * 2**e exactly
59 while f:
60 f = math.ldexp(f, CHUNK)
61 digit = trunc(f)
62 assert digit >> CHUNK == 0
63 top = (top << CHUNK) | digit
64 f = f - digit
65 assert 0.0 <= f < 1.0
66 e = e - CHUNK
67 assert top
68
69 # Add in the sign bit.
70 top = signbit * top
71
72 # now x = top * 2**e exactly; fold in 2**e
73 if e>0:
74 return (top * 2**e, 1)
75 else:
76 return (top, 2 ** -e)
77
78
Christian Heimes587c2bf2008-01-19 16:21:02 +000079_RATIONAL_FORMAT = re.compile(
80 r'^\s*(?P<sign>[-+]?)(?P<num>\d+)(?:/(?P<denom>\d+))?\s*$')
81
82
Guido van Rossum7736b5b2008-01-15 21:44:53 +000083class Rational(RationalAbc):
84 """This class implements rational numbers.
85
86 Rational(8, 6) will produce a rational number equivalent to
87 4/3. Both arguments must be Integral. The numerator defaults to 0
88 and the denominator defaults to 1 so that Rational(3) == 3 and
89 Rational() == 0.
90
Christian Heimes587c2bf2008-01-19 16:21:02 +000091 Rationals can also be constructed from strings of the form
92 '[-+]?[0-9]+(/[0-9]+)?', optionally surrounded by spaces.
93
Guido van Rossum7736b5b2008-01-15 21:44:53 +000094 """
95
96 __slots__ = ('_numerator', '_denominator')
97
Christian Heimes587c2bf2008-01-19 16:21:02 +000098 # We're immutable, so use __new__ not __init__
99 def __new__(cls, numerator=0, denominator=1):
100 """Constructs a Rational.
101
102 Takes a string, another Rational, or a numerator/denominator pair.
103
104 """
105 self = super(Rational, cls).__new__(cls)
106
107 if denominator == 1:
108 if isinstance(numerator, str):
109 # Handle construction from strings.
110 input = numerator
111 m = _RATIONAL_FORMAT.match(input)
112 if m is None:
113 raise ValueError('Invalid literal for Rational: ' + input)
114 numerator = int(m.group('num'))
115 # Default denominator to 1. That's the only optional group.
116 denominator = int(m.group('denom') or 1)
117 if m.group('sign') == '-':
118 numerator = -numerator
119
120 elif (not isinstance(numerator, numbers.Integral) and
121 isinstance(numerator, RationalAbc)):
122 # Handle copies from other rationals.
123 other_rational = numerator
124 numerator = other_rational.numerator
125 denominator = other_rational.denominator
Guido van Rossum7736b5b2008-01-15 21:44:53 +0000126
127 if (not isinstance(numerator, numbers.Integral) or
128 not isinstance(denominator, numbers.Integral)):
129 raise TypeError("Rational(%(numerator)s, %(denominator)s):"
130 " Both arguments must be integral." % locals())
131
132 if denominator == 0:
133 raise ZeroDivisionError('Rational(%s, 0)' % numerator)
134
135 g = _gcd(numerator, denominator)
136 self._numerator = int(numerator // g)
137 self._denominator = int(denominator // g)
Christian Heimes587c2bf2008-01-19 16:21:02 +0000138 return self
Guido van Rossum7736b5b2008-01-15 21:44:53 +0000139
140 @classmethod
141 def from_float(cls, f):
Christian Heimes587c2bf2008-01-19 16:21:02 +0000142 """Converts a finite float to a rational number, exactly.
143
144 Beware that Rational.from_float(0.3) != Rational(3, 10).
145
146 """
Guido van Rossum7736b5b2008-01-15 21:44:53 +0000147 if not isinstance(f, float):
148 raise TypeError("%s.from_float() only takes floats, not %r (%s)" %
149 (cls.__name__, f, type(f).__name__))
150 if math.isnan(f) or math.isinf(f):
151 raise TypeError("Cannot convert %r to %s." % (f, cls.__name__))
152 return cls(*_binary_float_to_ratio(f))
153
Christian Heimes587c2bf2008-01-19 16:21:02 +0000154 @classmethod
155 def from_decimal(cls, dec):
156 """Converts a finite Decimal instance to a rational number, exactly."""
157 from decimal import Decimal
158 if not isinstance(dec, Decimal):
159 raise TypeError(
160 "%s.from_decimal() only takes Decimals, not %r (%s)" %
161 (cls.__name__, dec, type(dec).__name__))
162 if not dec.is_finite():
163 # Catches infinities and nans.
164 raise TypeError("Cannot convert %s to %s." % (dec, cls.__name__))
165 sign, digits, exp = dec.as_tuple()
166 digits = int(''.join(map(str, digits)))
167 if sign:
168 digits = -digits
169 if exp >= 0:
170 return cls(digits * 10 ** exp)
171 else:
172 return cls(digits, 10 ** -exp)
173
Guido van Rossum7736b5b2008-01-15 21:44:53 +0000174 @property
175 def numerator(a):
176 return a._numerator
177
178 @property
179 def denominator(a):
180 return a._denominator
181
182 def __repr__(self):
183 """repr(self)"""
Christian Heimes587c2bf2008-01-19 16:21:02 +0000184 return ('Rational(%r,%r)' % (self.numerator, self.denominator))
Guido van Rossum7736b5b2008-01-15 21:44:53 +0000185
186 def __str__(self):
187 """str(self)"""
188 if self.denominator == 1:
189 return str(self.numerator)
190 else:
Christian Heimes587c2bf2008-01-19 16:21:02 +0000191 return '%s/%s' % (self.numerator, self.denominator)
Guido van Rossum7736b5b2008-01-15 21:44:53 +0000192
193 def _operator_fallbacks(monomorphic_operator, fallback_operator):
194 """Generates forward and reverse operators given a purely-rational
195 operator and a function from the operator module.
196
197 Use this like:
198 __op__, __rop__ = _operator_fallbacks(just_rational_op, operator.op)
199
200 """
201 def forward(a, b):
202 if isinstance(b, RationalAbc):
203 # Includes ints.
204 return monomorphic_operator(a, b)
205 elif isinstance(b, float):
206 return fallback_operator(float(a), b)
207 elif isinstance(b, complex):
208 return fallback_operator(complex(a), b)
209 else:
210 return NotImplemented
211 forward.__name__ = '__' + fallback_operator.__name__ + '__'
212 forward.__doc__ = monomorphic_operator.__doc__
213
214 def reverse(b, a):
215 if isinstance(a, RationalAbc):
216 # Includes ints.
217 return monomorphic_operator(a, b)
218 elif isinstance(a, numbers.Real):
219 return fallback_operator(float(a), float(b))
220 elif isinstance(a, numbers.Complex):
221 return fallback_operator(complex(a), complex(b))
222 else:
223 return NotImplemented
224 reverse.__name__ = '__r' + fallback_operator.__name__ + '__'
225 reverse.__doc__ = monomorphic_operator.__doc__
226
227 return forward, reverse
228
229 def _add(a, b):
230 """a + b"""
231 return Rational(a.numerator * b.denominator +
232 b.numerator * a.denominator,
233 a.denominator * b.denominator)
234
235 __add__, __radd__ = _operator_fallbacks(_add, operator.add)
236
237 def _sub(a, b):
238 """a - b"""
239 return Rational(a.numerator * b.denominator -
240 b.numerator * a.denominator,
241 a.denominator * b.denominator)
242
243 __sub__, __rsub__ = _operator_fallbacks(_sub, operator.sub)
244
245 def _mul(a, b):
246 """a * b"""
247 return Rational(a.numerator * b.numerator, a.denominator * b.denominator)
248
249 __mul__, __rmul__ = _operator_fallbacks(_mul, operator.mul)
250
251 def _div(a, b):
252 """a / b"""
253 return Rational(a.numerator * b.denominator,
254 a.denominator * b.numerator)
255
256 __truediv__, __rtruediv__ = _operator_fallbacks(_div, operator.truediv)
Guido van Rossum7736b5b2008-01-15 21:44:53 +0000257
258 def __floordiv__(a, b):
259 """a // b"""
Jeffrey Yasskin9893de12008-01-17 07:36:30 +0000260 return math.floor(a / b)
Guido van Rossum7736b5b2008-01-15 21:44:53 +0000261
262 def __rfloordiv__(b, a):
263 """a // b"""
Jeffrey Yasskin9893de12008-01-17 07:36:30 +0000264 return math.floor(a / b)
Guido van Rossum7736b5b2008-01-15 21:44:53 +0000265
266 @classmethod
267 def _mod(cls, a, b):
268 div = a // b
269 return a - b * div
270
271 def __mod__(a, b):
272 """a % b"""
273 return a._mod(a, b)
274
275 def __rmod__(b, a):
276 """a % b"""
277 return b._mod(a, b)
278
279 def __pow__(a, b):
280 """a ** b
281
282 If b is not an integer, the result will be a float or complex
283 since roots are generally irrational. If b is an integer, the
284 result will be rational.
285
286 """
287 if isinstance(b, RationalAbc):
288 if b.denominator == 1:
289 power = b.numerator
290 if power >= 0:
291 return Rational(a.numerator ** power,
292 a.denominator ** power)
293 else:
294 return Rational(a.denominator ** -power,
295 a.numerator ** -power)
296 else:
297 # A fractional power will generally produce an
298 # irrational number.
299 return float(a) ** float(b)
300 else:
301 return float(a) ** b
302
303 def __rpow__(b, a):
304 """a ** b"""
305 if b.denominator == 1 and b.numerator >= 0:
306 # If a is an int, keep it that way if possible.
307 return a ** b.numerator
308
309 if isinstance(a, RationalAbc):
310 return Rational(a.numerator, a.denominator) ** b
311
312 if b.denominator == 1:
313 return a ** b.numerator
314
315 return a ** float(b)
316
317 def __pos__(a):
318 """+a: Coerces a subclass instance to Rational"""
319 return Rational(a.numerator, a.denominator)
320
321 def __neg__(a):
322 """-a"""
323 return Rational(-a.numerator, a.denominator)
324
325 def __abs__(a):
326 """abs(a)"""
327 return Rational(abs(a.numerator), a.denominator)
328
329 def __trunc__(a):
330 """trunc(a)"""
331 if a.numerator < 0:
332 return -(-a.numerator // a.denominator)
333 else:
334 return a.numerator // a.denominator
335
336 def __floor__(a):
337 """Will be math.floor(a) in 3.0."""
338 return a.numerator // a.denominator
339
340 def __ceil__(a):
341 """Will be math.ceil(a) in 3.0."""
342 # The negations cleverly convince floordiv to return the ceiling.
343 return -(-a.numerator // a.denominator)
344
345 def __round__(self, ndigits=None):
346 """Will be round(self, ndigits) in 3.0.
347
348 Rounds half toward even.
349 """
350 if ndigits is None:
351 floor, remainder = divmod(self.numerator, self.denominator)
352 if remainder * 2 < self.denominator:
353 return floor
354 elif remainder * 2 > self.denominator:
355 return floor + 1
356 # Deal with the half case:
357 elif floor % 2 == 0:
358 return floor
359 else:
360 return floor + 1
361 shift = 10**abs(ndigits)
362 # See _operator_fallbacks.forward to check that the results of
363 # these operations will always be Rational and therefore have
Jeffrey Yasskin9893de12008-01-17 07:36:30 +0000364 # round().
Guido van Rossum7736b5b2008-01-15 21:44:53 +0000365 if ndigits > 0:
Jeffrey Yasskin9893de12008-01-17 07:36:30 +0000366 return Rational(round(self * shift), shift)
Guido van Rossum7736b5b2008-01-15 21:44:53 +0000367 else:
Jeffrey Yasskin9893de12008-01-17 07:36:30 +0000368 return Rational(round(self / shift) * shift)
Guido van Rossum7736b5b2008-01-15 21:44:53 +0000369
370 def __hash__(self):
371 """hash(self)
372
373 Tricky because values that are exactly representable as a
374 float must have the same hash as that float.
375
376 """
377 if self.denominator == 1:
378 # Get integers right.
379 return hash(self.numerator)
380 # Expensive check, but definitely correct.
381 if self == float(self):
382 return hash(float(self))
383 else:
384 # Use tuple's hash to avoid a high collision rate on
385 # simple fractions.
386 return hash((self.numerator, self.denominator))
387
388 def __eq__(a, b):
389 """a == b"""
390 if isinstance(b, RationalAbc):
391 return (a.numerator == b.numerator and
392 a.denominator == b.denominator)
393 if isinstance(b, numbers.Complex) and b.imag == 0:
394 b = b.real
395 if isinstance(b, float):
396 return a == a.from_float(b)
397 else:
398 # XXX: If b.__eq__ is implemented like this method, it may
399 # give the wrong answer after float(a) changes a's
400 # value. Better ways of doing this are welcome.
401 return float(a) == b
402
403 def _subtractAndCompareToZero(a, b, op):
404 """Helper function for comparison operators.
405
406 Subtracts b from a, exactly if possible, and compares the
407 result with 0 using op, in such a way that the comparison
408 won't recurse. If the difference raises a TypeError, returns
409 NotImplemented instead.
410
411 """
412 if isinstance(b, numbers.Complex) and b.imag == 0:
413 b = b.real
414 if isinstance(b, float):
415 b = a.from_float(b)
416 try:
417 # XXX: If b <: Real but not <: RationalAbc, this is likely
418 # to fall back to a float. If the actual values differ by
419 # less than MIN_FLOAT, this could falsely call them equal,
420 # which would make <= inconsistent with ==. Better ways of
421 # doing this are welcome.
422 diff = a - b
423 except TypeError:
424 return NotImplemented
425 if isinstance(diff, RationalAbc):
426 return op(diff.numerator, 0)
427 return op(diff, 0)
428
429 def __lt__(a, b):
430 """a < b"""
431 return a._subtractAndCompareToZero(b, operator.lt)
432
433 def __gt__(a, b):
434 """a > b"""
435 return a._subtractAndCompareToZero(b, operator.gt)
436
437 def __le__(a, b):
438 """a <= b"""
439 return a._subtractAndCompareToZero(b, operator.le)
440
441 def __ge__(a, b):
442 """a >= b"""
443 return a._subtractAndCompareToZero(b, operator.ge)
444
445 def __bool__(a):
446 """a != 0"""
447 return a.numerator != 0