FFT-based auto correlation.

During pitch search in the RNN VAD, we calculate auto
correlation. Before this CL, we computed kNumInvertedLags12kHz=147 dot
products of vectors with kBufSize12kHz-kMaxPitch12kHz=240
elements. This was the most time consuming step of the new VAD.

This CL makes the computation happen in frequency domain. Profiling
shows a 3x speed increase. In future, we can try using a more efficient
FFT and to reduce the FFT length to some of e.g. 400, 405, 432.

# For minimal Clang plugin check change.
TBR: kwiberg@webrtc.org

Bug: webrtc:9076
Change-Id: I688251a415869d53175a37f390f441d4e035d954
Reviewed-on: https://webrtc-review.googlesource.com/73366
Reviewed-by: Karl Wiberg <kwiberg@webrtc.org>
Reviewed-by: Alessio Bazzica <alessiob@webrtc.org>
Commit-Queue: Alex Loiko <aleloi@webrtc.org>
Cr-Commit-Position: refs/heads/master@{#23171}
diff --git a/common_audio/real_fourier_ooura.cc b/common_audio/real_fourier_ooura.cc
index 5d75717..ca043e4 100644
--- a/common_audio/real_fourier_ooura.cc
+++ b/common_audio/real_fourier_ooura.cc
@@ -45,6 +45,8 @@
   RTC_CHECK_GE(fft_order, 1);
 }
 
+RealFourierOoura::~RealFourierOoura() = default;
+
 void RealFourierOoura::Forward(const float* src, complex<float>* dest) const {
   {
     // This cast is well-defined since C++11. See "Non-static data members" at:
@@ -82,4 +84,8 @@
   std::for_each(dest, dest + length_, [scale](float& v) { v *= scale; });
 }
 
+int RealFourierOoura::order() const {
+  return order_;
+}
+
 }  // namespace webrtc
diff --git a/common_audio/real_fourier_ooura.h b/common_audio/real_fourier_ooura.h
index f885a34..bb8eef9 100644
--- a/common_audio/real_fourier_ooura.h
+++ b/common_audio/real_fourier_ooura.h
@@ -21,13 +21,12 @@
 class RealFourierOoura : public RealFourier {
  public:
   explicit RealFourierOoura(int fft_order);
+  ~RealFourierOoura() override;
 
   void Forward(const float* src, std::complex<float>* dest) const override;
   void Inverse(const std::complex<float>* src, float* dest) const override;
 
-  int order() const override {
-    return order_;
-  }
+  int order() const override;
 
  private:
   const int order_;
@@ -42,4 +41,3 @@
 }  // namespace webrtc
 
 #endif  // COMMON_AUDIO_REAL_FOURIER_OOURA_H_
