blob: f09f7be354c2e30291c56a89dc5d9a7694105d8d [file] [log] [blame]
Larry Hastingsf5e987b2013-10-19 11:50:09 -07001"""
2Basic statistics module.
3
4This module provides functions for calculating statistics of data, including
5averages, variance, and standard deviation.
6
7Calculating averages
8--------------------
9
Raymond Hettinger9013ccf2019-04-23 00:06:35 -070010================== ==================================================
Larry Hastingsf5e987b2013-10-19 11:50:09 -070011Function Description
Raymond Hettinger9013ccf2019-04-23 00:06:35 -070012================== ==================================================
Larry Hastingsf5e987b2013-10-19 11:50:09 -070013mean Arithmetic mean (average) of data.
Raymond Hettinger72800482019-04-23 01:35:16 -070014fmean Fast, floating point arithmetic mean.
Raymond Hettinger6463ba32019-04-07 09:20:03 -070015geometric_mean Geometric mean of data.
Steven D'Apranoa474afd2016-08-09 12:49:01 +100016harmonic_mean Harmonic mean of data.
Larry Hastingsf5e987b2013-10-19 11:50:09 -070017median Median (middle value) of data.
18median_low Low median of data.
19median_high High median of data.
20median_grouped Median, or 50th percentile, of grouped data.
21mode Mode (most common value) of data.
Raymond Hettinger6463ba32019-04-07 09:20:03 -070022multimode List of modes (most common values of data).
Raymond Hettinger9013ccf2019-04-23 00:06:35 -070023quantiles Divide data into intervals with equal probability.
24================== ==================================================
Larry Hastingsf5e987b2013-10-19 11:50:09 -070025
26Calculate the arithmetic mean ("the average") of data:
27
28>>> mean([-1.0, 2.5, 3.25, 5.75])
292.625
30
31
32Calculate the standard median of discrete data:
33
34>>> median([2, 3, 4, 5])
353.5
36
37
38Calculate the median, or 50th percentile, of data grouped into class intervals
39centred on the data values provided. E.g. if your data points are rounded to
40the nearest whole number:
41
42>>> median_grouped([2, 2, 3, 3, 3, 4]) #doctest: +ELLIPSIS
432.8333333333...
44
45This should be interpreted in this way: you have two data points in the class
46interval 1.5-2.5, three data points in the class interval 2.5-3.5, and one in
47the class interval 3.5-4.5. The median of these data points is 2.8333...
48
49
50Calculating variability or spread
51---------------------------------
52
53================== =============================================
54Function Description
55================== =============================================
56pvariance Population variance of data.
57variance Sample variance of data.
58pstdev Population standard deviation of data.
59stdev Sample standard deviation of data.
60================== =============================================
61
62Calculate the standard deviation of sample data:
63
64>>> stdev([2.5, 3.25, 5.5, 11.25, 11.75]) #doctest: +ELLIPSIS
654.38961843444...
66
67If you have previously calculated the mean, you can pass it as the optional
68second argument to the four "spread" functions to avoid recalculating it:
69
70>>> data = [1, 2, 2, 4, 4, 4, 5, 6]
71>>> mu = mean(data)
72>>> pvariance(data, mu)
732.5
74
75
76Exceptions
77----------
78
79A single exception is defined: StatisticsError is a subclass of ValueError.
80
81"""
82
Raymond Hettinger9013ccf2019-04-23 00:06:35 -070083__all__ = [ 'StatisticsError', 'NormalDist', 'quantiles',
Larry Hastingsf5e987b2013-10-19 11:50:09 -070084 'pstdev', 'pvariance', 'stdev', 'variance',
85 'median', 'median_low', 'median_high', 'median_grouped',
Raymond Hettingerfc06a192019-03-12 00:43:27 -070086 'mean', 'mode', 'multimode', 'harmonic_mean', 'fmean',
Raymond Hettinger6463ba32019-04-07 09:20:03 -070087 'geometric_mean',
Larry Hastingsf5e987b2013-10-19 11:50:09 -070088 ]
89
Larry Hastingsf5e987b2013-10-19 11:50:09 -070090import math
Steven D'Apranoa474afd2016-08-09 12:49:01 +100091import numbers
Raymond Hettinger11c79532019-02-23 14:44:07 -080092import random
Larry Hastingsf5e987b2013-10-19 11:50:09 -070093
94from fractions import Fraction
95from decimal import Decimal
Victor Stinnerd6debb22017-03-27 16:05:26 +020096from itertools import groupby
Steven D'Aprano3b06e242016-05-05 03:54:29 +100097from bisect import bisect_left, bisect_right
Raymond Hettinger318d5372019-03-06 22:59:40 -080098from math import hypot, sqrt, fabs, exp, erf, tau, log, fsum
Raymond Hettingerfc06a192019-03-12 00:43:27 -070099from operator import itemgetter
100from collections import Counter
Larry Hastingsf5e987b2013-10-19 11:50:09 -0700101
102# === Exceptions ===
103
104class StatisticsError(ValueError):
105 pass
106
107
108# === Private utilities ===
109
110def _sum(data, start=0):
Steven D'Apranob28c3272015-12-01 19:59:53 +1100111 """_sum(data [, start]) -> (type, sum, count)
Larry Hastingsf5e987b2013-10-19 11:50:09 -0700112
Steven D'Apranob28c3272015-12-01 19:59:53 +1100113 Return a high-precision sum of the given numeric data as a fraction,
114 together with the type to be converted to and the count of items.
115
116 If optional argument ``start`` is given, it is added to the total.
117 If ``data`` is empty, ``start`` (defaulting to 0) is returned.
Larry Hastingsf5e987b2013-10-19 11:50:09 -0700118
119
120 Examples
121 --------
122
123 >>> _sum([3, 2.25, 4.5, -0.5, 1.0], 0.75)
Benjamin Petersonab078e92016-07-13 21:13:29 -0700124 (<class 'float'>, Fraction(11, 1), 5)
Larry Hastingsf5e987b2013-10-19 11:50:09 -0700125
126 Some sources of round-off error will be avoided:
127
Steven D'Apranoa474afd2016-08-09 12:49:01 +1000128 # Built-in sum returns zero.
129 >>> _sum([1e50, 1, -1e50] * 1000)
Benjamin Petersonab078e92016-07-13 21:13:29 -0700130 (<class 'float'>, Fraction(1000, 1), 3000)
Larry Hastingsf5e987b2013-10-19 11:50:09 -0700131
132 Fractions and Decimals are also supported:
133
134 >>> from fractions import Fraction as F
135 >>> _sum([F(2, 3), F(7, 5), F(1, 4), F(5, 6)])
Benjamin Petersonab078e92016-07-13 21:13:29 -0700136 (<class 'fractions.Fraction'>, Fraction(63, 20), 4)
Larry Hastingsf5e987b2013-10-19 11:50:09 -0700137
138 >>> from decimal import Decimal as D
139 >>> data = [D("0.1375"), D("0.2108"), D("0.3061"), D("0.0419")]
140 >>> _sum(data)
Benjamin Petersonab078e92016-07-13 21:13:29 -0700141 (<class 'decimal.Decimal'>, Fraction(6963, 10000), 4)
Larry Hastingsf5e987b2013-10-19 11:50:09 -0700142
Nick Coghlan73afe2a2014-02-08 19:58:04 +1000143 Mixed types are currently treated as an error, except that int is
144 allowed.
Larry Hastingsf5e987b2013-10-19 11:50:09 -0700145 """
Steven D'Apranob28c3272015-12-01 19:59:53 +1100146 count = 0
Larry Hastingsf5e987b2013-10-19 11:50:09 -0700147 n, d = _exact_ratio(start)
Steven D'Apranob28c3272015-12-01 19:59:53 +1100148 partials = {d: n}
Larry Hastingsf5e987b2013-10-19 11:50:09 -0700149 partials_get = partials.get
Steven D'Apranob28c3272015-12-01 19:59:53 +1100150 T = _coerce(int, type(start))
151 for typ, values in groupby(data, type):
152 T = _coerce(T, typ) # or raise TypeError
153 for n,d in map(_exact_ratio, values):
154 count += 1
155 partials[d] = partials_get(d, 0) + n
Larry Hastingsf5e987b2013-10-19 11:50:09 -0700156 if None in partials:
Steven D'Apranob28c3272015-12-01 19:59:53 +1100157 # The sum will be a NAN or INF. We can ignore all the finite
158 # partials, and just look at this special one.
159 total = partials[None]
160 assert not _isfinite(total)
161 else:
162 # Sum all the partial sums using builtin sum.
163 # FIXME is this faster if we sum them in order of the denominator?
164 total = sum(Fraction(n, d) for d, n in sorted(partials.items()))
165 return (T, total, count)
Larry Hastingsf5e987b2013-10-19 11:50:09 -0700166
167
Steven D'Apranob28c3272015-12-01 19:59:53 +1100168def _isfinite(x):
169 try:
170 return x.is_finite() # Likely a Decimal.
171 except AttributeError:
172 return math.isfinite(x) # Coerces to float first.
173
174
175def _coerce(T, S):
176 """Coerce types T and S to a common type, or raise TypeError.
177
178 Coercion rules are currently an implementation detail. See the CoerceTest
179 test class in test_statistics for details.
180 """
181 # See http://bugs.python.org/issue24068.
182 assert T is not bool, "initial type T is bool"
183 # If the types are the same, no need to coerce anything. Put this
184 # first, so that the usual case (no coercion needed) happens as soon
185 # as possible.
186 if T is S: return T
187 # Mixed int & other coerce to the other type.
188 if S is int or S is bool: return T
189 if T is int: return S
190 # If one is a (strict) subclass of the other, coerce to the subclass.
191 if issubclass(S, T): return S
192 if issubclass(T, S): return T
193 # Ints coerce to the other type.
194 if issubclass(T, int): return S
195 if issubclass(S, int): return T
196 # Mixed fraction & float coerces to float (or float subclass).
197 if issubclass(T, Fraction) and issubclass(S, float):
198 return S
199 if issubclass(T, float) and issubclass(S, Fraction):
200 return T
201 # Any other combination is disallowed.
202 msg = "don't know how to coerce %s and %s"
203 raise TypeError(msg % (T.__name__, S.__name__))
Nick Coghlan73afe2a2014-02-08 19:58:04 +1000204
205
Larry Hastingsf5e987b2013-10-19 11:50:09 -0700206def _exact_ratio(x):
Steven D'Apranob28c3272015-12-01 19:59:53 +1100207 """Return Real number x to exact (numerator, denominator) pair.
Larry Hastingsf5e987b2013-10-19 11:50:09 -0700208
209 >>> _exact_ratio(0.25)
210 (1, 4)
211
212 x is expected to be an int, Fraction, Decimal or float.
213 """
214 try:
Steven D'Apranob28c3272015-12-01 19:59:53 +1100215 # Optimise the common case of floats. We expect that the most often
216 # used numeric type will be builtin floats, so try to make this as
217 # fast as possible.
Steven D'Aprano3b06e242016-05-05 03:54:29 +1000218 if type(x) is float or type(x) is Decimal:
Steven D'Apranob28c3272015-12-01 19:59:53 +1100219 return x.as_integer_ratio()
Larry Hastingsf5e987b2013-10-19 11:50:09 -0700220 try:
Steven D'Apranob28c3272015-12-01 19:59:53 +1100221 # x may be an int, Fraction, or Integral ABC.
Larry Hastingsf5e987b2013-10-19 11:50:09 -0700222 return (x.numerator, x.denominator)
223 except AttributeError:
Larry Hastingsf5e987b2013-10-19 11:50:09 -0700224 try:
Steven D'Aprano3b06e242016-05-05 03:54:29 +1000225 # x may be a float or Decimal subclass.
Larry Hastingsf5e987b2013-10-19 11:50:09 -0700226 return x.as_integer_ratio()
227 except AttributeError:
Steven D'Aprano3b06e242016-05-05 03:54:29 +1000228 # Just give up?
229 pass
Larry Hastingsf5e987b2013-10-19 11:50:09 -0700230 except (OverflowError, ValueError):
Steven D'Apranob28c3272015-12-01 19:59:53 +1100231 # float NAN or INF.
Steven D'Aprano3b06e242016-05-05 03:54:29 +1000232 assert not _isfinite(x)
Larry Hastingsf5e987b2013-10-19 11:50:09 -0700233 return (x, None)
Steven D'Apranob28c3272015-12-01 19:59:53 +1100234 msg = "can't convert type '{}' to numerator/denominator"
235 raise TypeError(msg.format(type(x).__name__))
Larry Hastingsf5e987b2013-10-19 11:50:09 -0700236
237
Steven D'Apranob28c3272015-12-01 19:59:53 +1100238def _convert(value, T):
239 """Convert value to given numeric type T."""
240 if type(value) is T:
241 # This covers the cases where T is Fraction, or where value is
242 # a NAN or INF (Decimal or float).
243 return value
244 if issubclass(T, int) and value.denominator != 1:
245 T = float
246 try:
247 # FIXME: what do we do if this overflows?
248 return T(value)
249 except TypeError:
250 if issubclass(T, Decimal):
251 return T(value.numerator)/T(value.denominator)
252 else:
253 raise
254
255
Steven D'Aprano3b06e242016-05-05 03:54:29 +1000256def _find_lteq(a, x):
257 'Locate the leftmost value exactly equal to x'
258 i = bisect_left(a, x)
259 if i != len(a) and a[i] == x:
260 return i
261 raise ValueError
262
263
264def _find_rteq(a, l, x):
265 'Locate the rightmost value exactly equal to x'
266 i = bisect_right(a, x, lo=l)
267 if i != (len(a)+1) and a[i-1] == x:
268 return i-1
269 raise ValueError
270
Steven D'Apranoa474afd2016-08-09 12:49:01 +1000271
272def _fail_neg(values, errmsg='negative value'):
273 """Iterate over values, failing if any are less than zero."""
274 for x in values:
275 if x < 0:
276 raise StatisticsError(errmsg)
277 yield x
278
279
Larry Hastingsf5e987b2013-10-19 11:50:09 -0700280# === Measures of central tendency (averages) ===
281
282def mean(data):
283 """Return the sample arithmetic mean of data.
284
285 >>> mean([1, 2, 3, 4, 4])
286 2.8
287
288 >>> from fractions import Fraction as F
289 >>> mean([F(3, 7), F(1, 21), F(5, 3), F(1, 3)])
290 Fraction(13, 21)
291
292 >>> from decimal import Decimal as D
293 >>> mean([D("0.5"), D("0.75"), D("0.625"), D("0.375")])
294 Decimal('0.5625')
295
296 If ``data`` is empty, StatisticsError will be raised.
297 """
298 if iter(data) is data:
299 data = list(data)
300 n = len(data)
301 if n < 1:
302 raise StatisticsError('mean requires at least one data point')
Steven D'Apranob28c3272015-12-01 19:59:53 +1100303 T, total, count = _sum(data)
304 assert count == n
305 return _convert(total/n, T)
Larry Hastingsf5e987b2013-10-19 11:50:09 -0700306
Raymond Hettinger47d99872019-02-21 15:06:29 -0800307def fmean(data):
308 """ Convert data to floats and compute the arithmetic mean.
309
310 This runs faster than the mean() function and it always returns a float.
311 The result is highly accurate but not as perfect as mean().
312 If the input dataset is empty, it raises a StatisticsError.
313
314 >>> fmean([3.5, 4.0, 5.25])
315 4.25
316
317 """
318 try:
319 n = len(data)
320 except TypeError:
321 # Handle iterators that do not define __len__().
322 n = 0
Raymond Hettinger6c01ebc2019-06-05 07:39:38 -0700323 def count(iterable):
Raymond Hettinger47d99872019-02-21 15:06:29 -0800324 nonlocal n
Raymond Hettinger6c01ebc2019-06-05 07:39:38 -0700325 for n, x in enumerate(iterable, start=1):
326 yield x
327 total = fsum(count(data))
Raymond Hettinger47d99872019-02-21 15:06:29 -0800328 else:
Raymond Hettingerfc06a192019-03-12 00:43:27 -0700329 total = fsum(data)
Raymond Hettinger47d99872019-02-21 15:06:29 -0800330 try:
331 return total / n
332 except ZeroDivisionError:
333 raise StatisticsError('fmean requires at least one data point') from None
Larry Hastingsf5e987b2013-10-19 11:50:09 -0700334
Raymond Hettinger6463ba32019-04-07 09:20:03 -0700335def geometric_mean(data):
336 """Convert data to floats and compute the geometric mean.
337
338 Raises a StatisticsError if the input dataset is empty,
339 if it contains a zero, or if it contains a negative value.
340
341 No special efforts are made to achieve exact results.
342 (However, this may change in the future.)
343
344 >>> round(geometric_mean([54, 24, 36]), 9)
345 36.0
346 """
347 try:
348 return exp(fmean(map(log, data)))
349 except ValueError:
350 raise StatisticsError('geometric mean requires a non-empty dataset '
351 ' containing positive numbers') from None
352
Steven D'Apranoa474afd2016-08-09 12:49:01 +1000353def harmonic_mean(data):
354 """Return the harmonic mean of data.
355
356 The harmonic mean, sometimes called the subcontrary mean, is the
357 reciprocal of the arithmetic mean of the reciprocals of the data,
358 and is often appropriate when averaging quantities which are rates
359 or ratios, for example speeds. Example:
360
361 Suppose an investor purchases an equal value of shares in each of
362 three companies, with P/E (price/earning) ratios of 2.5, 3 and 10.
363 What is the average P/E ratio for the investor's portfolio?
364
365 >>> harmonic_mean([2.5, 3, 10]) # For an equal investment portfolio.
366 3.6
367
368 Using the arithmetic mean would give an average of about 5.167, which
369 is too high.
370
371 If ``data`` is empty, or any element is less than zero,
372 ``harmonic_mean`` will raise ``StatisticsError``.
373 """
374 # For a justification for using harmonic mean for P/E ratios, see
375 # http://fixthepitch.pellucid.com/comps-analysis-the-missing-harmony-of-summary-statistics/
376 # http://papers.ssrn.com/sol3/papers.cfm?abstract_id=2621087
377 if iter(data) is data:
378 data = list(data)
379 errmsg = 'harmonic mean does not support negative values'
380 n = len(data)
381 if n < 1:
382 raise StatisticsError('harmonic_mean requires at least one data point')
383 elif n == 1:
384 x = data[0]
385 if isinstance(x, (numbers.Real, Decimal)):
386 if x < 0:
387 raise StatisticsError(errmsg)
388 return x
389 else:
390 raise TypeError('unsupported type')
391 try:
392 T, total, count = _sum(1/x for x in _fail_neg(data, errmsg))
393 except ZeroDivisionError:
394 return 0
395 assert count == n
396 return _convert(n/total, T)
397
398
Larry Hastingsf5e987b2013-10-19 11:50:09 -0700399# FIXME: investigate ways to calculate medians without sorting? Quickselect?
400def median(data):
401 """Return the median (middle value) of numeric data.
402
403 When the number of data points is odd, return the middle data point.
404 When the number of data points is even, the median is interpolated by
405 taking the average of the two middle values:
406
407 >>> median([1, 3, 5])
408 3
409 >>> median([1, 3, 5, 7])
410 4.0
411
412 """
413 data = sorted(data)
414 n = len(data)
415 if n == 0:
416 raise StatisticsError("no median for empty data")
417 if n%2 == 1:
418 return data[n//2]
419 else:
420 i = n//2
421 return (data[i - 1] + data[i])/2
422
423
424def median_low(data):
425 """Return the low median of numeric data.
426
427 When the number of data points is odd, the middle value is returned.
428 When it is even, the smaller of the two middle values is returned.
429
430 >>> median_low([1, 3, 5])
431 3
432 >>> median_low([1, 3, 5, 7])
433 3
434
435 """
436 data = sorted(data)
437 n = len(data)
438 if n == 0:
439 raise StatisticsError("no median for empty data")
440 if n%2 == 1:
441 return data[n//2]
442 else:
443 return data[n//2 - 1]
444
445
446def median_high(data):
447 """Return the high median of data.
448
449 When the number of data points is odd, the middle value is returned.
450 When it is even, the larger of the two middle values is returned.
451
452 >>> median_high([1, 3, 5])
453 3
454 >>> median_high([1, 3, 5, 7])
455 5
456
457 """
458 data = sorted(data)
459 n = len(data)
460 if n == 0:
461 raise StatisticsError("no median for empty data")
462 return data[n//2]
463
464
465def median_grouped(data, interval=1):
Zachary Waredf2660e2015-10-27 22:00:41 -0500466 """Return the 50th percentile (median) of grouped continuous data.
Larry Hastingsf5e987b2013-10-19 11:50:09 -0700467
468 >>> median_grouped([1, 2, 2, 3, 4, 4, 4, 4, 4, 5])
469 3.7
470 >>> median_grouped([52, 52, 53, 54])
471 52.5
472
473 This calculates the median as the 50th percentile, and should be
474 used when your data is continuous and grouped. In the above example,
475 the values 1, 2, 3, etc. actually represent the midpoint of classes
476 0.5-1.5, 1.5-2.5, 2.5-3.5, etc. The middle value falls somewhere in
477 class 3.5-4.5, and interpolation is used to estimate it.
478
479 Optional argument ``interval`` represents the class interval, and
480 defaults to 1. Changing the class interval naturally will change the
481 interpolated 50th percentile value:
482
483 >>> median_grouped([1, 3, 3, 5, 7], interval=1)
484 3.25
485 >>> median_grouped([1, 3, 3, 5, 7], interval=2)
486 3.5
487
488 This function does not check whether the data points are at least
489 ``interval`` apart.
490 """
491 data = sorted(data)
492 n = len(data)
493 if n == 0:
494 raise StatisticsError("no median for empty data")
495 elif n == 1:
496 return data[0]
497 # Find the value at the midpoint. Remember this corresponds to the
498 # centre of the class interval.
499 x = data[n//2]
500 for obj in (x, interval):
501 if isinstance(obj, (str, bytes)):
502 raise TypeError('expected number but got %r' % obj)
503 try:
504 L = x - interval/2 # The lower limit of the median interval.
505 except TypeError:
506 # Mixed type. For now we just coerce to float.
507 L = float(x) - float(interval)/2
Steven D'Aprano3b06e242016-05-05 03:54:29 +1000508
509 # Uses bisection search to search for x in data with log(n) time complexity
Martin Panterf1579822016-05-26 06:03:33 +0000510 # Find the position of leftmost occurrence of x in data
Steven D'Aprano3b06e242016-05-05 03:54:29 +1000511 l1 = _find_lteq(data, x)
Martin Panterf1579822016-05-26 06:03:33 +0000512 # Find the position of rightmost occurrence of x in data[l1...len(data)]
Steven D'Aprano3b06e242016-05-05 03:54:29 +1000513 # Assuming always l1 <= l2
514 l2 = _find_rteq(data, l1, x)
515 cf = l1
516 f = l2 - l1 + 1
Larry Hastingsf5e987b2013-10-19 11:50:09 -0700517 return L + interval*(n/2 - cf)/f
518
519
520def mode(data):
521 """Return the most common data point from discrete or nominal data.
522
523 ``mode`` assumes discrete data, and returns a single value. This is the
524 standard treatment of the mode as commonly taught in schools:
525
526 >>> mode([1, 1, 2, 3, 3, 3, 3, 4])
527 3
528
529 This also works with nominal (non-numeric) data:
530
531 >>> mode(["red", "blue", "blue", "red", "green", "red", "red"])
532 'red'
533
Raymond Hettingerfc06a192019-03-12 00:43:27 -0700534 If there are multiple modes, return the first one encountered.
535
536 >>> mode(['red', 'red', 'green', 'blue', 'blue'])
537 'red'
538
539 If *data* is empty, ``mode``, raises StatisticsError.
540
Larry Hastingsf5e987b2013-10-19 11:50:09 -0700541 """
Raymond Hettingerfc06a192019-03-12 00:43:27 -0700542 data = iter(data)
543 try:
544 return Counter(data).most_common(1)[0][0]
545 except IndexError:
546 raise StatisticsError('no mode for empty data') from None
547
548
549def multimode(data):
550 """ Return a list of the most frequently occurring values.
551
552 Will return more than one result if there are multiple modes
553 or an empty list if *data* is empty.
554
555 >>> multimode('aabbbbbbbbcc')
556 ['b']
557 >>> multimode('aabbbbccddddeeffffgg')
558 ['b', 'd', 'f']
559 >>> multimode('')
560 []
561
562 """
563 counts = Counter(iter(data)).most_common()
564 maxcount, mode_items = next(groupby(counts, key=itemgetter(1)), (0, []))
565 return list(map(itemgetter(0), mode_items))
Larry Hastingsf5e987b2013-10-19 11:50:09 -0700566
Raymond Hettingercba9f842019-06-02 21:07:43 -0700567# Notes on methods for computing quantiles
568# ----------------------------------------
569#
570# There is no one perfect way to compute quantiles. Here we offer
571# two methods that serve common needs. Most other packages
572# surveyed offered at least one or both of these two, making them
573# "standard" in the sense of "widely-adopted and reproducible".
574# They are also easy to explain, easy to compute manually, and have
575# straight-forward interpretations that aren't surprising.
576
577# The default method is known as "R6", "PERCENTILE.EXC", or "expected
578# value of rank order statistics". The alternative method is known as
579# "R7", "PERCENTILE.INC", or "mode of rank order statistics".
580
581# For sample data where there is a positive probability for values
582# beyond the range of the data, the R6 exclusive method is a
583# reasonable choice. Consider a random sample of nine values from a
584# population with a uniform distribution from 0.0 to 100.0. The
585# distribution of the third ranked sample point is described by
586# betavariate(alpha=3, beta=7) which has mode=0.250, median=0.286, and
587# mean=0.300. Only the latter (which corresponds with R6) gives the
588# desired cut point with 30% of the population falling below that
589# value, making it comparable to a result from an inv_cdf() function.
590
591# For describing population data where the end points are known to
592# be included in the data, the R7 inclusive method is a reasonable
593# choice. Instead of the mean, it uses the mode of the beta
594# distribution for the interior points. Per Hyndman & Fan, "One nice
595# property is that the vertices of Q7(p) divide the range into n - 1
596# intervals, and exactly 100p% of the intervals lie to the left of
597# Q7(p) and 100(1 - p)% of the intervals lie to the right of Q7(p)."
598
Raymond Hettingereed5e9a2019-07-19 01:57:22 -0700599# If needed, other methods could be added. However, for now, the
600# position is that fewer options make for easier choices and that
601# external packages can be used for anything more advanced.
Raymond Hettingercba9f842019-06-02 21:07:43 -0700602
Raymond Hettinger17911282019-06-25 04:39:22 +0200603def quantiles(dist, /, *, n=4, method='exclusive'):
Raymond Hettinger9013ccf2019-04-23 00:06:35 -0700604 '''Divide *dist* into *n* continuous intervals with equal probability.
605
606 Returns a list of (n - 1) cut points separating the intervals.
607
608 Set *n* to 4 for quartiles (the default). Set *n* to 10 for deciles.
609 Set *n* to 100 for percentiles which gives the 99 cuts points that
610 separate *dist* in to 100 equal sized groups.
611
612 The *dist* can be any iterable containing sample data or it can be
613 an instance of a class that defines an inv_cdf() method. For sample
614 data, the cut points are linearly interpolated between data points.
615
616 If *method* is set to *inclusive*, *dist* is treated as population
617 data. The minimum value is treated as the 0th percentile and the
618 maximum value is treated as the 100th percentile.
619 '''
Raymond Hettinger9013ccf2019-04-23 00:06:35 -0700620 if n < 1:
621 raise StatisticsError('n must be at least 1')
622 if hasattr(dist, 'inv_cdf'):
623 return [dist.inv_cdf(i / n) for i in range(1, n)]
624 data = sorted(dist)
625 ld = len(data)
626 if ld < 2:
627 raise StatisticsError('must have at least two data points')
628 if method == 'inclusive':
629 m = ld - 1
630 result = []
631 for i in range(1, n):
632 j = i * m // n
633 delta = i*m - j*n
634 interpolated = (data[j] * (n - delta) + data[j+1] * delta) / n
635 result.append(interpolated)
636 return result
637 if method == 'exclusive':
638 m = ld + 1
639 result = []
640 for i in range(1, n):
641 j = i * m // n # rescale i to m/n
642 j = 1 if j < 1 else ld-1 if j > ld-1 else j # clamp to 1 .. ld-1
643 delta = i*m - j*n # exact integer math
644 interpolated = (data[j-1] * (n - delta) + data[j] * delta) / n
645 result.append(interpolated)
646 return result
647 raise ValueError(f'Unknown method: {method!r}')
Larry Hastingsf5e987b2013-10-19 11:50:09 -0700648
649# === Measures of spread ===
650
651# See http://mathworld.wolfram.com/Variance.html
652# http://mathworld.wolfram.com/SampleVariance.html
653# http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
654#
655# Under no circumstances use the so-called "computational formula for
656# variance", as that is only suitable for hand calculations with a small
657# amount of low-precision data. It has terrible numeric properties.
658#
659# See a comparison of three computational methods here:
660# http://www.johndcook.com/blog/2008/09/26/comparing-three-methods-of-computing-standard-deviation/
661
662def _ss(data, c=None):
663 """Return sum of square deviations of sequence data.
664
665 If ``c`` is None, the mean is calculated in one pass, and the deviations
666 from the mean are calculated in a second pass. Otherwise, deviations are
667 calculated from ``c`` as given. Use the second case with care, as it can
668 lead to garbage results.
669 """
670 if c is None:
671 c = mean(data)
Steven D'Apranob28c3272015-12-01 19:59:53 +1100672 T, total, count = _sum((x-c)**2 for x in data)
Larry Hastingsf5e987b2013-10-19 11:50:09 -0700673 # The following sum should mathematically equal zero, but due to rounding
674 # error may not.
Steven D'Apranob28c3272015-12-01 19:59:53 +1100675 U, total2, count2 = _sum((x-c) for x in data)
676 assert T == U and count == count2
677 total -= total2**2/len(data)
678 assert not total < 0, 'negative sum of square deviations: %f' % total
679 return (T, total)
Larry Hastingsf5e987b2013-10-19 11:50:09 -0700680
681
682def variance(data, xbar=None):
683 """Return the sample variance of data.
684
685 data should be an iterable of Real-valued numbers, with at least two
686 values. The optional argument xbar, if given, should be the mean of
687 the data. If it is missing or None, the mean is automatically calculated.
688
689 Use this function when your data is a sample from a population. To
690 calculate the variance from the entire population, see ``pvariance``.
691
692 Examples:
693
694 >>> data = [2.75, 1.75, 1.25, 0.25, 0.5, 1.25, 3.5]
695 >>> variance(data)
696 1.3720238095238095
697
698 If you have already calculated the mean of your data, you can pass it as
699 the optional second argument ``xbar`` to avoid recalculating it:
700
701 >>> m = mean(data)
702 >>> variance(data, m)
703 1.3720238095238095
704
705 This function does not check that ``xbar`` is actually the mean of
706 ``data``. Giving arbitrary values for ``xbar`` may lead to invalid or
707 impossible results.
708
709 Decimals and Fractions are supported:
710
711 >>> from decimal import Decimal as D
712 >>> variance([D("27.5"), D("30.25"), D("30.25"), D("34.5"), D("41.75")])
713 Decimal('31.01875')
714
715 >>> from fractions import Fraction as F
716 >>> variance([F(1, 6), F(1, 2), F(5, 3)])
717 Fraction(67, 108)
718
719 """
720 if iter(data) is data:
721 data = list(data)
722 n = len(data)
723 if n < 2:
724 raise StatisticsError('variance requires at least two data points')
Steven D'Apranob28c3272015-12-01 19:59:53 +1100725 T, ss = _ss(data, xbar)
726 return _convert(ss/(n-1), T)
Larry Hastingsf5e987b2013-10-19 11:50:09 -0700727
728
729def pvariance(data, mu=None):
730 """Return the population variance of ``data``.
731
732 data should be an iterable of Real-valued numbers, with at least one
733 value. The optional argument mu, if given, should be the mean of
734 the data. If it is missing or None, the mean is automatically calculated.
735
736 Use this function to calculate the variance from the entire population.
737 To estimate the variance from a sample, the ``variance`` function is
738 usually a better choice.
739
740 Examples:
741
742 >>> data = [0.0, 0.25, 0.25, 1.25, 1.5, 1.75, 2.75, 3.25]
743 >>> pvariance(data)
744 1.25
745
746 If you have already calculated the mean of the data, you can pass it as
747 the optional second argument to avoid recalculating it:
748
749 >>> mu = mean(data)
750 >>> pvariance(data, mu)
751 1.25
752
753 This function does not check that ``mu`` is actually the mean of ``data``.
754 Giving arbitrary values for ``mu`` may lead to invalid or impossible
755 results.
756
757 Decimals and Fractions are supported:
758
759 >>> from decimal import Decimal as D
760 >>> pvariance([D("27.5"), D("30.25"), D("30.25"), D("34.5"), D("41.75")])
761 Decimal('24.815')
762
763 >>> from fractions import Fraction as F
764 >>> pvariance([F(1, 4), F(5, 4), F(1, 2)])
765 Fraction(13, 72)
766
767 """
768 if iter(data) is data:
769 data = list(data)
770 n = len(data)
771 if n < 1:
772 raise StatisticsError('pvariance requires at least one data point')
Steven D'Apranob28c3272015-12-01 19:59:53 +1100773 T, ss = _ss(data, mu)
774 return _convert(ss/n, T)
Larry Hastingsf5e987b2013-10-19 11:50:09 -0700775
776
777def stdev(data, xbar=None):
778 """Return the square root of the sample variance.
779
780 See ``variance`` for arguments and other details.
781
782 >>> stdev([1.5, 2.5, 2.5, 2.75, 3.25, 4.75])
783 1.0810874155219827
784
785 """
786 var = variance(data, xbar)
787 try:
788 return var.sqrt()
789 except AttributeError:
790 return math.sqrt(var)
791
792
793def pstdev(data, mu=None):
794 """Return the square root of the population variance.
795
796 See ``pvariance`` for arguments and other details.
797
798 >>> pstdev([1.5, 2.5, 2.5, 2.75, 3.25, 4.75])
799 0.986893273527251
800
801 """
802 var = pvariance(data, mu)
803 try:
804 return var.sqrt()
805 except AttributeError:
806 return math.sqrt(var)
Raymond Hettinger11c79532019-02-23 14:44:07 -0800807
808## Normal Distribution #####################################################
809
810class NormalDist:
811 'Normal distribution of a random variable'
812 # https://en.wikipedia.org/wiki/Normal_distribution
813 # https://en.wikipedia.org/wiki/Variance#Properties
814
Raymond Hettingerd1e768a2019-03-25 13:01:13 -0700815 __slots__ = {'mu': 'Arithmetic mean of a normal distribution',
816 'sigma': 'Standard deviation of a normal distribution'}
Raymond Hettinger11c79532019-02-23 14:44:07 -0800817
818 def __init__(self, mu=0.0, sigma=1.0):
Raymond Hettinger5f1e8b42019-03-18 22:24:15 -0700819 'NormalDist where mu is the mean and sigma is the standard deviation.'
Raymond Hettinger11c79532019-02-23 14:44:07 -0800820 if sigma < 0.0:
821 raise StatisticsError('sigma must be non-negative')
822 self.mu = mu
823 self.sigma = sigma
824
825 @classmethod
826 def from_samples(cls, data):
Raymond Hettinger5f1e8b42019-03-18 22:24:15 -0700827 'Make a normal distribution instance from sample data.'
Raymond Hettinger11c79532019-02-23 14:44:07 -0800828 if not isinstance(data, (list, tuple)):
829 data = list(data)
830 xbar = fmean(data)
831 return cls(xbar, stdev(data, xbar))
832
Raymond Hettingerfb8c7d52019-04-23 01:46:18 -0700833 def samples(self, n, *, seed=None):
Raymond Hettinger5f1e8b42019-03-18 22:24:15 -0700834 'Generate *n* samples for a given mean and standard deviation.'
Raymond Hettinger11c79532019-02-23 14:44:07 -0800835 gauss = random.gauss if seed is None else random.Random(seed).gauss
836 mu, sigma = self.mu, self.sigma
837 return [gauss(mu, sigma) for i in range(n)]
838
839 def pdf(self, x):
Raymond Hettinger5f1e8b42019-03-18 22:24:15 -0700840 'Probability density function. P(x <= X < x+dx) / dx'
Raymond Hettinger11c79532019-02-23 14:44:07 -0800841 variance = self.sigma ** 2.0
842 if not variance:
843 raise StatisticsError('pdf() not defined when sigma is zero')
844 return exp((x - self.mu)**2.0 / (-2.0*variance)) / sqrt(tau * variance)
845
846 def cdf(self, x):
Raymond Hettinger5f1e8b42019-03-18 22:24:15 -0700847 'Cumulative distribution function. P(X <= x)'
Raymond Hettinger11c79532019-02-23 14:44:07 -0800848 if not self.sigma:
849 raise StatisticsError('cdf() not defined when sigma is zero')
850 return 0.5 * (1.0 + erf((x - self.mu) / (self.sigma * sqrt(2.0))))
851
Raymond Hettinger714c60d2019-03-18 20:17:14 -0700852 def inv_cdf(self, p):
Raymond Hettinger5f1e8b42019-03-18 22:24:15 -0700853 '''Inverse cumulative distribution function. x : P(X <= x) = p
Raymond Hettinger714c60d2019-03-18 20:17:14 -0700854
Raymond Hettinger5f1e8b42019-03-18 22:24:15 -0700855 Finds the value of the random variable such that the probability of the
856 variable being less than or equal to that value equals the given probability.
Raymond Hettinger714c60d2019-03-18 20:17:14 -0700857
Raymond Hettinger5f1e8b42019-03-18 22:24:15 -0700858 This function is also called the percent point function or quantile function.
Raymond Hettinger714c60d2019-03-18 20:17:14 -0700859 '''
860 if (p <= 0.0 or p >= 1.0):
861 raise StatisticsError('p must be in the range 0.0 < p < 1.0')
862 if self.sigma <= 0.0:
863 raise StatisticsError('cdf() not defined when sigma at or below zero')
864
865 # There is no closed-form solution to the inverse CDF for the normal
866 # distribution, so we use a rational approximation instead:
867 # Wichura, M.J. (1988). "Algorithm AS241: The Percentage Points of the
868 # Normal Distribution". Applied Statistics. Blackwell Publishing. 37
869 # (3): 477–484. doi:10.2307/2347330. JSTOR 2347330.
870
871 q = p - 0.5
872 if fabs(q) <= 0.425:
Raymond Hettinger714c60d2019-03-18 20:17:14 -0700873 r = 0.180625 - q * q
Raymond Hettingerfe138832019-03-19 14:29:13 -0700874 num = (((((((2.50908_09287_30122_6727e+3 * r +
875 3.34305_75583_58812_8105e+4) * r +
876 6.72657_70927_00870_0853e+4) * r +
877 4.59219_53931_54987_1457e+4) * r +
878 1.37316_93765_50946_1125e+4) * r +
879 1.97159_09503_06551_4427e+3) * r +
880 1.33141_66789_17843_7745e+2) * r +
881 3.38713_28727_96366_6080e+0) * q
882 den = (((((((5.22649_52788_52854_5610e+3 * r +
883 2.87290_85735_72194_2674e+4) * r +
884 3.93078_95800_09271_0610e+4) * r +
885 2.12137_94301_58659_5867e+4) * r +
886 5.39419_60214_24751_1077e+3) * r +
887 6.87187_00749_20579_0830e+2) * r +
888 4.23133_30701_60091_1252e+1) * r +
889 1.0)
Raymond Hettinger714c60d2019-03-18 20:17:14 -0700890 x = num / den
891 return self.mu + (x * self.sigma)
Raymond Hettinger714c60d2019-03-18 20:17:14 -0700892 r = p if q <= 0.0 else 1.0 - p
893 r = sqrt(-log(r))
894 if r <= 5.0:
Raymond Hettinger714c60d2019-03-18 20:17:14 -0700895 r = r - 1.6
Raymond Hettingerfe138832019-03-19 14:29:13 -0700896 num = (((((((7.74545_01427_83414_07640e-4 * r +
897 2.27238_44989_26918_45833e-2) * r +
898 2.41780_72517_74506_11770e-1) * r +
899 1.27045_82524_52368_38258e+0) * r +
900 3.64784_83247_63204_60504e+0) * r +
901 5.76949_72214_60691_40550e+0) * r +
902 4.63033_78461_56545_29590e+0) * r +
903 1.42343_71107_49683_57734e+0)
904 den = (((((((1.05075_00716_44416_84324e-9 * r +
905 5.47593_80849_95344_94600e-4) * r +
906 1.51986_66563_61645_71966e-2) * r +
907 1.48103_97642_74800_74590e-1) * r +
908 6.89767_33498_51000_04550e-1) * r +
909 1.67638_48301_83803_84940e+0) * r +
910 2.05319_16266_37758_82187e+0) * r +
911 1.0)
Raymond Hettinger52a594b2019-03-19 12:48:04 -0700912 else:
913 r = r - 5.0
Raymond Hettingerfe138832019-03-19 14:29:13 -0700914 num = (((((((2.01033_43992_92288_13265e-7 * r +
915 2.71155_55687_43487_57815e-5) * r +
916 1.24266_09473_88078_43860e-3) * r +
917 2.65321_89526_57612_30930e-2) * r +
918 2.96560_57182_85048_91230e-1) * r +
919 1.78482_65399_17291_33580e+0) * r +
920 5.46378_49111_64114_36990e+0) * r +
921 6.65790_46435_01103_77720e+0)
922 den = (((((((2.04426_31033_89939_78564e-15 * r +
923 1.42151_17583_16445_88870e-7) * r +
924 1.84631_83175_10054_68180e-5) * r +
925 7.86869_13114_56132_59100e-4) * r +
926 1.48753_61290_85061_48525e-2) * r +
927 1.36929_88092_27358_05310e-1) * r +
928 5.99832_20655_58879_37690e-1) * r +
929 1.0)
Raymond Hettinger714c60d2019-03-18 20:17:14 -0700930 x = num / den
931 if q < 0.0:
932 x = -x
933 return self.mu + (x * self.sigma)
934
Raymond Hettinger318d5372019-03-06 22:59:40 -0800935 def overlap(self, other):
936 '''Compute the overlapping coefficient (OVL) between two normal distributions.
937
938 Measures the agreement between two normal probability distributions.
939 Returns a value between 0.0 and 1.0 giving the overlapping area in
940 the two underlying probability density functions.
941
942 >>> N1 = NormalDist(2.4, 1.6)
943 >>> N2 = NormalDist(3.2, 2.0)
944 >>> N1.overlap(N2)
945 0.8035050657330205
Raymond Hettinger318d5372019-03-06 22:59:40 -0800946 '''
947 # See: "The overlapping coefficient as a measure of agreement between
948 # probability distributions and point estimation of the overlap of two
949 # normal densities" -- Henry F. Inman and Edwin L. Bradley Jr
950 # http://dx.doi.org/10.1080/03610928908830127
951 if not isinstance(other, NormalDist):
952 raise TypeError('Expected another NormalDist instance')
953 X, Y = self, other
954 if (Y.sigma, Y.mu) < (X.sigma, X.mu): # sort to assure commutativity
955 X, Y = Y, X
956 X_var, Y_var = X.variance, Y.variance
957 if not X_var or not Y_var:
958 raise StatisticsError('overlap() not defined when sigma is zero')
959 dv = Y_var - X_var
960 dm = fabs(Y.mu - X.mu)
961 if not dv:
Raymond Hettinger41f0b782019-03-14 02:25:26 -0700962 return 1.0 - erf(dm / (2.0 * X.sigma * sqrt(2.0)))
Raymond Hettinger318d5372019-03-06 22:59:40 -0800963 a = X.mu * Y_var - Y.mu * X_var
964 b = X.sigma * Y.sigma * sqrt(dm**2.0 + dv * log(Y_var / X_var))
965 x1 = (a + b) / dv
966 x2 = (a - b) / dv
967 return 1.0 - (fabs(Y.cdf(x1) - X.cdf(x1)) + fabs(Y.cdf(x2) - X.cdf(x2)))
968
Raymond Hettinger11c79532019-02-23 14:44:07 -0800969 @property
Raymond Hettinger9e456bc2019-02-24 11:44:55 -0800970 def mean(self):
Raymond Hettinger5f1e8b42019-03-18 22:24:15 -0700971 'Arithmetic mean of the normal distribution.'
Raymond Hettinger9e456bc2019-02-24 11:44:55 -0800972 return self.mu
973
974 @property
975 def stdev(self):
Raymond Hettinger5f1e8b42019-03-18 22:24:15 -0700976 'Standard deviation of the normal distribution.'
Raymond Hettinger9e456bc2019-02-24 11:44:55 -0800977 return self.sigma
978
979 @property
Raymond Hettinger11c79532019-02-23 14:44:07 -0800980 def variance(self):
Raymond Hettinger5f1e8b42019-03-18 22:24:15 -0700981 'Square of the standard deviation.'
Raymond Hettinger11c79532019-02-23 14:44:07 -0800982 return self.sigma ** 2.0
983
984 def __add__(x1, x2):
Raymond Hettinger5f1e8b42019-03-18 22:24:15 -0700985 '''Add a constant or another NormalDist instance.
986
987 If *other* is a constant, translate mu by the constant,
988 leaving sigma unchanged.
989
990 If *other* is a NormalDist, add both the means and the variances.
991 Mathematically, this works only if the two distributions are
992 independent or if they are jointly normally distributed.
993 '''
Raymond Hettinger11c79532019-02-23 14:44:07 -0800994 if isinstance(x2, NormalDist):
995 return NormalDist(x1.mu + x2.mu, hypot(x1.sigma, x2.sigma))
996 return NormalDist(x1.mu + x2, x1.sigma)
997
998 def __sub__(x1, x2):
Raymond Hettinger5f1e8b42019-03-18 22:24:15 -0700999 '''Subtract a constant or another NormalDist instance.
1000
1001 If *other* is a constant, translate by the constant mu,
1002 leaving sigma unchanged.
1003
1004 If *other* is a NormalDist, subtract the means and add the variances.
1005 Mathematically, this works only if the two distributions are
1006 independent or if they are jointly normally distributed.
1007 '''
Raymond Hettinger11c79532019-02-23 14:44:07 -08001008 if isinstance(x2, NormalDist):
1009 return NormalDist(x1.mu - x2.mu, hypot(x1.sigma, x2.sigma))
1010 return NormalDist(x1.mu - x2, x1.sigma)
1011
1012 def __mul__(x1, x2):
Raymond Hettinger5f1e8b42019-03-18 22:24:15 -07001013 '''Multiply both mu and sigma by a constant.
1014
1015 Used for rescaling, perhaps to change measurement units.
1016 Sigma is scaled with the absolute value of the constant.
1017 '''
Raymond Hettinger11c79532019-02-23 14:44:07 -08001018 return NormalDist(x1.mu * x2, x1.sigma * fabs(x2))
1019
1020 def __truediv__(x1, x2):
Raymond Hettinger5f1e8b42019-03-18 22:24:15 -07001021 '''Divide both mu and sigma by a constant.
1022
1023 Used for rescaling, perhaps to change measurement units.
1024 Sigma is scaled with the absolute value of the constant.
1025 '''
Raymond Hettinger11c79532019-02-23 14:44:07 -08001026 return NormalDist(x1.mu / x2, x1.sigma / fabs(x2))
1027
1028 def __pos__(x1):
Raymond Hettinger5f1e8b42019-03-18 22:24:15 -07001029 'Return a copy of the instance.'
Raymond Hettinger79fbcc52019-02-23 22:19:01 -08001030 return NormalDist(x1.mu, x1.sigma)
Raymond Hettinger11c79532019-02-23 14:44:07 -08001031
1032 def __neg__(x1):
Raymond Hettinger5f1e8b42019-03-18 22:24:15 -07001033 'Negates mu while keeping sigma the same.'
Raymond Hettinger11c79532019-02-23 14:44:07 -08001034 return NormalDist(-x1.mu, x1.sigma)
1035
1036 __radd__ = __add__
1037
1038 def __rsub__(x1, x2):
Raymond Hettinger5f1e8b42019-03-18 22:24:15 -07001039 'Subtract a NormalDist from a constant or another NormalDist.'
Raymond Hettinger11c79532019-02-23 14:44:07 -08001040 return -(x1 - x2)
1041
1042 __rmul__ = __mul__
1043
1044 def __eq__(x1, x2):
Raymond Hettinger5f1e8b42019-03-18 22:24:15 -07001045 'Two NormalDist objects are equal if their mu and sigma are both equal.'
Raymond Hettinger11c79532019-02-23 14:44:07 -08001046 if not isinstance(x2, NormalDist):
1047 return NotImplemented
1048 return (x1.mu, x2.sigma) == (x2.mu, x2.sigma)
1049
1050 def __repr__(self):
1051 return f'{type(self).__name__}(mu={self.mu!r}, sigma={self.sigma!r})'
1052
1053
1054if __name__ == '__main__':
1055
1056 # Show math operations computed analytically in comparsion
1057 # to a monte carlo simulation of the same operations
1058
1059 from math import isclose
1060 from operator import add, sub, mul, truediv
1061 from itertools import repeat
Raymond Hettingerfc06a192019-03-12 00:43:27 -07001062 import doctest
Raymond Hettinger11c79532019-02-23 14:44:07 -08001063
1064 g1 = NormalDist(10, 20)
1065 g2 = NormalDist(-5, 25)
1066
1067 # Test scaling by a constant
1068 assert (g1 * 5 / 5).mu == g1.mu
1069 assert (g1 * 5 / 5).sigma == g1.sigma
1070
1071 n = 100_000
1072 G1 = g1.samples(n)
1073 G2 = g2.samples(n)
1074
1075 for func in (add, sub):
1076 print(f'\nTest {func.__name__} with another NormalDist:')
1077 print(func(g1, g2))
1078 print(NormalDist.from_samples(map(func, G1, G2)))
1079
1080 const = 11
1081 for func in (add, sub, mul, truediv):
1082 print(f'\nTest {func.__name__} with a constant:')
1083 print(func(g1, const))
1084 print(NormalDist.from_samples(map(func, G1, repeat(const))))
1085
1086 const = 19
1087 for func in (add, sub, mul):
1088 print(f'\nTest constant with {func.__name__}:')
1089 print(func(const, g1))
1090 print(NormalDist.from_samples(map(func, repeat(const), G1)))
1091
1092 def assert_close(G1, G2):
1093 assert isclose(G1.mu, G1.mu, rel_tol=0.01), (G1, G2)
1094 assert isclose(G1.sigma, G2.sigma, rel_tol=0.01), (G1, G2)
1095
1096 X = NormalDist(-105, 73)
1097 Y = NormalDist(31, 47)
1098 s = 32.75
1099 n = 100_000
1100
1101 S = NormalDist.from_samples([x + s for x in X.samples(n)])
1102 assert_close(X + s, S)
1103
1104 S = NormalDist.from_samples([x - s for x in X.samples(n)])
1105 assert_close(X - s, S)
1106
1107 S = NormalDist.from_samples([x * s for x in X.samples(n)])
1108 assert_close(X * s, S)
1109
1110 S = NormalDist.from_samples([x / s for x in X.samples(n)])
1111 assert_close(X / s, S)
1112
1113 S = NormalDist.from_samples([x + y for x, y in zip(X.samples(n),
1114 Y.samples(n))])
1115 assert_close(X + Y, S)
1116
1117 S = NormalDist.from_samples([x - y for x, y in zip(X.samples(n),
1118 Y.samples(n))])
1119 assert_close(X - Y, S)
Raymond Hettingerfc06a192019-03-12 00:43:27 -07001120
1121 print(doctest.testmod())