blob: 3972ed2e5f7ec0278bd4c32863d07a1abf25f0a4 [file] [log] [blame]
Larry Hastingsf5e987b2013-10-19 11:50:09 -07001## Module statistics.py
2##
3## Copyright (c) 2013 Steven D'Aprano <steve+python@pearwood.info>.
4##
5## Licensed under the Apache License, Version 2.0 (the "License");
6## you may not use this file except in compliance with the License.
7## You may obtain a copy of the License at
8##
9## http://www.apache.org/licenses/LICENSE-2.0
10##
11## Unless required by applicable law or agreed to in writing, software
12## distributed under the License is distributed on an "AS IS" BASIS,
13## WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14## See the License for the specific language governing permissions and
15## limitations under the License.
16
17
18"""
19Basic statistics module.
20
21This module provides functions for calculating statistics of data, including
22averages, variance, and standard deviation.
23
24Calculating averages
25--------------------
26
27================== =============================================
28Function Description
29================== =============================================
30mean Arithmetic mean (average) of data.
31median Median (middle value) of data.
32median_low Low median of data.
33median_high High median of data.
34median_grouped Median, or 50th percentile, of grouped data.
35mode Mode (most common value) of data.
36================== =============================================
37
38Calculate the arithmetic mean ("the average") of data:
39
40>>> mean([-1.0, 2.5, 3.25, 5.75])
412.625
42
43
44Calculate the standard median of discrete data:
45
46>>> median([2, 3, 4, 5])
473.5
48
49
50Calculate the median, or 50th percentile, of data grouped into class intervals
51centred on the data values provided. E.g. if your data points are rounded to
52the nearest whole number:
53
54>>> median_grouped([2, 2, 3, 3, 3, 4]) #doctest: +ELLIPSIS
552.8333333333...
56
57This should be interpreted in this way: you have two data points in the class
58interval 1.5-2.5, three data points in the class interval 2.5-3.5, and one in
59the class interval 3.5-4.5. The median of these data points is 2.8333...
60
61
62Calculating variability or spread
63---------------------------------
64
65================== =============================================
66Function Description
67================== =============================================
68pvariance Population variance of data.
69variance Sample variance of data.
70pstdev Population standard deviation of data.
71stdev Sample standard deviation of data.
72================== =============================================
73
74Calculate the standard deviation of sample data:
75
76>>> stdev([2.5, 3.25, 5.5, 11.25, 11.75]) #doctest: +ELLIPSIS
774.38961843444...
78
79If you have previously calculated the mean, you can pass it as the optional
80second argument to the four "spread" functions to avoid recalculating it:
81
82>>> data = [1, 2, 2, 4, 4, 4, 5, 6]
83>>> mu = mean(data)
84>>> pvariance(data, mu)
852.5
86
87
88Exceptions
89----------
90
91A single exception is defined: StatisticsError is a subclass of ValueError.
92
93"""
94
95__all__ = [ 'StatisticsError',
96 'pstdev', 'pvariance', 'stdev', 'variance',
97 'median', 'median_low', 'median_high', 'median_grouped',
98 'mean', 'mode',
99 ]
100
101
102import collections
103import math
Larry Hastingsf5e987b2013-10-19 11:50:09 -0700104
105from fractions import Fraction
106from decimal import Decimal
107
108
109# === Exceptions ===
110
111class StatisticsError(ValueError):
112 pass
113
114
115# === Private utilities ===
116
117def _sum(data, start=0):
118 """_sum(data [, start]) -> value
119
120 Return a high-precision sum of the given numeric data. If optional
121 argument ``start`` is given, it is added to the total. If ``data`` is
122 empty, ``start`` (defaulting to 0) is returned.
123
124
125 Examples
126 --------
127
128 >>> _sum([3, 2.25, 4.5, -0.5, 1.0], 0.75)
129 11.0
130
131 Some sources of round-off error will be avoided:
132
133 >>> _sum([1e50, 1, -1e50] * 1000) # Built-in sum returns zero.
134 1000.0
135
136 Fractions and Decimals are also supported:
137
138 >>> from fractions import Fraction as F
139 >>> _sum([F(2, 3), F(7, 5), F(1, 4), F(5, 6)])
140 Fraction(63, 20)
141
142 >>> from decimal import Decimal as D
143 >>> data = [D("0.1375"), D("0.2108"), D("0.3061"), D("0.0419")]
144 >>> _sum(data)
145 Decimal('0.6963')
146
Nick Coghlan73afe2a2014-02-08 19:58:04 +1000147 Mixed types are currently treated as an error, except that int is
148 allowed.
Larry Hastingsf5e987b2013-10-19 11:50:09 -0700149 """
Nick Coghlan73afe2a2014-02-08 19:58:04 +1000150 # We fail as soon as we reach a value that is not an int or the type of
151 # the first value which is not an int. E.g. _sum([int, int, float, int])
152 # is okay, but sum([int, int, float, Fraction]) is not.
Raymond Hettingerdf1b6992014-11-09 15:56:33 -0800153 allowed_types = {int, type(start)}
Larry Hastingsf5e987b2013-10-19 11:50:09 -0700154 n, d = _exact_ratio(start)
Larry Hastingsf5e987b2013-10-19 11:50:09 -0700155 partials = {d: n} # map {denominator: sum of numerators}
156 # Micro-optimizations.
Larry Hastingsf5e987b2013-10-19 11:50:09 -0700157 exact_ratio = _exact_ratio
158 partials_get = partials.get
Nick Coghlan73afe2a2014-02-08 19:58:04 +1000159 # Add numerators for each denominator.
Larry Hastingsf5e987b2013-10-19 11:50:09 -0700160 for x in data:
Nick Coghlan73afe2a2014-02-08 19:58:04 +1000161 _check_type(type(x), allowed_types)
Larry Hastingsf5e987b2013-10-19 11:50:09 -0700162 n, d = exact_ratio(x)
163 partials[d] = partials_get(d, 0) + n
Nick Coghlan73afe2a2014-02-08 19:58:04 +1000164 # Find the expected result type. If allowed_types has only one item, it
165 # will be int; if it has two, use the one which isn't int.
166 assert len(allowed_types) in (1, 2)
167 if len(allowed_types) == 1:
168 assert allowed_types.pop() is int
169 T = int
170 else:
Raymond Hettingerdf1b6992014-11-09 15:56:33 -0800171 T = (allowed_types - {int}).pop()
Larry Hastingsf5e987b2013-10-19 11:50:09 -0700172 if None in partials:
173 assert issubclass(T, (float, Decimal))
174 assert not math.isfinite(partials[None])
175 return T(partials[None])
176 total = Fraction()
177 for d, n in sorted(partials.items()):
178 total += Fraction(n, d)
179 if issubclass(T, int):
180 assert total.denominator == 1
181 return T(total.numerator)
182 if issubclass(T, Decimal):
183 return T(total.numerator)/total.denominator
184 return T(total)
185
186
Nick Coghlan73afe2a2014-02-08 19:58:04 +1000187def _check_type(T, allowed):
188 if T not in allowed:
189 if len(allowed) == 1:
190 allowed.add(T)
191 else:
192 types = ', '.join([t.__name__ for t in allowed] + [T.__name__])
193 raise TypeError("unsupported mixed types: %s" % types)
194
195
Larry Hastingsf5e987b2013-10-19 11:50:09 -0700196def _exact_ratio(x):
197 """Convert Real number x exactly to (numerator, denominator) pair.
198
199 >>> _exact_ratio(0.25)
200 (1, 4)
201
202 x is expected to be an int, Fraction, Decimal or float.
203 """
204 try:
205 try:
206 # int, Fraction
207 return (x.numerator, x.denominator)
208 except AttributeError:
209 # float
210 try:
211 return x.as_integer_ratio()
212 except AttributeError:
213 # Decimal
214 try:
215 return _decimal_to_ratio(x)
216 except AttributeError:
217 msg = "can't convert type '{}' to numerator/denominator"
218 raise TypeError(msg.format(type(x).__name__)) from None
219 except (OverflowError, ValueError):
220 # INF or NAN
221 if __debug__:
222 # Decimal signalling NANs cannot be converted to float :-(
223 if isinstance(x, Decimal):
224 assert not x.is_finite()
225 else:
226 assert not math.isfinite(x)
227 return (x, None)
228
229
230# FIXME This is faster than Fraction.from_decimal, but still too slow.
231def _decimal_to_ratio(d):
232 """Convert Decimal d to exact integer ratio (numerator, denominator).
233
234 >>> from decimal import Decimal
235 >>> _decimal_to_ratio(Decimal("2.6"))
236 (26, 10)
237
238 """
239 sign, digits, exp = d.as_tuple()
240 if exp in ('F', 'n', 'N'): # INF, NAN, sNAN
241 assert not d.is_finite()
242 raise ValueError
243 num = 0
244 for digit in digits:
245 num = num*10 + digit
Nick Coghlan4a7668a2014-02-08 23:55:14 +1000246 if exp < 0:
247 den = 10**-exp
248 else:
249 num *= 10**exp
250 den = 1
Larry Hastingsf5e987b2013-10-19 11:50:09 -0700251 if sign:
252 num = -num
Larry Hastingsf5e987b2013-10-19 11:50:09 -0700253 return (num, den)
254
255
Larry Hastingsf5e987b2013-10-19 11:50:09 -0700256def _counts(data):
257 # Generate a table of sorted (value, frequency) pairs.
Nick Coghlanbfd68bf2014-02-08 19:44:16 +1000258 table = collections.Counter(iter(data)).most_common()
Larry Hastingsf5e987b2013-10-19 11:50:09 -0700259 if not table:
260 return table
261 # Extract the values with the highest frequency.
262 maxfreq = table[0][1]
263 for i in range(1, len(table)):
264 if table[i][1] != maxfreq:
265 table = table[:i]
266 break
267 return table
268
269
270# === Measures of central tendency (averages) ===
271
272def mean(data):
273 """Return the sample arithmetic mean of data.
274
275 >>> mean([1, 2, 3, 4, 4])
276 2.8
277
278 >>> from fractions import Fraction as F
279 >>> mean([F(3, 7), F(1, 21), F(5, 3), F(1, 3)])
280 Fraction(13, 21)
281
282 >>> from decimal import Decimal as D
283 >>> mean([D("0.5"), D("0.75"), D("0.625"), D("0.375")])
284 Decimal('0.5625')
285
286 If ``data`` is empty, StatisticsError will be raised.
287 """
288 if iter(data) is data:
289 data = list(data)
290 n = len(data)
291 if n < 1:
292 raise StatisticsError('mean requires at least one data point')
293 return _sum(data)/n
294
295
296# FIXME: investigate ways to calculate medians without sorting? Quickselect?
297def median(data):
298 """Return the median (middle value) of numeric data.
299
300 When the number of data points is odd, return the middle data point.
301 When the number of data points is even, the median is interpolated by
302 taking the average of the two middle values:
303
304 >>> median([1, 3, 5])
305 3
306 >>> median([1, 3, 5, 7])
307 4.0
308
309 """
310 data = sorted(data)
311 n = len(data)
312 if n == 0:
313 raise StatisticsError("no median for empty data")
314 if n%2 == 1:
315 return data[n//2]
316 else:
317 i = n//2
318 return (data[i - 1] + data[i])/2
319
320
321def median_low(data):
322 """Return the low median of numeric data.
323
324 When the number of data points is odd, the middle value is returned.
325 When it is even, the smaller of the two middle values is returned.
326
327 >>> median_low([1, 3, 5])
328 3
329 >>> median_low([1, 3, 5, 7])
330 3
331
332 """
333 data = sorted(data)
334 n = len(data)
335 if n == 0:
336 raise StatisticsError("no median for empty data")
337 if n%2 == 1:
338 return data[n//2]
339 else:
340 return data[n//2 - 1]
341
342
343def median_high(data):
344 """Return the high median of data.
345
346 When the number of data points is odd, the middle value is returned.
347 When it is even, the larger of the two middle values is returned.
348
349 >>> median_high([1, 3, 5])
350 3
351 >>> median_high([1, 3, 5, 7])
352 5
353
354 """
355 data = sorted(data)
356 n = len(data)
357 if n == 0:
358 raise StatisticsError("no median for empty data")
359 return data[n//2]
360
361
362def median_grouped(data, interval=1):
363 """"Return the 50th percentile (median) of grouped continuous data.
364
365 >>> median_grouped([1, 2, 2, 3, 4, 4, 4, 4, 4, 5])
366 3.7
367 >>> median_grouped([52, 52, 53, 54])
368 52.5
369
370 This calculates the median as the 50th percentile, and should be
371 used when your data is continuous and grouped. In the above example,
372 the values 1, 2, 3, etc. actually represent the midpoint of classes
373 0.5-1.5, 1.5-2.5, 2.5-3.5, etc. The middle value falls somewhere in
374 class 3.5-4.5, and interpolation is used to estimate it.
375
376 Optional argument ``interval`` represents the class interval, and
377 defaults to 1. Changing the class interval naturally will change the
378 interpolated 50th percentile value:
379
380 >>> median_grouped([1, 3, 3, 5, 7], interval=1)
381 3.25
382 >>> median_grouped([1, 3, 3, 5, 7], interval=2)
383 3.5
384
385 This function does not check whether the data points are at least
386 ``interval`` apart.
387 """
388 data = sorted(data)
389 n = len(data)
390 if n == 0:
391 raise StatisticsError("no median for empty data")
392 elif n == 1:
393 return data[0]
394 # Find the value at the midpoint. Remember this corresponds to the
395 # centre of the class interval.
396 x = data[n//2]
397 for obj in (x, interval):
398 if isinstance(obj, (str, bytes)):
399 raise TypeError('expected number but got %r' % obj)
400 try:
401 L = x - interval/2 # The lower limit of the median interval.
402 except TypeError:
403 # Mixed type. For now we just coerce to float.
404 L = float(x) - float(interval)/2
405 cf = data.index(x) # Number of values below the median interval.
406 # FIXME The following line could be more efficient for big lists.
407 f = data.count(x) # Number of data points in the median interval.
408 return L + interval*(n/2 - cf)/f
409
410
411def mode(data):
412 """Return the most common data point from discrete or nominal data.
413
414 ``mode`` assumes discrete data, and returns a single value. This is the
415 standard treatment of the mode as commonly taught in schools:
416
417 >>> mode([1, 1, 2, 3, 3, 3, 3, 4])
418 3
419
420 This also works with nominal (non-numeric) data:
421
422 >>> mode(["red", "blue", "blue", "red", "green", "red", "red"])
423 'red'
424
425 If there is not exactly one most common value, ``mode`` will raise
426 StatisticsError.
427 """
428 # Generate a table of sorted (value, frequency) pairs.
429 table = _counts(data)
430 if len(table) == 1:
431 return table[0][0]
432 elif table:
433 raise StatisticsError(
434 'no unique mode; found %d equally common values' % len(table)
435 )
436 else:
437 raise StatisticsError('no mode for empty data')
438
439
440# === Measures of spread ===
441
442# See http://mathworld.wolfram.com/Variance.html
443# http://mathworld.wolfram.com/SampleVariance.html
444# http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
445#
446# Under no circumstances use the so-called "computational formula for
447# variance", as that is only suitable for hand calculations with a small
448# amount of low-precision data. It has terrible numeric properties.
449#
450# See a comparison of three computational methods here:
451# http://www.johndcook.com/blog/2008/09/26/comparing-three-methods-of-computing-standard-deviation/
452
453def _ss(data, c=None):
454 """Return sum of square deviations of sequence data.
455
456 If ``c`` is None, the mean is calculated in one pass, and the deviations
457 from the mean are calculated in a second pass. Otherwise, deviations are
458 calculated from ``c`` as given. Use the second case with care, as it can
459 lead to garbage results.
460 """
461 if c is None:
462 c = mean(data)
463 ss = _sum((x-c)**2 for x in data)
464 # The following sum should mathematically equal zero, but due to rounding
465 # error may not.
466 ss -= _sum((x-c) for x in data)**2/len(data)
467 assert not ss < 0, 'negative sum of square deviations: %f' % ss
468 return ss
469
470
471def variance(data, xbar=None):
472 """Return the sample variance of data.
473
474 data should be an iterable of Real-valued numbers, with at least two
475 values. The optional argument xbar, if given, should be the mean of
476 the data. If it is missing or None, the mean is automatically calculated.
477
478 Use this function when your data is a sample from a population. To
479 calculate the variance from the entire population, see ``pvariance``.
480
481 Examples:
482
483 >>> data = [2.75, 1.75, 1.25, 0.25, 0.5, 1.25, 3.5]
484 >>> variance(data)
485 1.3720238095238095
486
487 If you have already calculated the mean of your data, you can pass it as
488 the optional second argument ``xbar`` to avoid recalculating it:
489
490 >>> m = mean(data)
491 >>> variance(data, m)
492 1.3720238095238095
493
494 This function does not check that ``xbar`` is actually the mean of
495 ``data``. Giving arbitrary values for ``xbar`` may lead to invalid or
496 impossible results.
497
498 Decimals and Fractions are supported:
499
500 >>> from decimal import Decimal as D
501 >>> variance([D("27.5"), D("30.25"), D("30.25"), D("34.5"), D("41.75")])
502 Decimal('31.01875')
503
504 >>> from fractions import Fraction as F
505 >>> variance([F(1, 6), F(1, 2), F(5, 3)])
506 Fraction(67, 108)
507
508 """
509 if iter(data) is data:
510 data = list(data)
511 n = len(data)
512 if n < 2:
513 raise StatisticsError('variance requires at least two data points')
514 ss = _ss(data, xbar)
515 return ss/(n-1)
516
517
518def pvariance(data, mu=None):
519 """Return the population variance of ``data``.
520
521 data should be an iterable of Real-valued numbers, with at least one
522 value. The optional argument mu, if given, should be the mean of
523 the data. If it is missing or None, the mean is automatically calculated.
524
525 Use this function to calculate the variance from the entire population.
526 To estimate the variance from a sample, the ``variance`` function is
527 usually a better choice.
528
529 Examples:
530
531 >>> data = [0.0, 0.25, 0.25, 1.25, 1.5, 1.75, 2.75, 3.25]
532 >>> pvariance(data)
533 1.25
534
535 If you have already calculated the mean of the data, you can pass it as
536 the optional second argument to avoid recalculating it:
537
538 >>> mu = mean(data)
539 >>> pvariance(data, mu)
540 1.25
541
542 This function does not check that ``mu`` is actually the mean of ``data``.
543 Giving arbitrary values for ``mu`` may lead to invalid or impossible
544 results.
545
546 Decimals and Fractions are supported:
547
548 >>> from decimal import Decimal as D
549 >>> pvariance([D("27.5"), D("30.25"), D("30.25"), D("34.5"), D("41.75")])
550 Decimal('24.815')
551
552 >>> from fractions import Fraction as F
553 >>> pvariance([F(1, 4), F(5, 4), F(1, 2)])
554 Fraction(13, 72)
555
556 """
557 if iter(data) is data:
558 data = list(data)
559 n = len(data)
560 if n < 1:
561 raise StatisticsError('pvariance requires at least one data point')
562 ss = _ss(data, mu)
563 return ss/n
564
565
566def stdev(data, xbar=None):
567 """Return the square root of the sample variance.
568
569 See ``variance`` for arguments and other details.
570
571 >>> stdev([1.5, 2.5, 2.5, 2.75, 3.25, 4.75])
572 1.0810874155219827
573
574 """
575 var = variance(data, xbar)
576 try:
577 return var.sqrt()
578 except AttributeError:
579 return math.sqrt(var)
580
581
582def pstdev(data, mu=None):
583 """Return the square root of the population variance.
584
585 See ``pvariance`` for arguments and other details.
586
587 >>> pstdev([1.5, 2.5, 2.5, 2.75, 3.25, 4.75])
588 0.986893273527251
589
590 """
591 var = pvariance(data, mu)
592 try:
593 return var.sqrt()
594 except AttributeError:
595 return math.sqrt(var)