bpo-42944 Fix Random.sample when counts is not None (GH-24235)
diff --git a/Lib/test/test_random.py b/Lib/test/test_random.py
index 41a26e3..35ae4e6 100644
--- a/Lib/test/test_random.py
+++ b/Lib/test/test_random.py
@@ -223,33 +223,6 @@ def test_sample_with_counts(self):
with self.assertRaises(ValueError):
sample(['red', 'green', 'blue'], counts=[1, 2, 3, 4], k=2) # too many counts
- def test_sample_counts_equivalence(self):
- # Test the documented strong equivalence to a sample with repeated elements.
- # We run this test on random.Random() which makes deterministic selections
- # for a given seed value.
- sample = random.sample
- seed = random.seed
-
- colors = ['red', 'green', 'blue', 'orange', 'black', 'amber']
- counts = [500, 200, 20, 10, 5, 1 ]
- k = 700
- seed(8675309)
- s1 = sample(colors, counts=counts, k=k)
- seed(8675309)
- expanded = [color for (color, count) in zip(colors, counts) for i in range(count)]
- self.assertEqual(len(expanded), sum(counts))
- s2 = sample(expanded, k=k)
- self.assertEqual(s1, s2)
-
- pop = 'abcdefghi'
- counts = [10, 9, 8, 7, 6, 5, 4, 3, 2]
- seed(8675309)
- s1 = ''.join(sample(pop, counts=counts, k=30))
- expanded = ''.join([letter for (letter, count) in zip(pop, counts) for i in range(count)])
- seed(8675309)
- s2 = ''.join(sample(expanded, k=30))
- self.assertEqual(s1, s2)
-
def test_choices(self):
choices = self.gen.choices
data = ['red', 'green', 'blue', 'yellow']
@@ -957,6 +930,33 @@ def test_randbytes_getrandbits(self):
self.assertEqual(self.gen.randbytes(n),
gen2.getrandbits(n * 8).to_bytes(n, 'little'))
+ def test_sample_counts_equivalence(self):
+ # Test the documented strong equivalence to a sample with repeated elements.
+ # We run this test on random.Random() which makes deterministic selections
+ # for a given seed value.
+ sample = self.gen.sample
+ seed = self.gen.seed
+
+ colors = ['red', 'green', 'blue', 'orange', 'black', 'amber']
+ counts = [500, 200, 20, 10, 5, 1 ]
+ k = 700
+ seed(8675309)
+ s1 = sample(colors, counts=counts, k=k)
+ seed(8675309)
+ expanded = [color for (color, count) in zip(colors, counts) for i in range(count)]
+ self.assertEqual(len(expanded), sum(counts))
+ s2 = sample(expanded, k=k)
+ self.assertEqual(s1, s2)
+
+ pop = 'abcdefghi'
+ counts = [10, 9, 8, 7, 6, 5, 4, 3, 2]
+ seed(8675309)
+ s1 = ''.join(sample(pop, counts=counts, k=30))
+ expanded = ''.join([letter for (letter, count) in zip(pop, counts) for i in range(count)])
+ seed(8675309)
+ s2 = ''.join(sample(expanded, k=30))
+ self.assertEqual(s1, s2)
+
def gamma(z, sqrt2pi=(2.0*pi)**0.5):
# Reflection to right half of complex plane