blob: a67a6d11cd6c058c6e9e47b28af0447829f05cbd [file] [log] [blame]
Larry Hastingsf5e987b2013-10-19 11:50:09 -07001## Module statistics.py
2##
3## Copyright (c) 2013 Steven D'Aprano <steve+python@pearwood.info>.
4##
5## Licensed under the Apache License, Version 2.0 (the "License");
6## you may not use this file except in compliance with the License.
7## You may obtain a copy of the License at
8##
9## http://www.apache.org/licenses/LICENSE-2.0
10##
11## Unless required by applicable law or agreed to in writing, software
12## distributed under the License is distributed on an "AS IS" BASIS,
13## WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14## See the License for the specific language governing permissions and
15## limitations under the License.
16
17
18"""
19Basic statistics module.
20
21This module provides functions for calculating statistics of data, including
22averages, variance, and standard deviation.
23
24Calculating averages
25--------------------
26
27================== =============================================
28Function Description
29================== =============================================
30mean Arithmetic mean (average) of data.
31median Median (middle value) of data.
32median_low Low median of data.
33median_high High median of data.
34median_grouped Median, or 50th percentile, of grouped data.
35mode Mode (most common value) of data.
36================== =============================================
37
38Calculate the arithmetic mean ("the average") of data:
39
40>>> mean([-1.0, 2.5, 3.25, 5.75])
412.625
42
43
44Calculate the standard median of discrete data:
45
46>>> median([2, 3, 4, 5])
473.5
48
49
50Calculate the median, or 50th percentile, of data grouped into class intervals
51centred on the data values provided. E.g. if your data points are rounded to
52the nearest whole number:
53
54>>> median_grouped([2, 2, 3, 3, 3, 4]) #doctest: +ELLIPSIS
552.8333333333...
56
57This should be interpreted in this way: you have two data points in the class
58interval 1.5-2.5, three data points in the class interval 2.5-3.5, and one in
59the class interval 3.5-4.5. The median of these data points is 2.8333...
60
61
62Calculating variability or spread
63---------------------------------
64
65================== =============================================
66Function Description
67================== =============================================
68pvariance Population variance of data.
69variance Sample variance of data.
70pstdev Population standard deviation of data.
71stdev Sample standard deviation of data.
72================== =============================================
73
74Calculate the standard deviation of sample data:
75
76>>> stdev([2.5, 3.25, 5.5, 11.25, 11.75]) #doctest: +ELLIPSIS
774.38961843444...
78
79If you have previously calculated the mean, you can pass it as the optional
80second argument to the four "spread" functions to avoid recalculating it:
81
82>>> data = [1, 2, 2, 4, 4, 4, 5, 6]
83>>> mu = mean(data)
84>>> pvariance(data, mu)
852.5
86
87
88Exceptions
89----------
90
91A single exception is defined: StatisticsError is a subclass of ValueError.
92
93"""
94
95__all__ = [ 'StatisticsError',
96 'pstdev', 'pvariance', 'stdev', 'variance',
97 'median', 'median_low', 'median_high', 'median_grouped',
98 'mean', 'mode',
99 ]
100
101
102import collections
103import math
Larry Hastingsf5e987b2013-10-19 11:50:09 -0700104
105from fractions import Fraction
106from decimal import Decimal
107
108
109# === Exceptions ===
110
111class StatisticsError(ValueError):
112 pass
113
114
115# === Private utilities ===
116
117def _sum(data, start=0):
118 """_sum(data [, start]) -> value
119
120 Return a high-precision sum of the given numeric data. If optional
121 argument ``start`` is given, it is added to the total. If ``data`` is
122 empty, ``start`` (defaulting to 0) is returned.
123
124
125 Examples
126 --------
127
128 >>> _sum([3, 2.25, 4.5, -0.5, 1.0], 0.75)
129 11.0
130
131 Some sources of round-off error will be avoided:
132
133 >>> _sum([1e50, 1, -1e50] * 1000) # Built-in sum returns zero.
134 1000.0
135
136 Fractions and Decimals are also supported:
137
138 >>> from fractions import Fraction as F
139 >>> _sum([F(2, 3), F(7, 5), F(1, 4), F(5, 6)])
140 Fraction(63, 20)
141
142 >>> from decimal import Decimal as D
143 >>> data = [D("0.1375"), D("0.2108"), D("0.3061"), D("0.0419")]
144 >>> _sum(data)
145 Decimal('0.6963')
146
147 """
148 n, d = _exact_ratio(start)
149 T = type(start)
150 partials = {d: n} # map {denominator: sum of numerators}
151 # Micro-optimizations.
152 coerce_types = _coerce_types
153 exact_ratio = _exact_ratio
154 partials_get = partials.get
155 # Add numerators for each denominator, and track the "current" type.
156 for x in data:
157 T = _coerce_types(T, type(x))
158 n, d = exact_ratio(x)
159 partials[d] = partials_get(d, 0) + n
160 if None in partials:
161 assert issubclass(T, (float, Decimal))
162 assert not math.isfinite(partials[None])
163 return T(partials[None])
164 total = Fraction()
165 for d, n in sorted(partials.items()):
166 total += Fraction(n, d)
167 if issubclass(T, int):
168 assert total.denominator == 1
169 return T(total.numerator)
170 if issubclass(T, Decimal):
171 return T(total.numerator)/total.denominator
172 return T(total)
173
174
175def _exact_ratio(x):
176 """Convert Real number x exactly to (numerator, denominator) pair.
177
178 >>> _exact_ratio(0.25)
179 (1, 4)
180
181 x is expected to be an int, Fraction, Decimal or float.
182 """
183 try:
184 try:
185 # int, Fraction
186 return (x.numerator, x.denominator)
187 except AttributeError:
188 # float
189 try:
190 return x.as_integer_ratio()
191 except AttributeError:
192 # Decimal
193 try:
194 return _decimal_to_ratio(x)
195 except AttributeError:
196 msg = "can't convert type '{}' to numerator/denominator"
197 raise TypeError(msg.format(type(x).__name__)) from None
198 except (OverflowError, ValueError):
199 # INF or NAN
200 if __debug__:
201 # Decimal signalling NANs cannot be converted to float :-(
202 if isinstance(x, Decimal):
203 assert not x.is_finite()
204 else:
205 assert not math.isfinite(x)
206 return (x, None)
207
208
209# FIXME This is faster than Fraction.from_decimal, but still too slow.
210def _decimal_to_ratio(d):
211 """Convert Decimal d to exact integer ratio (numerator, denominator).
212
213 >>> from decimal import Decimal
214 >>> _decimal_to_ratio(Decimal("2.6"))
215 (26, 10)
216
217 """
218 sign, digits, exp = d.as_tuple()
219 if exp in ('F', 'n', 'N'): # INF, NAN, sNAN
220 assert not d.is_finite()
221 raise ValueError
222 num = 0
223 for digit in digits:
224 num = num*10 + digit
225 if sign:
226 num = -num
227 den = 10**-exp
228 return (num, den)
229
230
231def _coerce_types(T1, T2):
232 """Coerce types T1 and T2 to a common type.
233
234 >>> _coerce_types(int, float)
235 <class 'float'>
236
237 Coercion is performed according to this table, where "N/A" means
238 that a TypeError exception is raised.
239
240 +----------+-----------+-----------+-----------+----------+
241 | | int | Fraction | Decimal | float |
242 +----------+-----------+-----------+-----------+----------+
243 | int | int | Fraction | Decimal | float |
244 | Fraction | Fraction | Fraction | N/A | float |
245 | Decimal | Decimal | N/A | Decimal | float |
246 | float | float | float | float | float |
247 +----------+-----------+-----------+-----------+----------+
248
249 Subclasses trump their parent class; two subclasses of the same
250 base class will be coerced to the second of the two.
251
252 """
253 # Get the common/fast cases out of the way first.
254 if T1 is T2: return T1
255 if T1 is int: return T2
256 if T2 is int: return T1
257 # Subclasses trump their parent class.
258 if issubclass(T2, T1): return T2
259 if issubclass(T1, T2): return T1
260 # Floats trump everything else.
261 if issubclass(T2, float): return T2
262 if issubclass(T1, float): return T1
263 # Subclasses of the same base class give priority to the second.
264 if T1.__base__ is T2.__base__: return T2
265 # Otherwise, just give up.
266 raise TypeError('cannot coerce types %r and %r' % (T1, T2))
267
268
269def _counts(data):
270 # Generate a table of sorted (value, frequency) pairs.
271 if data is None:
272 raise TypeError('None is not iterable')
273 table = collections.Counter(data).most_common()
274 if not table:
275 return table
276 # Extract the values with the highest frequency.
277 maxfreq = table[0][1]
278 for i in range(1, len(table)):
279 if table[i][1] != maxfreq:
280 table = table[:i]
281 break
282 return table
283
284
285# === Measures of central tendency (averages) ===
286
287def mean(data):
288 """Return the sample arithmetic mean of data.
289
290 >>> mean([1, 2, 3, 4, 4])
291 2.8
292
293 >>> from fractions import Fraction as F
294 >>> mean([F(3, 7), F(1, 21), F(5, 3), F(1, 3)])
295 Fraction(13, 21)
296
297 >>> from decimal import Decimal as D
298 >>> mean([D("0.5"), D("0.75"), D("0.625"), D("0.375")])
299 Decimal('0.5625')
300
301 If ``data`` is empty, StatisticsError will be raised.
302 """
303 if iter(data) is data:
304 data = list(data)
305 n = len(data)
306 if n < 1:
307 raise StatisticsError('mean requires at least one data point')
308 return _sum(data)/n
309
310
311# FIXME: investigate ways to calculate medians without sorting? Quickselect?
312def median(data):
313 """Return the median (middle value) of numeric data.
314
315 When the number of data points is odd, return the middle data point.
316 When the number of data points is even, the median is interpolated by
317 taking the average of the two middle values:
318
319 >>> median([1, 3, 5])
320 3
321 >>> median([1, 3, 5, 7])
322 4.0
323
324 """
325 data = sorted(data)
326 n = len(data)
327 if n == 0:
328 raise StatisticsError("no median for empty data")
329 if n%2 == 1:
330 return data[n//2]
331 else:
332 i = n//2
333 return (data[i - 1] + data[i])/2
334
335
336def median_low(data):
337 """Return the low median of numeric data.
338
339 When the number of data points is odd, the middle value is returned.
340 When it is even, the smaller of the two middle values is returned.
341
342 >>> median_low([1, 3, 5])
343 3
344 >>> median_low([1, 3, 5, 7])
345 3
346
347 """
348 data = sorted(data)
349 n = len(data)
350 if n == 0:
351 raise StatisticsError("no median for empty data")
352 if n%2 == 1:
353 return data[n//2]
354 else:
355 return data[n//2 - 1]
356
357
358def median_high(data):
359 """Return the high median of data.
360
361 When the number of data points is odd, the middle value is returned.
362 When it is even, the larger of the two middle values is returned.
363
364 >>> median_high([1, 3, 5])
365 3
366 >>> median_high([1, 3, 5, 7])
367 5
368
369 """
370 data = sorted(data)
371 n = len(data)
372 if n == 0:
373 raise StatisticsError("no median for empty data")
374 return data[n//2]
375
376
377def median_grouped(data, interval=1):
378 """"Return the 50th percentile (median) of grouped continuous data.
379
380 >>> median_grouped([1, 2, 2, 3, 4, 4, 4, 4, 4, 5])
381 3.7
382 >>> median_grouped([52, 52, 53, 54])
383 52.5
384
385 This calculates the median as the 50th percentile, and should be
386 used when your data is continuous and grouped. In the above example,
387 the values 1, 2, 3, etc. actually represent the midpoint of classes
388 0.5-1.5, 1.5-2.5, 2.5-3.5, etc. The middle value falls somewhere in
389 class 3.5-4.5, and interpolation is used to estimate it.
390
391 Optional argument ``interval`` represents the class interval, and
392 defaults to 1. Changing the class interval naturally will change the
393 interpolated 50th percentile value:
394
395 >>> median_grouped([1, 3, 3, 5, 7], interval=1)
396 3.25
397 >>> median_grouped([1, 3, 3, 5, 7], interval=2)
398 3.5
399
400 This function does not check whether the data points are at least
401 ``interval`` apart.
402 """
403 data = sorted(data)
404 n = len(data)
405 if n == 0:
406 raise StatisticsError("no median for empty data")
407 elif n == 1:
408 return data[0]
409 # Find the value at the midpoint. Remember this corresponds to the
410 # centre of the class interval.
411 x = data[n//2]
412 for obj in (x, interval):
413 if isinstance(obj, (str, bytes)):
414 raise TypeError('expected number but got %r' % obj)
415 try:
416 L = x - interval/2 # The lower limit of the median interval.
417 except TypeError:
418 # Mixed type. For now we just coerce to float.
419 L = float(x) - float(interval)/2
420 cf = data.index(x) # Number of values below the median interval.
421 # FIXME The following line could be more efficient for big lists.
422 f = data.count(x) # Number of data points in the median interval.
423 return L + interval*(n/2 - cf)/f
424
425
426def mode(data):
427 """Return the most common data point from discrete or nominal data.
428
429 ``mode`` assumes discrete data, and returns a single value. This is the
430 standard treatment of the mode as commonly taught in schools:
431
432 >>> mode([1, 1, 2, 3, 3, 3, 3, 4])
433 3
434
435 This also works with nominal (non-numeric) data:
436
437 >>> mode(["red", "blue", "blue", "red", "green", "red", "red"])
438 'red'
439
440 If there is not exactly one most common value, ``mode`` will raise
441 StatisticsError.
442 """
443 # Generate a table of sorted (value, frequency) pairs.
444 table = _counts(data)
445 if len(table) == 1:
446 return table[0][0]
447 elif table:
448 raise StatisticsError(
449 'no unique mode; found %d equally common values' % len(table)
450 )
451 else:
452 raise StatisticsError('no mode for empty data')
453
454
455# === Measures of spread ===
456
457# See http://mathworld.wolfram.com/Variance.html
458# http://mathworld.wolfram.com/SampleVariance.html
459# http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
460#
461# Under no circumstances use the so-called "computational formula for
462# variance", as that is only suitable for hand calculations with a small
463# amount of low-precision data. It has terrible numeric properties.
464#
465# See a comparison of three computational methods here:
466# http://www.johndcook.com/blog/2008/09/26/comparing-three-methods-of-computing-standard-deviation/
467
468def _ss(data, c=None):
469 """Return sum of square deviations of sequence data.
470
471 If ``c`` is None, the mean is calculated in one pass, and the deviations
472 from the mean are calculated in a second pass. Otherwise, deviations are
473 calculated from ``c`` as given. Use the second case with care, as it can
474 lead to garbage results.
475 """
476 if c is None:
477 c = mean(data)
478 ss = _sum((x-c)**2 for x in data)
479 # The following sum should mathematically equal zero, but due to rounding
480 # error may not.
481 ss -= _sum((x-c) for x in data)**2/len(data)
482 assert not ss < 0, 'negative sum of square deviations: %f' % ss
483 return ss
484
485
486def variance(data, xbar=None):
487 """Return the sample variance of data.
488
489 data should be an iterable of Real-valued numbers, with at least two
490 values. The optional argument xbar, if given, should be the mean of
491 the data. If it is missing or None, the mean is automatically calculated.
492
493 Use this function when your data is a sample from a population. To
494 calculate the variance from the entire population, see ``pvariance``.
495
496 Examples:
497
498 >>> data = [2.75, 1.75, 1.25, 0.25, 0.5, 1.25, 3.5]
499 >>> variance(data)
500 1.3720238095238095
501
502 If you have already calculated the mean of your data, you can pass it as
503 the optional second argument ``xbar`` to avoid recalculating it:
504
505 >>> m = mean(data)
506 >>> variance(data, m)
507 1.3720238095238095
508
509 This function does not check that ``xbar`` is actually the mean of
510 ``data``. Giving arbitrary values for ``xbar`` may lead to invalid or
511 impossible results.
512
513 Decimals and Fractions are supported:
514
515 >>> from decimal import Decimal as D
516 >>> variance([D("27.5"), D("30.25"), D("30.25"), D("34.5"), D("41.75")])
517 Decimal('31.01875')
518
519 >>> from fractions import Fraction as F
520 >>> variance([F(1, 6), F(1, 2), F(5, 3)])
521 Fraction(67, 108)
522
523 """
524 if iter(data) is data:
525 data = list(data)
526 n = len(data)
527 if n < 2:
528 raise StatisticsError('variance requires at least two data points')
529 ss = _ss(data, xbar)
530 return ss/(n-1)
531
532
533def pvariance(data, mu=None):
534 """Return the population variance of ``data``.
535
536 data should be an iterable of Real-valued numbers, with at least one
537 value. The optional argument mu, if given, should be the mean of
538 the data. If it is missing or None, the mean is automatically calculated.
539
540 Use this function to calculate the variance from the entire population.
541 To estimate the variance from a sample, the ``variance`` function is
542 usually a better choice.
543
544 Examples:
545
546 >>> data = [0.0, 0.25, 0.25, 1.25, 1.5, 1.75, 2.75, 3.25]
547 >>> pvariance(data)
548 1.25
549
550 If you have already calculated the mean of the data, you can pass it as
551 the optional second argument to avoid recalculating it:
552
553 >>> mu = mean(data)
554 >>> pvariance(data, mu)
555 1.25
556
557 This function does not check that ``mu`` is actually the mean of ``data``.
558 Giving arbitrary values for ``mu`` may lead to invalid or impossible
559 results.
560
561 Decimals and Fractions are supported:
562
563 >>> from decimal import Decimal as D
564 >>> pvariance([D("27.5"), D("30.25"), D("30.25"), D("34.5"), D("41.75")])
565 Decimal('24.815')
566
567 >>> from fractions import Fraction as F
568 >>> pvariance([F(1, 4), F(5, 4), F(1, 2)])
569 Fraction(13, 72)
570
571 """
572 if iter(data) is data:
573 data = list(data)
574 n = len(data)
575 if n < 1:
576 raise StatisticsError('pvariance requires at least one data point')
577 ss = _ss(data, mu)
578 return ss/n
579
580
581def stdev(data, xbar=None):
582 """Return the square root of the sample variance.
583
584 See ``variance`` for arguments and other details.
585
586 >>> stdev([1.5, 2.5, 2.5, 2.75, 3.25, 4.75])
587 1.0810874155219827
588
589 """
590 var = variance(data, xbar)
591 try:
592 return var.sqrt()
593 except AttributeError:
594 return math.sqrt(var)
595
596
597def pstdev(data, mu=None):
598 """Return the square root of the population variance.
599
600 See ``pvariance`` for arguments and other details.
601
602 >>> pstdev([1.5, 2.5, 2.5, 2.75, 3.25, 4.75])
603 0.986893273527251
604
605 """
606 var = pvariance(data, mu)
607 try:
608 return var.sqrt()
609 except AttributeError:
610 return math.sqrt(var)