AEC3: Move decimator filters to the new notation

Preparing for changing the filters of the decimator by moving the old
filters to the new zero, pole, gain notation.

Bug: webrtc:9288,chromium:846615
Change-Id: I2b01a2555d34617e0bf251c782703753f72cd56f
Reviewed-on: https://webrtc-review.googlesource.com/81189
Reviewed-by: Per Ã…hgren <peah@webrtc.org>
Commit-Queue: Gustaf Ullberg <gustaf@webrtc.org>
Cr-Commit-Position: refs/heads/master@{#23528}
diff --git a/modules/audio_processing/aec3/cascaded_biquad_filter.cc b/modules/audio_processing/aec3/cascaded_biquad_filter.cc
index 5881d60..333d226 100644
--- a/modules/audio_processing/aec3/cascaded_biquad_filter.cc
+++ b/modules/audio_processing/aec3/cascaded_biquad_filter.cc
@@ -15,8 +15,14 @@
 
 CascadedBiQuadFilter::BiQuadParam::BiQuadParam(std::complex<float> zero,
                                                std::complex<float> pole,
-                                               float gain)
-    : zero(zero), pole(pole), gain(gain) {}
+                                               float gain,
+                                               bool mirror_zero_along_i_axis)
+    : zero(zero),
+      pole(pole),
+      gain(gain),
+      mirror_zero_along_i_axis(mirror_zero_along_i_axis) {}
+
+CascadedBiQuadFilter::BiQuadParam::BiQuadParam(const BiQuadParam&) = default;
 
 CascadedBiQuadFilter::BiQuad::BiQuad(
     const CascadedBiQuadFilter::BiQuadParam& param)
@@ -27,10 +33,20 @@
   float p_i = std::imag(param.pole);
   float gain = param.gain;
 
-  coefficients.b[0] = gain * 1.f;
-  coefficients.b[1] = gain * -2.f * z_r;
-  coefficients.b[2] = gain * (z_r * z_r + z_i * z_i);
+  if (param.mirror_zero_along_i_axis) {
+    // Assuming zeroes at z_r and -z_r.
+    RTC_DCHECK(z_i == 0.f);
+    coefficients.b[0] = gain * 1.f;
+    coefficients.b[1] = 0.f;
+    coefficients.b[2] = gain * -(z_r * z_r);
+  } else {
+    // Assuming zeros at (z_r + z_i*i) and (z_r - z_i*i).
+    coefficients.b[0] = gain * 1.f;
+    coefficients.b[1] = gain * -2.f * z_r;
+    coefficients.b[2] = gain * (z_r * z_r + z_i * z_i);
+  }
 
+  // Assuming poles at (p_r + p_i*i) and (p_r - p_i*i).
   coefficients.a[0] = -2.f * p_r;
   coefficients.a[1] = p_r * p_r + p_i * p_i;
 }
diff --git a/modules/audio_processing/aec3/cascaded_biquad_filter.h b/modules/audio_processing/aec3/cascaded_biquad_filter.h
index feae68d..1e09fa6 100644
--- a/modules/audio_processing/aec3/cascaded_biquad_filter.h
+++ b/modules/audio_processing/aec3/cascaded_biquad_filter.h
@@ -24,10 +24,15 @@
 class CascadedBiQuadFilter {
  public:
   struct BiQuadParam {
-    BiQuadParam(std::complex<float> zero, std::complex<float> pole, float gain);
+    BiQuadParam(std::complex<float> zero,
+                std::complex<float> pole,
+                float gain,
+                bool mirror_zero_along_i_axis = false);
+    BiQuadParam(const BiQuadParam&);
     std::complex<float> zero;
     std::complex<float> pole;
     float gain;
+    bool mirror_zero_along_i_axis;
   };
 
   struct BiQuadCoefficients {
diff --git a/modules/audio_processing/aec3/cascaded_biquad_filter_unittest.cc b/modules/audio_processing/aec3/cascaded_biquad_filter_unittest.cc
index 0f1b0db..57f4b04 100644
--- a/modules/audio_processing/aec3/cascaded_biquad_filter_unittest.cc
+++ b/modules/audio_processing/aec3/cascaded_biquad_filter_unittest.cc
@@ -123,4 +123,18 @@
   EXPECT_NEAR(filter.coefficients.a[1], 0.57406192f, epsilon);
 }
 
+// Verifies the conversion from zero, pole, gain to filter coefficients for
+// bandpass filter.
+TEST(CascadedBiquadFilter, BiQuadParamBandPass) {
+  CascadedBiQuadFilter::BiQuadParam param(
+      {1.0f, 0.0f}, {1.11022302e-16f, 0.71381051f}, 0.2452372752527856f, true);
+  CascadedBiQuadFilter::BiQuad filter(param);
+  const float epsilon = 1e-6f;
+  EXPECT_NEAR(filter.coefficients.b[0], 0.24523728f, epsilon);
+  EXPECT_NEAR(filter.coefficients.b[1], 0.f, epsilon);
+  EXPECT_NEAR(filter.coefficients.b[2], -0.24523728f, epsilon);
+  EXPECT_NEAR(filter.coefficients.a[0], -2.22044605e-16f, epsilon);
+  EXPECT_NEAR(filter.coefficients.a[1], 5.09525449e-01f, epsilon);
+}
+
 }  // namespace webrtc
