RNN VAD: clean-up unit tests

- add test that checks that the computed VAD probability is within
  tolerance *1
- speed-up some tests by reducing the input length and skipping frames
- remove unused code in test_utils
- fix some comments

*1: RnnVadTest::RnnBitExactness is replaced by
    RnnVadTest::RnnVadProbabilityWithinTolerance

Bug: webrtc:10480
Change-Id: I19332d06eacffbbe671bf7749ff4c92798bdc55c
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/133910
Reviewed-by: Alex Loiko <aleloi@webrtc.org>
Commit-Queue: Alessio Bazzica <alessiob@webrtc.org>
Cr-Commit-Position: refs/heads/master@{#27803}
diff --git a/modules/audio_processing/agc2/rnn_vad/BUILD.gn b/modules/audio_processing/agc2/rnn_vad/BUILD.gn
index 902082e..7cf6e3d 100644
--- a/modules/audio_processing/agc2/rnn_vad/BUILD.gn
+++ b/modules/audio_processing/agc2/rnn_vad/BUILD.gn
@@ -64,12 +64,10 @@
 
   unittest_resources = [
     "../../../../resources/audio_processing/agc2/rnn_vad/band_energies.dat",
-    "../../../../resources/audio_processing/agc2/rnn_vad/fft.dat",
     "../../../../resources/audio_processing/agc2/rnn_vad/pitch_buf_24k.dat",
     "../../../../resources/audio_processing/agc2/rnn_vad/pitch_search_int.dat",
     "../../../../resources/audio_processing/agc2/rnn_vad/pitch_lp_res.dat",
     "../../../../resources/audio_processing/agc2/rnn_vad/samples.pcm",
-    "../../../../resources/audio_processing/agc2/rnn_vad/sil_features.dat",
     "../../../../resources/audio_processing/agc2/rnn_vad/vad_prob.dat",
   ]
 
diff --git a/modules/audio_processing/agc2/rnn_vad/auto_correlation_unittest.cc b/modules/audio_processing/agc2/rnn_vad/auto_correlation_unittest.cc
index a5e456a..f66c0b2 100644
--- a/modules/audio_processing/agc2/rnn_vad/auto_correlation_unittest.cc
+++ b/modules/audio_processing/agc2/rnn_vad/auto_correlation_unittest.cc
@@ -10,6 +10,7 @@
 
 #include "modules/audio_processing/agc2/rnn_vad/auto_correlation.h"
 
+#include "modules/audio_processing/agc2/rnn_vad/common.h"
 #include "modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h"
 #include "modules/audio_processing/agc2/rnn_vad/test_utils.h"
 #include "test/gtest.h"
@@ -18,6 +19,8 @@
 namespace rnn_vad {
 namespace test {
 
+// Checks that the auto correlation function produces output within tolerance
+// given test input data.
 TEST(RnnVadTest, PitchBufferAutoCorrelationWithinTolerance) {
   PitchTestData test_data;
   std::array<float, kBufSize12kHz> pitch_buf_decimated;
@@ -35,7 +38,7 @@
                      computed_output, 3e-3f);
 }
 
-// Check that the auto correlation function computes the right thing for a
+// Checks that the auto correlation function computes the right thing for a
 // simple use case.
 TEST(RnnVadTest, CheckAutoCorrelationOnConstantPitchBuffer) {
   // Create constant signal with no pitch.
@@ -49,11 +52,12 @@
     auto_corr_calculator.ComputeOnPitchBuffer(pitch_buf_decimated,
                                               computed_output);
   }
-  // The expected output is constantly the length of the fixed 'x'
-  // array in ComputePitchAutoCorrelation.
+  // The expected output is a vector filled with the same expected
+  // auto-correlation value. The latter equals the length of a 20 ms frame.
+  constexpr size_t kFrameSize20ms12kHz = kFrameSize20ms24kHz / 2;
   std::array<float, kNumPitchBufAutoCorrCoeffs> expected_output;
   std::fill(expected_output.begin(), expected_output.end(),
-            kBufSize12kHz - kMaxPitch12kHz);
+            static_cast<float>(kFrameSize20ms12kHz));
   ExpectNearAbsolute(expected_output, computed_output, 4e-5f);
 }
 
diff --git a/modules/audio_processing/agc2/rnn_vad/lp_residual_unittest.cc b/modules/audio_processing/agc2/rnn_vad/lp_residual_unittest.cc
index 47d8bf5..1e80ee0 100644
--- a/modules/audio_processing/agc2/rnn_vad/lp_residual_unittest.cc
+++ b/modules/audio_processing/agc2/rnn_vad/lp_residual_unittest.cc
@@ -10,7 +10,9 @@
 
 #include "modules/audio_processing/agc2/rnn_vad/lp_residual.h"
 
+#include <algorithm>
 #include <array>
+#include <vector>
 
 #include "modules/audio_processing/agc2/rnn_vad/common.h"
 #include "modules/audio_processing/agc2/rnn_vad/test_utils.h"
