blob: e1dfbd49317bddc373f5d0eb0fb75ffeb4eb4cea [file] [log] [blame]
Larry Hastingsf5e987b2013-10-19 11:50:09 -07001## Module statistics.py
2##
3## Copyright (c) 2013 Steven D'Aprano <steve+python@pearwood.info>.
4##
5## Licensed under the Apache License, Version 2.0 (the "License");
6## you may not use this file except in compliance with the License.
7## You may obtain a copy of the License at
8##
9## http://www.apache.org/licenses/LICENSE-2.0
10##
11## Unless required by applicable law or agreed to in writing, software
12## distributed under the License is distributed on an "AS IS" BASIS,
13## WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14## See the License for the specific language governing permissions and
15## limitations under the License.
16
17
18"""
19Basic statistics module.
20
21This module provides functions for calculating statistics of data, including
22averages, variance, and standard deviation.
23
24Calculating averages
25--------------------
26
27================== =============================================
28Function Description
29================== =============================================
30mean Arithmetic mean (average) of data.
31median Median (middle value) of data.
32median_low Low median of data.
33median_high High median of data.
34median_grouped Median, or 50th percentile, of grouped data.
35mode Mode (most common value) of data.
36================== =============================================
37
38Calculate the arithmetic mean ("the average") of data:
39
40>>> mean([-1.0, 2.5, 3.25, 5.75])
412.625
42
43
44Calculate the standard median of discrete data:
45
46>>> median([2, 3, 4, 5])
473.5
48
49
50Calculate the median, or 50th percentile, of data grouped into class intervals
51centred on the data values provided. E.g. if your data points are rounded to
52the nearest whole number:
53
54>>> median_grouped([2, 2, 3, 3, 3, 4]) #doctest: +ELLIPSIS
552.8333333333...
56
57This should be interpreted in this way: you have two data points in the class
58interval 1.5-2.5, three data points in the class interval 2.5-3.5, and one in
59the class interval 3.5-4.5. The median of these data points is 2.8333...
60
61
62Calculating variability or spread
63---------------------------------
64
65================== =============================================
66Function Description
67================== =============================================
68pvariance Population variance of data.
69variance Sample variance of data.
70pstdev Population standard deviation of data.
71stdev Sample standard deviation of data.
72================== =============================================
73
74Calculate the standard deviation of sample data:
75
76>>> stdev([2.5, 3.25, 5.5, 11.25, 11.75]) #doctest: +ELLIPSIS
774.38961843444...
78
79If you have previously calculated the mean, you can pass it as the optional
80second argument to the four "spread" functions to avoid recalculating it:
81
82>>> data = [1, 2, 2, 4, 4, 4, 5, 6]
83>>> mu = mean(data)
84>>> pvariance(data, mu)
852.5
86
87
88Exceptions
89----------
90
91A single exception is defined: StatisticsError is a subclass of ValueError.
92
93"""
94
95__all__ = [ 'StatisticsError',
96 'pstdev', 'pvariance', 'stdev', 'variance',
97 'median', 'median_low', 'median_high', 'median_grouped',
98 'mean', 'mode',
99 ]
100
101
102import collections
103import math
Larry Hastingsf5e987b2013-10-19 11:50:09 -0700104
105from fractions import Fraction
106from decimal import Decimal
107
108
109# === Exceptions ===
110
111class StatisticsError(ValueError):
112 pass
113
114
115# === Private utilities ===
116
117def _sum(data, start=0):
118 """_sum(data [, start]) -> value
119
120 Return a high-precision sum of the given numeric data. If optional
121 argument ``start`` is given, it is added to the total. If ``data`` is
122 empty, ``start`` (defaulting to 0) is returned.
123
124
125 Examples
126 --------
127
128 >>> _sum([3, 2.25, 4.5, -0.5, 1.0], 0.75)
129 11.0
130
131 Some sources of round-off error will be avoided:
132
133 >>> _sum([1e50, 1, -1e50] * 1000) # Built-in sum returns zero.
134 1000.0
135
136 Fractions and Decimals are also supported:
137
138 >>> from fractions import Fraction as F
139 >>> _sum([F(2, 3), F(7, 5), F(1, 4), F(5, 6)])
140 Fraction(63, 20)
141
142 >>> from decimal import Decimal as D
143 >>> data = [D("0.1375"), D("0.2108"), D("0.3061"), D("0.0419")]
144 >>> _sum(data)
145 Decimal('0.6963')
146
Nick Coghlan73afe2a2014-02-08 19:58:04 +1000147 Mixed types are currently treated as an error, except that int is
148 allowed.
Larry Hastingsf5e987b2013-10-19 11:50:09 -0700149 """
Nick Coghlan73afe2a2014-02-08 19:58:04 +1000150 # We fail as soon as we reach a value that is not an int or the type of
151 # the first value which is not an int. E.g. _sum([int, int, float, int])
152 # is okay, but sum([int, int, float, Fraction]) is not.
153 allowed_types = set([int, type(start)])
Larry Hastingsf5e987b2013-10-19 11:50:09 -0700154 n, d = _exact_ratio(start)
Larry Hastingsf5e987b2013-10-19 11:50:09 -0700155 partials = {d: n} # map {denominator: sum of numerators}
156 # Micro-optimizations.
Larry Hastingsf5e987b2013-10-19 11:50:09 -0700157 exact_ratio = _exact_ratio
158 partials_get = partials.get
Nick Coghlan73afe2a2014-02-08 19:58:04 +1000159 # Add numerators for each denominator.
Larry Hastingsf5e987b2013-10-19 11:50:09 -0700160 for x in data:
Nick Coghlan73afe2a2014-02-08 19:58:04 +1000161 _check_type(type(x), allowed_types)
Larry Hastingsf5e987b2013-10-19 11:50:09 -0700162 n, d = exact_ratio(x)
163 partials[d] = partials_get(d, 0) + n
Nick Coghlan73afe2a2014-02-08 19:58:04 +1000164 # Find the expected result type. If allowed_types has only one item, it
165 # will be int; if it has two, use the one which isn't int.
166 assert len(allowed_types) in (1, 2)
167 if len(allowed_types) == 1:
168 assert allowed_types.pop() is int
169 T = int
170 else:
171 T = (allowed_types - set([int])).pop()
Larry Hastingsf5e987b2013-10-19 11:50:09 -0700172 if None in partials:
173 assert issubclass(T, (float, Decimal))
174 assert not math.isfinite(partials[None])
175 return T(partials[None])
176 total = Fraction()
177 for d, n in sorted(partials.items()):
178 total += Fraction(n, d)
179 if issubclass(T, int):
180 assert total.denominator == 1
181 return T(total.numerator)
182 if issubclass(T, Decimal):
183 return T(total.numerator)/total.denominator
184 return T(total)
185
186
Nick Coghlan73afe2a2014-02-08 19:58:04 +1000187def _check_type(T, allowed):
188 if T not in allowed:
189 if len(allowed) == 1:
190 allowed.add(T)
191 else:
192 types = ', '.join([t.__name__ for t in allowed] + [T.__name__])
193 raise TypeError("unsupported mixed types: %s" % types)
194
195
Larry Hastingsf5e987b2013-10-19 11:50:09 -0700196def _exact_ratio(x):
197 """Convert Real number x exactly to (numerator, denominator) pair.
198
199 >>> _exact_ratio(0.25)
200 (1, 4)
201
202 x is expected to be an int, Fraction, Decimal or float.
203 """
204 try:
205 try:
206 # int, Fraction
207 return (x.numerator, x.denominator)
208 except AttributeError:
209 # float
210 try:
211 return x.as_integer_ratio()
212 except AttributeError:
213 # Decimal
214 try:
215 return _decimal_to_ratio(x)
216 except AttributeError:
217 msg = "can't convert type '{}' to numerator/denominator"
218 raise TypeError(msg.format(type(x).__name__)) from None
219 except (OverflowError, ValueError):
220 # INF or NAN
221 if __debug__:
222 # Decimal signalling NANs cannot be converted to float :-(
223 if isinstance(x, Decimal):
224 assert not x.is_finite()
225 else:
226 assert not math.isfinite(x)
227 return (x, None)
228
229
230# FIXME This is faster than Fraction.from_decimal, but still too slow.
231def _decimal_to_ratio(d):
232 """Convert Decimal d to exact integer ratio (numerator, denominator).
233
234 >>> from decimal import Decimal
235 >>> _decimal_to_ratio(Decimal("2.6"))
236 (26, 10)
237
238 """
239 sign, digits, exp = d.as_tuple()
240 if exp in ('F', 'n', 'N'): # INF, NAN, sNAN
241 assert not d.is_finite()
242 raise ValueError
243 num = 0
244 for digit in digits:
245 num = num*10 + digit
246 if sign:
247 num = -num
248 den = 10**-exp
249 return (num, den)
250
251
Larry Hastingsf5e987b2013-10-19 11:50:09 -0700252def _counts(data):
253 # Generate a table of sorted (value, frequency) pairs.
Nick Coghlanbfd68bf2014-02-08 19:44:16 +1000254 table = collections.Counter(iter(data)).most_common()
Larry Hastingsf5e987b2013-10-19 11:50:09 -0700255 if not table:
256 return table
257 # Extract the values with the highest frequency.
258 maxfreq = table[0][1]
259 for i in range(1, len(table)):
260 if table[i][1] != maxfreq:
261 table = table[:i]
262 break
263 return table
264
265
266# === Measures of central tendency (averages) ===
267
268def mean(data):
269 """Return the sample arithmetic mean of data.
270
271 >>> mean([1, 2, 3, 4, 4])
272 2.8
273
274 >>> from fractions import Fraction as F
275 >>> mean([F(3, 7), F(1, 21), F(5, 3), F(1, 3)])
276 Fraction(13, 21)
277
278 >>> from decimal import Decimal as D
279 >>> mean([D("0.5"), D("0.75"), D("0.625"), D("0.375")])
280 Decimal('0.5625')
281
282 If ``data`` is empty, StatisticsError will be raised.
283 """
284 if iter(data) is data:
285 data = list(data)
286 n = len(data)
287 if n < 1:
288 raise StatisticsError('mean requires at least one data point')
289 return _sum(data)/n
290
291
292# FIXME: investigate ways to calculate medians without sorting? Quickselect?
293def median(data):
294 """Return the median (middle value) of numeric data.
295
296 When the number of data points is odd, return the middle data point.
297 When the number of data points is even, the median is interpolated by
298 taking the average of the two middle values:
299
300 >>> median([1, 3, 5])
301 3
302 >>> median([1, 3, 5, 7])
303 4.0
304
305 """
306 data = sorted(data)
307 n = len(data)
308 if n == 0:
309 raise StatisticsError("no median for empty data")
310 if n%2 == 1:
311 return data[n//2]
312 else:
313 i = n//2
314 return (data[i - 1] + data[i])/2
315
316
317def median_low(data):
318 """Return the low median of numeric data.
319
320 When the number of data points is odd, the middle value is returned.
321 When it is even, the smaller of the two middle values is returned.
322
323 >>> median_low([1, 3, 5])
324 3
325 >>> median_low([1, 3, 5, 7])
326 3
327
328 """
329 data = sorted(data)
330 n = len(data)
331 if n == 0:
332 raise StatisticsError("no median for empty data")
333 if n%2 == 1:
334 return data[n//2]
335 else:
336 return data[n//2 - 1]
337
338
339def median_high(data):
340 """Return the high median of data.
341
342 When the number of data points is odd, the middle value is returned.
343 When it is even, the larger of the two middle values is returned.
344
345 >>> median_high([1, 3, 5])
346 3
347 >>> median_high([1, 3, 5, 7])
348 5
349
350 """
351 data = sorted(data)
352 n = len(data)
353 if n == 0:
354 raise StatisticsError("no median for empty data")
355 return data[n//2]
356
357
358def median_grouped(data, interval=1):
359 """"Return the 50th percentile (median) of grouped continuous data.
360
361 >>> median_grouped([1, 2, 2, 3, 4, 4, 4, 4, 4, 5])
362 3.7
363 >>> median_grouped([52, 52, 53, 54])
364 52.5
365
366 This calculates the median as the 50th percentile, and should be
367 used when your data is continuous and grouped. In the above example,
368 the values 1, 2, 3, etc. actually represent the midpoint of classes
369 0.5-1.5, 1.5-2.5, 2.5-3.5, etc. The middle value falls somewhere in
370 class 3.5-4.5, and interpolation is used to estimate it.
371
372 Optional argument ``interval`` represents the class interval, and
373 defaults to 1. Changing the class interval naturally will change the
374 interpolated 50th percentile value:
375
376 >>> median_grouped([1, 3, 3, 5, 7], interval=1)
377 3.25
378 >>> median_grouped([1, 3, 3, 5, 7], interval=2)
379 3.5
380
381 This function does not check whether the data points are at least
382 ``interval`` apart.
383 """
384 data = sorted(data)
385 n = len(data)
386 if n == 0:
387 raise StatisticsError("no median for empty data")
388 elif n == 1:
389 return data[0]
390 # Find the value at the midpoint. Remember this corresponds to the
391 # centre of the class interval.
392 x = data[n//2]
393 for obj in (x, interval):
394 if isinstance(obj, (str, bytes)):
395 raise TypeError('expected number but got %r' % obj)
396 try:
397 L = x - interval/2 # The lower limit of the median interval.
398 except TypeError:
399 # Mixed type. For now we just coerce to float.
400 L = float(x) - float(interval)/2
401 cf = data.index(x) # Number of values below the median interval.
402 # FIXME The following line could be more efficient for big lists.
403 f = data.count(x) # Number of data points in the median interval.
404 return L + interval*(n/2 - cf)/f
405
406
407def mode(data):
408 """Return the most common data point from discrete or nominal data.
409
410 ``mode`` assumes discrete data, and returns a single value. This is the
411 standard treatment of the mode as commonly taught in schools:
412
413 >>> mode([1, 1, 2, 3, 3, 3, 3, 4])
414 3
415
416 This also works with nominal (non-numeric) data:
417
418 >>> mode(["red", "blue", "blue", "red", "green", "red", "red"])
419 'red'
420
421 If there is not exactly one most common value, ``mode`` will raise
422 StatisticsError.
423 """
424 # Generate a table of sorted (value, frequency) pairs.
425 table = _counts(data)
426 if len(table) == 1:
427 return table[0][0]
428 elif table:
429 raise StatisticsError(
430 'no unique mode; found %d equally common values' % len(table)
431 )
432 else:
433 raise StatisticsError('no mode for empty data')
434
435
436# === Measures of spread ===
437
438# See http://mathworld.wolfram.com/Variance.html
439# http://mathworld.wolfram.com/SampleVariance.html
440# http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
441#
442# Under no circumstances use the so-called "computational formula for
443# variance", as that is only suitable for hand calculations with a small
444# amount of low-precision data. It has terrible numeric properties.
445#
446# See a comparison of three computational methods here:
447# http://www.johndcook.com/blog/2008/09/26/comparing-three-methods-of-computing-standard-deviation/
448
449def _ss(data, c=None):
450 """Return sum of square deviations of sequence data.
451
452 If ``c`` is None, the mean is calculated in one pass, and the deviations
453 from the mean are calculated in a second pass. Otherwise, deviations are
454 calculated from ``c`` as given. Use the second case with care, as it can
455 lead to garbage results.
456 """
457 if c is None:
458 c = mean(data)
459 ss = _sum((x-c)**2 for x in data)
460 # The following sum should mathematically equal zero, but due to rounding
461 # error may not.
462 ss -= _sum((x-c) for x in data)**2/len(data)
463 assert not ss < 0, 'negative sum of square deviations: %f' % ss
464 return ss
465
466
467def variance(data, xbar=None):
468 """Return the sample variance of data.
469
470 data should be an iterable of Real-valued numbers, with at least two
471 values. The optional argument xbar, if given, should be the mean of
472 the data. If it is missing or None, the mean is automatically calculated.
473
474 Use this function when your data is a sample from a population. To
475 calculate the variance from the entire population, see ``pvariance``.
476
477 Examples:
478
479 >>> data = [2.75, 1.75, 1.25, 0.25, 0.5, 1.25, 3.5]
480 >>> variance(data)
481 1.3720238095238095
482
483 If you have already calculated the mean of your data, you can pass it as
484 the optional second argument ``xbar`` to avoid recalculating it:
485
486 >>> m = mean(data)
487 >>> variance(data, m)
488 1.3720238095238095
489
490 This function does not check that ``xbar`` is actually the mean of
491 ``data``. Giving arbitrary values for ``xbar`` may lead to invalid or
492 impossible results.
493
494 Decimals and Fractions are supported:
495
496 >>> from decimal import Decimal as D
497 >>> variance([D("27.5"), D("30.25"), D("30.25"), D("34.5"), D("41.75")])
498 Decimal('31.01875')
499
500 >>> from fractions import Fraction as F
501 >>> variance([F(1, 6), F(1, 2), F(5, 3)])
502 Fraction(67, 108)
503
504 """
505 if iter(data) is data:
506 data = list(data)
507 n = len(data)
508 if n < 2:
509 raise StatisticsError('variance requires at least two data points')
510 ss = _ss(data, xbar)
511 return ss/(n-1)
512
513
514def pvariance(data, mu=None):
515 """Return the population variance of ``data``.
516
517 data should be an iterable of Real-valued numbers, with at least one
518 value. The optional argument mu, if given, should be the mean of
519 the data. If it is missing or None, the mean is automatically calculated.
520
521 Use this function to calculate the variance from the entire population.
522 To estimate the variance from a sample, the ``variance`` function is
523 usually a better choice.
524
525 Examples:
526
527 >>> data = [0.0, 0.25, 0.25, 1.25, 1.5, 1.75, 2.75, 3.25]
528 >>> pvariance(data)
529 1.25
530
531 If you have already calculated the mean of the data, you can pass it as
532 the optional second argument to avoid recalculating it:
533
534 >>> mu = mean(data)
535 >>> pvariance(data, mu)
536 1.25
537
538 This function does not check that ``mu`` is actually the mean of ``data``.
539 Giving arbitrary values for ``mu`` may lead to invalid or impossible
540 results.
541
542 Decimals and Fractions are supported:
543
544 >>> from decimal import Decimal as D
545 >>> pvariance([D("27.5"), D("30.25"), D("30.25"), D("34.5"), D("41.75")])
546 Decimal('24.815')
547
548 >>> from fractions import Fraction as F
549 >>> pvariance([F(1, 4), F(5, 4), F(1, 2)])
550 Fraction(13, 72)
551
552 """
553 if iter(data) is data:
554 data = list(data)
555 n = len(data)
556 if n < 1:
557 raise StatisticsError('pvariance requires at least one data point')
558 ss = _ss(data, mu)
559 return ss/n
560
561
562def stdev(data, xbar=None):
563 """Return the square root of the sample variance.
564
565 See ``variance`` for arguments and other details.
566
567 >>> stdev([1.5, 2.5, 2.5, 2.75, 3.25, 4.75])
568 1.0810874155219827
569
570 """
571 var = variance(data, xbar)
572 try:
573 return var.sqrt()
574 except AttributeError:
575 return math.sqrt(var)
576
577
578def pstdev(data, mu=None):
579 """Return the square root of the population variance.
580
581 See ``pvariance`` for arguments and other details.
582
583 >>> pstdev([1.5, 2.5, 2.5, 2.75, 3.25, 4.75])
584 0.986893273527251
585
586 """
587 var = pvariance(data, mu)
588 try:
589 return var.sqrt()
590 except AttributeError:
591 return math.sqrt(var)