diff --git a/modules/audio_processing/aec3/decimator.cc b/modules/audio_processing/aec3/decimator.cc
index 8fffc8a..0ebc7db 100644
--- a/modules/audio_processing/aec3/decimator.cc
+++ b/modules/audio_processing/aec3/decimator.cc
@@ -14,53 +14,42 @@
 namespace webrtc {
 namespace {
 
-// b, a = signal.butter(2, 3400/8000.0, 'lowpass', analog=False) which are the
-// same as b, a = signal.butter(2, 1700/4000.0, 'lowpass', analog=False).
-const CascadedBiQuadFilter::BiQuadCoefficients kLowPassFilterCoefficients2 = {
-    {0.22711796f, 0.45423593f, 0.22711796f},
-    {-0.27666461f, 0.18513647f}};
-constexpr int kNumFilters2 = 3;
+// signal.butter(2, 3400/8000.0, 'lowpass', analog=False)
+const std::vector<CascadedBiQuadFilter::BiQuadParam> kLowPassFilterDS2 = {
+    {{-1.f, 0.f}, {0.13833231f, 0.40743176f}, 0.22711796393486466f},
+    {{-1.f, 0.f}, {0.13833231f, 0.40743176f}, 0.22711796393486466f},
+    {{-1.f, 0.f}, {0.13833231f, 0.40743176f}, 0.22711796393486466f}};
 
-// b, a = signal.butter(2, 750/8000.0, 'lowpass', analog=False) which are the
-// same as b, a = signal.butter(2, 375/4000.0, 'lowpass', analog=False).
-const CascadedBiQuadFilter::BiQuadCoefficients kLowPassFilterCoefficients4 = {
-    {0.0179f, 0.0357f, 0.0179f},
-    {-1.5879f, 0.6594f}};
-constexpr int kNumFilters4 = 3;
+// signal.butter(2, 750/8000.0, 'lowpass', analog=False)
+const std::vector<CascadedBiQuadFilter::BiQuadParam> kLowPassFilterDS4 = {
+    {{-1.f, 0.f}, {0.79396855f, 0.17030506f}, 0.017863192751682862f},
+    {{-1.f, 0.f}, {0.79396855f, 0.17030506f}, 0.017863192751682862f},
+    {{-1.f, 0.f}, {0.79396855f, 0.17030506f}, 0.017863192751682862f}};
 
-// b, a = signal.cheby1(1, 6, [1000/8000, 2000/8000], btype='bandpass',
-// analog=False)
-const CascadedBiQuadFilter::BiQuadCoefficients kBandPassFilterCoefficients8 = {
-    {0.10330478f, 0.f, -0.10330478f},
-    {-1.520363f, 0.79339043f}};
-constexpr int kNumFilters8 = 5;
+// signal.cheby1(1, 6, [1000/8000, 2000/8000], btype='bandpass', analog=False)
+const std::vector<CascadedBiQuadFilter::BiQuadParam> kBandPassFilterDS8 = {
+    {{1.f, 0.f}, {0.7601815f, 0.46423542f}, 0.10330478266505948f, true},
+    {{1.f, 0.f}, {0.7601815f, 0.46423542f}, 0.10330478266505948f, true},
+    {{1.f, 0.f}, {0.7601815f, 0.46423542f}, 0.10330478266505948f, true},
+    {{1.f, 0.f}, {0.7601815f, 0.46423542f}, 0.10330478266505948f, true},
+    {{1.f, 0.f}, {0.7601815f, 0.46423542f}, 0.10330478266505948f, true}};
 
-// b, a = signal.butter(2, 1000/8000.0, 'highpass', analog=False)
-const CascadedBiQuadFilter::BiQuadCoefficients kHighPassFilterCoefficients = {
-    {0.75707638f, -1.51415275f, 0.75707638f},
-    {-1.45424359f, 0.57406192f}};
-constexpr int kNumFiltersHP2 = 1;
-constexpr int kNumFiltersHP4 = 1;
-constexpr int kNumFiltersHP8 = 0;
+// signal.butter(2, 1000/8000.0, 'highpass', analog=False)
+const std::vector<CascadedBiQuadFilter::BiQuadParam> kHighPassFilter = {
+    {{1.f, 0.f}, {0.72712179f, 0.21296904f}, 0.7570763753338849f}};
 
+const std::vector<CascadedBiQuadFilter::BiQuadParam> kPassThroughFilter = {};
 }  // namespace
 
 Decimator::Decimator(size_t down_sampling_factor)
     : down_sampling_factor_(down_sampling_factor),
-      anti_aliasing_filter_(
-          down_sampling_factor_ == 4
-              ? kLowPassFilterCoefficients4
-              : (down_sampling_factor_ == 8 ? kBandPassFilterCoefficients8
-                                            : kLowPassFilterCoefficients2),
-          down_sampling_factor_ == 4
-              ? kNumFilters4
-              : (down_sampling_factor_ == 8 ? kNumFilters8 : kNumFilters2)),
-      noise_reduction_filter_(
-          kHighPassFilterCoefficients,
-          down_sampling_factor_ == 4
-              ? kNumFiltersHP4
-              : (down_sampling_factor_ == 8 ? kNumFiltersHP8
-                                            : kNumFiltersHP2)) {
+      anti_aliasing_filter_(down_sampling_factor_ == 4
+                                ? kLowPassFilterDS4
+                                : (down_sampling_factor_ == 8
+                                       ? kBandPassFilterDS8
+                                       : kLowPassFilterDS2)),
+      noise_reduction_filter_(down_sampling_factor_ == 8 ? kPassThroughFilter
+                                                         : kHighPassFilter) {
   RTC_DCHECK(down_sampling_factor_ == 2 || down_sampling_factor_ == 4 ||
              down_sampling_factor_ == 8);
 }