@@ -22,6 +24,7 @@
 namespace rnn_vad {
 namespace test {
 
+// Checks that the LP residual can be computed on an empty frame.
 TEST(RnnVadTest, LpResidualOfEmptyFrame) {
   // TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
   // FloatingPointExceptionObserver fpe_observer;
@@ -37,44 +40,44 @@
   ComputeLpResidual(lpc_coeffs, empty_frame, lp_residual);
 }
 
-// TODO(bugs.webrtc.org/9076): Remove when the issue is fixed.
+// Checks that the computed LP residual is bit-exact given test input data.
 TEST(RnnVadTest, LpResidualPipelineBitExactness) {
-  // Pitch buffer 24 kHz data reader.
+  // Input and expected output readers.
   auto pitch_buf_24kHz_reader = CreatePitchBuffer24kHzReader();
-  const size_t num_frames = pitch_buf_24kHz_reader.second;
-  std::array<float, kBufSize24kHz> pitch_buf_data;
-  // Read ground-truth.
   auto lp_residual_reader = CreateLpResidualAndPitchPeriodGainReader();
-  ASSERT_EQ(num_frames, lp_residual_reader.second);
-  std::array<float, kBufSize24kHz> expected_lp_residual;
-  rtc::ArrayView<float, kBufSize24kHz> expected_lp_residual_view(
-      expected_lp_residual.data(), expected_lp_residual.size());
-  // Init pipeline.
+
+  // Buffers.
+  std::vector<float> pitch_buf_data(kBufSize24kHz);
   std::array<float, kNumLpcCoefficients> lpc_coeffs;
-  rtc::ArrayView<float, kNumLpcCoefficients> lpc_coeffs_view(
-      lpc_coeffs.data(), kNumLpcCoefficients);
-  std::array<float, kBufSize24kHz> computed_lp_residual;
-  rtc::ArrayView<float, kBufSize24kHz> computed_lp_residual_view(
-      computed_lp_residual.data(), computed_lp_residual.size());
+  std::vector<float> computed_lp_residual(kBufSize24kHz);
+  std::vector<float> expected_lp_residual(kBufSize24kHz);
+
+  // Test length.
+  const size_t num_frames = std::min(pitch_buf_24kHz_reader.second,
+                                     static_cast<size_t>(300));  // Max 3 s.
+  ASSERT_GE(lp_residual_reader.second, num_frames);
+
   {
     // TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
     // FloatingPointExceptionObserver fpe_observer;
     for (size_t i = 0; i < num_frames; ++i) {
-      SCOPED_TRACE(i);
-      // Read input and expected output.
-      pitch_buf_24kHz_reader.first->ReadChunk(pitch_buf_data);
-      lp_residual_reader.first->ReadChunk(expected_lp_residual_view);
-      // Skip pitch gain and period.
+      // Read input.
+      ASSERT_TRUE(pitch_buf_24kHz_reader.first->ReadChunk(pitch_buf_data));
+      // Read expected output (ignore pitch gain and period).
+      ASSERT_TRUE(lp_residual_reader.first->ReadChunk(expected_lp_residual));
       float unused;
-      lp_residual_reader.first->ReadValue(&unused);
-      lp_residual_reader.first->ReadValue(&unused);
-      // Run pipeline.
-      ComputeAndPostProcessLpcCoefficients(pitch_buf_data, lpc_coeffs_view);
-      ComputeLpResidual(lpc_coeffs_view, pitch_buf_data,
-                        computed_lp_residual_view);
-      // Compare.
-      ExpectNearAbsolute(expected_lp_residual_view, computed_lp_residual_view,
-                         kFloatMin);
+      ASSERT_TRUE(lp_residual_reader.first->ReadValue(&unused));
+      ASSERT_TRUE(lp_residual_reader.first->ReadValue(&unused));
+
+      // Check every 200 ms.
+      if (i % 20 != 0) {
+        continue;
+      }
+
+      SCOPED_TRACE(i);
+      ComputeAndPostProcessLpcCoefficients(pitch_buf_data, lpc_coeffs);
+      ComputeLpResidual(lpc_coeffs, pitch_buf_data, computed_lp_residual);
+      ExpectNearAbsolute(expected_lp_residual, computed_lp_residual, kFloatMin);
     }
   }
 }
diff --git a/modules/audio_processing/agc2/rnn_vad/pitch_search_internal_unittest.cc b/modules/audio_processing/agc2/rnn_vad/pitch_search_internal_unittest.cc
index 7e29417..23ff49a 100644
--- a/modules/audio_processing/agc2/rnn_vad/pitch_search_internal_unittest.cc
+++ b/modules/audio_processing/agc2/rnn_vad/pitch_search_internal_unittest.cc
@@ -23,39 +23,45 @@
 namespace test {
 namespace {
 
-constexpr std::array<int, 2> kTestPitchPeriods = {
-    3 * kMinPitch48kHz / 2,
-    (3 * kMinPitch48kHz + kMaxPitch48kHz) / 2,
-};
-constexpr std::array<float, 2> kTestPitchGains = {0.35f, 0.75f};
+constexpr int kTestPitchPeriodsLow = 3 * kMinPitch48kHz / 2;
+constexpr int kTestPitchPeriodsHigh = (3 * kMinPitch48kHz + kMaxPitch48kHz) / 2;
+
+constexpr float kTestPitchGainsLow = 0.35f;
+constexpr float kTestPitchGainsHigh = 0.75f;
 
 }  // namespace
 
 class ComputePitchGainThresholdTest
     : public ::testing::Test,
-      public ::testing::WithParamInterface<
-          std::tuple<size_t, size_t, size_t, float, size_t, float, float>> {};
+      public ::testing::WithParamInterface<std::tuple<
+          /*candidate_pitch_period=*/size_t,
+          /*pitch_period_ratio=*/size_t,
+          /*initial_pitch_period=*/size_t,
+          /*initial_pitch_gain=*/float,
+          /*prev_pitch_period=*/size_t,
+          /*prev_pitch_gain=*/float,
+          /*threshold=*/float>> {};
 
-TEST_P(ComputePitchGainThresholdTest, BitExactness) {
+// Checks that the computed pitch gain is within tolerance given test input
+// data.
+TEST_P(ComputePitchGainThresholdTest, WithinTolerance) {
   const auto params = GetParam();
   const size_t candidate_pitch_period = std::get<0>(params);
   const size_t pitch_period_ratio = std::get<1>(params);
   const size_t initial_pitch_period = std::get<2>(params);
   const float initial_pitch_gain = std::get<3>(params);
   const size_t prev_pitch_period = std::get<4>(params);
-  const size_t prev_pitch_gain = std::get<5>(params);
+  const float prev_pitch_gain = std::get<5>(params);
   const float threshold = std::get<6>(params);
-
   {
     // TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
     // FloatingPointExceptionObserver fpe_observer;
-
     EXPECT_NEAR(
         threshold,
         ComputePitchGainThreshold(candidate_pitch_period, pitch_period_ratio,
                                   initial_pitch_period, initial_pitch_gain,
                                   prev_pitch_period, prev_pitch_gain),
-        3e-6f);
+        5e-7f);
   }
 }
 
@@ -77,7 +83,9 @@
         std::make_tuple(31, 5, 153, 0.85069299f, 150, 0.79073799f, 0.72308898f),
         std::make_tuple(78, 2, 156, 0.72750503f, 153, 0.85069299f, 0.618379f)));
 
-TEST(RnnVadTest, ComputeSlidingFrameSquareEnergiesBitExactness) {
+// Checks that the frame-wise sliding square energy function produces output
+// within tolerance given test input data.
+TEST(RnnVadTest, ComputeSlidingFrameSquareEnergiesWithinTolerance) {
   PitchTestData test_data;
   std::array<float, kNumPitchBufSquareEnergies> computed_output;
   {
@@ -91,6 +99,7 @@
                      computed_output, 3e-2f);
 }
 
+// Checks that the estimated pitch period is bit-exact given test input data.
 TEST(RnnVadTest, FindBestPitchPeriodsBitExactness) {
   PitchTestData test_data;
   std::array<float, kBufSize12kHz> pitch_buf_decimated;
@@ -104,14 +113,13 @@
         FindBestPitchPeriods({auto_corr_view.data(), auto_corr_view.size()},
                              pitch_buf_decimated, kMaxPitch12kHz);
   }
-  const std::array<size_t, 2> expected_output = {140, 142};
-  EXPECT_EQ(expected_output, pitch_candidates_inv_lags);
+  EXPECT_EQ(pitch_candidates_inv_lags[0], static_cast<size_t>(140));
+  EXPECT_EQ(pitch_candidates_inv_lags[1], static_cast<size_t>(142));
 }
 
+// Checks that the refined pitch period is bit-exact given test input data.
 TEST(RnnVadTest, RefinePitchPeriod48kHzBitExactness) {
   PitchTestData test_data;
-  std::array<float, kBufSize12kHz> pitch_buf_decimated;
-  Decimate2x(test_data.GetPitchBufView(), pitch_buf_decimated);
   size_t pitch_inv_lag;
   {
     // TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
@@ -125,10 +133,17 @@
 
 class CheckLowerPitchPeriodsAndComputePitchGainTest
     : public ::testing::Test,
-      public ::testing::WithParamInterface<
-          std::tuple<int, int, float, int, float>> {};
+      public ::testing::WithParamInterface<std::tuple<
+          /*initial_pitch_period=*/int,
+          /*prev_pitch_period=*/int,
+          /*prev_pitch_gain=*/float,
+          /*expected_pitch_period=*/int,
+          /*expected_pitch_gain=*/float>> {};
 
-TEST_P(CheckLowerPitchPeriodsAndComputePitchGainTest, BitExactness) {
+// Checks that the computed pitch period is bit-exact and that the computed
+// pitch gain is within tolerance given test input data.
+TEST_P(CheckLowerPitchPeriodsAndComputePitchGainTest,
+       PeriodBitExactnessGainWithinTolerance) {
   const auto params = GetParam();
   const int initial_pitch_period = std::get<0>(params);
   const int prev_pitch_period = std::get<1>(params);
@@ -147,48 +162,49 @@
   }
 }
 
-INSTANTIATE_TEST_SUITE_P(RnnVadTest,
-                         CheckLowerPitchPeriodsAndComputePitchGainTest,
-                         ::testing::Values(std::make_tuple(kTestPitchPeriods[0],
-                                                           kTestPitchPeriods[0],
-                                                           kTestPitchGains[0],
-                                                           91,
-                                                           -0.0188608f),
-                                           std::make_tuple(kTestPitchPeriods[0],
-                                                           kTestPitchPeriods[0],
-                                                           kTestPitchGains[1],
-                                                           91,
-                                                           -0.0188608f),
-                                           std::make_tuple(kTestPitchPeriods[0],
-                                                           kTestPitchPeriods[1],
-                                                           kTestPitchGains[0],
-                                                           91,
-                                                           -0.0188608f),
-                                           std::make_tuple(kTestPitchPeriods[0],
-                                                           kTestPitchPeriods[1],
-                                                           kTestPitchGains[1],
-                                                           91,
-                                                           -0.0188608f),
-                                           std::make_tuple(kTestPitchPeriods[1],
-                                                           kTestPitchPeriods[0],
-                                                           kTestPitchGains[0],
-                                                           475,
-                                                           -0.0904344f),
-                                           std::make_tuple(kTestPitchPeriods[1],
-                                                           kTestPitchPeriods[0],
-                                                           kTestPitchGains[1],
-                                                           475,
-                                                           -0.0904344f),
-                                           std::make_tuple(kTestPitchPeriods[1],
-                                                           kTestPitchPeriods[1],
-                                                           kTestPitchGains[0],
-                                                           475,
-                                                           -0.0904344f),
-                                           std::make_tuple(kTestPitchPeriods[1],
-                                                           kTestPitchPeriods[1],
-                                                           kTestPitchGains[1],
-                                                           475,
-                                                           -0.0904344f)));
+INSTANTIATE_TEST_SUITE_P(
+    RnnVadTest,
+    CheckLowerPitchPeriodsAndComputePitchGainTest,
+    ::testing::Values(std::make_tuple(kTestPitchPeriodsLow,
+                                      kTestPitchPeriodsLow,
+                                      kTestPitchGainsLow,
+                                      91,
+                                      -0.0188608f),
+                      std::make_tuple(kTestPitchPeriodsLow,
+                                      kTestPitchPeriodsLow,
+                                      kTestPitchGainsHigh,
+                                      91,
+                                      -0.0188608f),
+                      std::make_tuple(kTestPitchPeriodsLow,
+                                      kTestPitchPeriodsHigh,
+                                      kTestPitchGainsLow,
+                                      91,
+                                      -0.0188608f),
+                      std::make_tuple(kTestPitchPeriodsLow,
+                                      kTestPitchPeriodsHigh,
+                                      kTestPitchGainsHigh,
+                                      91,
+                                      -0.0188608f),
+                      std::make_tuple(kTestPitchPeriodsHigh,
+                                      kTestPitchPeriodsLow,
+                                      kTestPitchGainsLow,
+                                      475,
+                                      -0.0904344f),
+                      std::make_tuple(kTestPitchPeriodsHigh,
+                                      kTestPitchPeriodsLow,
+                                      kTestPitchGainsHigh,
+                                      475,
+                                      -0.0904344f),
+                      std::make_tuple(kTestPitchPeriodsHigh,
+                                      kTestPitchPeriodsHigh,
+                                      kTestPitchGainsLow,
+                                      475,
+                                      -0.0904344f),
+                      std::make_tuple(kTestPitchPeriodsHigh,
+                                      kTestPitchPeriodsHigh,
+                                      kTestPitchGainsHigh,
+                                      475,
+                                      -0.0904344f)));
 
 }  // namespace test
 }  // namespace rnn_vad
diff --git a/modules/audio_processing/agc2/rnn_vad/pitch_search_unittest.cc b/modules/audio_processing/agc2/rnn_vad/pitch_search_unittest.cc
index eac332e..494dfe7 100644
--- a/modules/audio_processing/agc2/rnn_vad/pitch_search_unittest.cc
+++ b/modules/audio_processing/agc2/rnn_vad/pitch_search_unittest.cc
@@ -12,7 +12,8 @@
 #include "modules/audio_processing/agc2/rnn_vad/pitch_info.h"
 #include "modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h"
 
-#include <array>
+#include <algorithm>
+#include <vector>
 
 #include "modules/audio_processing/agc2/rnn_vad/test_utils.h"
 // TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
@@ -23,11 +24,13 @@
 namespace rnn_vad {
 namespace test {
 
-// TODO(bugs.webrtc.org/9076): Remove when the issue is fixed.
-TEST(RnnVadTest, PitchSearchBitExactness) {
+// Checks that the computed pitch period is bit-exact and that the computed
+// pitch gain is within tolerance given test input data.
+TEST(RnnVadTest, PitchSearchWithinTolerance) {
   auto lp_residual_reader = CreateLpResidualAndPitchPeriodGainReader();
-  const size_t num_frames = lp_residual_reader.second;
-  std::array<float, 864> lp_residual;
+  const size_t num_frames = std::min(lp_residual_reader.second,
+                                     static_cast<size_t>(300));  // Max 3 s.
+  std::vector<float> lp_residual(kBufSize24kHz);
   float expected_pitch_period, expected_pitch_gain;
   PitchEstimator pitch_estimator;
   {
@@ -38,7 +41,8 @@
       lp_residual_reader.first->ReadChunk(lp_residual);
       lp_residual_reader.first->ReadValue(&expected_pitch_period);
       lp_residual_reader.first->ReadValue(&expected_pitch_gain);
-      PitchInfo pitch_info = pitch_estimator.Estimate(lp_residual);
+      PitchInfo pitch_info =
+          pitch_estimator.Estimate({lp_residual.data(), kBufSize24kHz});
       EXPECT_EQ(static_cast<int>(expected_pitch_period), pitch_info.period);
       EXPECT_NEAR(expected_pitch_gain, pitch_info.gain, 1e-5f);
     }
diff --git a/modules/audio_processing/agc2/rnn_vad/rnn_unittest.cc b/modules/audio_processing/agc2/rnn_vad/rnn_unittest.cc
index 289ce8d..933b555 100644
--- a/modules/audio_processing/agc2/rnn_vad/rnn_unittest.cc
+++ b/modules/audio_processing/agc2/rnn_vad/rnn_unittest.cc
@@ -8,9 +8,7 @@
  *  be found in the AUTHORS file in the root of the source tree.
  */
 
-#include <algorithm>
 #include <array>
-#include <vector>
 
 #include "modules/audio_processing/agc2/rnn_vad/rnn.h"
 #include "modules/audio_processing/agc2/rnn_vad/test_utils.h"
@@ -63,7 +61,8 @@
 
 }  // namespace
 
-// Bit-exactness check for fully connected layers.
+// Checks that the output of a fully connected layer is within tolerance given
+// test input data.
 TEST(RnnVadTest, CheckFullyConnectedLayerOutput) {
   const std::array<int8_t, 1> bias = {-50};
   const std::array<int8_t, 24> weights = {
@@ -106,6 +105,8 @@
   }
 }
 
+// Checks that the output of a GRU layer is within tolerance given test input
+// data.
 TEST(RnnVadTest, CheckGatedRecurrentLayer) {
   const std::array<int8_t, 12> bias = {96,   -99, -81, -114, 49,  119,
                                        -118, 68,  -76, 91,   121, 125};
@@ -139,41 +140,6 @@
   }
 }
 
-// TODO(bugs.webrtc.org/9076): Remove when the issue is fixed.
-// Bit-exactness test checking that precomputed frame-wise features lead to the
-// expected VAD probabilities.
-TEST(RnnVadTest, RnnBitExactness) {
-  // Init.
-  auto features_reader = CreateSilenceFlagsFeatureMatrixReader();
-  auto vad_probs_reader = CreateVadProbsReader();
-  ASSERT_EQ(features_reader.second, vad_probs_reader.second);
-  const size_t num_frames = features_reader.second;
-  // Frame-wise buffers.
-  float expected_vad_probability;
-  float is_silence;
-  std::array<float, kFeatureVectorSize> features;
-
-  // Compute VAD probability using the precomputed features.
-  RnnBasedVad vad;
-  for (size_t i = 0; i < num_frames; ++i) {
-    SCOPED_TRACE(i);
-    // Read frame data.
-    RTC_CHECK(vad_probs_reader.first->ReadValue(&expected_vad_probability));
-    // The features file also includes a silence flag for each frame.
-    RTC_CHECK(features_reader.first->ReadValue(&is_silence));
-    RTC_CHECK(features_reader.first->ReadChunk(features));
-    // Compute and check VAD probability.
-    float vad_probability = vad.ComputeVadProbability(features, is_silence);
-    ASSERT_TRUE(is_silence == 0.f || is_silence == 1.f);
-    if (is_silence == 1.f) {
-      ASSERT_EQ(0.f, expected_vad_probability);
-      EXPECT_EQ(0.f, vad_probability);
-    } else {
-      EXPECT_NEAR(expected_vad_probability, vad_probability, 3e-6f);
-    }
-  }
-}
-
 }  // namespace test
 }  // namespace rnn_vad
 }  // namespace webrtc