-
diff --git a/modules/audio_processing/agc2/rnn_vad/BUILD.gn b/modules/audio_processing/agc2/rnn_vad/BUILD.gn
index b0ca347..f35d5c3 100644
--- a/modules/audio_processing/agc2/rnn_vad/BUILD.gn
+++ b/modules/audio_processing/agc2/rnn_vad/BUILD.gn
@@ -36,6 +36,7 @@
   ]
   deps = [
     "../../../../api:array_view",
+    "../../../../common_audio/",
     "../../../../rtc_base:checks",
     "../../../../rtc_base:rtc_base_approved",
     "//third_party/rnnoise:kiss_fft",
@@ -95,6 +96,7 @@
       ":lib",
       ":lib_test",
       "../../../../api:array_view",
+      "../../../../common_audio/",
       "../../../../rtc_base:checks",
       "../../../../test:test_support",
       "//third_party/rnnoise:rnn_vad",
diff --git a/modules/audio_processing/agc2/rnn_vad/pitch_search.cc b/modules/audio_processing/agc2/rnn_vad/pitch_search.cc
index 4d83588..9261935 100644
--- a/modules/audio_processing/agc2/rnn_vad/pitch_search.cc
+++ b/modules/audio_processing/agc2/rnn_vad/pitch_search.cc
@@ -15,7 +15,8 @@
 namespace rnn_vad {
 
 PitchInfo PitchSearch(rtc::ArrayView<const float, kBufSize24kHz> pitch_buf,
-                      PitchInfo prev_pitch_48kHz) {
+                      PitchInfo prev_pitch_48kHz,
+                      RealFourier* fft) {
   // Perform the initial pitch search at 12 kHz.
   std::array<float, kBufSize12kHz> pitch_buf_decimated;
   Decimate2x(pitch_buf,
@@ -24,7 +25,8 @@
   std::array<float, kNumInvertedLags12kHz> auto_corr;
   ComputePitchAutoCorrelation(
       {pitch_buf_decimated.data(), pitch_buf_decimated.size()}, kMaxPitch12kHz,
-      {auto_corr.data(), auto_corr.size()});
+      {auto_corr.data(), auto_corr.size()}, fft);
+
   // Search for pitch at 12 kHz.
   std::array<size_t, 2> pitch_candidates_inv_lags = FindBestPitchPeriods(
       {auto_corr.data(), auto_corr.size()},
diff --git a/modules/audio_processing/agc2/rnn_vad/pitch_search.h b/modules/audio_processing/agc2/rnn_vad/pitch_search.h
index a0af0eb..21e7a05 100644
--- a/modules/audio_processing/agc2/rnn_vad/pitch_search.h
+++ b/modules/audio_processing/agc2/rnn_vad/pitch_search.h
@@ -12,6 +12,7 @@
 #define MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_PITCH_SEARCH_H_
 
 #include "api/array_view.h"
+#include "common_audio/real_fourier.h"
 #include "modules/audio_processing/agc2/rnn_vad/common.h"
 #include "modules/audio_processing/agc2/rnn_vad/pitch_info.h"
 
@@ -21,7 +22,8 @@
 // Searches the pitch period and gain. Return the pitch estimation data for
 // 48 kHz.
 PitchInfo PitchSearch(rtc::ArrayView<const float, kBufSize24kHz> pitch_buf,
-                      PitchInfo prev_pitch_48kHz);
+                      PitchInfo prev_pitch_48kHz,
+                      RealFourier* fft);
 
 }  // namespace rnn_vad
 }  // namespace webrtc
diff --git a/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.cc b/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.cc
index 1ff4621..99600e0 100644
--- a/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.cc
+++ b/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.cc
@@ -205,18 +205,62 @@
   }
 }
 
