Added faster initial model adaptation speed in AEC3

Bug: webrtc:8746
Change-Id: Idcb65e2b1241a7da8c4a98622923e401d174b879
Reviewed-on: https://webrtc-review.googlesource.com/39506
Commit-Queue: Per Åhgren <peah@webrtc.org>
Reviewed-by: Per Åhgren <peah@webrtc.org>
Reviewed-by: Gustaf Ullberg <gustaf@webrtc.org>
Cr-Commit-Position: refs/heads/master@{#21619}
diff --git a/modules/audio_processing/aec3/adaptive_fir_filter.cc b/modules/audio_processing/aec3/adaptive_fir_filter.cc
index d92e538..e080b4b 100644
--- a/modules/audio_processing/aec3/adaptive_fir_filter.cc
+++ b/modules/audio_processing/aec3/adaptive_fir_filter.cc
@@ -22,6 +22,7 @@
 
 #include "modules/audio_processing/aec3/fft_data.h"
 #include "rtc_base/checks.h"
+#include "rtc_base/logging.h"
 
 namespace webrtc {
 
@@ -414,15 +415,16 @@
 
 }  // namespace aec3
 
-AdaptiveFirFilter::AdaptiveFirFilter(size_t size_partitions,
+AdaptiveFirFilter::AdaptiveFirFilter(size_t max_size_partitions,
                                      Aec3Optimization optimization,
                                      ApmDataDumper* data_dumper)
     : data_dumper_(data_dumper),
       fft_(),
       optimization_(optimization),
-      H_(size_partitions),
-      H2_(size_partitions, std::array<float, kFftLengthBy2Plus1>()),
-      h_(GetTimeDomainLength(size_partitions), 0.f) {
+      max_size_partitions_(max_size_partitions),
+      H_(max_size_partitions_),
+      H2_(max_size_partitions_, std::array<float, kFftLengthBy2Plus1>()),
+      h_(GetTimeDomainLength(max_size_partitions_), 0.f) {
   RTC_DCHECK(data_dumper_);
 
   for (auto& H_j : H_) {
@@ -437,16 +439,53 @@
 AdaptiveFirFilter::~AdaptiveFirFilter() = default;
 
 void AdaptiveFirFilter::HandleEchoPathChange() {
+  size_t current_h_size = h_.size();
+  h_.resize(GetTimeDomainLength(max_size_partitions_));
   std::fill(h_.begin(), h_.end(), 0.f);
+  h_.resize(current_h_size);
+
+  size_t current_size_partitions = H_.size();
+  H_.resize(max_size_partitions_);
   for (auto& H_j : H_) {
     H_j.Clear();
   }
+  H_.resize(current_size_partitions);
+
+  H2_.resize(max_size_partitions_);
   for (auto& H2_k : H2_) {
     H2_k.fill(0.f);
   }
+  H2_.resize(current_size_partitions);
+
   erl_.fill(0.f);
 }
 
+void AdaptiveFirFilter::SetSizePartitions(size_t size) {
+  RTC_DCHECK_EQ(max_size_partitions_, H_.capacity());
+  RTC_DCHECK_EQ(max_size_partitions_, H2_.capacity());
+  RTC_DCHECK_EQ(GetTimeDomainLength(max_size_partitions_), h_.capacity());
+  RTC_DCHECK_EQ(H_.size(), H2_.size());
+  RTC_DCHECK_EQ(h_.size(), GetTimeDomainLength(H_.size()));
+
+  if (size > max_size_partitions_) {
+    RTC_LOG(LS_ERROR) << "Too large adaptive filter size specificed: " << size;
+    size = max_size_partitions_;
+  }
+
+  if (size < H_.size()) {
+    for (size_t k = size; k < H_.size(); ++k) {
+      H_[k].Clear();
+      H2_[k].fill(0.f);
+    }
+
+    std::fill(h_.begin() + GetTimeDomainLength(size), h_.end(), 0.f);
+  }
+
+  H_.resize(size);
+  H2_.resize(size);
+  h_.resize(GetTimeDomainLength(size));
+}
+
 void AdaptiveFirFilter::Filter(const RenderBuffer& render_buffer,
                                FftData* S) const {
   RTC_DCHECK(S);
diff --git a/modules/audio_processing/aec3/adaptive_fir_filter.h b/modules/audio_processing/aec3/adaptive_fir_filter.h
index 5fa86a1..e993c76 100644
--- a/modules/audio_processing/aec3/adaptive_fir_filter.h
+++ b/modules/audio_processing/aec3/adaptive_fir_filter.h
@@ -91,7 +91,7 @@
 // Provides a frequency domain adaptive filter functionality.
 class AdaptiveFirFilter {
  public:
-  AdaptiveFirFilter(size_t size_partitions,
+  AdaptiveFirFilter(size_t max_size_partitions,
                     Aec3Optimization optimization,
                     ApmDataDumper* data_dumper);
 
@@ -110,6 +110,9 @@
   // Returns the filter size.
   size_t SizePartitions() const { return H_.size(); }
 
+  // Sets the filter size.
+  void SetSizePartitions(size_t size);
+
   // Returns the filter based echo return loss.
   const std::array<float, kFftLengthBy2Plus1>& Erl() const { return erl_; }
 
@@ -123,10 +126,13 @@
   const std::vector<float>& FilterImpulseResponse() const { return h_; }
 
   void DumpFilter(const char* name) {
+    size_t current_size_partitions = H_.size();
+    H_.resize(max_size_partitions_);
     for (auto& H : H_) {
       data_dumper_->DumpRaw(name, H.re);
       data_dumper_->DumpRaw(name, H.im);
     }
+    H_.resize(current_size_partitions);
   }
 
  private:
@@ -136,6 +142,7 @@
   ApmDataDumper* const data_dumper_;
   const Aec3Fft fft_;
   const Aec3Optimization optimization_;
+  const size_t max_size_partitions_;
   std::vector<FftData> H_;
   std::vector<std::array<float, kFftLengthBy2Plus1>> H2_;
   std::vector<float> h_;
diff --git a/modules/audio_processing/aec3/aec_state.cc b/modules/audio_processing/aec3/aec_state.cc
index 0fd035b..00677f7 100644
--- a/modules/audio_processing/aec3/aec_state.cc
+++ b/modules/audio_processing/aec3/aec_state.cc
@@ -78,6 +78,7 @@
     render_received_ = false;
     force_zero_gain_ = true;
     blocks_with_active_render_ = 0;
+    initial_state_ = true;
   };
 
   // TODO(peah): Refine the reset scheme according to the type of gain and
@@ -155,6 +156,9 @@
   filter_has_had_time_to_converge_ =
       blocks_with_proper_filter_adaptation_ >= 2 * kNumBlocksPerSecond;
 
+  initial_state_ =
+      blocks_with_proper_filter_adaptation_ < 5 * kNumBlocksPerSecond;
+
   // Flag whether the linear filter estimate is usable.
   usable_linear_estimate_ =
       !echo_saturation_ &&
diff --git a/modules/audio_processing/aec3/aec_state.h b/modules/audio_processing/aec3/aec_state.h
index 98a78dd..e2039ad 100644
--- a/modules/audio_processing/aec3/aec_state.h
+++ b/modules/audio_processing/aec3/aec_state.h
@@ -103,6 +103,9 @@
     return filter_has_had_time_to_converge_;
   }
 
+  // Returns whether the filter adaptation is still in the initial state.
+  bool InitialState() const { return initial_state_; }
+
   // Updates the aec state.
   void Update(const std::vector<std::array<float, kFftLengthBy2Plus1>>&
                   adaptive_filter_frequency_response,
@@ -161,6 +164,7 @@
   float reverb_decay_;
   bool saturating_echo_path_ = false;
   bool filter_has_had_time_to_converge_ = false;
+  bool initial_state_ = true;
 
   RTC_DISALLOW_COPY_AND_ASSIGN(AecState);
 };
diff --git a/modules/audio_processing/aec3/echo_remover.cc b/modules/audio_processing/aec3/echo_remover.cc
index fc0e680..a153deb 100644
--- a/modules/audio_processing/aec3/echo_remover.cc
+++ b/modules/audio_processing/aec3/echo_remover.cc
@@ -86,6 +86,7 @@
   bool echo_leakage_detected_ = false;
   AecState aec_state_;
   EchoRemoverMetrics metrics_;
+  bool initial_state_ = true;
 
   RTC_DISALLOW_COPY_AND_ASSIGN(EchoRemoverImpl);
 };
@@ -146,6 +147,7 @@
   if (echo_path_variability.AudioPathChanged()) {
     subtractor_.HandleEchoPathChange(echo_path_variability);
     aec_state_.HandleEchoPathChange(echo_path_variability);
+    initial_state_ = true;
   }
 
   std::array<float, kFftLengthBy2Plus1> Y2;
@@ -166,6 +168,10 @@
   render_signal_analyzer_.Update(*render_buffer, aec_state_.FilterDelay());
 
   // Perform linear echo cancellation.
+  if (initial_state_ && !aec_state_.InitialState()) {
+    subtractor_.ExitInitialState();
+    initial_state_ = false;
+  }
   subtractor_.Process(*render_buffer, y0, render_signal_analyzer_, aec_state_,
                       &subtractor_output);
 
diff --git a/modules/audio_processing/aec3/main_filter_update_gain.h b/modules/audio_processing/aec3/main_filter_update_gain.h
index 07cad28..8026768 100644
--- a/modules/audio_processing/aec3/main_filter_update_gain.h
+++ b/modules/audio_processing/aec3/main_filter_update_gain.h
@@ -44,10 +44,16 @@
                bool saturated_capture_signal,
                FftData* gain_fft);
 
+  // Sets a new config.
+  void SetConfig(
+      const EchoCanceller3Config::Filter::MainConfiguration& config) {
+    config_ = config;
+  }
+
  private:
   static int instance_count_;
   std::unique_ptr<ApmDataDumper> data_dumper_;
-  const EchoCanceller3Config::Filter::MainConfiguration config_;
+  EchoCanceller3Config::Filter::MainConfiguration config_;
   std::array<float, kFftLengthBy2Plus1> H_error_;
   size_t poor_excitation_counter_;
   size_t call_counter_ = 0;
diff --git a/modules/audio_processing/aec3/shadow_filter_update_gain.h b/modules/audio_processing/aec3/shadow_filter_update_gain.h
index e0322d5..5f7f5a1 100644
--- a/modules/audio_processing/aec3/shadow_filter_update_gain.h
+++ b/modules/audio_processing/aec3/shadow_filter_update_gain.h
@@ -36,8 +36,14 @@
                bool saturated_capture_signal,
                FftData* G);
 
+  // Sets a new config.
+  void SetConfig(
+      const EchoCanceller3Config::Filter::ShadowConfiguration& config) {
+    config_ = config;
+  }
+
  private:
-  const EchoCanceller3Config::Filter::ShadowConfiguration config_;
+  EchoCanceller3Config::Filter::ShadowConfiguration config_;
   // TODO(peah): Check whether this counter should instead be initialized to a
   // large value.
   size_t poor_signal_excitation_counter_ = 0;
diff --git a/modules/audio_processing/aec3/subtractor.cc b/modules/audio_processing/aec3/subtractor.cc
index 69899ac..b19bd64 100644
--- a/modules/audio_processing/aec3/subtractor.cc
+++ b/modules/audio_processing/aec3/subtractor.cc
@@ -51,19 +51,30 @@
     : fft_(),
       data_dumper_(data_dumper),
       optimization_(optimization),
-      main_filter_(config.filter.main.length_blocks,
+      config_(config),
+      main_filter_(config_.filter.main.length_blocks,
                    optimization,
                    data_dumper_),
-      shadow_filter_(config.filter.shadow.length_blocks,
+      shadow_filter_(config_.filter.shadow.length_blocks,
                      optimization,
                      data_dumper_),
-      G_main_(config.filter.main),
-      G_shadow_(config.filter.shadow) {
+      G_main_(config_.filter.main_initial),
+      G_shadow_(config_.filter.shadow_initial) {
   RTC_DCHECK(data_dumper_);
   // Currently, the rest of AEC3 requires the main and shadow filter lengths to
   // be identical.
-  RTC_DCHECK_EQ(config.filter.main.length_blocks,
-                config.filter.shadow.length_blocks);
+  RTC_DCHECK_EQ(config_.filter.main.length_blocks,
+                config_.filter.shadow.length_blocks);
+  RTC_DCHECK_EQ(config_.filter.main_initial.length_blocks,
+                config_.filter.shadow_initial.length_blocks);
+
+  RTC_DCHECK_GE(config_.filter.main.length_blocks,
+                config_.filter.main_initial.length_blocks);
+  RTC_DCHECK_GE(config_.filter.shadow.length_blocks,
+                config_.filter.shadow_initial.length_blocks);
+
+  main_filter_.SetSizePartitions(config_.filter.main_initial.length_blocks);
+  shadow_filter_.SetSizePartitions(config_.filter.shadow_initial.length_blocks);
 }
 
 Subtractor::~Subtractor() = default;
@@ -75,6 +86,12 @@
     shadow_filter_.HandleEchoPathChange();
     G_main_.HandleEchoPathChange(echo_path_variability);
     G_shadow_.HandleEchoPathChange();
+    G_main_.SetConfig(config_.filter.main_initial);
+    G_shadow_.SetConfig(config_.filter.shadow_initial);
+    main_filter_.SetSizePartitions(config_.filter.main_initial.length_blocks);
+    shadow_filter_.SetSizePartitions(
+        config_.filter.shadow_initial.length_blocks);
+
     converged_filter_ = false;
   };
 
@@ -93,6 +110,13 @@
   }
 }
 
+void Subtractor::ExitInitialState() {
+  G_main_.SetConfig(config_.filter.main);
+  G_shadow_.SetConfig(config_.filter.shadow);
+  main_filter_.SetSizePartitions(config_.filter.main.length_blocks);
+  shadow_filter_.SetSizePartitions(config_.filter.shadow.length_blocks);
+}
+
 void Subtractor::Process(const RenderBuffer& render_buffer,
                          const rtc::ArrayView<const float> capture,
                          const RenderSignalAnalyzer& render_signal_analyzer,
diff --git a/modules/audio_processing/aec3/subtractor.h b/modules/audio_processing/aec3/subtractor.h
index 12ae488..b267abf 100644
--- a/modules/audio_processing/aec3/subtractor.h
+++ b/modules/audio_processing/aec3/subtractor.h
@@ -47,6 +47,9 @@
 
   void HandleEchoPathChange(const EchoPathVariability& echo_path_variability);
 
+  // Exits the initial state.
+  void ExitInitialState();
+
   // Returns the block-wise frequency response for the main adaptive filter.
   const std::vector<std::array<float, kFftLengthBy2Plus1>>&
   FilterFrequencyResponse() const {
@@ -64,6 +67,7 @@
   const Aec3Fft fft_;
   ApmDataDumper* data_dumper_;
   const Aec3Optimization optimization_;
+  const EchoCanceller3Config config_;
   AdaptiveFirFilter main_filter_;
   AdaptiveFirFilter shadow_filter_;
   MainFilterUpdateGain G_main_;
diff --git a/modules/audio_processing/include/audio_processing.h b/modules/audio_processing/include/audio_processing.h
index 4f68a6d..63da80c 100644
--- a/modules/audio_processing/include/audio_processing.h
+++ b/modules/audio_processing/include/audio_processing.h
@@ -1260,6 +1260,9 @@
 
     MainConfiguration main = {12, 0.005f, 0.05f, 0.001f, 20075344.f};
     ShadowConfiguration shadow = {12, 0.1f, 20075344.f};
+
+    MainConfiguration main_initial = {12, 0.01f, 0.1f, 0.001f, 20075344.f};
+    ShadowConfiguration shadow_initial = {12, 0.7f, 20075344.f};
   } filter;
 
   struct Erle {