diff --git a/modules/audio_processing/agc2/rnn_vad/rnn_vad_unittest.cc b/modules/audio_processing/agc2/rnn_vad/rnn_vad_unittest.cc
index 4afe24b..8583d4b 100644
--- a/modules/audio_processing/agc2/rnn_vad/rnn_vad_unittest.cc
+++ b/modules/audio_processing/agc2/rnn_vad/rnn_vad_unittest.cc
@@ -9,6 +9,7 @@
  */
 
 #include <array>
+#include <string>
 #include <vector>
 
 #include "common_audio/resampler/push_sinc_resampler.h"
@@ -43,8 +44,68 @@
   RTC_LOG(LS_INFO) << "speed: " << speed << "x";
 }
 
+// When the RNN VAD model is updated and the expected output changes, set the
+// constant below to true in order to write new expected output binary files.
+constexpr bool kWriteComputedOutputToFile = false;
+
 }  // namespace
 
+// Avoids that one forgets to set |kWriteComputedOutputToFile| back to false
+// when the expected output files are re-exported.
+TEST(RnnVadTest, CheckWriteComputedOutputIsFalse) {
+  ASSERT_FALSE(kWriteComputedOutputToFile)
+      << "Cannot land if kWriteComputedOutput is true.";
+}
+
+// Checks that the computed VAD probability for a test input sequence sampled at
+// 48 kHz is within tolerance.
+TEST(RnnVadTest, RnnVadProbabilityWithinTolerance) {
+  // Init resampler, feature extractor and RNN.
+  PushSincResampler decimator(kFrameSize10ms48kHz, kFrameSize10ms24kHz);
+  FeaturesExtractor features_extractor;
+  RnnBasedVad rnn_vad;
+
+  // Init input samples and expected output readers.
+  auto samples_reader = CreatePcmSamplesReader(kFrameSize10ms48kHz);
+  auto expected_vad_prob_reader = CreateVadProbsReader();
+
+  // Input length.
+  const size_t num_frames = samples_reader.second;
+  ASSERT_GE(expected_vad_prob_reader.second, num_frames);
+
+  // Init buffers.
+  std::vector<float> samples_48k(kFrameSize10ms48kHz);
+  std::vector<float> samples_24k(kFrameSize10ms24kHz);
+  std::vector<float> feature_vector(kFeatureVectorSize);
+  std::vector<float> computed_vad_prob(num_frames);
+  std::vector<float> expected_vad_prob(num_frames);
+
+  // Read expected output.
+  ASSERT_TRUE(expected_vad_prob_reader.first->ReadChunk(expected_vad_prob));
+
+  // Compute VAD probabilities on the downsampled input.
+  float cumulative_error = 0.f;
+  for (size_t i = 0; i < num_frames; ++i) {
+    samples_reader.first->ReadChunk(samples_48k);
+    decimator.Resample(samples_48k.data(), samples_48k.size(),
+                       samples_24k.data(), samples_24k.size());
+    bool is_silence = features_extractor.CheckSilenceComputeFeatures(
+        {samples_24k.data(), kFrameSize10ms24kHz},
+        {feature_vector.data(), kFeatureVectorSize});
+    computed_vad_prob[i] = rnn_vad.ComputeVadProbability(
+        {feature_vector.data(), kFeatureVectorSize}, is_silence);
+    EXPECT_NEAR(computed_vad_prob[i], expected_vad_prob[i], 1e-3f);
+    cumulative_error += std::abs(computed_vad_prob[i] - expected_vad_prob[i]);
+  }
+  // Check average error.
+  EXPECT_LT(cumulative_error / num_frames, 1e-4f);
+
+  if (kWriteComputedOutputToFile) {
+    BinaryFileWriter<float> vad_prob_writer("new_vad_prob.dat");
+    vad_prob_writer.WriteChunk(computed_vad_prob);
+  }
+}
+
 // Performance test for the RNN VAD (pre-fetching and downsampling are
 // excluded). Keep disabled and only enable locally to measure performance as
 // follows:
