FFT-based auto correlation.
During pitch search in the RNN VAD, we calculate auto
correlation. Before this CL, we computed kNumInvertedLags12kHz=147 dot
products of vectors with kBufSize12kHz-kMaxPitch12kHz=240
elements. This was the most time consuming step of the new VAD.
This CL makes the computation happen in frequency domain. Profiling
shows a 3x speed increase. In future, we can try using a more efficient
FFT and to reduce the FFT length to some of e.g. 400, 405, 432.
# For minimal Clang plugin check change.
TBR: kwiberg@webrtc.org
Bug: webrtc:9076
Change-Id: I688251a415869d53175a37f390f441d4e035d954
Reviewed-on: https://webrtc-review.googlesource.com/73366
Reviewed-by: Karl Wiberg <kwiberg@webrtc.org>
Reviewed-by: Alessio Bazzica <alessiob@webrtc.org>
Commit-Queue: Alex Loiko <aleloi@webrtc.org>
Cr-Commit-Position: refs/heads/master@{#23171}
diff --git a/common_audio/real_fourier_ooura.cc b/common_audio/real_fourier_ooura.cc
index 5d75717..ca043e4 100644
--- a/common_audio/real_fourier_ooura.cc
+++ b/common_audio/real_fourier_ooura.cc
@@ -45,6 +45,8 @@
RTC_CHECK_GE(fft_order, 1);
}
+RealFourierOoura::~RealFourierOoura() = default;
+
void RealFourierOoura::Forward(const float* src, complex<float>* dest) const {
{
// This cast is well-defined since C++11. See "Non-static data members" at:
@@ -82,4 +84,8 @@
std::for_each(dest, dest + length_, [scale](float& v) { v *= scale; });
}
+int RealFourierOoura::order() const {
+ return order_;
+}
+
} // namespace webrtc
diff --git a/common_audio/real_fourier_ooura.h b/common_audio/real_fourier_ooura.h
index f885a34..bb8eef9 100644
--- a/common_audio/real_fourier_ooura.h
+++ b/common_audio/real_fourier_ooura.h
@@ -21,13 +21,12 @@
class RealFourierOoura : public RealFourier {
public:
explicit RealFourierOoura(int fft_order);
+ ~RealFourierOoura() override;
void Forward(const float* src, std::complex<float>* dest) const override;
void Inverse(const std::complex<float>* src, float* dest) const override;
- int order() const override {
- return order_;
- }
+ int order() const override;
private:
const int order_;
@@ -42,4 +41,3 @@
} // namespace webrtc
#endif // COMMON_AUDIO_REAL_FOURIER_OOURA_H_
-
diff --git a/modules/audio_processing/agc2/rnn_vad/BUILD.gn b/modules/audio_processing/agc2/rnn_vad/BUILD.gn
index b0ca347..f35d5c3 100644
--- a/modules/audio_processing/agc2/rnn_vad/BUILD.gn
+++ b/modules/audio_processing/agc2/rnn_vad/BUILD.gn
@@ -36,6 +36,7 @@
]
deps = [
"../../../../api:array_view",
+ "../../../../common_audio/",
"../../../../rtc_base:checks",
"../../../../rtc_base:rtc_base_approved",
"//third_party/rnnoise:kiss_fft",
@@ -95,6 +96,7 @@
":lib",
":lib_test",
"../../../../api:array_view",
+ "../../../../common_audio/",
"../../../../rtc_base:checks",
"../../../../test:test_support",
"//third_party/rnnoise:rnn_vad",
diff --git a/modules/audio_processing/agc2/rnn_vad/pitch_search.cc b/modules/audio_processing/agc2/rnn_vad/pitch_search.cc
index 4d83588..9261935 100644
--- a/modules/audio_processing/agc2/rnn_vad/pitch_search.cc
+++ b/modules/audio_processing/agc2/rnn_vad/pitch_search.cc
@@ -15,7 +15,8 @@
namespace rnn_vad {
PitchInfo PitchSearch(rtc::ArrayView<const float, kBufSize24kHz> pitch_buf,
- PitchInfo prev_pitch_48kHz) {
+ PitchInfo prev_pitch_48kHz,
+ RealFourier* fft) {
// Perform the initial pitch search at 12 kHz.
std::array<float, kBufSize12kHz> pitch_buf_decimated;
Decimate2x(pitch_buf,
@@ -24,7 +25,8 @@
std::array<float, kNumInvertedLags12kHz> auto_corr;
ComputePitchAutoCorrelation(
{pitch_buf_decimated.data(), pitch_buf_decimated.size()}, kMaxPitch12kHz,
- {auto_corr.data(), auto_corr.size()});
+ {auto_corr.data(), auto_corr.size()}, fft);
+
// Search for pitch at 12 kHz.
std::array<size_t, 2> pitch_candidates_inv_lags = FindBestPitchPeriods(
{auto_corr.data(), auto_corr.size()},
diff --git a/modules/audio_processing/agc2/rnn_vad/pitch_search.h b/modules/audio_processing/agc2/rnn_vad/pitch_search.h
index a0af0eb..21e7a05 100644
--- a/modules/audio_processing/agc2/rnn_vad/pitch_search.h
+++ b/modules/audio_processing/agc2/rnn_vad/pitch_search.h
@@ -12,6 +12,7 @@
#define MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_PITCH_SEARCH_H_
#include "api/array_view.h"
+#include "common_audio/real_fourier.h"
#include "modules/audio_processing/agc2/rnn_vad/common.h"
#include "modules/audio_processing/agc2/rnn_vad/pitch_info.h"
@@ -21,7 +22,8 @@
// Searches the pitch period and gain. Return the pitch estimation data for
// 48 kHz.
PitchInfo PitchSearch(rtc::ArrayView<const float, kBufSize24kHz> pitch_buf,
- PitchInfo prev_pitch_48kHz);
+ PitchInfo prev_pitch_48kHz,
+ RealFourier* fft);
} // namespace rnn_vad
} // namespace webrtc
diff --git a/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.cc b/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.cc
index 1ff4621..99600e0 100644
--- a/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.cc
+++ b/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.cc
@@ -205,18 +205,62 @@
}
}
-// TODO(bugs.webrtc.org/9076): Optimize using FFT and/or vectorization.
void ComputePitchAutoCorrelation(
rtc::ArrayView<const float, kBufSize12kHz> pitch_buf,
size_t max_pitch_period,
- rtc::ArrayView<float, kNumInvertedLags12kHz> auto_corr) {
+ rtc::ArrayView<float, kNumInvertedLags12kHz> auto_corr,
+ webrtc::RealFourier* fft) {
RTC_DCHECK_GT(max_pitch_period, auto_corr.size());
RTC_DCHECK_LT(max_pitch_period, pitch_buf.size());
- // Compute auto-correlation coefficients.
- for (size_t inv_lag = 0; inv_lag < auto_corr.size(); ++inv_lag) {
- auto_corr[inv_lag] =
- ComputeAutoCorrelationCoeff(pitch_buf, inv_lag, max_pitch_period);
+ RTC_DCHECK(fft);
+
+ constexpr size_t time_domain_fft_length = 1 << kAutoCorrelationFftOrder;
+ constexpr size_t freq_domain_fft_length = time_domain_fft_length / 2 + 1;
+
+ RTC_DCHECK_EQ(RealFourier::FftLength(fft->order()), time_domain_fft_length);
+ RTC_DCHECK_EQ(RealFourier::ComplexLength(fft->order()),
+ freq_domain_fft_length);
+
+ // Cross-correlation of y_i=pitch_buf[i:i+convolution_length] and
+ // x=pitch_buf[-convolution_length:] is equivalent to convolution of
+ // y_i and reversed(x). New notation: h=reversed(x), x=y.
+ std::array<float, time_domain_fft_length> h{};
+ std::array<float, time_domain_fft_length> x{};
+
+ const size_t convolution_length = kBufSize12kHz - max_pitch_period;
+ // Check that the FFT-length is big enough to avoid cyclic
+ // convolution errors.
+ RTC_DCHECK_GT(time_domain_fft_length,
+ kNumInvertedLags12kHz + convolution_length);
+
+ // h[0:convolution_length] is reversed pitch_buf[-convolution_length:].
+ std::reverse_copy(pitch_buf.end() - convolution_length, pitch_buf.end(),
+ h.begin());
+
+ // x is pitch_buf[:kNumInvertedLags12kHz + convolution_length].
+ std::copy(pitch_buf.begin(),
+ pitch_buf.begin() + kNumInvertedLags12kHz + convolution_length,
+ x.begin());
+
+ // Shift to frequency domain.
+ std::array<std::complex<float>, freq_domain_fft_length> X{};
+ std::array<std::complex<float>, freq_domain_fft_length> H{};
+ fft->Forward(&x[0], &X[0]);
+ fft->Forward(&h[0], &H[0]);
+
+ // Convolve in frequency domain.
+ for (size_t i = 0; i < X.size(); ++i) {
+ X[i] *= H[i];
}
+
+ // Shift back to time domain.
+ std::array<float, time_domain_fft_length> x_conv_h;
+ fft->Inverse(&X[0], &x_conv_h[0]);
+
+ // Collect the result.
+ std::copy(x_conv_h.begin() + convolution_length - 1,
+ x_conv_h.begin() + convolution_length + kNumInvertedLags12kHz - 1,
+ auto_corr.begin());
}
std::array<size_t, 2> FindBestPitchPeriods(
diff --git a/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h b/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h
index dfe1b35..75f7f17 100644
--- a/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h
+++ b/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h
@@ -14,6 +14,7 @@
#include <array>
#include "api/array_view.h"
+#include "common_audio/real_fourier.h"
#include "modules/audio_processing/agc2/rnn_vad/common.h"
#include "modules/audio_processing/agc2/rnn_vad/pitch_info.h"
@@ -26,6 +27,11 @@
static_assert(kMaxPitch24kHz > kInitialMinPitch24kHz, "");
constexpr size_t kNumInvertedLags12kHz = kMaxPitch12kHz - kInitialMinPitch12kHz;
constexpr size_t kNumInvertedLags24kHz = kMaxPitch24kHz - kInitialMinPitch24kHz;
+constexpr int kAutoCorrelationFftOrder = 9; // Length-512 FFT.
+
+static_assert(1 << kAutoCorrelationFftOrder >
+ kNumInvertedLags12kHz + kBufSize12kHz - kMaxPitch12kHz,
+ "");
// Performs 2x decimation without any anti-aliasing filter.
void Decimate2x(rtc::ArrayView<const float, kBufSize24kHz> src,
@@ -70,7 +76,8 @@
void ComputePitchAutoCorrelation(
rtc::ArrayView<const float, kBufSize12kHz> pitch_buf,
size_t max_pitch_period,
- rtc::ArrayView<float, kNumInvertedLags12kHz> auto_corr);
+ rtc::ArrayView<float, kNumInvertedLags12kHz> auto_corr,
+ webrtc::RealFourier* fft);
// Given the auto-correlation coefficients stored according to
// ComputePitchAutoCorrelation() (i.e., using inverted lags), returns the best
diff --git a/modules/audio_processing/agc2/rnn_vad/pitch_search_internal_unittest.cc b/modules/audio_processing/agc2/rnn_vad/pitch_search_internal_unittest.cc
index 9a6a267..4b1be0d 100644
--- a/modules/audio_processing/agc2/rnn_vad/pitch_search_internal_unittest.cc
+++ b/modules/audio_processing/agc2/rnn_vad/pitch_search_internal_unittest.cc
@@ -9,6 +9,7 @@
*/
#include "modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h"
+#include "common_audio/real_fourier.h"
#include <array>
#include <tuple>
@@ -415,16 +416,47 @@
{
// TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
// FloatingPointExceptionObserver fpe_observer;
-
+ std::unique_ptr<RealFourier> fft =
+ RealFourier::Create(kAutoCorrelationFftOrder);
ComputePitchAutoCorrelation(
{pitch_buf_decimated.data(), pitch_buf_decimated.size()},
- kMaxPitch12kHz, {computed_output.data(), computed_output.size()});
+ kMaxPitch12kHz, {computed_output.data(), computed_output.size()},
+ fft.get());
}
ExpectNearAbsolute(
{kPitchBufferAutoCorrCoeffs.data(), kPitchBufferAutoCorrCoeffs.size()},
{computed_output.data(), computed_output.size()}, 3e-3f);
}
+// Check that the auto correlation function computes the right thing for a
+// simple use case.
+TEST(RnnVadTest, ComputePitchAutoCorrelationConstantBuffer) {
+ // Create constant signal with no pitch.
+ std::array<float, kBufSize12kHz> pitch_buf_decimated;
+ std::fill(pitch_buf_decimated.begin(), pitch_buf_decimated.end(), 1.f);
+
+ std::array<float, kPitchBufferAutoCorrCoeffs.size()> computed_output;
+ {
+ // TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
+ // FloatingPointExceptionObserver fpe_observer;
+
+ std::unique_ptr<RealFourier> fft =
+ RealFourier::Create(kAutoCorrelationFftOrder);
+ ComputePitchAutoCorrelation(
+ {pitch_buf_decimated.data(), pitch_buf_decimated.size()},
+ kMaxPitch12kHz, {computed_output.data(), computed_output.size()},
+ fft.get());
+ }
+
+ // The expected output is constantly the length of the fixed 'x'
+ // array in ComputePitchAutoCorrelation.
+ std::array<float, kPitchBufferAutoCorrCoeffs.size()> expected_output;
+ std::fill(expected_output.begin(), expected_output.end(),
+ kBufSize12kHz - kMaxPitch12kHz);
+ ExpectNearAbsolute({expected_output.data(), expected_output.size()},
+ {computed_output.data(), computed_output.size()}, 4e-5f);
+}
+
TEST(RnnVadTest, FindBestPitchPeriodsBitExactness) {
std::array<float, kBufSize12kHz> pitch_buf_decimated;
Decimate2x({kPitchBufferData.data(), kPitchBufferData.size()},
diff --git a/modules/audio_processing/agc2/rnn_vad/pitch_search_unittest.cc b/modules/audio_processing/agc2/rnn_vad/pitch_search_unittest.cc
index 4417764..b25aba3 100644
--- a/modules/audio_processing/agc2/rnn_vad/pitch_search_unittest.cc
+++ b/modules/audio_processing/agc2/rnn_vad/pitch_search_unittest.cc
@@ -9,6 +9,7 @@
*/
#include "modules/audio_processing/agc2/rnn_vad/pitch_search.h"
+#include "modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h"
#include <array>
@@ -28,6 +29,8 @@
std::array<float, 864> lp_residual;
float expected_pitch_period, expected_pitch_gain;
PitchInfo last_pitch;
+ std::unique_ptr<RealFourier> fft =
+ RealFourier::Create(kAutoCorrelationFftOrder);
{
// TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
// FloatingPointExceptionObserver fpe_observer;
@@ -38,8 +41,8 @@
{lp_residual.data(), lp_residual.size()});
lp_residual_reader.first->ReadValue(&expected_pitch_period);
lp_residual_reader.first->ReadValue(&expected_pitch_gain);
- last_pitch =
- PitchSearch({lp_residual.data(), lp_residual.size()}, last_pitch);
+ last_pitch = PitchSearch({lp_residual.data(), lp_residual.size()},
+ last_pitch, fft.get());
EXPECT_EQ(static_cast<size_t>(expected_pitch_period), last_pitch.period);
EXPECT_NEAR(expected_pitch_gain, last_pitch.gain, 1e-5f);
}