bpo-38490: statistics: Add covariance, Pearson's correlation, and simple linear regression (#16813)
Co-authored-by: Tymoteusz Wołodźko <twolodzko+gitkraken@gmail.com
diff --git a/Lib/test/test_statistics.py b/Lib/test/test_statistics.py
index 4b8686b..70d269d 100644
--- a/Lib/test/test_statistics.py
+++ b/Lib/test/test_statistics.py
@@ -2407,6 +2407,84 @@ def test_error_cases(self):
quantiles([10, None, 30], n=4) # data is non-numeric
+class TestBivariateStatistics(unittest.TestCase):
+
+ def test_unequal_size_error(self):
+ for x, y in [
+ ([1, 2, 3], [1, 2]),
+ ([1, 2], [1, 2, 3]),
+ ]:
+ with self.assertRaises(statistics.StatisticsError):
+ statistics.covariance(x, y)
+ with self.assertRaises(statistics.StatisticsError):
+ statistics.correlation(x, y)
+ with self.assertRaises(statistics.StatisticsError):
+ statistics.linear_regression(x, y)
+
+ def test_small_sample_error(self):
+ for x, y in [
+ ([], []),
+ ([], [1, 2,]),
+ ([1, 2,], []),
+ ([1,], [1,]),
+ ([1,], [1, 2,]),
+ ([1, 2,], [1,]),
+ ]:
+ with self.assertRaises(statistics.StatisticsError):
+ statistics.covariance(x, y)
+ with self.assertRaises(statistics.StatisticsError):
+ statistics.correlation(x, y)
+ with self.assertRaises(statistics.StatisticsError):
+ statistics.linear_regression(x, y)
+
+
+class TestCorrelationAndCovariance(unittest.TestCase):
+
+ def test_results(self):
+ for x, y, result in [
+ ([1, 2, 3], [1, 2, 3], 1),
+ ([1, 2, 3], [-1, -2, -3], -1),
+ ([1, 2, 3], [3, 2, 1], -1),
+ ([1, 2, 3], [1, 2, 1], 0),
+ ([1, 2, 3], [1, 3, 2], 0.5),
+ ]:
+ self.assertAlmostEqual(statistics.correlation(x, y), result)
+ self.assertAlmostEqual(statistics.covariance(x, y), result)
+
+ def test_different_scales(self):
+ x = [1, 2, 3]
+ y = [10, 30, 20]
+ self.assertAlmostEqual(statistics.correlation(x, y), 0.5)
+ self.assertAlmostEqual(statistics.covariance(x, y), 5)
+
+ y = [.1, .2, .3]
+ self.assertAlmostEqual(statistics.correlation(x, y), 1)
+ self.assertAlmostEqual(statistics.covariance(x, y), 0.1)
+
+
+class TestLinearRegression(unittest.TestCase):
+
+ def test_constant_input_error(self):
+ x = [1, 1, 1,]
+ y = [1, 2, 3,]
+ with self.assertRaises(statistics.StatisticsError):
+ statistics.linear_regression(x, y)
+
+ def test_results(self):
+ for x, y, true_intercept, true_slope in [
+ ([1, 2, 3], [0, 0, 0], 0, 0),
+ ([1, 2, 3], [1, 2, 3], 0, 1),
+ ([1, 2, 3], [100, 100, 100], 100, 0),
+ ([1, 2, 3], [12, 14, 16], 10, 2),
+ ([1, 2, 3], [-1, -2, -3], 0, -1),
+ ([1, 2, 3], [21, 22, 23], 20, 1),
+ ([1, 2, 3], [5.1, 5.2, 5.3], 5, 0.1),
+ ]:
+ intercept, slope = statistics.linear_regression(x, y)
+ self.assertAlmostEqual(intercept, true_intercept)
+ self.assertAlmostEqual(slope, true_slope)
+
+
class TestNormalDist:
# General note on precision: The pdf(), cdf(), and overlap() methods