RNN VAD: clean-up unit tests
- add test that checks that the computed VAD probability is within
tolerance *1
- speed-up some tests by reducing the input length and skipping frames
- remove unused code in test_utils
- fix some comments
*1: RnnVadTest::RnnBitExactness is replaced by
RnnVadTest::RnnVadProbabilityWithinTolerance
Bug: webrtc:10480
Change-Id: I19332d06eacffbbe671bf7749ff4c92798bdc55c
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/133910
Reviewed-by: Alex Loiko <aleloi@webrtc.org>
Commit-Queue: Alessio Bazzica <alessiob@webrtc.org>
Cr-Commit-Position: refs/heads/master@{#27803}
diff --git a/modules/audio_processing/agc2/rnn_vad/BUILD.gn b/modules/audio_processing/agc2/rnn_vad/BUILD.gn
index 902082e..7cf6e3d 100644
--- a/modules/audio_processing/agc2/rnn_vad/BUILD.gn
+++ b/modules/audio_processing/agc2/rnn_vad/BUILD.gn
@@ -64,12 +64,10 @@
unittest_resources = [
"../../../../resources/audio_processing/agc2/rnn_vad/band_energies.dat",
- "../../../../resources/audio_processing/agc2/rnn_vad/fft.dat",
"../../../../resources/audio_processing/agc2/rnn_vad/pitch_buf_24k.dat",
"../../../../resources/audio_processing/agc2/rnn_vad/pitch_search_int.dat",
"../../../../resources/audio_processing/agc2/rnn_vad/pitch_lp_res.dat",
"../../../../resources/audio_processing/agc2/rnn_vad/samples.pcm",
- "../../../../resources/audio_processing/agc2/rnn_vad/sil_features.dat",
"../../../../resources/audio_processing/agc2/rnn_vad/vad_prob.dat",
]
diff --git a/modules/audio_processing/agc2/rnn_vad/auto_correlation_unittest.cc b/modules/audio_processing/agc2/rnn_vad/auto_correlation_unittest.cc
index a5e456a..f66c0b2 100644
--- a/modules/audio_processing/agc2/rnn_vad/auto_correlation_unittest.cc
+++ b/modules/audio_processing/agc2/rnn_vad/auto_correlation_unittest.cc
@@ -10,6 +10,7 @@
#include "modules/audio_processing/agc2/rnn_vad/auto_correlation.h"
+#include "modules/audio_processing/agc2/rnn_vad/common.h"
#include "modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h"
#include "modules/audio_processing/agc2/rnn_vad/test_utils.h"
#include "test/gtest.h"
@@ -18,6 +19,8 @@
namespace rnn_vad {
namespace test {
+// Checks that the auto correlation function produces output within tolerance
+// given test input data.
TEST(RnnVadTest, PitchBufferAutoCorrelationWithinTolerance) {
PitchTestData test_data;
std::array<float, kBufSize12kHz> pitch_buf_decimated;
@@ -35,7 +38,7 @@
computed_output, 3e-3f);
}
-// Check that the auto correlation function computes the right thing for a
+// Checks that the auto correlation function computes the right thing for a
// simple use case.
TEST(RnnVadTest, CheckAutoCorrelationOnConstantPitchBuffer) {
// Create constant signal with no pitch.
@@ -49,11 +52,12 @@
auto_corr_calculator.ComputeOnPitchBuffer(pitch_buf_decimated,
computed_output);
}
- // The expected output is constantly the length of the fixed 'x'
- // array in ComputePitchAutoCorrelation.
+ // The expected output is a vector filled with the same expected
+ // auto-correlation value. The latter equals the length of a 20 ms frame.
+ constexpr size_t kFrameSize20ms12kHz = kFrameSize20ms24kHz / 2;
std::array<float, kNumPitchBufAutoCorrCoeffs> expected_output;
std::fill(expected_output.begin(), expected_output.end(),
- kBufSize12kHz - kMaxPitch12kHz);
+ static_cast<float>(kFrameSize20ms12kHz));
ExpectNearAbsolute(expected_output, computed_output, 4e-5f);
}
diff --git a/modules/audio_processing/agc2/rnn_vad/lp_residual_unittest.cc b/modules/audio_processing/agc2/rnn_vad/lp_residual_unittest.cc
index 47d8bf5..1e80ee0 100644
--- a/modules/audio_processing/agc2/rnn_vad/lp_residual_unittest.cc
+++ b/modules/audio_processing/agc2/rnn_vad/lp_residual_unittest.cc
@@ -10,7 +10,9 @@
#include "modules/audio_processing/agc2/rnn_vad/lp_residual.h"
+#include <algorithm>
#include <array>
+#include <vector>
#include "modules/audio_processing/agc2/rnn_vad/common.h"
#include "modules/audio_processing/agc2/rnn_vad/test_utils.h"
@@ -22,6 +24,7 @@
namespace rnn_vad {
namespace test {
+// Checks that the LP residual can be computed on an empty frame.
TEST(RnnVadTest, LpResidualOfEmptyFrame) {
// TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
// FloatingPointExceptionObserver fpe_observer;
@@ -37,44 +40,44 @@
ComputeLpResidual(lpc_coeffs, empty_frame, lp_residual);
}
-// TODO(bugs.webrtc.org/9076): Remove when the issue is fixed.
+// Checks that the computed LP residual is bit-exact given test input data.
TEST(RnnVadTest, LpResidualPipelineBitExactness) {
- // Pitch buffer 24 kHz data reader.
+ // Input and expected output readers.
auto pitch_buf_24kHz_reader = CreatePitchBuffer24kHzReader();
- const size_t num_frames = pitch_buf_24kHz_reader.second;
- std::array<float, kBufSize24kHz> pitch_buf_data;
- // Read ground-truth.
auto lp_residual_reader = CreateLpResidualAndPitchPeriodGainReader();
- ASSERT_EQ(num_frames, lp_residual_reader.second);
- std::array<float, kBufSize24kHz> expected_lp_residual;
- rtc::ArrayView<float, kBufSize24kHz> expected_lp_residual_view(
- expected_lp_residual.data(), expected_lp_residual.size());
- // Init pipeline.
+
+ // Buffers.
+ std::vector<float> pitch_buf_data(kBufSize24kHz);
std::array<float, kNumLpcCoefficients> lpc_coeffs;
- rtc::ArrayView<float, kNumLpcCoefficients> lpc_coeffs_view(
- lpc_coeffs.data(), kNumLpcCoefficients);
- std::array<float, kBufSize24kHz> computed_lp_residual;
- rtc::ArrayView<float, kBufSize24kHz> computed_lp_residual_view(
- computed_lp_residual.data(), computed_lp_residual.size());
+ std::vector<float> computed_lp_residual(kBufSize24kHz);
+ std::vector<float> expected_lp_residual(kBufSize24kHz);
+
+ // Test length.
+ const size_t num_frames = std::min(pitch_buf_24kHz_reader.second,
+ static_cast<size_t>(300)); // Max 3 s.
+ ASSERT_GE(lp_residual_reader.second, num_frames);
+
{
// TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
// FloatingPointExceptionObserver fpe_observer;
for (size_t i = 0; i < num_frames; ++i) {
- SCOPED_TRACE(i);
- // Read input and expected output.
- pitch_buf_24kHz_reader.first->ReadChunk(pitch_buf_data);
- lp_residual_reader.first->ReadChunk(expected_lp_residual_view);
- // Skip pitch gain and period.
+ // Read input.
+ ASSERT_TRUE(pitch_buf_24kHz_reader.first->ReadChunk(pitch_buf_data));
+ // Read expected output (ignore pitch gain and period).
+ ASSERT_TRUE(lp_residual_reader.first->ReadChunk(expected_lp_residual));
float unused;
- lp_residual_reader.first->ReadValue(&unused);
- lp_residual_reader.first->ReadValue(&unused);
- // Run pipeline.
- ComputeAndPostProcessLpcCoefficients(pitch_buf_data, lpc_coeffs_view);
- ComputeLpResidual(lpc_coeffs_view, pitch_buf_data,
- computed_lp_residual_view);
- // Compare.
- ExpectNearAbsolute(expected_lp_residual_view, computed_lp_residual_view,
- kFloatMin);
+ ASSERT_TRUE(lp_residual_reader.first->ReadValue(&unused));
+ ASSERT_TRUE(lp_residual_reader.first->ReadValue(&unused));
+
+ // Check every 200 ms.
+ if (i % 20 != 0) {
+ continue;
+ }
+
+ SCOPED_TRACE(i);
+ ComputeAndPostProcessLpcCoefficients(pitch_buf_data, lpc_coeffs);
+ ComputeLpResidual(lpc_coeffs, pitch_buf_data, computed_lp_residual);
+ ExpectNearAbsolute(expected_lp_residual, computed_lp_residual, kFloatMin);
}
}
}
diff --git a/modules/audio_processing/agc2/rnn_vad/pitch_search_internal_unittest.cc b/modules/audio_processing/agc2/rnn_vad/pitch_search_internal_unittest.cc
index 7e29417..23ff49a 100644
--- a/modules/audio_processing/agc2/rnn_vad/pitch_search_internal_unittest.cc
+++ b/modules/audio_processing/agc2/rnn_vad/pitch_search_internal_unittest.cc
@@ -23,39 +23,45 @@
namespace test {
namespace {
-constexpr std::array<int, 2> kTestPitchPeriods = {
- 3 * kMinPitch48kHz / 2,
- (3 * kMinPitch48kHz + kMaxPitch48kHz) / 2,
-};
-constexpr std::array<float, 2> kTestPitchGains = {0.35f, 0.75f};
+constexpr int kTestPitchPeriodsLow = 3 * kMinPitch48kHz / 2;
+constexpr int kTestPitchPeriodsHigh = (3 * kMinPitch48kHz + kMaxPitch48kHz) / 2;
+
+constexpr float kTestPitchGainsLow = 0.35f;
+constexpr float kTestPitchGainsHigh = 0.75f;
} // namespace
class ComputePitchGainThresholdTest
: public ::testing::Test,
- public ::testing::WithParamInterface<
- std::tuple<size_t, size_t, size_t, float, size_t, float, float>> {};
+ public ::testing::WithParamInterface<std::tuple<
+ /*candidate_pitch_period=*/size_t,
+ /*pitch_period_ratio=*/size_t,
+ /*initial_pitch_period=*/size_t,
+ /*initial_pitch_gain=*/float,
+ /*prev_pitch_period=*/size_t,
+ /*prev_pitch_gain=*/float,
+ /*threshold=*/float>> {};
-TEST_P(ComputePitchGainThresholdTest, BitExactness) {
+// Checks that the computed pitch gain is within tolerance given test input
+// data.
+TEST_P(ComputePitchGainThresholdTest, WithinTolerance) {
const auto params = GetParam();
const size_t candidate_pitch_period = std::get<0>(params);
const size_t pitch_period_ratio = std::get<1>(params);
const size_t initial_pitch_period = std::get<2>(params);
const float initial_pitch_gain = std::get<3>(params);
const size_t prev_pitch_period = std::get<4>(params);
- const size_t prev_pitch_gain = std::get<5>(params);
+ const float prev_pitch_gain = std::get<5>(params);
const float threshold = std::get<6>(params);
-
{
// TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
// FloatingPointExceptionObserver fpe_observer;
-
EXPECT_NEAR(
threshold,
ComputePitchGainThreshold(candidate_pitch_period, pitch_period_ratio,
initial_pitch_period, initial_pitch_gain,
prev_pitch_period, prev_pitch_gain),
- 3e-6f);
+ 5e-7f);
}
}
@@ -77,7 +83,9 @@
std::make_tuple(31, 5, 153, 0.85069299f, 150, 0.79073799f, 0.72308898f),
std::make_tuple(78, 2, 156, 0.72750503f, 153, 0.85069299f, 0.618379f)));
-TEST(RnnVadTest, ComputeSlidingFrameSquareEnergiesBitExactness) {
+// Checks that the frame-wise sliding square energy function produces output
+// within tolerance given test input data.
+TEST(RnnVadTest, ComputeSlidingFrameSquareEnergiesWithinTolerance) {
PitchTestData test_data;
std::array<float, kNumPitchBufSquareEnergies> computed_output;
{
@@ -91,6 +99,7 @@
computed_output, 3e-2f);
}
+// Checks that the estimated pitch period is bit-exact given test input data.
TEST(RnnVadTest, FindBestPitchPeriodsBitExactness) {
PitchTestData test_data;
std::array<float, kBufSize12kHz> pitch_buf_decimated;
@@ -104,14 +113,13 @@
FindBestPitchPeriods({auto_corr_view.data(), auto_corr_view.size()},
pitch_buf_decimated, kMaxPitch12kHz);
}
- const std::array<size_t, 2> expected_output = {140, 142};
- EXPECT_EQ(expected_output, pitch_candidates_inv_lags);
+ EXPECT_EQ(pitch_candidates_inv_lags[0], static_cast<size_t>(140));
+ EXPECT_EQ(pitch_candidates_inv_lags[1], static_cast<size_t>(142));
}
+// Checks that the refined pitch period is bit-exact given test input data.
TEST(RnnVadTest, RefinePitchPeriod48kHzBitExactness) {
PitchTestData test_data;
- std::array<float, kBufSize12kHz> pitch_buf_decimated;
- Decimate2x(test_data.GetPitchBufView(), pitch_buf_decimated);
size_t pitch_inv_lag;
{
// TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
@@ -125,10 +133,17 @@
class CheckLowerPitchPeriodsAndComputePitchGainTest
: public ::testing::Test,
- public ::testing::WithParamInterface<
- std::tuple<int, int, float, int, float>> {};
+ public ::testing::WithParamInterface<std::tuple<
+ /*initial_pitch_period=*/int,
+ /*prev_pitch_period=*/int,
+ /*prev_pitch_gain=*/float,
+ /*expected_pitch_period=*/int,
+ /*expected_pitch_gain=*/float>> {};
-TEST_P(CheckLowerPitchPeriodsAndComputePitchGainTest, BitExactness) {
+// Checks that the computed pitch period is bit-exact and that the computed
+// pitch gain is within tolerance given test input data.
+TEST_P(CheckLowerPitchPeriodsAndComputePitchGainTest,
+ PeriodBitExactnessGainWithinTolerance) {
const auto params = GetParam();
const int initial_pitch_period = std::get<0>(params);
const int prev_pitch_period = std::get<1>(params);
@@ -147,48 +162,49 @@
}
}
-INSTANTIATE_TEST_SUITE_P(RnnVadTest,
- CheckLowerPitchPeriodsAndComputePitchGainTest,
- ::testing::Values(std::make_tuple(kTestPitchPeriods[0],
- kTestPitchPeriods[0],
- kTestPitchGains[0],
- 91,
- -0.0188608f),
- std::make_tuple(kTestPitchPeriods[0],
- kTestPitchPeriods[0],
- kTestPitchGains[1],
- 91,
- -0.0188608f),
- std::make_tuple(kTestPitchPeriods[0],
- kTestPitchPeriods[1],
- kTestPitchGains[0],
- 91,
- -0.0188608f),
- std::make_tuple(kTestPitchPeriods[0],
- kTestPitchPeriods[1],
- kTestPitchGains[1],
- 91,
- -0.0188608f),
- std::make_tuple(kTestPitchPeriods[1],
- kTestPitchPeriods[0],
- kTestPitchGains[0],
- 475,
- -0.0904344f),
- std::make_tuple(kTestPitchPeriods[1],
- kTestPitchPeriods[0],
- kTestPitchGains[1],
- 475,
- -0.0904344f),
- std::make_tuple(kTestPitchPeriods[1],
- kTestPitchPeriods[1],
- kTestPitchGains[0],
- 475,
- -0.0904344f),
- std::make_tuple(kTestPitchPeriods[1],
- kTestPitchPeriods[1],
- kTestPitchGains[1],
- 475,
- -0.0904344f)));
+INSTANTIATE_TEST_SUITE_P(
+ RnnVadTest,
+ CheckLowerPitchPeriodsAndComputePitchGainTest,
+ ::testing::Values(std::make_tuple(kTestPitchPeriodsLow,
+ kTestPitchPeriodsLow,
+ kTestPitchGainsLow,
+ 91,
+ -0.0188608f),
+ std::make_tuple(kTestPitchPeriodsLow,
+ kTestPitchPeriodsLow,
+ kTestPitchGainsHigh,
+ 91,
+ -0.0188608f),
+ std::make_tuple(kTestPitchPeriodsLow,
+ kTestPitchPeriodsHigh,
+ kTestPitchGainsLow,
+ 91,
+ -0.0188608f),
+ std::make_tuple(kTestPitchPeriodsLow,
+ kTestPitchPeriodsHigh,
+ kTestPitchGainsHigh,
+ 91,
+ -0.0188608f),
+ std::make_tuple(kTestPitchPeriodsHigh,
+ kTestPitchPeriodsLow,
+ kTestPitchGainsLow,
+ 475,
+ -0.0904344f),
+ std::make_tuple(kTestPitchPeriodsHigh,
+ kTestPitchPeriodsLow,
+ kTestPitchGainsHigh,
+ 475,
+ -0.0904344f),
+ std::make_tuple(kTestPitchPeriodsHigh,
+ kTestPitchPeriodsHigh,
+ kTestPitchGainsLow,
+ 475,
+ -0.0904344f),
+ std::make_tuple(kTestPitchPeriodsHigh,
+ kTestPitchPeriodsHigh,
+ kTestPitchGainsHigh,
+ 475,
+ -0.0904344f)));
} // namespace test
} // namespace rnn_vad
diff --git a/modules/audio_processing/agc2/rnn_vad/pitch_search_unittest.cc b/modules/audio_processing/agc2/rnn_vad/pitch_search_unittest.cc
index eac332e..494dfe7 100644
--- a/modules/audio_processing/agc2/rnn_vad/pitch_search_unittest.cc
+++ b/modules/audio_processing/agc2/rnn_vad/pitch_search_unittest.cc
@@ -12,7 +12,8 @@
#include "modules/audio_processing/agc2/rnn_vad/pitch_info.h"
#include "modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h"
-#include <array>
+#include <algorithm>
+#include <vector>
#include "modules/audio_processing/agc2/rnn_vad/test_utils.h"
// TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
@@ -23,11 +24,13 @@
namespace rnn_vad {
namespace test {
-// TODO(bugs.webrtc.org/9076): Remove when the issue is fixed.
-TEST(RnnVadTest, PitchSearchBitExactness) {
+// Checks that the computed pitch period is bit-exact and that the computed
+// pitch gain is within tolerance given test input data.
+TEST(RnnVadTest, PitchSearchWithinTolerance) {
auto lp_residual_reader = CreateLpResidualAndPitchPeriodGainReader();
- const size_t num_frames = lp_residual_reader.second;
- std::array<float, 864> lp_residual;
+ const size_t num_frames = std::min(lp_residual_reader.second,
+ static_cast<size_t>(300)); // Max 3 s.
+ std::vector<float> lp_residual(kBufSize24kHz);
float expected_pitch_period, expected_pitch_gain;
PitchEstimator pitch_estimator;
{
@@ -38,7 +41,8 @@
lp_residual_reader.first->ReadChunk(lp_residual);
lp_residual_reader.first->ReadValue(&expected_pitch_period);
lp_residual_reader.first->ReadValue(&expected_pitch_gain);
- PitchInfo pitch_info = pitch_estimator.Estimate(lp_residual);
+ PitchInfo pitch_info =
+ pitch_estimator.Estimate({lp_residual.data(), kBufSize24kHz});
EXPECT_EQ(static_cast<int>(expected_pitch_period), pitch_info.period);
EXPECT_NEAR(expected_pitch_gain, pitch_info.gain, 1e-5f);
}
diff --git a/modules/audio_processing/agc2/rnn_vad/rnn_unittest.cc b/modules/audio_processing/agc2/rnn_vad/rnn_unittest.cc
index 289ce8d..933b555 100644
--- a/modules/audio_processing/agc2/rnn_vad/rnn_unittest.cc
+++ b/modules/audio_processing/agc2/rnn_vad/rnn_unittest.cc
@@ -8,9 +8,7 @@
* be found in the AUTHORS file in the root of the source tree.
*/
-#include <algorithm>
#include <array>
-#include <vector>
#include "modules/audio_processing/agc2/rnn_vad/rnn.h"
#include "modules/audio_processing/agc2/rnn_vad/test_utils.h"
@@ -63,7 +61,8 @@
} // namespace
-// Bit-exactness check for fully connected layers.
+// Checks that the output of a fully connected layer is within tolerance given
+// test input data.
TEST(RnnVadTest, CheckFullyConnectedLayerOutput) {
const std::array<int8_t, 1> bias = {-50};
const std::array<int8_t, 24> weights = {
@@ -106,6 +105,8 @@
}
}
+// Checks that the output of a GRU layer is within tolerance given test input
+// data.
TEST(RnnVadTest, CheckGatedRecurrentLayer) {
const std::array<int8_t, 12> bias = {96, -99, -81, -114, 49, 119,
-118, 68, -76, 91, 121, 125};
@@ -139,41 +140,6 @@
}
}
-// TODO(bugs.webrtc.org/9076): Remove when the issue is fixed.
-// Bit-exactness test checking that precomputed frame-wise features lead to the
-// expected VAD probabilities.
-TEST(RnnVadTest, RnnBitExactness) {
- // Init.
- auto features_reader = CreateSilenceFlagsFeatureMatrixReader();
- auto vad_probs_reader = CreateVadProbsReader();
- ASSERT_EQ(features_reader.second, vad_probs_reader.second);
- const size_t num_frames = features_reader.second;
- // Frame-wise buffers.
- float expected_vad_probability;
- float is_silence;
- std::array<float, kFeatureVectorSize> features;
-
- // Compute VAD probability using the precomputed features.
- RnnBasedVad vad;
- for (size_t i = 0; i < num_frames; ++i) {
- SCOPED_TRACE(i);
- // Read frame data.
- RTC_CHECK(vad_probs_reader.first->ReadValue(&expected_vad_probability));
- // The features file also includes a silence flag for each frame.
- RTC_CHECK(features_reader.first->ReadValue(&is_silence));
- RTC_CHECK(features_reader.first->ReadChunk(features));
- // Compute and check VAD probability.
- float vad_probability = vad.ComputeVadProbability(features, is_silence);
- ASSERT_TRUE(is_silence == 0.f || is_silence == 1.f);
- if (is_silence == 1.f) {
- ASSERT_EQ(0.f, expected_vad_probability);
- EXPECT_EQ(0.f, vad_probability);
- } else {
- EXPECT_NEAR(expected_vad_probability, vad_probability, 3e-6f);
- }
- }
-}
-
} // namespace test
} // namespace rnn_vad
} // namespace webrtc
diff --git a/modules/audio_processing/agc2/rnn_vad/rnn_vad_unittest.cc b/modules/audio_processing/agc2/rnn_vad/rnn_vad_unittest.cc
index 4afe24b..8583d4b 100644
--- a/modules/audio_processing/agc2/rnn_vad/rnn_vad_unittest.cc
+++ b/modules/audio_processing/agc2/rnn_vad/rnn_vad_unittest.cc
@@ -9,6 +9,7 @@
*/
#include <array>
+#include <string>
#include <vector>
#include "common_audio/resampler/push_sinc_resampler.h"
@@ -43,8 +44,68 @@
RTC_LOG(LS_INFO) << "speed: " << speed << "x";
}
+// When the RNN VAD model is updated and the expected output changes, set the
+// constant below to true in order to write new expected output binary files.
+constexpr bool kWriteComputedOutputToFile = false;
+
} // namespace
+// Avoids that one forgets to set |kWriteComputedOutputToFile| back to false
+// when the expected output files are re-exported.
+TEST(RnnVadTest, CheckWriteComputedOutputIsFalse) {
+ ASSERT_FALSE(kWriteComputedOutputToFile)
+ << "Cannot land if kWriteComputedOutput is true.";
+}
+
+// Checks that the computed VAD probability for a test input sequence sampled at
+// 48 kHz is within tolerance.
+TEST(RnnVadTest, RnnVadProbabilityWithinTolerance) {
+ // Init resampler, feature extractor and RNN.
+ PushSincResampler decimator(kFrameSize10ms48kHz, kFrameSize10ms24kHz);
+ FeaturesExtractor features_extractor;
+ RnnBasedVad rnn_vad;
+
+ // Init input samples and expected output readers.
+ auto samples_reader = CreatePcmSamplesReader(kFrameSize10ms48kHz);
+ auto expected_vad_prob_reader = CreateVadProbsReader();
+
+ // Input length.
+ const size_t num_frames = samples_reader.second;
+ ASSERT_GE(expected_vad_prob_reader.second, num_frames);
+
+ // Init buffers.
+ std::vector<float> samples_48k(kFrameSize10ms48kHz);
+ std::vector<float> samples_24k(kFrameSize10ms24kHz);
+ std::vector<float> feature_vector(kFeatureVectorSize);
+ std::vector<float> computed_vad_prob(num_frames);
+ std::vector<float> expected_vad_prob(num_frames);
+
+ // Read expected output.
+ ASSERT_TRUE(expected_vad_prob_reader.first->ReadChunk(expected_vad_prob));
+
+ // Compute VAD probabilities on the downsampled input.
+ float cumulative_error = 0.f;
+ for (size_t i = 0; i < num_frames; ++i) {
+ samples_reader.first->ReadChunk(samples_48k);
+ decimator.Resample(samples_48k.data(), samples_48k.size(),
+ samples_24k.data(), samples_24k.size());
+ bool is_silence = features_extractor.CheckSilenceComputeFeatures(
+ {samples_24k.data(), kFrameSize10ms24kHz},
+ {feature_vector.data(), kFeatureVectorSize});
+ computed_vad_prob[i] = rnn_vad.ComputeVadProbability(
+ {feature_vector.data(), kFeatureVectorSize}, is_silence);
+ EXPECT_NEAR(computed_vad_prob[i], expected_vad_prob[i], 1e-3f);
+ cumulative_error += std::abs(computed_vad_prob[i] - expected_vad_prob[i]);
+ }
+ // Check average error.
+ EXPECT_LT(cumulative_error / num_frames, 1e-4f);
+
+ if (kWriteComputedOutputToFile) {
+ BinaryFileWriter<float> vad_prob_writer("new_vad_prob.dat");
+ vad_prob_writer.WriteChunk(computed_vad_prob);
+ }
+}
+
// Performance test for the RNN VAD (pre-fetching and downsampling are
// excluded). Keep disabled and only enable locally to measure performance as
// follows:
diff --git a/modules/audio_processing/agc2/rnn_vad/spectral_features_internal_unittest.cc b/modules/audio_processing/agc2/rnn_vad/spectral_features_internal_unittest.cc
index d112eb7..ec81295 100644
--- a/modules/audio_processing/agc2/rnn_vad/spectral_features_internal_unittest.cc
+++ b/modules/audio_processing/agc2/rnn_vad/spectral_features_internal_unittest.cc
@@ -85,14 +85,18 @@
}
}
+// Checks that the computed band-wise auto-correlation is non-negative for a
+// simple input vector of FFT coefficients.
TEST(RnnVadTest, SpectralCorrelatorValidOutput) {
- SpectralCorrelator e;
+ // Input: vector of (1, 1j) values.
Pffft fft(kFrameSize20ms24kHz, Pffft::FftType::kReal);
auto in = fft.CreateBuffer();
std::array<float, kOpusBands24kHz> out;
auto in_view = in->GetView();
std::fill(in_view.begin(), in_view.end(), 1.f);
in_view[1] = 0.f; // Nyquist frequency.
+ // Compute and check output.
+ SpectralCorrelator e;
e.ComputeAutoCorrelation(in_view, out);
for (size_t i = 0; i < kOpusBands24kHz; ++i) {
SCOPED_TRACE(i);
@@ -100,6 +104,8 @@
}
}
+// Checks that the computed smoothed log magnitude spectrum is within tolerance
+// given hard-coded test input data.
TEST(RnnVadTest, ComputeSmoothedLogMagnitudeSpectrumWithinTolerance) {
constexpr std::array<float, kNumBands> input = {
{86.060539245605f, 275.668334960938f, 43.406528472900f, 6.541896820068f,
@@ -124,7 +130,9 @@
}
}
-TEST(RnnVadTest, ComputeDctBitExactness) {
+// Checks that the computed DCT is within tolerance given hard-coded test input
+// data.
+TEST(RnnVadTest, ComputeDctWithinTolerance) {
constexpr std::array<float, kNumBands> input = {
{0.232155621052f, 0.678957760334f, 0.220818966627f, -0.077363930643f,
-0.559227049351f, 0.432545185089f, 0.353900641203f, 0.398993015289f,
diff --git a/modules/audio_processing/agc2/rnn_vad/spectral_features_unittest.cc b/modules/audio_processing/agc2/rnn_vad/spectral_features_unittest.cc
index 39b9f93..bc00e2c 100644
--- a/modules/audio_processing/agc2/rnn_vad/spectral_features_unittest.cc
+++ b/modules/audio_processing/agc2/rnn_vad/spectral_features_unittest.cc
@@ -67,6 +67,8 @@
} // namespace
+// Checks that silence is detected when the input signal is 0 and that the
+// feature vector is written only if the input signal is not tagged as silence.
TEST(RnnVadTest, SpectralFeaturesWithAndWithoutSilence) {
// Initialize.
SpectralFeaturesExtractor sfe;
@@ -108,9 +110,10 @@
[](float x) { return x == kInitialFeatureVal; }));
}
-// When the input signal does not change, the cepstral coefficients average does
-// not change and the derivatives are zero. Similarly, the cepstral variability
-// score does not change either.
+// Feeds a constant input signal and checks that:
+// - the cepstral coefficients average does not change;
+// - the derivatives are zero;
+// - the cepstral variability score does not change.
TEST(RnnVadTest, CepstralFeaturesConstantAverageZeroDerivative) {
// Initialize.
SpectralFeaturesExtractor sfe;
diff --git a/modules/audio_processing/agc2/rnn_vad/test_utils.cc b/modules/audio_processing/agc2/rnn_vad/test_utils.cc
index 14b84a4..8236d5f 100644
--- a/modules/audio_processing/agc2/rnn_vad/test_utils.cc
+++ b/modules/audio_processing/agc2/rnn_vad/test_utils.cc
@@ -46,13 +46,6 @@
}
}
-std::unique_ptr<BinaryFileReader<float>> CreatePitchSearchTestDataReader() {
- constexpr size_t cols = 1396;
- return absl::make_unique<BinaryFileReader<float>>(
- ResourcePath("audio_processing/agc2/rnn_vad/pitch_search_int", "dat"),
- cols);
-}
-
std::pair<std::unique_ptr<BinaryFileReader<int16_t, float>>, const size_t>
CreatePcmSamplesReader(const size_t frame_length) {
auto ptr = absl::make_unique<BinaryFileReader<int16_t, float>>(
@@ -78,25 +71,6 @@
rtc::CheckedDivExact(ptr->data_length(), 2 + num_lp_residual_coeffs)};
}
-ReaderPairType CreateFftCoeffsReader() {
- constexpr size_t num_fft_points = 481;
- constexpr size_t row_size = 2 * num_fft_points; // Real and imaginary values.
- auto ptr = absl::make_unique<BinaryFileReader<float>>(
- test::ResourcePath("audio_processing/agc2/rnn_vad/fft", "dat"),
- num_fft_points);
- return {std::move(ptr), rtc::CheckedDivExact(ptr->data_length(), row_size)};
-}
-
-ReaderPairType CreateSilenceFlagsFeatureMatrixReader() {
- constexpr size_t feature_vector_size = 42;
- auto ptr = absl::make_unique<BinaryFileReader<float>>(
- test::ResourcePath("audio_processing/agc2/rnn_vad/sil_features", "dat"),
- feature_vector_size);
- // Features and silence flag.
- return {std::move(ptr),
- rtc::CheckedDivExact(ptr->data_length(), feature_vector_size + 1)};
-}
-
ReaderPairType CreateVadProbsReader() {
auto ptr = absl::make_unique<BinaryFileReader<float>>(
test::ResourcePath("audio_processing/agc2/rnn_vad/vad_prob", "dat"));
@@ -104,23 +78,26 @@
}
PitchTestData::PitchTestData() {
- auto test_data_reader = CreatePitchSearchTestDataReader();
- test_data_reader->ReadChunk(test_data_);
+ BinaryFileReader<float> test_data_reader(
+ ResourcePath("audio_processing/agc2/rnn_vad/pitch_search_int", "dat"),
+ static_cast<size_t>(1396));
+ test_data_reader.ReadChunk(test_data_);
}
PitchTestData::~PitchTestData() = default;
-rtc::ArrayView<const float, kBufSize24kHz> PitchTestData::GetPitchBufView() {
+rtc::ArrayView<const float, kBufSize24kHz> PitchTestData::GetPitchBufView()
+ const {
return {test_data_.data(), kBufSize24kHz};
}
rtc::ArrayView<const float, kNumPitchBufSquareEnergies>
-PitchTestData::GetPitchBufSquareEnergiesView() {
+PitchTestData::GetPitchBufSquareEnergiesView() const {
return {test_data_.data() + kBufSize24kHz, kNumPitchBufSquareEnergies};
}
rtc::ArrayView<const float, kNumPitchBufAutoCorrCoeffs>
-PitchTestData::GetPitchBufAutoCorrCoeffsView() {
+PitchTestData::GetPitchBufAutoCorrCoeffsView() const {
return {test_data_.data() + kBufSize24kHz + kNumPitchBufSquareEnergies,
kNumPitchBufAutoCorrCoeffs};
}
diff --git a/modules/audio_processing/agc2/rnn_vad/test_utils.h b/modules/audio_processing/agc2/rnn_vad/test_utils.h
index c11af7f..fbb270f 100644
--- a/modules/audio_processing/agc2/rnn_vad/test_utils.h
+++ b/modules/audio_processing/agc2/rnn_vad/test_utils.h
@@ -17,6 +17,7 @@
#include <limits>
#include <memory>
#include <string>
+#include <type_traits>
#include <utility>
#include <vector>
@@ -46,11 +47,10 @@
template <typename T, typename D = T>
class BinaryFileReader {
public:
- explicit BinaryFileReader(const std::string& file_path, size_t chunk_size = 1)
+ explicit BinaryFileReader(const std::string& file_path, size_t chunk_size = 0)
: is_(file_path, std::ios::binary | std::ios::ate),
data_length_(is_.tellg() / sizeof(T)),
chunk_size_(chunk_size) {
- RTC_CHECK_LT(0, chunk_size_);
RTC_CHECK(is_);
SeekBeginning();
buf_.resize(chunk_size_);
@@ -69,9 +69,11 @@
}
return is_.gcount() == sizeof(T);
}
+ // If |chunk_size| was specified in the ctor, it will check that the size of
+ // |dst| equals |chunk_size|.
bool ReadChunk(rtc::ArrayView<D> dst) {
- RTC_DCHECK_EQ(chunk_size_, dst.size());
- const std::streamsize bytes_to_read = chunk_size_ * sizeof(T);
+ RTC_DCHECK((chunk_size_ == 0) || (chunk_size_ == dst.size()));
+ const std::streamsize bytes_to_read = dst.size() * sizeof(T);
if (std::is_same<T, D>::value) {
is_.read(reinterpret_cast<char*>(dst.data()), bytes_to_read);
} else {
@@ -91,9 +93,26 @@
std::vector<T> buf_;
};
+// Writer for binary files.
+template <typename T>
+class BinaryFileWriter {
+ public:
+ explicit BinaryFileWriter(const std::string& file_path)
+ : os_(file_path, std::ios::binary) {}
+ BinaryFileWriter(const BinaryFileWriter&) = delete;
+ BinaryFileWriter& operator=(const BinaryFileWriter&) = delete;
+ ~BinaryFileWriter() = default;
+ static_assert(std::is_arithmetic<T>::value, "");
+ void WriteChunk(rtc::ArrayView<const T> value) {
+ const std::streamsize bytes_to_write = value.size() * sizeof(T);
+ os_.write(reinterpret_cast<const char*>(value.data()), bytes_to_write);
+ }
+
+ private:
+ std::ofstream os_;
+};
+
// Factories for resource file readers.
-// Creates a reader for the pitch search test data.
-std::unique_ptr<BinaryFileReader<float>> CreatePitchSearchTestDataReader();
// The functions below return a pair where the first item is a reader unique
// pointer and the second the number of chunks that can be read from the file.
// Creates a reader for the PCM samples that casts from S16 to float and reads
@@ -107,12 +126,6 @@
// and gain values.
std::pair<std::unique_ptr<BinaryFileReader<float>>, const size_t>
CreateLpResidualAndPitchPeriodGainReader();
-// Creates a reader for the FFT coefficients.
-std::pair<std::unique_ptr<BinaryFileReader<float>>, const size_t>
-CreateFftCoeffsReader();
-// Creates a reader for the silence flags and the feature matrix.
-std::pair<std::unique_ptr<BinaryFileReader<float>>, const size_t>
-CreateSilenceFlagsFeatureMatrixReader();
// Creates a reader for the VAD probabilities.
std::pair<std::unique_ptr<BinaryFileReader<float>>, const size_t>
CreateVadProbsReader();
@@ -128,11 +141,11 @@
public:
PitchTestData();
~PitchTestData();
- rtc::ArrayView<const float, kBufSize24kHz> GetPitchBufView();
+ rtc::ArrayView<const float, kBufSize24kHz> GetPitchBufView() const;
rtc::ArrayView<const float, kNumPitchBufSquareEnergies>
- GetPitchBufSquareEnergiesView();
+ GetPitchBufSquareEnergiesView() const;
rtc::ArrayView<const float, kNumPitchBufAutoCorrCoeffs>
- GetPitchBufAutoCorrCoeffsView();
+ GetPitchBufAutoCorrCoeffsView() const;
private:
std::array<float, kPitchTestDataSize> test_data_;
diff --git a/resources/audio_processing/agc2/rnn_vad/fft.dat.sha1 b/resources/audio_processing/agc2/rnn_vad/fft.dat.sha1
deleted file mode 100644
index ebd5124..0000000
--- a/resources/audio_processing/agc2/rnn_vad/fft.dat.sha1
+++ /dev/null
@@ -1 +0,0 @@
-e62364d35abd123663bfc800fa233071d6d7fffd
\ No newline at end of file
diff --git a/resources/audio_processing/agc2/rnn_vad/sil_features.dat.sha1 b/resources/audio_processing/agc2/rnn_vad/sil_features.dat.sha1
deleted file mode 100644
index bc591e9..0000000
--- a/resources/audio_processing/agc2/rnn_vad/sil_features.dat.sha1
+++ /dev/null
@@ -1 +0,0 @@
-e0a92782c2903be9da10385d924d34e8bf212d5e
\ No newline at end of file
diff --git a/resources/audio_processing/agc2/rnn_vad/vad_prob.dat.sha1 b/resources/audio_processing/agc2/rnn_vad/vad_prob.dat.sha1
index 1aa3bd0..8ee78b1 100644
--- a/resources/audio_processing/agc2/rnn_vad/vad_prob.dat.sha1
+++ b/resources/audio_processing/agc2/rnn_vad/vad_prob.dat.sha1
@@ -1 +1 @@
-05735ede0b457318e307d12f5acfd11bbbbd0afd
\ No newline at end of file
+68640327266262c3fe047ec7f07a46a355ff90b9
\ No newline at end of file