-// TODO(bugs.webrtc.org/9076): Optimize using FFT and/or vectorization.
 void ComputePitchAutoCorrelation(
     rtc::ArrayView<const float, kBufSize12kHz> pitch_buf,
     size_t max_pitch_period,
-    rtc::ArrayView<float, kNumInvertedLags12kHz> auto_corr) {
+    rtc::ArrayView<float, kNumInvertedLags12kHz> auto_corr,
+    webrtc::RealFourier* fft) {
   RTC_DCHECK_GT(max_pitch_period, auto_corr.size());
   RTC_DCHECK_LT(max_pitch_period, pitch_buf.size());
-  // Compute auto-correlation coefficients.
-  for (size_t inv_lag = 0; inv_lag < auto_corr.size(); ++inv_lag) {
-    auto_corr[inv_lag] =
-        ComputeAutoCorrelationCoeff(pitch_buf, inv_lag, max_pitch_period);
+  RTC_DCHECK(fft);
+
+  constexpr size_t time_domain_fft_length = 1 << kAutoCorrelationFftOrder;
+  constexpr size_t freq_domain_fft_length = time_domain_fft_length / 2 + 1;
+
+  RTC_DCHECK_EQ(RealFourier::FftLength(fft->order()), time_domain_fft_length);
+  RTC_DCHECK_EQ(RealFourier::ComplexLength(fft->order()),
+                freq_domain_fft_length);
+
+  // Cross-correlation of y_i=pitch_buf[i:i+convolution_length] and
+  // x=pitch_buf[-convolution_length:] is equivalent to convolution of
+  // y_i and reversed(x). New notation: h=reversed(x), x=y.
+  std::array<float, time_domain_fft_length> h{};
+  std::array<float, time_domain_fft_length> x{};
+
+  const size_t convolution_length = kBufSize12kHz - max_pitch_period;
+  // Check that the FFT-length is big enough to avoid cyclic
+  // convolution errors.
+  RTC_DCHECK_GT(time_domain_fft_length,
+                kNumInvertedLags12kHz + convolution_length);
+
+  // h[0:convolution_length] is reversed pitch_buf[-convolution_length:].
+  std::reverse_copy(pitch_buf.end() - convolution_length, pitch_buf.end(),
+                    h.begin());
+
+  // x is pitch_buf[:kNumInvertedLags12kHz + convolution_length].
+  std::copy(pitch_buf.begin(),
+            pitch_buf.begin() + kNumInvertedLags12kHz + convolution_length,
+            x.begin());
+
+  // Shift to frequency domain.
+  std::array<std::complex<float>, freq_domain_fft_length> X{};
+  std::array<std::complex<float>, freq_domain_fft_length> H{};
+  fft->Forward(&x[0], &X[0]);
+  fft->Forward(&h[0], &H[0]);
+
+  // Convolve in frequency domain.
+  for (size_t i = 0; i < X.size(); ++i) {
+    X[i] *= H[i];
   }
+
+  // Shift back to time domain.
+  std::array<float, time_domain_fft_length> x_conv_h;
+  fft->Inverse(&X[0], &x_conv_h[0]);
+
+  // Collect the result.
+  std::copy(x_conv_h.begin() + convolution_length - 1,
+            x_conv_h.begin() + convolution_length + kNumInvertedLags12kHz - 1,
+            auto_corr.begin());
 }
 
 std::array<size_t, 2> FindBestPitchPeriods(
diff --git a/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h b/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h
index dfe1b35..75f7f17 100644
--- a/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h
+++ b/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h
@@ -14,6 +14,7 @@
 #include <array>
 
 #include "api/array_view.h"
+#include "common_audio/real_fourier.h"
 #include "modules/audio_processing/agc2/rnn_vad/common.h"
 #include "modules/audio_processing/agc2/rnn_vad/pitch_info.h"
 
@@ -26,6 +27,11 @@
 static_assert(kMaxPitch24kHz > kInitialMinPitch24kHz, "");
 constexpr size_t kNumInvertedLags12kHz = kMaxPitch12kHz - kInitialMinPitch12kHz;
 constexpr size_t kNumInvertedLags24kHz = kMaxPitch24kHz - kInitialMinPitch24kHz;
+constexpr int kAutoCorrelationFftOrder = 9;  // Length-512 FFT.
+
+static_assert(1 << kAutoCorrelationFftOrder >
+                  kNumInvertedLags12kHz + kBufSize12kHz - kMaxPitch12kHz,
+              "");
 
 // Performs 2x decimation without any anti-aliasing filter.
 void Decimate2x(rtc::ArrayView<const float, kBufSize24kHz> src,
@@ -70,7 +76,8 @@
 void ComputePitchAutoCorrelation(
     rtc::ArrayView<const float, kBufSize12kHz> pitch_buf,
     size_t max_pitch_period,
-    rtc::ArrayView<float, kNumInvertedLags12kHz> auto_corr);
+    rtc::ArrayView<float, kNumInvertedLags12kHz> auto_corr,
+    webrtc::RealFourier* fft);
 
 // Given the auto-correlation coefficients stored according to
 // ComputePitchAutoCorrelation() (i.e., using inverted lags), returns the best
diff --git a/modules/audio_processing/agc2/rnn_vad/pitch_search_internal_unittest.cc b/modules/audio_processing/agc2/rnn_vad/pitch_search_internal_unittest.cc
index 9a6a267..4b1be0d 100644
--- a/modules/audio_processing/agc2/rnn_vad/pitch_search_internal_unittest.cc
+++ b/modules/audio_processing/agc2/rnn_vad/pitch_search_internal_unittest.cc
@@ -9,6 +9,7 @@
  */
 
 #include "modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h"
+#include "common_audio/real_fourier.h"
 
 #include <array>
 #include <tuple>
@@ -415,16 +416,47 @@
   {
     // TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
     // FloatingPointExceptionObserver fpe_observer;
-
+    std::unique_ptr<RealFourier> fft =
+        RealFourier::Create(kAutoCorrelationFftOrder);
     ComputePitchAutoCorrelation(
         {pitch_buf_decimated.data(), pitch_buf_decimated.size()},
-        kMaxPitch12kHz, {computed_output.data(), computed_output.size()});
+        kMaxPitch12kHz, {computed_output.data(), computed_output.size()},
+        fft.get());
   }
   ExpectNearAbsolute(
       {kPitchBufferAutoCorrCoeffs.data(), kPitchBufferAutoCorrCoeffs.size()},
       {computed_output.data(), computed_output.size()}, 3e-3f);
 }
 