diff --git a/modules/audio_processing/agc2/rnn_vad/spectral_features_internal_unittest.cc b/modules/audio_processing/agc2/rnn_vad/spectral_features_internal_unittest.cc
index d112eb7..ec81295 100644
--- a/modules/audio_processing/agc2/rnn_vad/spectral_features_internal_unittest.cc
+++ b/modules/audio_processing/agc2/rnn_vad/spectral_features_internal_unittest.cc
@@ -85,14 +85,18 @@
   }
 }
 
+// Checks that the computed band-wise auto-correlation is non-negative for a
+// simple input vector of FFT coefficients.
 TEST(RnnVadTest, SpectralCorrelatorValidOutput) {
-  SpectralCorrelator e;
+  // Input: vector of (1, 1j) values.
   Pffft fft(kFrameSize20ms24kHz, Pffft::FftType::kReal);
   auto in = fft.CreateBuffer();
   std::array<float, kOpusBands24kHz> out;
   auto in_view = in->GetView();
   std::fill(in_view.begin(), in_view.end(), 1.f);
   in_view[1] = 0.f;  // Nyquist frequency.
+  // Compute and check output.
+  SpectralCorrelator e;
   e.ComputeAutoCorrelation(in_view, out);
   for (size_t i = 0; i < kOpusBands24kHz; ++i) {
     SCOPED_TRACE(i);
@@ -100,6 +104,8 @@
   }
 }
 
