| // Copyright 2016 Google Inc. All rights reserved. |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| |
| #include "weighted_reservoir_sampler.h" |
| |
| #include "gtest/gtest.h" |
| |
| using testing::TestWithParam; |
| using testing::ValuesIn; |
| using testing::Combine; |
| using testing::Range; |
| |
| namespace protobuf_mutator { |
| |
| class WeightedReservoirSamplerTest |
| : public TestWithParam<std::tuple<int, std::vector<int>>> {}; |
| |
| const int kRuns = 1000000; |
| |
| const std::vector<int> kTests[] = { |
| {1}, |
| {1, 1, 1}, |
| {1, 1, 0}, |
| {1, 10, 100}, |
| {100, 1, 10}, |
| {1, 10000, 10000}, |
| {1, 3, 7, 100, 105}, |
| {93519, 52999, 354, 37837, 55285, 31787, 89096, 55695, 1587, |
| 18233, 77557, 67632, 59348, 51250, 17417, 96856, 78568, 44296, |
| 70170, 41328, 9206, 90187, 54086, 35602, 53167, 33791, 60118, |
| 52962, 10327, 80513, 49526, 18326, 83662, 49644, 70903, 4910, |
| 36309, 19196, 42982, 53316, 14773, 86607, 60835}}; |
| |
| INSTANTIATE_TEST_CASE_P(AllTest, WeightedReservoirSamplerTest, |
| Combine(Range(1, 10, 3), ValuesIn(kTests))); |
| |
| TEST_P(WeightedReservoirSamplerTest, Test) { |
| std::vector<int> weights = std::get<1>(GetParam()); |
| std::vector<int> counts(weights.size(), 0); |
| |
| using RandomEngine = std::mt19937; |
| RandomEngine rand(std::get<0>(GetParam())); |
| for (int i = 0; i < kRuns; ++i) { |
| WeightedReservoirSampler<int, RandomEngine> sampler(&rand); |
| for (size_t j = 0; j < weights.size(); ++j) sampler.Try(weights[j], j); |
| ++counts[sampler.selected()]; |
| } |
| |
| int sum = std::accumulate(weights.begin(), weights.end(), 0); |
| for (size_t j = 0; j < weights.size(); ++j) { |
| float expected = weights[j]; |
| expected /= sum; |
| |
| float actual = counts[j]; |
| actual /= kRuns; |
| |
| EXPECT_NEAR(expected, actual, 0.01); |
| } |
| } |
| |
| } // namespace protobuf_mutator |