bpo-40541: Add optional *counts* parameter to random.sample() (GH-19970)

diff --git a/Lib/random.py b/Lib/random.py
index f2c4f39..75f70d5 100644
--- a/Lib/random.py
+++ b/Lib/random.py
@@ -331,7 +331,7 @@
                 j = _int(random() * (i+1))
                 x[i], x[j] = x[j], x[i]
 
-    def sample(self, population, k):
+    def sample(self, population, k, *, counts=None):
         """Chooses k unique random elements from a population sequence or set.
 
         Returns a new list containing elements from the population while
@@ -344,9 +344,21 @@
         population contains repeats, then each occurrence is a possible
         selection in the sample.
 
-        To choose a sample in a range of integers, use range as an argument.
-        This is especially fast and space efficient for sampling from a
-        large population:   sample(range(10000000), 60)
+        Repeated elements can be specified one at a time or with the optional
+        counts parameter.  For example:
+
+            sample(['red', 'blue'], counts=[4, 2], k=5)
+
+        is equivalent to:
+
+            sample(['red', 'red', 'red', 'red', 'blue', 'blue'], k=5)
+
+        To choose a sample from a range of integers, use range() for the
+        population argument.  This is especially fast and space efficient
+        for sampling from a large population:
+
+            sample(range(10000000), 60)
+
         """
 
         # Sampling without replacement entails tracking either potential
@@ -379,8 +391,20 @@
             population = tuple(population)
         if not isinstance(population, _Sequence):
             raise TypeError("Population must be a sequence.  For dicts or sets, use sorted(d).")
-        randbelow = self._randbelow
         n = len(population)
+        if counts is not None:
+            cum_counts = list(_accumulate(counts))
+            if len(cum_counts) != n:
+                raise ValueError('The number of counts does not match the population')
+            total = cum_counts.pop()
+            if not isinstance(total, int):
+                raise TypeError('Counts must be integers')
+            if total <= 0:
+                raise ValueError('Total of counts must be greater than zero')
+            selections = sample(range(total), k=k)
+            bisect = _bisect
+            return [population[bisect(cum_counts, s)] for s in selections]
+        randbelow = self._randbelow
         if not 0 <= k <= n:
             raise ValueError("Sample larger than population or is negative")
         result = [None] * k
diff --git a/Lib/test/test_random.py b/Lib/test/test_random.py
index bb95ca0..a3710f4 100644
--- a/Lib/test/test_random.py
+++ b/Lib/test/test_random.py
@@ -9,7 +9,7 @@
 from math import log, exp, pi, fsum, sin, factorial
 from test import support
 from fractions import Fraction
-
+from collections import Counter
 
 class TestBasicOps:
     # Superclass with tests common to all generators.
@@ -161,6 +161,77 @@
             population = {10, 20, 30, 40, 50, 60, 70}
             self.gen.sample(population, k=5)
 
+    def test_sample_with_counts(self):
+        sample = self.gen.sample
+
+        # General case
+        colors = ['red', 'green', 'blue', 'orange', 'black', 'brown', 'amber']
+        counts = [500,      200,     20,       10,       5,       0,       1 ]
+        k = 700
+        summary = Counter(sample(colors, counts=counts, k=k))
+        self.assertEqual(sum(summary.values()), k)
+        for color, weight in zip(colors, counts):
+            self.assertLessEqual(summary[color], weight)
+        self.assertNotIn('brown', summary)
+
+        # Case that exhausts the population
+        k = sum(counts)
+        summary = Counter(sample(colors, counts=counts, k=k))
+        self.assertEqual(sum(summary.values()), k)
+        for color, weight in zip(colors, counts):
+            self.assertLessEqual(summary[color], weight)
+        self.assertNotIn('brown', summary)
+
+        # Case with population size of 1
+        summary = Counter(sample(['x'], counts=[10], k=8))
+        self.assertEqual(summary, Counter(x=8))
+
+        # Case with all counts equal.
+        nc = len(colors)
+        summary = Counter(sample(colors, counts=[10]*nc, k=10*nc))
+        self.assertEqual(summary, Counter(10*colors))
+
+        # Test error handling
+        with self.assertRaises(TypeError):
+            sample(['red', 'green', 'blue'], counts=10, k=10)               # counts not iterable
+        with self.assertRaises(ValueError):
+            sample(['red', 'green', 'blue'], counts=[-3, -7, -8], k=2)      # counts are negative
+        with self.assertRaises(ValueError):
+            sample(['red', 'green', 'blue'], counts=[0, 0, 0], k=2)         # counts are zero
+        with self.assertRaises(ValueError):
+            sample(['red', 'green'], counts=[10, 10], k=21)                 # population too small
+        with self.assertRaises(ValueError):
+            sample(['red', 'green', 'blue'], counts=[1, 2], k=2)            # too few counts
+        with self.assertRaises(ValueError):
+            sample(['red', 'green', 'blue'], counts=[1, 2, 3, 4], k=2)      # too many counts
+
+    def test_sample_counts_equivalence(self):
+        # Test the documented strong equivalence to a sample with repeated elements.
+        # We run this test on random.Random() which makes deterministic selections
+        # for a given seed value.
+        sample = random.sample
+        seed = random.seed
+
+        colors =  ['red', 'green', 'blue', 'orange', 'black', 'amber']
+        counts = [500,      200,     20,       10,       5,       1 ]
+        k = 700
+        seed(8675309)
+        s1 = sample(colors, counts=counts, k=k)
+        seed(8675309)
+        expanded = [color for (color, count) in zip(colors, counts) for i in range(count)]
+        self.assertEqual(len(expanded), sum(counts))
+        s2 = sample(expanded, k=k)
+        self.assertEqual(s1, s2)
+
+        pop = 'abcdefghi'
+        counts = [10, 9, 8, 7, 6, 5, 4, 3, 2]
+        seed(8675309)
+        s1 = ''.join(sample(pop, counts=counts, k=30))
+        expanded = ''.join([letter for (letter, count) in zip(pop, counts) for i in range(count)])
+        seed(8675309)
+        s2 = ''.join(sample(expanded, k=30))
+        self.assertEqual(s1, s2)
+
     def test_choices(self):
         choices = self.gen.choices
         data = ['red', 'green', 'blue', 'yellow']