blob: 033ea3e77fdcc11aecc19dfec07a296fc57258af [file] [log] [blame]
Alessio Bazzicaf2255012018-04-27 16:44:11 +02001/*
2 * Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
3 *
4 * Use of this source code is governed by a BSD-style license
5 * that can be found in the LICENSE file in the root of the source
6 * tree. An additional intellectual property rights grant can be found
7 * in the file PATENTS. All contributing project authors may
8 * be found in the AUTHORS file in the root of the source tree.
9 */
10
11#include "modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h"
Alex Loiko0520b0e2018-05-08 13:11:12 +020012#include "common_audio/real_fourier.h"
Alessio Bazzicaf2255012018-04-27 16:44:11 +020013
14#include <array>
15#include <tuple>
16
17#include "modules/audio_processing/agc2/rnn_vad/test_utils.h"
18// TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
19// #include "test/fpe_observer.h"
20#include "test/gtest.h"
21
22namespace webrtc {
23namespace rnn_vad {
24namespace test {
25namespace {
26
Alessio Bazzicac25fa892019-01-14 13:54:57 +010027constexpr std::array<int, 2> kTestPitchPeriods = {
28 3 * kMinPitch48kHz / 2,
29 (3 * kMinPitch48kHz + kMaxPitch48kHz) / 2,
Alessio Bazzicaf2255012018-04-27 16:44:11 +020030};
31constexpr std::array<float, 2> kTestPitchGains = {0.35f, 0.75f};
32
Alessio Bazzica2f1e6d42018-05-15 15:52:38 +020033constexpr size_t kNumPitchBufSquareEnergies = 385;
34constexpr size_t kNumPitchBufAutoCorrCoeffs = 147;
35constexpr size_t kTestDataSize =
36 kBufSize24kHz + kNumPitchBufSquareEnergies + kNumPitchBufAutoCorrCoeffs;
37
38class TestData {
39 public:
40 TestData() {
41 auto test_data_reader = CreatePitchSearchTestDataReader();
42 test_data_reader->ReadChunk(test_data_);
43 }
44 rtc::ArrayView<const float, kBufSize24kHz> GetPitchBufView() {
45 return {test_data_.data(), kBufSize24kHz};
46 }
47 rtc::ArrayView<const float, kNumPitchBufSquareEnergies>
48 GetPitchBufSquareEnergiesView() {
49 return {test_data_.data() + kBufSize24kHz, kNumPitchBufSquareEnergies};
50 }
51 rtc::ArrayView<const float, kNumPitchBufAutoCorrCoeffs>
52 GetPitchBufAutoCorrCoeffsView() {
53 return {test_data_.data() + kBufSize24kHz + kNumPitchBufSquareEnergies,
54 kNumPitchBufAutoCorrCoeffs};
55 }
56
57 private:
58 std::array<float, kTestDataSize> test_data_;
59};
60
Alessio Bazzicaf2255012018-04-27 16:44:11 +020061} // namespace
62
63class ComputePitchGainThresholdTest
64 : public testing::Test,
65 public ::testing::WithParamInterface<
66 std::tuple<size_t, size_t, size_t, float, size_t, float, float>> {};
67
68TEST_P(ComputePitchGainThresholdTest, BitExactness) {
69 const auto params = GetParam();
70 const size_t candidate_pitch_period = std::get<0>(params);
71 const size_t pitch_period_ratio = std::get<1>(params);
72 const size_t initial_pitch_period = std::get<2>(params);
73 const float initial_pitch_gain = std::get<3>(params);
74 const size_t prev_pitch_period = std::get<4>(params);
75 const size_t prev_pitch_gain = std::get<5>(params);
76 const float threshold = std::get<6>(params);
77
78 {
79 // TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
80 // FloatingPointExceptionObserver fpe_observer;
81
82 EXPECT_NEAR(
83 threshold,
84 ComputePitchGainThreshold(candidate_pitch_period, pitch_period_ratio,
85 initial_pitch_period, initial_pitch_gain,
86 prev_pitch_period, prev_pitch_gain),
87 3e-6f);
88 }
89}
90
91INSTANTIATE_TEST_CASE_P(
92 RnnVadTest,
93 ComputePitchGainThresholdTest,
94 ::testing::Values(
95 std::make_tuple(31, 7, 219, 0.45649201f, 199, 0.604747f, 0.40000001f),
96 std::make_tuple(113,
97 2,
98 226,
99 0.20967799f,
100 219,
101 0.40392199f,
102 0.30000001f),
103 std::make_tuple(63, 2, 126, 0.210788f, 364, 0.098519f, 0.40000001f),
104 std::make_tuple(30, 5, 152, 0.82356697f, 149, 0.55535901f, 0.700032f),
105 std::make_tuple(76, 2, 151, 0.79522997f, 151, 0.82356697f, 0.675946f),
106 std::make_tuple(31, 5, 153, 0.85069299f, 150, 0.79073799f, 0.72308898f),
107 std::make_tuple(78, 2, 156, 0.72750503f, 153, 0.85069299f, 0.618379f)));
108
109TEST(RnnVadTest, ComputeSlidingFrameSquareEnergiesBitExactness) {
Alessio Bazzica2f1e6d42018-05-15 15:52:38 +0200110 TestData test_data;
111 std::array<float, kNumPitchBufSquareEnergies> computed_output;
Alessio Bazzicaf2255012018-04-27 16:44:11 +0200112 {
113 // TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
114 // FloatingPointExceptionObserver fpe_observer;
Alessio Bazzica2f1e6d42018-05-15 15:52:38 +0200115 ComputeSlidingFrameSquareEnergies(test_data.GetPitchBufView(),
116 computed_output);
Alessio Bazzicaf2255012018-04-27 16:44:11 +0200117 }
Alessio Bazzica2f1e6d42018-05-15 15:52:38 +0200118 auto square_energies_view = test_data.GetPitchBufSquareEnergiesView();
119 ExpectNearAbsolute({square_energies_view.data(), square_energies_view.size()},
120 computed_output, 3e-2f);
Alessio Bazzicaf2255012018-04-27 16:44:11 +0200121}
122
123TEST(RnnVadTest, ComputePitchAutoCorrelationBitExactness) {
Alessio Bazzica2f1e6d42018-05-15 15:52:38 +0200124 TestData test_data;
Alessio Bazzicaf2255012018-04-27 16:44:11 +0200125 std::array<float, kBufSize12kHz> pitch_buf_decimated;
Alessio Bazzica2f1e6d42018-05-15 15:52:38 +0200126 Decimate2x(test_data.GetPitchBufView(), pitch_buf_decimated);
127 std::array<float, kNumPitchBufAutoCorrCoeffs> computed_output;
Alessio Bazzicaf2255012018-04-27 16:44:11 +0200128 {
129 // TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
130 // FloatingPointExceptionObserver fpe_observer;
Alex Loiko0520b0e2018-05-08 13:11:12 +0200131 std::unique_ptr<RealFourier> fft =
132 RealFourier::Create(kAutoCorrelationFftOrder);
Alessio Bazzica2f1e6d42018-05-15 15:52:38 +0200133 ComputePitchAutoCorrelation(pitch_buf_decimated, kMaxPitch12kHz,
134 computed_output, fft.get());
Alessio Bazzicaf2255012018-04-27 16:44:11 +0200135 }
Alessio Bazzica2f1e6d42018-05-15 15:52:38 +0200136 auto auto_corr_view = test_data.GetPitchBufAutoCorrCoeffsView();
137 ExpectNearAbsolute({auto_corr_view.data(), auto_corr_view.size()},
138 computed_output, 3e-3f);
Alessio Bazzicaf2255012018-04-27 16:44:11 +0200139}
140
Alex Loiko0520b0e2018-05-08 13:11:12 +0200141// Check that the auto correlation function computes the right thing for a
142// simple use case.
143TEST(RnnVadTest, ComputePitchAutoCorrelationConstantBuffer) {
144 // Create constant signal with no pitch.
145 std::array<float, kBufSize12kHz> pitch_buf_decimated;
146 std::fill(pitch_buf_decimated.begin(), pitch_buf_decimated.end(), 1.f);
147
Alessio Bazzica2f1e6d42018-05-15 15:52:38 +0200148 std::array<float, kNumPitchBufAutoCorrCoeffs> computed_output;
Alex Loiko0520b0e2018-05-08 13:11:12 +0200149 {
150 // TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
151 // FloatingPointExceptionObserver fpe_observer;
Alex Loiko0520b0e2018-05-08 13:11:12 +0200152 std::unique_ptr<RealFourier> fft =
153 RealFourier::Create(kAutoCorrelationFftOrder);
Alessio Bazzica2f1e6d42018-05-15 15:52:38 +0200154 ComputePitchAutoCorrelation(pitch_buf_decimated, kMaxPitch12kHz,
155 computed_output, fft.get());
Alex Loiko0520b0e2018-05-08 13:11:12 +0200156 }
157
158 // The expected output is constantly the length of the fixed 'x'
159 // array in ComputePitchAutoCorrelation.
Alessio Bazzica2f1e6d42018-05-15 15:52:38 +0200160 std::array<float, kNumPitchBufAutoCorrCoeffs> expected_output;
Alex Loiko0520b0e2018-05-08 13:11:12 +0200161 std::fill(expected_output.begin(), expected_output.end(),
162 kBufSize12kHz - kMaxPitch12kHz);
Alessio Bazzica2f1e6d42018-05-15 15:52:38 +0200163 ExpectNearAbsolute(expected_output, computed_output, 4e-5f);
Alex Loiko0520b0e2018-05-08 13:11:12 +0200164}
165
Alessio Bazzicaf2255012018-04-27 16:44:11 +0200166TEST(RnnVadTest, FindBestPitchPeriodsBitExactness) {
Alessio Bazzica2f1e6d42018-05-15 15:52:38 +0200167 TestData test_data;
Alessio Bazzicaf2255012018-04-27 16:44:11 +0200168 std::array<float, kBufSize12kHz> pitch_buf_decimated;
Alessio Bazzica2f1e6d42018-05-15 15:52:38 +0200169 Decimate2x(test_data.GetPitchBufView(), pitch_buf_decimated);
Alessio Bazzicaf2255012018-04-27 16:44:11 +0200170 std::array<size_t, 2> pitch_candidates_inv_lags;
171 {
172 // TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
173 // FloatingPointExceptionObserver fpe_observer;
Alessio Bazzica2f1e6d42018-05-15 15:52:38 +0200174 auto auto_corr_view = test_data.GetPitchBufAutoCorrCoeffsView();
175 pitch_candidates_inv_lags =
176 FindBestPitchPeriods({auto_corr_view.data(), auto_corr_view.size()},
177 pitch_buf_decimated, kMaxPitch12kHz);
Alessio Bazzicaf2255012018-04-27 16:44:11 +0200178 }
179 const std::array<size_t, 2> expected_output = {140, 142};
180 EXPECT_EQ(expected_output, pitch_candidates_inv_lags);
181}
182
183TEST(RnnVadTest, RefinePitchPeriod48kHzBitExactness) {
Alessio Bazzica2f1e6d42018-05-15 15:52:38 +0200184 TestData test_data;
Alessio Bazzicaf2255012018-04-27 16:44:11 +0200185 std::array<float, kBufSize12kHz> pitch_buf_decimated;
Alessio Bazzica2f1e6d42018-05-15 15:52:38 +0200186 Decimate2x(test_data.GetPitchBufView(), pitch_buf_decimated);
Alessio Bazzicaf2255012018-04-27 16:44:11 +0200187 size_t pitch_inv_lag;
188 {
189 // TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
190 // FloatingPointExceptionObserver fpe_observer;
Alessio Bazzicaf2255012018-04-27 16:44:11 +0200191 const std::array<size_t, 2> pitch_candidates_inv_lags = {280, 284};
Alessio Bazzica2f1e6d42018-05-15 15:52:38 +0200192 pitch_inv_lag = RefinePitchPeriod48kHz(test_data.GetPitchBufView(),
193 pitch_candidates_inv_lags);
Alessio Bazzicaf2255012018-04-27 16:44:11 +0200194 }
195 EXPECT_EQ(560u, pitch_inv_lag);
196}
197
198class CheckLowerPitchPeriodsAndComputePitchGainTest
199 : public testing::Test,
200 public ::testing::WithParamInterface<
Alessio Bazzicac25fa892019-01-14 13:54:57 +0100201 std::tuple<int, int, float, int, float>> {};
Alessio Bazzicaf2255012018-04-27 16:44:11 +0200202
203TEST_P(CheckLowerPitchPeriodsAndComputePitchGainTest, BitExactness) {
204 const auto params = GetParam();
Alessio Bazzicac25fa892019-01-14 13:54:57 +0100205 const int initial_pitch_period = std::get<0>(params);
206 const int prev_pitch_period = std::get<1>(params);
Alessio Bazzicaf2255012018-04-27 16:44:11 +0200207 const float prev_pitch_gain = std::get<2>(params);
Alessio Bazzicac25fa892019-01-14 13:54:57 +0100208 const int expected_pitch_period = std::get<3>(params);
Alessio Bazzicaf2255012018-04-27 16:44:11 +0200209 const float expected_pitch_gain = std::get<4>(params);
Alessio Bazzica2f1e6d42018-05-15 15:52:38 +0200210 TestData test_data;
Alessio Bazzicaf2255012018-04-27 16:44:11 +0200211 {
212 // TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
213 // FloatingPointExceptionObserver fpe_observer;
Alessio Bazzicaf2255012018-04-27 16:44:11 +0200214 const auto computed_output = CheckLowerPitchPeriodsAndComputePitchGain(
Alessio Bazzica2f1e6d42018-05-15 15:52:38 +0200215 test_data.GetPitchBufView(), initial_pitch_period,
216 {prev_pitch_period, prev_pitch_gain});
Alessio Bazzicaf2255012018-04-27 16:44:11 +0200217 EXPECT_EQ(expected_pitch_period, computed_output.period);
218 EXPECT_NEAR(expected_pitch_gain, computed_output.gain, 1e-6f);
219 }
220}
221
222INSTANTIATE_TEST_CASE_P(RnnVadTest,
223 CheckLowerPitchPeriodsAndComputePitchGainTest,
224 ::testing::Values(std::make_tuple(kTestPitchPeriods[0],
225 kTestPitchPeriods[0],
226 kTestPitchGains[0],
227 91,
228 -0.0188608f),
229 std::make_tuple(kTestPitchPeriods[0],
230 kTestPitchPeriods[0],
231 kTestPitchGains[1],
232 91,
233 -0.0188608f),
234 std::make_tuple(kTestPitchPeriods[0],
235 kTestPitchPeriods[1],
236 kTestPitchGains[0],
237 91,
238 -0.0188608f),
239 std::make_tuple(kTestPitchPeriods[0],
240 kTestPitchPeriods[1],
241 kTestPitchGains[1],
242 91,
243 -0.0188608f),
244 std::make_tuple(kTestPitchPeriods[1],
245 kTestPitchPeriods[0],
246 kTestPitchGains[0],
247 475,
248 -0.0904344f),
249 std::make_tuple(kTestPitchPeriods[1],
250 kTestPitchPeriods[0],
251 kTestPitchGains[1],
252 475,
253 -0.0904344f),
254 std::make_tuple(kTestPitchPeriods[1],
255 kTestPitchPeriods[1],
256 kTestPitchGains[0],
257 475,
258 -0.0904344f),
259 std::make_tuple(kTestPitchPeriods[1],
260 kTestPitchPeriods[1],
261 kTestPitchGains[1],
262 475,
263 -0.0904344f)));
264
265} // namespace test
266} // namespace rnn_vad
267} // namespace webrtc