+// Check that the auto correlation function computes the right thing for a
+// simple use case.
+TEST(RnnVadTest, ComputePitchAutoCorrelationConstantBuffer) {
+  // Create constant signal with no pitch.
+  std::array<float, kBufSize12kHz> pitch_buf_decimated;
+  std::fill(pitch_buf_decimated.begin(), pitch_buf_decimated.end(), 1.f);
+
+  std::array<float, kPitchBufferAutoCorrCoeffs.size()> computed_output;
+  {
+    // TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
+    // FloatingPointExceptionObserver fpe_observer;
+
+    std::unique_ptr<RealFourier> fft =
+        RealFourier::Create(kAutoCorrelationFftOrder);
+    ComputePitchAutoCorrelation(
+        {pitch_buf_decimated.data(), pitch_buf_decimated.size()},
+        kMaxPitch12kHz, {computed_output.data(), computed_output.size()},
+        fft.get());
+  }
+
+  // The expected output is constantly the length of the fixed 'x'
+  // array in ComputePitchAutoCorrelation.
+  std::array<float, kPitchBufferAutoCorrCoeffs.size()> expected_output;
+  std::fill(expected_output.begin(), expected_output.end(),
+            kBufSize12kHz - kMaxPitch12kHz);
+  ExpectNearAbsolute({expected_output.data(), expected_output.size()},
+                     {computed_output.data(), computed_output.size()}, 4e-5f);
+}
+
 TEST(RnnVadTest, FindBestPitchPeriodsBitExactness) {
   std::array<float, kBufSize12kHz> pitch_buf_decimated;
   Decimate2x({kPitchBufferData.data(), kPitchBufferData.size()},
diff --git a/modules/audio_processing/agc2/rnn_vad/pitch_search_unittest.cc b/modules/audio_processing/agc2/rnn_vad/pitch_search_unittest.cc
index 4417764..b25aba3 100644
--- a/modules/audio_processing/agc2/rnn_vad/pitch_search_unittest.cc
+++ b/modules/audio_processing/agc2/rnn_vad/pitch_search_unittest.cc
@@ -9,6 +9,7 @@
  */
 
 #include "modules/audio_processing/agc2/rnn_vad/pitch_search.h"
+#include "modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h"
 
 #include <array>
 
@@ -28,6 +29,8 @@
   std::array<float, 864> lp_residual;
   float expected_pitch_period, expected_pitch_gain;
   PitchInfo last_pitch;
+  std::unique_ptr<RealFourier> fft =
+      RealFourier::Create(kAutoCorrelationFftOrder);
   {
     // TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
     // FloatingPointExceptionObserver fpe_observer;
@@ -38,8 +41,8 @@
           {lp_residual.data(), lp_residual.size()});
       lp_residual_reader.first->ReadValue(&expected_pitch_period);
       lp_residual_reader.first->ReadValue(&expected_pitch_gain);
-      last_pitch =
-          PitchSearch({lp_residual.data(), lp_residual.size()}, last_pitch);
+      last_pitch = PitchSearch({lp_residual.data(), lp_residual.size()},
+                               last_pitch, fft.get());
       EXPECT_EQ(static_cast<size_t>(expected_pitch_period), last_pitch.period);
       EXPECT_NEAR(expected_pitch_gain, last_pitch.gain, 1e-5f);
     }