blob: 5be70e5ebf4ebb5dfe1f17598a94c76c720c9bf1 [file] [log] [blame]
Larry Hastingsf5e987b2013-10-19 11:50:09 -07001"""
2Basic statistics module.
3
4This module provides functions for calculating statistics of data, including
5averages, variance, and standard deviation.
6
7Calculating averages
8--------------------
9
Raymond Hettinger9013ccf2019-04-23 00:06:35 -070010================== ==================================================
Larry Hastingsf5e987b2013-10-19 11:50:09 -070011Function Description
Raymond Hettinger9013ccf2019-04-23 00:06:35 -070012================== ==================================================
Larry Hastingsf5e987b2013-10-19 11:50:09 -070013mean Arithmetic mean (average) of data.
Raymond Hettinger72800482019-04-23 01:35:16 -070014fmean Fast, floating point arithmetic mean.
Raymond Hettinger6463ba32019-04-07 09:20:03 -070015geometric_mean Geometric mean of data.
Steven D'Apranoa474afd2016-08-09 12:49:01 +100016harmonic_mean Harmonic mean of data.
Larry Hastingsf5e987b2013-10-19 11:50:09 -070017median Median (middle value) of data.
18median_low Low median of data.
19median_high High median of data.
20median_grouped Median, or 50th percentile, of grouped data.
21mode Mode (most common value) of data.
Raymond Hettinger6463ba32019-04-07 09:20:03 -070022multimode List of modes (most common values of data).
Raymond Hettinger9013ccf2019-04-23 00:06:35 -070023quantiles Divide data into intervals with equal probability.
24================== ==================================================
Larry Hastingsf5e987b2013-10-19 11:50:09 -070025
26Calculate the arithmetic mean ("the average") of data:
27
28>>> mean([-1.0, 2.5, 3.25, 5.75])
292.625
30
31
32Calculate the standard median of discrete data:
33
34>>> median([2, 3, 4, 5])
353.5
36
37
38Calculate the median, or 50th percentile, of data grouped into class intervals
39centred on the data values provided. E.g. if your data points are rounded to
40the nearest whole number:
41
42>>> median_grouped([2, 2, 3, 3, 3, 4]) #doctest: +ELLIPSIS
432.8333333333...
44
45This should be interpreted in this way: you have two data points in the class
46interval 1.5-2.5, three data points in the class interval 2.5-3.5, and one in
47the class interval 3.5-4.5. The median of these data points is 2.8333...
48
49
50Calculating variability or spread
51---------------------------------
52
53================== =============================================
54Function Description
55================== =============================================
56pvariance Population variance of data.
57variance Sample variance of data.
58pstdev Population standard deviation of data.
59stdev Sample standard deviation of data.
60================== =============================================
61
62Calculate the standard deviation of sample data:
63
64>>> stdev([2.5, 3.25, 5.5, 11.25, 11.75]) #doctest: +ELLIPSIS
654.38961843444...
66
67If you have previously calculated the mean, you can pass it as the optional
68second argument to the four "spread" functions to avoid recalculating it:
69
70>>> data = [1, 2, 2, 4, 4, 4, 5, 6]
71>>> mu = mean(data)
72>>> pvariance(data, mu)
732.5
74
75
76Exceptions
77----------
78
79A single exception is defined: StatisticsError is a subclass of ValueError.
80
81"""
82
Raymond Hettinger9013ccf2019-04-23 00:06:35 -070083__all__ = [ 'StatisticsError', 'NormalDist', 'quantiles',
Larry Hastingsf5e987b2013-10-19 11:50:09 -070084 'pstdev', 'pvariance', 'stdev', 'variance',
85 'median', 'median_low', 'median_high', 'median_grouped',
Raymond Hettingerfc06a192019-03-12 00:43:27 -070086 'mean', 'mode', 'multimode', 'harmonic_mean', 'fmean',
Raymond Hettinger6463ba32019-04-07 09:20:03 -070087 'geometric_mean',
Larry Hastingsf5e987b2013-10-19 11:50:09 -070088 ]
89
Larry Hastingsf5e987b2013-10-19 11:50:09 -070090import math
Steven D'Apranoa474afd2016-08-09 12:49:01 +100091import numbers
Raymond Hettinger11c79532019-02-23 14:44:07 -080092import random
Larry Hastingsf5e987b2013-10-19 11:50:09 -070093
94from fractions import Fraction
95from decimal import Decimal
Victor Stinnerd6debb22017-03-27 16:05:26 +020096from itertools import groupby
Steven D'Aprano3b06e242016-05-05 03:54:29 +100097from bisect import bisect_left, bisect_right
Raymond Hettinger318d5372019-03-06 22:59:40 -080098from math import hypot, sqrt, fabs, exp, erf, tau, log, fsum
Raymond Hettingerfc06a192019-03-12 00:43:27 -070099from operator import itemgetter
100from collections import Counter
Larry Hastingsf5e987b2013-10-19 11:50:09 -0700101
102# === Exceptions ===
103
104class StatisticsError(ValueError):
105 pass
106
107
108# === Private utilities ===
109
110def _sum(data, start=0):
Steven D'Apranob28c3272015-12-01 19:59:53 +1100111 """_sum(data [, start]) -> (type, sum, count)
Larry Hastingsf5e987b2013-10-19 11:50:09 -0700112
Steven D'Apranob28c3272015-12-01 19:59:53 +1100113 Return a high-precision sum of the given numeric data as a fraction,
114 together with the type to be converted to and the count of items.
115
116 If optional argument ``start`` is given, it is added to the total.
117 If ``data`` is empty, ``start`` (defaulting to 0) is returned.
Larry Hastingsf5e987b2013-10-19 11:50:09 -0700118
119
120 Examples
121 --------
122
123 >>> _sum([3, 2.25, 4.5, -0.5, 1.0], 0.75)
Benjamin Petersonab078e92016-07-13 21:13:29 -0700124 (<class 'float'>, Fraction(11, 1), 5)
Larry Hastingsf5e987b2013-10-19 11:50:09 -0700125
126 Some sources of round-off error will be avoided:
127
Steven D'Apranoa474afd2016-08-09 12:49:01 +1000128 # Built-in sum returns zero.
129 >>> _sum([1e50, 1, -1e50] * 1000)
Benjamin Petersonab078e92016-07-13 21:13:29 -0700130 (<class 'float'>, Fraction(1000, 1), 3000)
Larry Hastingsf5e987b2013-10-19 11:50:09 -0700131
132 Fractions and Decimals are also supported:
133
134 >>> from fractions import Fraction as F
135 >>> _sum([F(2, 3), F(7, 5), F(1, 4), F(5, 6)])
Benjamin Petersonab078e92016-07-13 21:13:29 -0700136 (<class 'fractions.Fraction'>, Fraction(63, 20), 4)
Larry Hastingsf5e987b2013-10-19 11:50:09 -0700137
138 >>> from decimal import Decimal as D
139 >>> data = [D("0.1375"), D("0.2108"), D("0.3061"), D("0.0419")]
140 >>> _sum(data)
Benjamin Petersonab078e92016-07-13 21:13:29 -0700141 (<class 'decimal.Decimal'>, Fraction(6963, 10000), 4)
Larry Hastingsf5e987b2013-10-19 11:50:09 -0700142
Nick Coghlan73afe2a2014-02-08 19:58:04 +1000143 Mixed types are currently treated as an error, except that int is
144 allowed.
Larry Hastingsf5e987b2013-10-19 11:50:09 -0700145 """
Steven D'Apranob28c3272015-12-01 19:59:53 +1100146 count = 0
Larry Hastingsf5e987b2013-10-19 11:50:09 -0700147 n, d = _exact_ratio(start)
Steven D'Apranob28c3272015-12-01 19:59:53 +1100148 partials = {d: n}
Larry Hastingsf5e987b2013-10-19 11:50:09 -0700149 partials_get = partials.get
Steven D'Apranob28c3272015-12-01 19:59:53 +1100150 T = _coerce(int, type(start))
151 for typ, values in groupby(data, type):
152 T = _coerce(T, typ) # or raise TypeError
153 for n,d in map(_exact_ratio, values):
154 count += 1
155 partials[d] = partials_get(d, 0) + n
Larry Hastingsf5e987b2013-10-19 11:50:09 -0700156 if None in partials:
Steven D'Apranob28c3272015-12-01 19:59:53 +1100157 # The sum will be a NAN or INF. We can ignore all the finite
158 # partials, and just look at this special one.
159 total = partials[None]
160 assert not _isfinite(total)
161 else:
162 # Sum all the partial sums using builtin sum.
163 # FIXME is this faster if we sum them in order of the denominator?
164 total = sum(Fraction(n, d) for d, n in sorted(partials.items()))
165 return (T, total, count)
Larry Hastingsf5e987b2013-10-19 11:50:09 -0700166
167
Steven D'Apranob28c3272015-12-01 19:59:53 +1100168def _isfinite(x):
169 try:
170 return x.is_finite() # Likely a Decimal.
171 except AttributeError:
172 return math.isfinite(x) # Coerces to float first.
173
174
175def _coerce(T, S):
176 """Coerce types T and S to a common type, or raise TypeError.
177
178 Coercion rules are currently an implementation detail. See the CoerceTest
179 test class in test_statistics for details.
180 """
181 # See http://bugs.python.org/issue24068.
182 assert T is not bool, "initial type T is bool"
183 # If the types are the same, no need to coerce anything. Put this
184 # first, so that the usual case (no coercion needed) happens as soon
185 # as possible.
186 if T is S: return T
187 # Mixed int & other coerce to the other type.
188 if S is int or S is bool: return T
189 if T is int: return S
190 # If one is a (strict) subclass of the other, coerce to the subclass.
191 if issubclass(S, T): return S
192 if issubclass(T, S): return T
193 # Ints coerce to the other type.
194 if issubclass(T, int): return S
195 if issubclass(S, int): return T
196 # Mixed fraction & float coerces to float (or float subclass).
197 if issubclass(T, Fraction) and issubclass(S, float):
198 return S
199 if issubclass(T, float) and issubclass(S, Fraction):
200 return T
201 # Any other combination is disallowed.
202 msg = "don't know how to coerce %s and %s"
203 raise TypeError(msg % (T.__name__, S.__name__))
Nick Coghlan73afe2a2014-02-08 19:58:04 +1000204
205
Larry Hastingsf5e987b2013-10-19 11:50:09 -0700206def _exact_ratio(x):
Steven D'Apranob28c3272015-12-01 19:59:53 +1100207 """Return Real number x to exact (numerator, denominator) pair.
Larry Hastingsf5e987b2013-10-19 11:50:09 -0700208
209 >>> _exact_ratio(0.25)
210 (1, 4)
211
212 x is expected to be an int, Fraction, Decimal or float.
213 """
214 try:
Steven D'Apranob28c3272015-12-01 19:59:53 +1100215 # Optimise the common case of floats. We expect that the most often
216 # used numeric type will be builtin floats, so try to make this as
217 # fast as possible.
Steven D'Aprano3b06e242016-05-05 03:54:29 +1000218 if type(x) is float or type(x) is Decimal:
Steven D'Apranob28c3272015-12-01 19:59:53 +1100219 return x.as_integer_ratio()
Larry Hastingsf5e987b2013-10-19 11:50:09 -0700220 try:
Steven D'Apranob28c3272015-12-01 19:59:53 +1100221 # x may be an int, Fraction, or Integral ABC.
Larry Hastingsf5e987b2013-10-19 11:50:09 -0700222 return (x.numerator, x.denominator)
223 except AttributeError:
Larry Hastingsf5e987b2013-10-19 11:50:09 -0700224 try:
Steven D'Aprano3b06e242016-05-05 03:54:29 +1000225 # x may be a float or Decimal subclass.
Larry Hastingsf5e987b2013-10-19 11:50:09 -0700226 return x.as_integer_ratio()
227 except AttributeError:
Steven D'Aprano3b06e242016-05-05 03:54:29 +1000228 # Just give up?
229 pass
Larry Hastingsf5e987b2013-10-19 11:50:09 -0700230 except (OverflowError, ValueError):
Steven D'Apranob28c3272015-12-01 19:59:53 +1100231 # float NAN or INF.
Steven D'Aprano3b06e242016-05-05 03:54:29 +1000232 assert not _isfinite(x)
Larry Hastingsf5e987b2013-10-19 11:50:09 -0700233 return (x, None)
Steven D'Apranob28c3272015-12-01 19:59:53 +1100234 msg = "can't convert type '{}' to numerator/denominator"
235 raise TypeError(msg.format(type(x).__name__))
Larry Hastingsf5e987b2013-10-19 11:50:09 -0700236
237
Steven D'Apranob28c3272015-12-01 19:59:53 +1100238def _convert(value, T):
239 """Convert value to given numeric type T."""
240 if type(value) is T:
241 # This covers the cases where T is Fraction, or where value is
242 # a NAN or INF (Decimal or float).
243 return value
244 if issubclass(T, int) and value.denominator != 1:
245 T = float
246 try:
247 # FIXME: what do we do if this overflows?
248 return T(value)
249 except TypeError:
250 if issubclass(T, Decimal):
251 return T(value.numerator)/T(value.denominator)
252 else:
253 raise
254
255
Steven D'Aprano3b06e242016-05-05 03:54:29 +1000256def _find_lteq(a, x):
257 'Locate the leftmost value exactly equal to x'
258 i = bisect_left(a, x)
259 if i != len(a) and a[i] == x:
260 return i
261 raise ValueError
262
263
264def _find_rteq(a, l, x):
265 'Locate the rightmost value exactly equal to x'
266 i = bisect_right(a, x, lo=l)
267 if i != (len(a)+1) and a[i-1] == x:
268 return i-1
269 raise ValueError
270
Steven D'Apranoa474afd2016-08-09 12:49:01 +1000271
272def _fail_neg(values, errmsg='negative value'):
273 """Iterate over values, failing if any are less than zero."""
274 for x in values:
275 if x < 0:
276 raise StatisticsError(errmsg)
277 yield x
278
279
Larry Hastingsf5e987b2013-10-19 11:50:09 -0700280# === Measures of central tendency (averages) ===
281
282def mean(data):
283 """Return the sample arithmetic mean of data.
284
285 >>> mean([1, 2, 3, 4, 4])
286 2.8
287
288 >>> from fractions import Fraction as F
289 >>> mean([F(3, 7), F(1, 21), F(5, 3), F(1, 3)])
290 Fraction(13, 21)
291
292 >>> from decimal import Decimal as D
293 >>> mean([D("0.5"), D("0.75"), D("0.625"), D("0.375")])
294 Decimal('0.5625')
295
296 If ``data`` is empty, StatisticsError will be raised.
297 """
298 if iter(data) is data:
299 data = list(data)
300 n = len(data)
301 if n < 1:
302 raise StatisticsError('mean requires at least one data point')
Steven D'Apranob28c3272015-12-01 19:59:53 +1100303 T, total, count = _sum(data)
304 assert count == n
305 return _convert(total/n, T)
Larry Hastingsf5e987b2013-10-19 11:50:09 -0700306
Raymond Hettinger47d99872019-02-21 15:06:29 -0800307def fmean(data):
308 """ Convert data to floats and compute the arithmetic mean.
309
310 This runs faster than the mean() function and it always returns a float.
311 The result is highly accurate but not as perfect as mean().
312 If the input dataset is empty, it raises a StatisticsError.
313
314 >>> fmean([3.5, 4.0, 5.25])
315 4.25
316
317 """
318 try:
319 n = len(data)
320 except TypeError:
321 # Handle iterators that do not define __len__().
322 n = 0
Raymond Hettinger6c01ebc2019-06-05 07:39:38 -0700323 def count(iterable):
Raymond Hettinger47d99872019-02-21 15:06:29 -0800324 nonlocal n
Raymond Hettinger6c01ebc2019-06-05 07:39:38 -0700325 for n, x in enumerate(iterable, start=1):
326 yield x
327 total = fsum(count(data))
Raymond Hettinger47d99872019-02-21 15:06:29 -0800328 else:
Raymond Hettingerfc06a192019-03-12 00:43:27 -0700329 total = fsum(data)
Raymond Hettinger47d99872019-02-21 15:06:29 -0800330 try:
331 return total / n
332 except ZeroDivisionError:
333 raise StatisticsError('fmean requires at least one data point') from None
Larry Hastingsf5e987b2013-10-19 11:50:09 -0700334
Raymond Hettinger6463ba32019-04-07 09:20:03 -0700335def geometric_mean(data):
336 """Convert data to floats and compute the geometric mean.
337
338 Raises a StatisticsError if the input dataset is empty,
339 if it contains a zero, or if it contains a negative value.
340
341 No special efforts are made to achieve exact results.
342 (However, this may change in the future.)
343
344 >>> round(geometric_mean([54, 24, 36]), 9)
345 36.0
346 """
347 try:
348 return exp(fmean(map(log, data)))
349 except ValueError:
350 raise StatisticsError('geometric mean requires a non-empty dataset '
351 ' containing positive numbers') from None
352
Steven D'Apranoa474afd2016-08-09 12:49:01 +1000353def harmonic_mean(data):
354 """Return the harmonic mean of data.
355
356 The harmonic mean, sometimes called the subcontrary mean, is the
357 reciprocal of the arithmetic mean of the reciprocals of the data,
358 and is often appropriate when averaging quantities which are rates
359 or ratios, for example speeds. Example:
360
361 Suppose an investor purchases an equal value of shares in each of
362 three companies, with P/E (price/earning) ratios of 2.5, 3 and 10.
363 What is the average P/E ratio for the investor's portfolio?
364
365 >>> harmonic_mean([2.5, 3, 10]) # For an equal investment portfolio.
366 3.6
367
368 Using the arithmetic mean would give an average of about 5.167, which
369 is too high.
370
371 If ``data`` is empty, or any element is less than zero,
372 ``harmonic_mean`` will raise ``StatisticsError``.
373 """
374 # For a justification for using harmonic mean for P/E ratios, see
375 # http://fixthepitch.pellucid.com/comps-analysis-the-missing-harmony-of-summary-statistics/
376 # http://papers.ssrn.com/sol3/papers.cfm?abstract_id=2621087
377 if iter(data) is data:
378 data = list(data)
379 errmsg = 'harmonic mean does not support negative values'
380 n = len(data)
381 if n < 1:
382 raise StatisticsError('harmonic_mean requires at least one data point')
383 elif n == 1:
384 x = data[0]
385 if isinstance(x, (numbers.Real, Decimal)):
386 if x < 0:
387 raise StatisticsError(errmsg)
388 return x
389 else:
390 raise TypeError('unsupported type')
391 try:
392 T, total, count = _sum(1/x for x in _fail_neg(data, errmsg))
393 except ZeroDivisionError:
394 return 0
395 assert count == n
396 return _convert(n/total, T)
397
398
Larry Hastingsf5e987b2013-10-19 11:50:09 -0700399# FIXME: investigate ways to calculate medians without sorting? Quickselect?
400def median(data):
401 """Return the median (middle value) of numeric data.
402
403 When the number of data points is odd, return the middle data point.
404 When the number of data points is even, the median is interpolated by
405 taking the average of the two middle values:
406
407 >>> median([1, 3, 5])
408 3
409 >>> median([1, 3, 5, 7])
410 4.0
411
412 """
413 data = sorted(data)
414 n = len(data)
415 if n == 0:
416 raise StatisticsError("no median for empty data")
417 if n%2 == 1:
418 return data[n//2]
419 else:
420 i = n//2
421 return (data[i - 1] + data[i])/2
422
423
424def median_low(data):
425 """Return the low median of numeric data.
426
427 When the number of data points is odd, the middle value is returned.
428 When it is even, the smaller of the two middle values is returned.
429
430 >>> median_low([1, 3, 5])
431 3
432 >>> median_low([1, 3, 5, 7])
433 3
434
435 """
436 data = sorted(data)
437 n = len(data)
438 if n == 0:
439 raise StatisticsError("no median for empty data")
440 if n%2 == 1:
441 return data[n//2]
442 else:
443 return data[n//2 - 1]
444
445
446def median_high(data):
447 """Return the high median of data.
448
449 When the number of data points is odd, the middle value is returned.
450 When it is even, the larger of the two middle values is returned.
451
452 >>> median_high([1, 3, 5])
453 3
454 >>> median_high([1, 3, 5, 7])
455 5
456
457 """
458 data = sorted(data)
459 n = len(data)
460 if n == 0:
461 raise StatisticsError("no median for empty data")
462 return data[n//2]
463
464
465def median_grouped(data, interval=1):
Zachary Waredf2660e2015-10-27 22:00:41 -0500466 """Return the 50th percentile (median) of grouped continuous data.
Larry Hastingsf5e987b2013-10-19 11:50:09 -0700467
468 >>> median_grouped([1, 2, 2, 3, 4, 4, 4, 4, 4, 5])
469 3.7
470 >>> median_grouped([52, 52, 53, 54])
471 52.5
472
473 This calculates the median as the 50th percentile, and should be
474 used when your data is continuous and grouped. In the above example,
475 the values 1, 2, 3, etc. actually represent the midpoint of classes
476 0.5-1.5, 1.5-2.5, 2.5-3.5, etc. The middle value falls somewhere in
477 class 3.5-4.5, and interpolation is used to estimate it.
478
479 Optional argument ``interval`` represents the class interval, and
480 defaults to 1. Changing the class interval naturally will change the
481 interpolated 50th percentile value:
482
483 >>> median_grouped([1, 3, 3, 5, 7], interval=1)
484 3.25
485 >>> median_grouped([1, 3, 3, 5, 7], interval=2)
486 3.5
487
488 This function does not check whether the data points are at least
489 ``interval`` apart.
490 """
491 data = sorted(data)
492 n = len(data)
493 if n == 0:
494 raise StatisticsError("no median for empty data")
495 elif n == 1:
496 return data[0]
497 # Find the value at the midpoint. Remember this corresponds to the
498 # centre of the class interval.
499 x = data[n//2]
500 for obj in (x, interval):
501 if isinstance(obj, (str, bytes)):
502 raise TypeError('expected number but got %r' % obj)
503 try:
504 L = x - interval/2 # The lower limit of the median interval.
505 except TypeError:
506 # Mixed type. For now we just coerce to float.
507 L = float(x) - float(interval)/2
Steven D'Aprano3b06e242016-05-05 03:54:29 +1000508
509 # Uses bisection search to search for x in data with log(n) time complexity
Martin Panterf1579822016-05-26 06:03:33 +0000510 # Find the position of leftmost occurrence of x in data
Steven D'Aprano3b06e242016-05-05 03:54:29 +1000511 l1 = _find_lteq(data, x)
Martin Panterf1579822016-05-26 06:03:33 +0000512 # Find the position of rightmost occurrence of x in data[l1...len(data)]
Steven D'Aprano3b06e242016-05-05 03:54:29 +1000513 # Assuming always l1 <= l2
514 l2 = _find_rteq(data, l1, x)
515 cf = l1
516 f = l2 - l1 + 1
Larry Hastingsf5e987b2013-10-19 11:50:09 -0700517 return L + interval*(n/2 - cf)/f
518
519
520def mode(data):
521 """Return the most common data point from discrete or nominal data.
522
523 ``mode`` assumes discrete data, and returns a single value. This is the
524 standard treatment of the mode as commonly taught in schools:
525
526 >>> mode([1, 1, 2, 3, 3, 3, 3, 4])
527 3
528
529 This also works with nominal (non-numeric) data:
530
531 >>> mode(["red", "blue", "blue", "red", "green", "red", "red"])
532 'red'
533
Raymond Hettingerfc06a192019-03-12 00:43:27 -0700534 If there are multiple modes, return the first one encountered.
535
536 >>> mode(['red', 'red', 'green', 'blue', 'blue'])
537 'red'
538
539 If *data* is empty, ``mode``, raises StatisticsError.
540
Larry Hastingsf5e987b2013-10-19 11:50:09 -0700541 """
Raymond Hettingerfc06a192019-03-12 00:43:27 -0700542 data = iter(data)
543 try:
544 return Counter(data).most_common(1)[0][0]
545 except IndexError:
546 raise StatisticsError('no mode for empty data') from None
547
548
549def multimode(data):
550 """ Return a list of the most frequently occurring values.
551
552 Will return more than one result if there are multiple modes
553 or an empty list if *data* is empty.
554
555 >>> multimode('aabbbbbbbbcc')
556 ['b']
557 >>> multimode('aabbbbccddddeeffffgg')
558 ['b', 'd', 'f']
559 >>> multimode('')
560 []
561
562 """
563 counts = Counter(iter(data)).most_common()
564 maxcount, mode_items = next(groupby(counts, key=itemgetter(1)), (0, []))
565 return list(map(itemgetter(0), mode_items))
Larry Hastingsf5e987b2013-10-19 11:50:09 -0700566
Raymond Hettingercba9f842019-06-02 21:07:43 -0700567# Notes on methods for computing quantiles
568# ----------------------------------------
569#
570# There is no one perfect way to compute quantiles. Here we offer
571# two methods that serve common needs. Most other packages
572# surveyed offered at least one or both of these two, making them
573# "standard" in the sense of "widely-adopted and reproducible".
574# They are also easy to explain, easy to compute manually, and have
575# straight-forward interpretations that aren't surprising.
576
577# The default method is known as "R6", "PERCENTILE.EXC", or "expected
578# value of rank order statistics". The alternative method is known as
579# "R7", "PERCENTILE.INC", or "mode of rank order statistics".
580
581# For sample data where there is a positive probability for values
582# beyond the range of the data, the R6 exclusive method is a
583# reasonable choice. Consider a random sample of nine values from a
584# population with a uniform distribution from 0.0 to 100.0. The
585# distribution of the third ranked sample point is described by
586# betavariate(alpha=3, beta=7) which has mode=0.250, median=0.286, and
587# mean=0.300. Only the latter (which corresponds with R6) gives the
588# desired cut point with 30% of the population falling below that
589# value, making it comparable to a result from an inv_cdf() function.
590
591# For describing population data where the end points are known to
592# be included in the data, the R7 inclusive method is a reasonable
593# choice. Instead of the mean, it uses the mode of the beta
594# distribution for the interior points. Per Hyndman & Fan, "One nice
595# property is that the vertices of Q7(p) divide the range into n - 1
596# intervals, and exactly 100p% of the intervals lie to the left of
597# Q7(p) and 100(1 - p)% of the intervals lie to the right of Q7(p)."
598
599# If the need arises, we could add method="median" for a median
600# unbiased, distribution-free alternative. Also if needed, the
601# distribution-free approaches could be augmented by adding
602# method='normal'. However, for now, the position is that fewer
603# options make for easier choices and that external packages can be
604# used for anything more advanced.
605
Raymond Hettinger9013ccf2019-04-23 00:06:35 -0700606def quantiles(dist, *, n=4, method='exclusive'):
607 '''Divide *dist* into *n* continuous intervals with equal probability.
608
609 Returns a list of (n - 1) cut points separating the intervals.
610
611 Set *n* to 4 for quartiles (the default). Set *n* to 10 for deciles.
612 Set *n* to 100 for percentiles which gives the 99 cuts points that
613 separate *dist* in to 100 equal sized groups.
614
615 The *dist* can be any iterable containing sample data or it can be
616 an instance of a class that defines an inv_cdf() method. For sample
617 data, the cut points are linearly interpolated between data points.
618
619 If *method* is set to *inclusive*, *dist* is treated as population
620 data. The minimum value is treated as the 0th percentile and the
621 maximum value is treated as the 100th percentile.
622 '''
623 # Possible future API extensions:
624 # quantiles(data, already_sorted=True)
625 # quantiles(data, cut_points=[0.02, 0.25, 0.50, 0.75, 0.98])
626 if n < 1:
627 raise StatisticsError('n must be at least 1')
628 if hasattr(dist, 'inv_cdf'):
629 return [dist.inv_cdf(i / n) for i in range(1, n)]
630 data = sorted(dist)
631 ld = len(data)
632 if ld < 2:
633 raise StatisticsError('must have at least two data points')
634 if method == 'inclusive':
635 m = ld - 1
636 result = []
637 for i in range(1, n):
638 j = i * m // n
639 delta = i*m - j*n
640 interpolated = (data[j] * (n - delta) + data[j+1] * delta) / n
641 result.append(interpolated)
642 return result
643 if method == 'exclusive':
644 m = ld + 1
645 result = []
646 for i in range(1, n):
647 j = i * m // n # rescale i to m/n
648 j = 1 if j < 1 else ld-1 if j > ld-1 else j # clamp to 1 .. ld-1
649 delta = i*m - j*n # exact integer math
650 interpolated = (data[j-1] * (n - delta) + data[j] * delta) / n
651 result.append(interpolated)
652 return result
653 raise ValueError(f'Unknown method: {method!r}')
Larry Hastingsf5e987b2013-10-19 11:50:09 -0700654
655# === Measures of spread ===
656
657# See http://mathworld.wolfram.com/Variance.html
658# http://mathworld.wolfram.com/SampleVariance.html
659# http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
660#
661# Under no circumstances use the so-called "computational formula for
662# variance", as that is only suitable for hand calculations with a small
663# amount of low-precision data. It has terrible numeric properties.
664#
665# See a comparison of three computational methods here:
666# http://www.johndcook.com/blog/2008/09/26/comparing-three-methods-of-computing-standard-deviation/
667
668def _ss(data, c=None):
669 """Return sum of square deviations of sequence data.
670
671 If ``c`` is None, the mean is calculated in one pass, and the deviations
672 from the mean are calculated in a second pass. Otherwise, deviations are
673 calculated from ``c`` as given. Use the second case with care, as it can
674 lead to garbage results.
675 """
676 if c is None:
677 c = mean(data)
Steven D'Apranob28c3272015-12-01 19:59:53 +1100678 T, total, count = _sum((x-c)**2 for x in data)
Larry Hastingsf5e987b2013-10-19 11:50:09 -0700679 # The following sum should mathematically equal zero, but due to rounding
680 # error may not.
Steven D'Apranob28c3272015-12-01 19:59:53 +1100681 U, total2, count2 = _sum((x-c) for x in data)
682 assert T == U and count == count2
683 total -= total2**2/len(data)
684 assert not total < 0, 'negative sum of square deviations: %f' % total
685 return (T, total)
Larry Hastingsf5e987b2013-10-19 11:50:09 -0700686
687
688def variance(data, xbar=None):
689 """Return the sample variance of data.
690
691 data should be an iterable of Real-valued numbers, with at least two
692 values. The optional argument xbar, if given, should be the mean of
693 the data. If it is missing or None, the mean is automatically calculated.
694
695 Use this function when your data is a sample from a population. To
696 calculate the variance from the entire population, see ``pvariance``.
697
698 Examples:
699
700 >>> data = [2.75, 1.75, 1.25, 0.25, 0.5, 1.25, 3.5]
701 >>> variance(data)
702 1.3720238095238095
703
704 If you have already calculated the mean of your data, you can pass it as
705 the optional second argument ``xbar`` to avoid recalculating it:
706
707 >>> m = mean(data)
708 >>> variance(data, m)
709 1.3720238095238095
710
711 This function does not check that ``xbar`` is actually the mean of
712 ``data``. Giving arbitrary values for ``xbar`` may lead to invalid or
713 impossible results.
714
715 Decimals and Fractions are supported:
716
717 >>> from decimal import Decimal as D
718 >>> variance([D("27.5"), D("30.25"), D("30.25"), D("34.5"), D("41.75")])
719 Decimal('31.01875')
720
721 >>> from fractions import Fraction as F
722 >>> variance([F(1, 6), F(1, 2), F(5, 3)])
723 Fraction(67, 108)
724
725 """
726 if iter(data) is data:
727 data = list(data)
728 n = len(data)
729 if n < 2:
730 raise StatisticsError('variance requires at least two data points')
Steven D'Apranob28c3272015-12-01 19:59:53 +1100731 T, ss = _ss(data, xbar)
732 return _convert(ss/(n-1), T)
Larry Hastingsf5e987b2013-10-19 11:50:09 -0700733
734
735def pvariance(data, mu=None):
736 """Return the population variance of ``data``.
737
738 data should be an iterable of Real-valued numbers, with at least one
739 value. The optional argument mu, if given, should be the mean of
740 the data. If it is missing or None, the mean is automatically calculated.
741
742 Use this function to calculate the variance from the entire population.
743 To estimate the variance from a sample, the ``variance`` function is
744 usually a better choice.
745
746 Examples:
747
748 >>> data = [0.0, 0.25, 0.25, 1.25, 1.5, 1.75, 2.75, 3.25]
749 >>> pvariance(data)
750 1.25
751
752 If you have already calculated the mean of the data, you can pass it as
753 the optional second argument to avoid recalculating it:
754
755 >>> mu = mean(data)
756 >>> pvariance(data, mu)
757 1.25
758
759 This function does not check that ``mu`` is actually the mean of ``data``.
760 Giving arbitrary values for ``mu`` may lead to invalid or impossible
761 results.
762
763 Decimals and Fractions are supported:
764
765 >>> from decimal import Decimal as D
766 >>> pvariance([D("27.5"), D("30.25"), D("30.25"), D("34.5"), D("41.75")])
767 Decimal('24.815')
768
769 >>> from fractions import Fraction as F
770 >>> pvariance([F(1, 4), F(5, 4), F(1, 2)])
771 Fraction(13, 72)
772
773 """
774 if iter(data) is data:
775 data = list(data)
776 n = len(data)
777 if n < 1:
778 raise StatisticsError('pvariance requires at least one data point')
Steven D'Apranob28c3272015-12-01 19:59:53 +1100779 T, ss = _ss(data, mu)
780 return _convert(ss/n, T)
Larry Hastingsf5e987b2013-10-19 11:50:09 -0700781
782
783def stdev(data, xbar=None):
784 """Return the square root of the sample variance.
785
786 See ``variance`` for arguments and other details.
787
788 >>> stdev([1.5, 2.5, 2.5, 2.75, 3.25, 4.75])
789 1.0810874155219827
790
791 """
792 var = variance(data, xbar)
793 try:
794 return var.sqrt()
795 except AttributeError:
796 return math.sqrt(var)
797
798
799def pstdev(data, mu=None):
800 """Return the square root of the population variance.
801
802 See ``pvariance`` for arguments and other details.
803
804 >>> pstdev([1.5, 2.5, 2.5, 2.75, 3.25, 4.75])
805 0.986893273527251
806
807 """
808 var = pvariance(data, mu)
809 try:
810 return var.sqrt()
811 except AttributeError:
812 return math.sqrt(var)
Raymond Hettinger11c79532019-02-23 14:44:07 -0800813
814## Normal Distribution #####################################################
815
816class NormalDist:
817 'Normal distribution of a random variable'
818 # https://en.wikipedia.org/wiki/Normal_distribution
819 # https://en.wikipedia.org/wiki/Variance#Properties
820
Raymond Hettingerd1e768a2019-03-25 13:01:13 -0700821 __slots__ = {'mu': 'Arithmetic mean of a normal distribution',
822 'sigma': 'Standard deviation of a normal distribution'}
Raymond Hettinger11c79532019-02-23 14:44:07 -0800823
824 def __init__(self, mu=0.0, sigma=1.0):
Raymond Hettinger5f1e8b42019-03-18 22:24:15 -0700825 'NormalDist where mu is the mean and sigma is the standard deviation.'
Raymond Hettinger11c79532019-02-23 14:44:07 -0800826 if sigma < 0.0:
827 raise StatisticsError('sigma must be non-negative')
828 self.mu = mu
829 self.sigma = sigma
830
831 @classmethod
832 def from_samples(cls, data):
Raymond Hettinger5f1e8b42019-03-18 22:24:15 -0700833 'Make a normal distribution instance from sample data.'
Raymond Hettinger11c79532019-02-23 14:44:07 -0800834 if not isinstance(data, (list, tuple)):
835 data = list(data)
836 xbar = fmean(data)
837 return cls(xbar, stdev(data, xbar))
838
Raymond Hettingerfb8c7d52019-04-23 01:46:18 -0700839 def samples(self, n, *, seed=None):
Raymond Hettinger5f1e8b42019-03-18 22:24:15 -0700840 'Generate *n* samples for a given mean and standard deviation.'
Raymond Hettinger11c79532019-02-23 14:44:07 -0800841 gauss = random.gauss if seed is None else random.Random(seed).gauss
842 mu, sigma = self.mu, self.sigma
843 return [gauss(mu, sigma) for i in range(n)]
844
845 def pdf(self, x):
Raymond Hettinger5f1e8b42019-03-18 22:24:15 -0700846 'Probability density function. P(x <= X < x+dx) / dx'
Raymond Hettinger11c79532019-02-23 14:44:07 -0800847 variance = self.sigma ** 2.0
848 if not variance:
849 raise StatisticsError('pdf() not defined when sigma is zero')
850 return exp((x - self.mu)**2.0 / (-2.0*variance)) / sqrt(tau * variance)
851
852 def cdf(self, x):
Raymond Hettinger5f1e8b42019-03-18 22:24:15 -0700853 'Cumulative distribution function. P(X <= x)'
Raymond Hettinger11c79532019-02-23 14:44:07 -0800854 if not self.sigma:
855 raise StatisticsError('cdf() not defined when sigma is zero')
856 return 0.5 * (1.0 + erf((x - self.mu) / (self.sigma * sqrt(2.0))))
857
Raymond Hettinger714c60d2019-03-18 20:17:14 -0700858 def inv_cdf(self, p):
Raymond Hettinger5f1e8b42019-03-18 22:24:15 -0700859 '''Inverse cumulative distribution function. x : P(X <= x) = p
Raymond Hettinger714c60d2019-03-18 20:17:14 -0700860
Raymond Hettinger5f1e8b42019-03-18 22:24:15 -0700861 Finds the value of the random variable such that the probability of the
862 variable being less than or equal to that value equals the given probability.
Raymond Hettinger714c60d2019-03-18 20:17:14 -0700863
Raymond Hettinger5f1e8b42019-03-18 22:24:15 -0700864 This function is also called the percent point function or quantile function.
Raymond Hettinger714c60d2019-03-18 20:17:14 -0700865 '''
866 if (p <= 0.0 or p >= 1.0):
867 raise StatisticsError('p must be in the range 0.0 < p < 1.0')
868 if self.sigma <= 0.0:
869 raise StatisticsError('cdf() not defined when sigma at or below zero')
870
871 # There is no closed-form solution to the inverse CDF for the normal
872 # distribution, so we use a rational approximation instead:
873 # Wichura, M.J. (1988). "Algorithm AS241: The Percentage Points of the
874 # Normal Distribution". Applied Statistics. Blackwell Publishing. 37
875 # (3): 477–484. doi:10.2307/2347330. JSTOR 2347330.
876
877 q = p - 0.5
878 if fabs(q) <= 0.425:
Raymond Hettinger714c60d2019-03-18 20:17:14 -0700879 r = 0.180625 - q * q
Raymond Hettingerfe138832019-03-19 14:29:13 -0700880 num = (((((((2.50908_09287_30122_6727e+3 * r +
881 3.34305_75583_58812_8105e+4) * r +
882 6.72657_70927_00870_0853e+4) * r +
883 4.59219_53931_54987_1457e+4) * r +
884 1.37316_93765_50946_1125e+4) * r +
885 1.97159_09503_06551_4427e+3) * r +
886 1.33141_66789_17843_7745e+2) * r +
887 3.38713_28727_96366_6080e+0) * q
888 den = (((((((5.22649_52788_52854_5610e+3 * r +
889 2.87290_85735_72194_2674e+4) * r +
890 3.93078_95800_09271_0610e+4) * r +
891 2.12137_94301_58659_5867e+4) * r +
892 5.39419_60214_24751_1077e+3) * r +
893 6.87187_00749_20579_0830e+2) * r +
894 4.23133_30701_60091_1252e+1) * r +
895 1.0)
Raymond Hettinger714c60d2019-03-18 20:17:14 -0700896 x = num / den
897 return self.mu + (x * self.sigma)
Raymond Hettinger714c60d2019-03-18 20:17:14 -0700898 r = p if q <= 0.0 else 1.0 - p
899 r = sqrt(-log(r))
900 if r <= 5.0:
Raymond Hettinger714c60d2019-03-18 20:17:14 -0700901 r = r - 1.6
Raymond Hettingerfe138832019-03-19 14:29:13 -0700902 num = (((((((7.74545_01427_83414_07640e-4 * r +
903 2.27238_44989_26918_45833e-2) * r +
904 2.41780_72517_74506_11770e-1) * r +
905 1.27045_82524_52368_38258e+0) * r +
906 3.64784_83247_63204_60504e+0) * r +
907 5.76949_72214_60691_40550e+0) * r +
908 4.63033_78461_56545_29590e+0) * r +
909 1.42343_71107_49683_57734e+0)
910 den = (((((((1.05075_00716_44416_84324e-9 * r +
911 5.47593_80849_95344_94600e-4) * r +
912 1.51986_66563_61645_71966e-2) * r +
913 1.48103_97642_74800_74590e-1) * r +
914 6.89767_33498_51000_04550e-1) * r +
915 1.67638_48301_83803_84940e+0) * r +
916 2.05319_16266_37758_82187e+0) * r +
917 1.0)
Raymond Hettinger52a594b2019-03-19 12:48:04 -0700918 else:
919 r = r - 5.0
Raymond Hettingerfe138832019-03-19 14:29:13 -0700920 num = (((((((2.01033_43992_92288_13265e-7 * r +
921 2.71155_55687_43487_57815e-5) * r +
922 1.24266_09473_88078_43860e-3) * r +
923 2.65321_89526_57612_30930e-2) * r +
924 2.96560_57182_85048_91230e-1) * r +
925 1.78482_65399_17291_33580e+0) * r +
926 5.46378_49111_64114_36990e+0) * r +
927 6.65790_46435_01103_77720e+0)
928 den = (((((((2.04426_31033_89939_78564e-15 * r +
929 1.42151_17583_16445_88870e-7) * r +
930 1.84631_83175_10054_68180e-5) * r +
931 7.86869_13114_56132_59100e-4) * r +
932 1.48753_61290_85061_48525e-2) * r +
933 1.36929_88092_27358_05310e-1) * r +
934 5.99832_20655_58879_37690e-1) * r +
935 1.0)
Raymond Hettinger714c60d2019-03-18 20:17:14 -0700936 x = num / den
937 if q < 0.0:
938 x = -x
939 return self.mu + (x * self.sigma)
940
Raymond Hettinger318d5372019-03-06 22:59:40 -0800941 def overlap(self, other):
942 '''Compute the overlapping coefficient (OVL) between two normal distributions.
943
944 Measures the agreement between two normal probability distributions.
945 Returns a value between 0.0 and 1.0 giving the overlapping area in
946 the two underlying probability density functions.
947
948 >>> N1 = NormalDist(2.4, 1.6)
949 >>> N2 = NormalDist(3.2, 2.0)
950 >>> N1.overlap(N2)
951 0.8035050657330205
Raymond Hettinger318d5372019-03-06 22:59:40 -0800952 '''
953 # See: "The overlapping coefficient as a measure of agreement between
954 # probability distributions and point estimation of the overlap of two
955 # normal densities" -- Henry F. Inman and Edwin L. Bradley Jr
956 # http://dx.doi.org/10.1080/03610928908830127
957 if not isinstance(other, NormalDist):
958 raise TypeError('Expected another NormalDist instance')
959 X, Y = self, other
960 if (Y.sigma, Y.mu) < (X.sigma, X.mu): # sort to assure commutativity
961 X, Y = Y, X
962 X_var, Y_var = X.variance, Y.variance
963 if not X_var or not Y_var:
964 raise StatisticsError('overlap() not defined when sigma is zero')
965 dv = Y_var - X_var
966 dm = fabs(Y.mu - X.mu)
967 if not dv:
Raymond Hettinger41f0b782019-03-14 02:25:26 -0700968 return 1.0 - erf(dm / (2.0 * X.sigma * sqrt(2.0)))
Raymond Hettinger318d5372019-03-06 22:59:40 -0800969 a = X.mu * Y_var - Y.mu * X_var
970 b = X.sigma * Y.sigma * sqrt(dm**2.0 + dv * log(Y_var / X_var))
971 x1 = (a + b) / dv
972 x2 = (a - b) / dv
973 return 1.0 - (fabs(Y.cdf(x1) - X.cdf(x1)) + fabs(Y.cdf(x2) - X.cdf(x2)))
974
Raymond Hettinger11c79532019-02-23 14:44:07 -0800975 @property
Raymond Hettinger9e456bc2019-02-24 11:44:55 -0800976 def mean(self):
Raymond Hettinger5f1e8b42019-03-18 22:24:15 -0700977 'Arithmetic mean of the normal distribution.'
Raymond Hettinger9e456bc2019-02-24 11:44:55 -0800978 return self.mu
979
980 @property
981 def stdev(self):
Raymond Hettinger5f1e8b42019-03-18 22:24:15 -0700982 'Standard deviation of the normal distribution.'
Raymond Hettinger9e456bc2019-02-24 11:44:55 -0800983 return self.sigma
984
985 @property
Raymond Hettinger11c79532019-02-23 14:44:07 -0800986 def variance(self):
Raymond Hettinger5f1e8b42019-03-18 22:24:15 -0700987 'Square of the standard deviation.'
Raymond Hettinger11c79532019-02-23 14:44:07 -0800988 return self.sigma ** 2.0
989
990 def __add__(x1, x2):
Raymond Hettinger5f1e8b42019-03-18 22:24:15 -0700991 '''Add a constant or another NormalDist instance.
992
993 If *other* is a constant, translate mu by the constant,
994 leaving sigma unchanged.
995
996 If *other* is a NormalDist, add both the means and the variances.
997 Mathematically, this works only if the two distributions are
998 independent or if they are jointly normally distributed.
999 '''
Raymond Hettinger11c79532019-02-23 14:44:07 -08001000 if isinstance(x2, NormalDist):
1001 return NormalDist(x1.mu + x2.mu, hypot(x1.sigma, x2.sigma))
1002 return NormalDist(x1.mu + x2, x1.sigma)
1003
1004 def __sub__(x1, x2):
Raymond Hettinger5f1e8b42019-03-18 22:24:15 -07001005 '''Subtract a constant or another NormalDist instance.
1006
1007 If *other* is a constant, translate by the constant mu,
1008 leaving sigma unchanged.
1009
1010 If *other* is a NormalDist, subtract the means and add the variances.
1011 Mathematically, this works only if the two distributions are
1012 independent or if they are jointly normally distributed.
1013 '''
Raymond Hettinger11c79532019-02-23 14:44:07 -08001014 if isinstance(x2, NormalDist):
1015 return NormalDist(x1.mu - x2.mu, hypot(x1.sigma, x2.sigma))
1016 return NormalDist(x1.mu - x2, x1.sigma)
1017
1018 def __mul__(x1, x2):
Raymond Hettinger5f1e8b42019-03-18 22:24:15 -07001019 '''Multiply both mu and sigma by a constant.
1020
1021 Used for rescaling, perhaps to change measurement units.
1022 Sigma is scaled with the absolute value of the constant.
1023 '''
Raymond Hettinger11c79532019-02-23 14:44:07 -08001024 return NormalDist(x1.mu * x2, x1.sigma * fabs(x2))
1025
1026 def __truediv__(x1, x2):
Raymond Hettinger5f1e8b42019-03-18 22:24:15 -07001027 '''Divide both mu and sigma by a constant.
1028
1029 Used for rescaling, perhaps to change measurement units.
1030 Sigma is scaled with the absolute value of the constant.
1031 '''
Raymond Hettinger11c79532019-02-23 14:44:07 -08001032 return NormalDist(x1.mu / x2, x1.sigma / fabs(x2))
1033
1034 def __pos__(x1):
Raymond Hettinger5f1e8b42019-03-18 22:24:15 -07001035 'Return a copy of the instance.'
Raymond Hettinger79fbcc52019-02-23 22:19:01 -08001036 return NormalDist(x1.mu, x1.sigma)
Raymond Hettinger11c79532019-02-23 14:44:07 -08001037
1038 def __neg__(x1):
Raymond Hettinger5f1e8b42019-03-18 22:24:15 -07001039 'Negates mu while keeping sigma the same.'
Raymond Hettinger11c79532019-02-23 14:44:07 -08001040 return NormalDist(-x1.mu, x1.sigma)
1041
1042 __radd__ = __add__
1043
1044 def __rsub__(x1, x2):
Raymond Hettinger5f1e8b42019-03-18 22:24:15 -07001045 'Subtract a NormalDist from a constant or another NormalDist.'
Raymond Hettinger11c79532019-02-23 14:44:07 -08001046 return -(x1 - x2)
1047
1048 __rmul__ = __mul__
1049
1050 def __eq__(x1, x2):
Raymond Hettinger5f1e8b42019-03-18 22:24:15 -07001051 'Two NormalDist objects are equal if their mu and sigma are both equal.'
Raymond Hettinger11c79532019-02-23 14:44:07 -08001052 if not isinstance(x2, NormalDist):
1053 return NotImplemented
1054 return (x1.mu, x2.sigma) == (x2.mu, x2.sigma)
1055
1056 def __repr__(self):
1057 return f'{type(self).__name__}(mu={self.mu!r}, sigma={self.sigma!r})'
1058
1059
1060if __name__ == '__main__':
1061
1062 # Show math operations computed analytically in comparsion
1063 # to a monte carlo simulation of the same operations
1064
1065 from math import isclose
1066 from operator import add, sub, mul, truediv
1067 from itertools import repeat
Raymond Hettingerfc06a192019-03-12 00:43:27 -07001068 import doctest
Raymond Hettinger11c79532019-02-23 14:44:07 -08001069
1070 g1 = NormalDist(10, 20)
1071 g2 = NormalDist(-5, 25)
1072
1073 # Test scaling by a constant
1074 assert (g1 * 5 / 5).mu == g1.mu
1075 assert (g1 * 5 / 5).sigma == g1.sigma
1076
1077 n = 100_000
1078 G1 = g1.samples(n)
1079 G2 = g2.samples(n)
1080
1081 for func in (add, sub):
1082 print(f'\nTest {func.__name__} with another NormalDist:')
1083 print(func(g1, g2))
1084 print(NormalDist.from_samples(map(func, G1, G2)))
1085
1086 const = 11
1087 for func in (add, sub, mul, truediv):
1088 print(f'\nTest {func.__name__} with a constant:')
1089 print(func(g1, const))
1090 print(NormalDist.from_samples(map(func, G1, repeat(const))))
1091
1092 const = 19
1093 for func in (add, sub, mul):
1094 print(f'\nTest constant with {func.__name__}:')
1095 print(func(const, g1))
1096 print(NormalDist.from_samples(map(func, repeat(const), G1)))
1097
1098 def assert_close(G1, G2):
1099 assert isclose(G1.mu, G1.mu, rel_tol=0.01), (G1, G2)
1100 assert isclose(G1.sigma, G2.sigma, rel_tol=0.01), (G1, G2)
1101
1102 X = NormalDist(-105, 73)
1103 Y = NormalDist(31, 47)
1104 s = 32.75
1105 n = 100_000
1106
1107 S = NormalDist.from_samples([x + s for x in X.samples(n)])
1108 assert_close(X + s, S)
1109
1110 S = NormalDist.from_samples([x - s for x in X.samples(n)])
1111 assert_close(X - s, S)
1112
1113 S = NormalDist.from_samples([x * s for x in X.samples(n)])
1114 assert_close(X * s, S)
1115
1116 S = NormalDist.from_samples([x / s for x in X.samples(n)])
1117 assert_close(X / s, S)
1118
1119 S = NormalDist.from_samples([x + y for x, y in zip(X.samples(n),
1120 Y.samples(n))])
1121 assert_close(X + Y, S)
1122
1123 S = NormalDist.from_samples([x - y for x, y in zip(X.samples(n),
1124 Y.samples(n))])
1125 assert_close(X - Y, S)
Raymond Hettingerfc06a192019-03-12 00:43:27 -07001126
1127 print(doctest.testmod())