blob: 9359ed71e51497c5b20e0125e0b4d35c4594af56 [file] [log] [blame]
Larry Hastingsf5e987b2013-10-19 11:50:09 -07001## Module statistics.py
2##
3## Copyright (c) 2013 Steven D'Aprano <steve+python@pearwood.info>.
4##
5## Licensed under the Apache License, Version 2.0 (the "License");
6## you may not use this file except in compliance with the License.
7## You may obtain a copy of the License at
8##
9## http://www.apache.org/licenses/LICENSE-2.0
10##
11## Unless required by applicable law or agreed to in writing, software
12## distributed under the License is distributed on an "AS IS" BASIS,
13## WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14## See the License for the specific language governing permissions and
15## limitations under the License.
16
17
18"""
19Basic statistics module.
20
21This module provides functions for calculating statistics of data, including
22averages, variance, and standard deviation.
23
24Calculating averages
25--------------------
26
27================== =============================================
28Function Description
29================== =============================================
30mean Arithmetic mean (average) of data.
31median Median (middle value) of data.
32median_low Low median of data.
33median_high High median of data.
34median_grouped Median, or 50th percentile, of grouped data.
35mode Mode (most common value) of data.
36================== =============================================
37
38Calculate the arithmetic mean ("the average") of data:
39
40>>> mean([-1.0, 2.5, 3.25, 5.75])
412.625
42
43
44Calculate the standard median of discrete data:
45
46>>> median([2, 3, 4, 5])
473.5
48
49
50Calculate the median, or 50th percentile, of data grouped into class intervals
51centred on the data values provided. E.g. if your data points are rounded to
52the nearest whole number:
53
54>>> median_grouped([2, 2, 3, 3, 3, 4]) #doctest: +ELLIPSIS
552.8333333333...
56
57This should be interpreted in this way: you have two data points in the class
58interval 1.5-2.5, three data points in the class interval 2.5-3.5, and one in
59the class interval 3.5-4.5. The median of these data points is 2.8333...
60
61
62Calculating variability or spread
63---------------------------------
64
65================== =============================================
66Function Description
67================== =============================================
68pvariance Population variance of data.
69variance Sample variance of data.
70pstdev Population standard deviation of data.
71stdev Sample standard deviation of data.
72================== =============================================
73
74Calculate the standard deviation of sample data:
75
76>>> stdev([2.5, 3.25, 5.5, 11.25, 11.75]) #doctest: +ELLIPSIS
774.38961843444...
78
79If you have previously calculated the mean, you can pass it as the optional
80second argument to the four "spread" functions to avoid recalculating it:
81
82>>> data = [1, 2, 2, 4, 4, 4, 5, 6]
83>>> mu = mean(data)
84>>> pvariance(data, mu)
852.5
86
87
88Exceptions
89----------
90
91A single exception is defined: StatisticsError is a subclass of ValueError.
92
93"""
94
95__all__ = [ 'StatisticsError',
96 'pstdev', 'pvariance', 'stdev', 'variance',
97 'median', 'median_low', 'median_high', 'median_grouped',
98 'mean', 'mode',
99 ]
100
101
102import collections
103import math
Larry Hastingsf5e987b2013-10-19 11:50:09 -0700104
105from fractions import Fraction
106from decimal import Decimal
107
108
109# === Exceptions ===
110
111class StatisticsError(ValueError):
112 pass
113
114
115# === Private utilities ===
116
117def _sum(data, start=0):
118 """_sum(data [, start]) -> value
119
120 Return a high-precision sum of the given numeric data. If optional
121 argument ``start`` is given, it is added to the total. If ``data`` is
122 empty, ``start`` (defaulting to 0) is returned.
123
124
125 Examples
126 --------
127
128 >>> _sum([3, 2.25, 4.5, -0.5, 1.0], 0.75)
129 11.0
130
131 Some sources of round-off error will be avoided:
132
133 >>> _sum([1e50, 1, -1e50] * 1000) # Built-in sum returns zero.
134 1000.0
135
136 Fractions and Decimals are also supported:
137
138 >>> from fractions import Fraction as F
139 >>> _sum([F(2, 3), F(7, 5), F(1, 4), F(5, 6)])
140 Fraction(63, 20)
141
142 >>> from decimal import Decimal as D
143 >>> data = [D("0.1375"), D("0.2108"), D("0.3061"), D("0.0419")]
144 >>> _sum(data)
145 Decimal('0.6963')
146
147 """
148 n, d = _exact_ratio(start)
149 T = type(start)
150 partials = {d: n} # map {denominator: sum of numerators}
151 # Micro-optimizations.
152 coerce_types = _coerce_types
153 exact_ratio = _exact_ratio
154 partials_get = partials.get
155 # Add numerators for each denominator, and track the "current" type.
156 for x in data:
157 T = _coerce_types(T, type(x))
158 n, d = exact_ratio(x)
159 partials[d] = partials_get(d, 0) + n
160 if None in partials:
161 assert issubclass(T, (float, Decimal))
162 assert not math.isfinite(partials[None])
163 return T(partials[None])
164 total = Fraction()
165 for d, n in sorted(partials.items()):
166 total += Fraction(n, d)
167 if issubclass(T, int):
168 assert total.denominator == 1
169 return T(total.numerator)
170 if issubclass(T, Decimal):
171 return T(total.numerator)/total.denominator
172 return T(total)
173
174
175def _exact_ratio(x):
176 """Convert Real number x exactly to (numerator, denominator) pair.
177
178 >>> _exact_ratio(0.25)
179 (1, 4)
180
181 x is expected to be an int, Fraction, Decimal or float.
182 """
183 try:
184 try:
185 # int, Fraction
186 return (x.numerator, x.denominator)
187 except AttributeError:
188 # float
189 try:
190 return x.as_integer_ratio()
191 except AttributeError:
192 # Decimal
193 try:
194 return _decimal_to_ratio(x)
195 except AttributeError:
196 msg = "can't convert type '{}' to numerator/denominator"
197 raise TypeError(msg.format(type(x).__name__)) from None
198 except (OverflowError, ValueError):
199 # INF or NAN
200 if __debug__:
201 # Decimal signalling NANs cannot be converted to float :-(
202 if isinstance(x, Decimal):
203 assert not x.is_finite()
204 else:
205 assert not math.isfinite(x)
206 return (x, None)
207
208
209# FIXME This is faster than Fraction.from_decimal, but still too slow.
210def _decimal_to_ratio(d):
211 """Convert Decimal d to exact integer ratio (numerator, denominator).
212
213 >>> from decimal import Decimal
214 >>> _decimal_to_ratio(Decimal("2.6"))
215 (26, 10)
216
217 """
218 sign, digits, exp = d.as_tuple()
219 if exp in ('F', 'n', 'N'): # INF, NAN, sNAN
220 assert not d.is_finite()
221 raise ValueError
222 num = 0
223 for digit in digits:
224 num = num*10 + digit
225 if sign:
226 num = -num
227 den = 10**-exp
228 return (num, den)
229
230
231def _coerce_types(T1, T2):
232 """Coerce types T1 and T2 to a common type.
233
234 >>> _coerce_types(int, float)
235 <class 'float'>
236
237 Coercion is performed according to this table, where "N/A" means
238 that a TypeError exception is raised.
239
240 +----------+-----------+-----------+-----------+----------+
241 | | int | Fraction | Decimal | float |
242 +----------+-----------+-----------+-----------+----------+
243 | int | int | Fraction | Decimal | float |
244 | Fraction | Fraction | Fraction | N/A | float |
245 | Decimal | Decimal | N/A | Decimal | float |
246 | float | float | float | float | float |
247 +----------+-----------+-----------+-----------+----------+
248
249 Subclasses trump their parent class; two subclasses of the same
250 base class will be coerced to the second of the two.
251
252 """
253 # Get the common/fast cases out of the way first.
254 if T1 is T2: return T1
255 if T1 is int: return T2
256 if T2 is int: return T1
257 # Subclasses trump their parent class.
258 if issubclass(T2, T1): return T2
259 if issubclass(T1, T2): return T1
260 # Floats trump everything else.
261 if issubclass(T2, float): return T2
262 if issubclass(T1, float): return T1
263 # Subclasses of the same base class give priority to the second.
264 if T1.__base__ is T2.__base__: return T2
265 # Otherwise, just give up.
266 raise TypeError('cannot coerce types %r and %r' % (T1, T2))
267
268
269def _counts(data):
270 # Generate a table of sorted (value, frequency) pairs.
Nick Coghlanbfd68bf2014-02-08 19:44:16 +1000271 table = collections.Counter(iter(data)).most_common()
Larry Hastingsf5e987b2013-10-19 11:50:09 -0700272 if not table:
273 return table
274 # Extract the values with the highest frequency.
275 maxfreq = table[0][1]
276 for i in range(1, len(table)):
277 if table[i][1] != maxfreq:
278 table = table[:i]
279 break
280 return table
281
282
283# === Measures of central tendency (averages) ===
284
285def mean(data):
286 """Return the sample arithmetic mean of data.
287
288 >>> mean([1, 2, 3, 4, 4])
289 2.8
290
291 >>> from fractions import Fraction as F
292 >>> mean([F(3, 7), F(1, 21), F(5, 3), F(1, 3)])
293 Fraction(13, 21)
294
295 >>> from decimal import Decimal as D
296 >>> mean([D("0.5"), D("0.75"), D("0.625"), D("0.375")])
297 Decimal('0.5625')
298
299 If ``data`` is empty, StatisticsError will be raised.
300 """
301 if iter(data) is data:
302 data = list(data)
303 n = len(data)
304 if n < 1:
305 raise StatisticsError('mean requires at least one data point')
306 return _sum(data)/n
307
308
309# FIXME: investigate ways to calculate medians without sorting? Quickselect?
310def median(data):
311 """Return the median (middle value) of numeric data.
312
313 When the number of data points is odd, return the middle data point.
314 When the number of data points is even, the median is interpolated by
315 taking the average of the two middle values:
316
317 >>> median([1, 3, 5])
318 3
319 >>> median([1, 3, 5, 7])
320 4.0
321
322 """
323 data = sorted(data)
324 n = len(data)
325 if n == 0:
326 raise StatisticsError("no median for empty data")
327 if n%2 == 1:
328 return data[n//2]
329 else:
330 i = n//2
331 return (data[i - 1] + data[i])/2
332
333
334def median_low(data):
335 """Return the low median of numeric data.
336
337 When the number of data points is odd, the middle value is returned.
338 When it is even, the smaller of the two middle values is returned.
339
340 >>> median_low([1, 3, 5])
341 3
342 >>> median_low([1, 3, 5, 7])
343 3
344
345 """
346 data = sorted(data)
347 n = len(data)
348 if n == 0:
349 raise StatisticsError("no median for empty data")
350 if n%2 == 1:
351 return data[n//2]
352 else:
353 return data[n//2 - 1]
354
355
356def median_high(data):
357 """Return the high median of data.
358
359 When the number of data points is odd, the middle value is returned.
360 When it is even, the larger of the two middle values is returned.
361
362 >>> median_high([1, 3, 5])
363 3
364 >>> median_high([1, 3, 5, 7])
365 5
366
367 """
368 data = sorted(data)
369 n = len(data)
370 if n == 0:
371 raise StatisticsError("no median for empty data")
372 return data[n//2]
373
374
375def median_grouped(data, interval=1):
376 """"Return the 50th percentile (median) of grouped continuous data.
377
378 >>> median_grouped([1, 2, 2, 3, 4, 4, 4, 4, 4, 5])
379 3.7
380 >>> median_grouped([52, 52, 53, 54])
381 52.5
382
383 This calculates the median as the 50th percentile, and should be
384 used when your data is continuous and grouped. In the above example,
385 the values 1, 2, 3, etc. actually represent the midpoint of classes
386 0.5-1.5, 1.5-2.5, 2.5-3.5, etc. The middle value falls somewhere in
387 class 3.5-4.5, and interpolation is used to estimate it.
388
389 Optional argument ``interval`` represents the class interval, and
390 defaults to 1. Changing the class interval naturally will change the
391 interpolated 50th percentile value:
392
393 >>> median_grouped([1, 3, 3, 5, 7], interval=1)
394 3.25
395 >>> median_grouped([1, 3, 3, 5, 7], interval=2)
396 3.5
397
398 This function does not check whether the data points are at least
399 ``interval`` apart.
400 """
401 data = sorted(data)
402 n = len(data)
403 if n == 0:
404 raise StatisticsError("no median for empty data")
405 elif n == 1:
406 return data[0]
407 # Find the value at the midpoint. Remember this corresponds to the
408 # centre of the class interval.
409 x = data[n//2]
410 for obj in (x, interval):
411 if isinstance(obj, (str, bytes)):
412 raise TypeError('expected number but got %r' % obj)
413 try:
414 L = x - interval/2 # The lower limit of the median interval.
415 except TypeError:
416 # Mixed type. For now we just coerce to float.
417 L = float(x) - float(interval)/2
418 cf = data.index(x) # Number of values below the median interval.
419 # FIXME The following line could be more efficient for big lists.
420 f = data.count(x) # Number of data points in the median interval.
421 return L + interval*(n/2 - cf)/f
422
423
424def mode(data):
425 """Return the most common data point from discrete or nominal data.
426
427 ``mode`` assumes discrete data, and returns a single value. This is the
428 standard treatment of the mode as commonly taught in schools:
429
430 >>> mode([1, 1, 2, 3, 3, 3, 3, 4])
431 3
432
433 This also works with nominal (non-numeric) data:
434
435 >>> mode(["red", "blue", "blue", "red", "green", "red", "red"])
436 'red'
437
438 If there is not exactly one most common value, ``mode`` will raise
439 StatisticsError.
440 """
441 # Generate a table of sorted (value, frequency) pairs.
442 table = _counts(data)
443 if len(table) == 1:
444 return table[0][0]
445 elif table:
446 raise StatisticsError(
447 'no unique mode; found %d equally common values' % len(table)
448 )
449 else:
450 raise StatisticsError('no mode for empty data')
451
452
453# === Measures of spread ===
454
455# See http://mathworld.wolfram.com/Variance.html
456# http://mathworld.wolfram.com/SampleVariance.html
457# http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
458#
459# Under no circumstances use the so-called "computational formula for
460# variance", as that is only suitable for hand calculations with a small
461# amount of low-precision data. It has terrible numeric properties.
462#
463# See a comparison of three computational methods here:
464# http://www.johndcook.com/blog/2008/09/26/comparing-three-methods-of-computing-standard-deviation/
465
466def _ss(data, c=None):
467 """Return sum of square deviations of sequence data.
468
469 If ``c`` is None, the mean is calculated in one pass, and the deviations
470 from the mean are calculated in a second pass. Otherwise, deviations are
471 calculated from ``c`` as given. Use the second case with care, as it can
472 lead to garbage results.
473 """
474 if c is None:
475 c = mean(data)
476 ss = _sum((x-c)**2 for x in data)
477 # The following sum should mathematically equal zero, but due to rounding
478 # error may not.
479 ss -= _sum((x-c) for x in data)**2/len(data)
480 assert not ss < 0, 'negative sum of square deviations: %f' % ss
481 return ss
482
483
484def variance(data, xbar=None):
485 """Return the sample variance of data.
486
487 data should be an iterable of Real-valued numbers, with at least two
488 values. The optional argument xbar, if given, should be the mean of
489 the data. If it is missing or None, the mean is automatically calculated.
490
491 Use this function when your data is a sample from a population. To
492 calculate the variance from the entire population, see ``pvariance``.
493
494 Examples:
495
496 >>> data = [2.75, 1.75, 1.25, 0.25, 0.5, 1.25, 3.5]
497 >>> variance(data)
498 1.3720238095238095
499
500 If you have already calculated the mean of your data, you can pass it as
501 the optional second argument ``xbar`` to avoid recalculating it:
502
503 >>> m = mean(data)
504 >>> variance(data, m)
505 1.3720238095238095
506
507 This function does not check that ``xbar`` is actually the mean of
508 ``data``. Giving arbitrary values for ``xbar`` may lead to invalid or
509 impossible results.
510
511 Decimals and Fractions are supported:
512
513 >>> from decimal import Decimal as D
514 >>> variance([D("27.5"), D("30.25"), D("30.25"), D("34.5"), D("41.75")])
515 Decimal('31.01875')
516
517 >>> from fractions import Fraction as F
518 >>> variance([F(1, 6), F(1, 2), F(5, 3)])
519 Fraction(67, 108)
520
521 """
522 if iter(data) is data:
523 data = list(data)
524 n = len(data)
525 if n < 2:
526 raise StatisticsError('variance requires at least two data points')
527 ss = _ss(data, xbar)
528 return ss/(n-1)
529
530
531def pvariance(data, mu=None):
532 """Return the population variance of ``data``.
533
534 data should be an iterable of Real-valued numbers, with at least one
535 value. The optional argument mu, if given, should be the mean of
536 the data. If it is missing or None, the mean is automatically calculated.
537
538 Use this function to calculate the variance from the entire population.
539 To estimate the variance from a sample, the ``variance`` function is
540 usually a better choice.
541
542 Examples:
543
544 >>> data = [0.0, 0.25, 0.25, 1.25, 1.5, 1.75, 2.75, 3.25]
545 >>> pvariance(data)
546 1.25
547
548 If you have already calculated the mean of the data, you can pass it as
549 the optional second argument to avoid recalculating it:
550
551 >>> mu = mean(data)
552 >>> pvariance(data, mu)
553 1.25
554
555 This function does not check that ``mu`` is actually the mean of ``data``.
556 Giving arbitrary values for ``mu`` may lead to invalid or impossible
557 results.
558
559 Decimals and Fractions are supported:
560
561 >>> from decimal import Decimal as D
562 >>> pvariance([D("27.5"), D("30.25"), D("30.25"), D("34.5"), D("41.75")])
563 Decimal('24.815')
564
565 >>> from fractions import Fraction as F
566 >>> pvariance([F(1, 4), F(5, 4), F(1, 2)])
567 Fraction(13, 72)
568
569 """
570 if iter(data) is data:
571 data = list(data)
572 n = len(data)
573 if n < 1:
574 raise StatisticsError('pvariance requires at least one data point')
575 ss = _ss(data, mu)
576 return ss/n
577
578
579def stdev(data, xbar=None):
580 """Return the square root of the sample variance.
581
582 See ``variance`` for arguments and other details.
583
584 >>> stdev([1.5, 2.5, 2.5, 2.75, 3.25, 4.75])
585 1.0810874155219827
586
587 """
588 var = variance(data, xbar)
589 try:
590 return var.sqrt()
591 except AttributeError:
592 return math.sqrt(var)
593
594
595def pstdev(data, mu=None):
596 """Return the square root of the population variance.
597
598 See ``pvariance`` for arguments and other details.
599
600 >>> pstdev([1.5, 2.5, 2.5, 2.75, 3.25, 4.75])
601 0.986893273527251
602
603 """
604 var = pvariance(data, mu)
605 try:
606 return var.sqrt()
607 except AttributeError:
608 return math.sqrt(var)