+// Checks that the computed smoothed log magnitude spectrum is within tolerance
+// given hard-coded test input data.
 TEST(RnnVadTest, ComputeSmoothedLogMagnitudeSpectrumWithinTolerance) {
   constexpr std::array<float, kNumBands> input = {
       {86.060539245605f, 275.668334960938f, 43.406528472900f, 6.541896820068f,
@@ -124,7 +130,9 @@
   }
 }
 
-TEST(RnnVadTest, ComputeDctBitExactness) {
+// Checks that the computed DCT is within tolerance given hard-coded test input
+// data.
+TEST(RnnVadTest, ComputeDctWithinTolerance) {
   constexpr std::array<float, kNumBands> input = {
       {0.232155621052f,  0.678957760334f, 0.220818966627f,  -0.077363930643f,
        -0.559227049351f, 0.432545185089f, 0.353900641203f,  0.398993015289f,
diff --git a/modules/audio_processing/agc2/rnn_vad/spectral_features_unittest.cc b/modules/audio_processing/agc2/rnn_vad/spectral_features_unittest.cc
index 39b9f93..bc00e2c 100644
--- a/modules/audio_processing/agc2/rnn_vad/spectral_features_unittest.cc
+++ b/modules/audio_processing/agc2/rnn_vad/spectral_features_unittest.cc
@@ -67,6 +67,8 @@
 
 }  // namespace
 
+// Checks that silence is detected when the input signal is 0 and that the
+// feature vector is written only if the input signal is not tagged as silence.
 TEST(RnnVadTest, SpectralFeaturesWithAndWithoutSilence) {
   // Initialize.
   SpectralFeaturesExtractor sfe;
@@ -108,9 +110,10 @@
                            [](float x) { return x == kInitialFeatureVal; }));
 }
 
-// When the input signal does not change, the cepstral coefficients average does
-// not change and the derivatives are zero. Similarly, the cepstral variability
-// score does not change either.
+// Feeds a constant input signal and checks that:
+// - the cepstral coefficients average does not change;
+// - the derivatives are zero;
+// - the cepstral variability score does not change.
 TEST(RnnVadTest, CepstralFeaturesConstantAverageZeroDerivative) {
   // Initialize.
   SpectralFeaturesExtractor sfe;
diff --git a/modules/audio_processing/agc2/rnn_vad/test_utils.cc b/modules/audio_processing/agc2/rnn_vad/test_utils.cc
index 14b84a4..8236d5f 100644
--- a/modules/audio_processing/agc2/rnn_vad/test_utils.cc
+++ b/modules/audio_processing/agc2/rnn_vad/test_utils.cc
@@ -46,13 +46,6 @@
   }
 }
 
-std::unique_ptr<BinaryFileReader<float>> CreatePitchSearchTestDataReader() {
-  constexpr size_t cols = 1396;
-  return absl::make_unique<BinaryFileReader<float>>(
-      ResourcePath("audio_processing/agc2/rnn_vad/pitch_search_int", "dat"),
-      cols);
-}
-
 std::pair<std::unique_ptr<BinaryFileReader<int16_t, float>>, const size_t>
 CreatePcmSamplesReader(const size_t frame_length) {
   auto ptr = absl::make_unique<BinaryFileReader<int16_t, float>>(
@@ -78,25 +71,6 @@
           rtc::CheckedDivExact(ptr->data_length(), 2 + num_lp_residual_coeffs)};
 }
 
-ReaderPairType CreateFftCoeffsReader() {
-  constexpr size_t num_fft_points = 481;
-  constexpr size_t row_size = 2 * num_fft_points;  // Real and imaginary values.
-  auto ptr = absl::make_unique<BinaryFileReader<float>>(
-      test::ResourcePath("audio_processing/agc2/rnn_vad/fft", "dat"),
-      num_fft_points);
-  return {std::move(ptr), rtc::CheckedDivExact(ptr->data_length(), row_size)};
-}
-
-ReaderPairType CreateSilenceFlagsFeatureMatrixReader() {
-  constexpr size_t feature_vector_size = 42;
-  auto ptr = absl::make_unique<BinaryFileReader<float>>(
-      test::ResourcePath("audio_processing/agc2/rnn_vad/sil_features", "dat"),
-      feature_vector_size);
-  // Features and silence flag.
-  return {std::move(ptr),
-          rtc::CheckedDivExact(ptr->data_length(), feature_vector_size + 1)};
-}
-
 ReaderPairType CreateVadProbsReader() {
   auto ptr = absl::make_unique<BinaryFileReader<float>>(
       test::ResourcePath("audio_processing/agc2/rnn_vad/vad_prob", "dat"));
@@ -104,23 +78,26 @@
 }
 
 PitchTestData::PitchTestData() {
-  auto test_data_reader = CreatePitchSearchTestDataReader();
-  test_data_reader->ReadChunk(test_data_);
+  BinaryFileReader<float> test_data_reader(
+      ResourcePath("audio_processing/agc2/rnn_vad/pitch_search_int", "dat"),
+      static_cast<size_t>(1396));
+  test_data_reader.ReadChunk(test_data_);
 }
 
 PitchTestData::~PitchTestData() = default;
 
-rtc::ArrayView<const float, kBufSize24kHz> PitchTestData::GetPitchBufView() {
+rtc::ArrayView<const float, kBufSize24kHz> PitchTestData::GetPitchBufView()
+    const {
   return {test_data_.data(), kBufSize24kHz};
 }
 
 rtc::ArrayView<const float, kNumPitchBufSquareEnergies>
-PitchTestData::GetPitchBufSquareEnergiesView() {
+PitchTestData::GetPitchBufSquareEnergiesView() const {
   return {test_data_.data() + kBufSize24kHz, kNumPitchBufSquareEnergies};
 }
 
 rtc::ArrayView<const float, kNumPitchBufAutoCorrCoeffs>
-PitchTestData::GetPitchBufAutoCorrCoeffsView() {
+PitchTestData::GetPitchBufAutoCorrCoeffsView() const {
   return {test_data_.data() + kBufSize24kHz + kNumPitchBufSquareEnergies,
           kNumPitchBufAutoCorrCoeffs};
 }
diff --git a/modules/audio_processing/agc2/rnn_vad/test_utils.h b/modules/audio_processing/agc2/rnn_vad/test_utils.h
index c11af7f..fbb270f 100644
--- a/modules/audio_processing/agc2/rnn_vad/test_utils.h
+++ b/modules/audio_processing/agc2/rnn_vad/test_utils.h
@@ -17,6 +17,7 @@
 #include <limits>
 #include <memory>
 #include <string>
+#include <type_traits>
 #include <utility>
 #include <vector>
 
@@ -46,11 +47,10 @@
 template <typename T, typename D = T>
 class BinaryFileReader {
  public:
-  explicit BinaryFileReader(const std::string& file_path, size_t chunk_size = 1)
+  explicit BinaryFileReader(const std::string& file_path, size_t chunk_size = 0)
       : is_(file_path, std::ios::binary | std::ios::ate),
         data_length_(is_.tellg() / sizeof(T)),
         chunk_size_(chunk_size) {
-    RTC_CHECK_LT(0, chunk_size_);
     RTC_CHECK(is_);
     SeekBeginning();
     buf_.resize(chunk_size_);
@@ -69,9 +69,11 @@
     }
     return is_.gcount() == sizeof(T);
   }
+  // If |chunk_size| was specified in the ctor, it will check that the size of
+  // |dst| equals |chunk_size|.
   bool ReadChunk(rtc::ArrayView<D> dst) {
-    RTC_DCHECK_EQ(chunk_size_, dst.size());
-    const std::streamsize bytes_to_read = chunk_size_ * sizeof(T);
+    RTC_DCHECK((chunk_size_ == 0) || (chunk_size_ == dst.size()));
+    const std::streamsize bytes_to_read = dst.size() * sizeof(T);
     if (std::is_same<T, D>::value) {
       is_.read(reinterpret_cast<char*>(dst.data()), bytes_to_read);
     } else {
@@ -91,9 +93,26 @@
   std::vector<T> buf_;
 };
 
+// Writer for binary files.
+template <typename T>
+class BinaryFileWriter {
+ public:
+  explicit BinaryFileWriter(const std::string& file_path)
+      : os_(file_path, std::ios::binary) {}
+  BinaryFileWriter(const BinaryFileWriter&) = delete;
+  BinaryFileWriter& operator=(const BinaryFileWriter&) = delete;
+  ~BinaryFileWriter() = default;
+  static_assert(std::is_arithmetic<T>::value, "");
+  void WriteChunk(rtc::ArrayView<const T> value) {
+    const std::streamsize bytes_to_write = value.size() * sizeof(T);
+    os_.write(reinterpret_cast<const char*>(value.data()), bytes_to_write);
+  }
+
+ private:
+  std::ofstream os_;
+};
+
 // Factories for resource file readers.
-// Creates a reader for the pitch search test data.
-std::unique_ptr<BinaryFileReader<float>> CreatePitchSearchTestDataReader();
 // The functions below return a pair where the first item is a reader unique
 // pointer and the second the number of chunks that can be read from the file.
 // Creates a reader for the PCM samples that casts from S16 to float and reads
@@ -107,12 +126,6 @@
 // and gain values.
 std::pair<std::unique_ptr<BinaryFileReader<float>>, const size_t>
 CreateLpResidualAndPitchPeriodGainReader();
-// Creates a reader for the FFT coefficients.
-std::pair<std::unique_ptr<BinaryFileReader<float>>, const size_t>
-CreateFftCoeffsReader();
-// Creates a reader for the silence flags and the feature matrix.
-std::pair<std::unique_ptr<BinaryFileReader<float>>, const size_t>
-CreateSilenceFlagsFeatureMatrixReader();
 // Creates a reader for the VAD probabilities.
 std::pair<std::unique_ptr<BinaryFileReader<float>>, const size_t>
 CreateVadProbsReader();
@@ -128,11 +141,11 @@
  public:
   PitchTestData();
   ~PitchTestData();
-  rtc::ArrayView<const float, kBufSize24kHz> GetPitchBufView();
+  rtc::ArrayView<const float, kBufSize24kHz> GetPitchBufView() const;
   rtc::ArrayView<const float, kNumPitchBufSquareEnergies>
-  GetPitchBufSquareEnergiesView();
+  GetPitchBufSquareEnergiesView() const;
   rtc::ArrayView<const float, kNumPitchBufAutoCorrCoeffs>
-  GetPitchBufAutoCorrCoeffsView();
+  GetPitchBufAutoCorrCoeffsView() const;
 
  private:
   std::array<float, kPitchTestDataSize> test_data_;
diff --git a/resources/audio_processing/agc2/rnn_vad/fft.dat.sha1 b/resources/audio_processing/agc2/rnn_vad/fft.dat.sha1
deleted file mode 100644
index ebd5124..0000000
--- a/resources/audio_processing/agc2/rnn_vad/fft.dat.sha1
+++ /dev/null
@@ -1 +0,0 @@
-e62364d35abd123663bfc800fa233071d6d7fffd
\ No newline at end of file
diff --git a/resources/audio_processing/agc2/rnn_vad/sil_features.dat.sha1 b/resources/audio_processing/agc2/rnn_vad/sil_features.dat.sha1
deleted file mode 100644
index bc591e9..0000000
--- a/resources/audio_processing/agc2/rnn_vad/sil_features.dat.sha1
+++ /dev/null
@@ -1 +0,0 @@
-e0a92782c2903be9da10385d924d34e8bf212d5e
\ No newline at end of file
diff --git a/resources/audio_processing/agc2/rnn_vad/vad_prob.dat.sha1 b/resources/audio_processing/agc2/rnn_vad/vad_prob.dat.sha1
index 1aa3bd0..8ee78b1 100644
--- a/resources/audio_processing/agc2/rnn_vad/vad_prob.dat.sha1
+++ b/resources/audio_processing/agc2/rnn_vad/vad_prob.dat.sha1
@@ -1 +1 @@
-05735ede0b457318e307d12f5acfd11bbbbd0afd
\ No newline at end of file
+68640327266262c3fe047ec7f07a46a355ff90b9
\ No newline at end of file