Add PostProcessing interface to audio processing module.

This CL adds an interface for a generic PostProcessing module that
is optionally added to the APM at construction time.

(Parenthetically this CL also adds a missing lock check to
InitializeGainController2.)

Bug: webrtc:8201
Change-Id: I7de64cf8d5335ecec450da8a961660906141d42a
Reviewed-on: https://webrtc-review.googlesource.com/1570
Commit-Queue: Sam Zackrisson <saza@webrtc.org>
Reviewed-by: Per Ã…hgren <peah@webrtc.org>
Reviewed-by: Alex Loiko <aleloi@webrtc.org>
Cr-Commit-Position: refs/heads/master@{#19973}
diff --git a/modules/audio_processing/audio_processing_impl.cc b/modules/audio_processing/audio_processing_impl.cc
index feea33d..99cd082 100644
--- a/modules/audio_processing/audio_processing_impl.cc
+++ b/modules/audio_processing/audio_processing_impl.cc
@@ -170,7 +170,9 @@
 // Throughout webrtc, it's assumed that success is represented by zero.
 static_assert(AudioProcessing::kNoError == 0, "kNoError must be zero");
 
-AudioProcessingImpl::ApmSubmoduleStates::ApmSubmoduleStates() {}
+AudioProcessingImpl::ApmSubmoduleStates::ApmSubmoduleStates(
+    bool capture_post_processor_enabled)
+    : capture_post_processor_enabled_(capture_post_processor_enabled) {}
 
 bool AudioProcessingImpl::ApmSubmoduleStates::Update(
     bool low_cut_filter_enabled,
@@ -250,7 +252,7 @@
 
 bool AudioProcessingImpl::ApmSubmoduleStates::CaptureFullBandProcessingActive()
     const {
-  return level_controller_enabled_;
+  return level_controller_enabled_ || capture_post_processor_enabled_;
 }
 
 bool AudioProcessingImpl::ApmSubmoduleStates::RenderMultiBandSubModulesActive()
@@ -289,8 +291,10 @@
 };
 
 struct AudioProcessingImpl::ApmPrivateSubmodules {
-  explicit ApmPrivateSubmodules(NonlinearBeamformer* beamformer)
-      : beamformer(beamformer) {}
+  ApmPrivateSubmodules(NonlinearBeamformer* beamformer,
+                       std::unique_ptr<PostProcessing> capture_post_processor)
+      : beamformer(beamformer),
+        capture_post_processor(std::move(capture_post_processor)) {}
   // Accessed internally from capture or during initialization
   std::unique_ptr<NonlinearBeamformer> beamformer;
   std::unique_ptr<AgcManagerDirect> agc_manager;
@@ -299,21 +303,29 @@
   std::unique_ptr<LevelController> level_controller;
   std::unique_ptr<ResidualEchoDetector> residual_echo_detector;
   std::unique_ptr<EchoCanceller3> echo_canceller3;
+  std::unique_ptr<PostProcessing> capture_post_processor;
 };
 
 AudioProcessing* AudioProcessing::Create() {
   webrtc::Config config;
-  return Create(config, nullptr);
+  return Create(config, nullptr, nullptr);
 }
 
 AudioProcessing* AudioProcessing::Create(const webrtc::Config& config) {
-  return Create(config, nullptr);
+  return Create(config, nullptr, nullptr);
 }
 
 AudioProcessing* AudioProcessing::Create(const webrtc::Config& config,
                                          NonlinearBeamformer* beamformer) {
-  AudioProcessingImpl* apm =
-      new rtc::RefCountedObject<AudioProcessingImpl>(config, beamformer);
+  return Create(config, nullptr, beamformer);
+}
+
+AudioProcessing* AudioProcessing::Create(
+    const webrtc::Config& config,
+    std::unique_ptr<PostProcessing> capture_post_processor,
+    NonlinearBeamformer* beamformer) {
+  AudioProcessingImpl* apm = new rtc::RefCountedObject<AudioProcessingImpl>(
+      config, std::move(capture_post_processor), beamformer);
   if (apm->Initialize() != kNoError) {
     delete apm;
     apm = nullptr;
@@ -323,13 +335,18 @@
 }
 
 AudioProcessingImpl::AudioProcessingImpl(const webrtc::Config& config)
-    : AudioProcessingImpl(config, nullptr) {}
+    : AudioProcessingImpl(config, nullptr, nullptr) {}
 
-AudioProcessingImpl::AudioProcessingImpl(const webrtc::Config& config,
-                                         NonlinearBeamformer* beamformer)
+AudioProcessingImpl::AudioProcessingImpl(
+    const webrtc::Config& config,
+    std::unique_ptr<PostProcessing> capture_post_processor,
+    NonlinearBeamformer* beamformer)
     : high_pass_filter_impl_(new HighPassFilterImpl(this)),
+      submodule_states_(!!capture_post_processor),
       public_submodules_(new ApmPublicSubmodules()),
-      private_submodules_(new ApmPrivateSubmodules(beamformer)),
+      private_submodules_(
+          new ApmPrivateSubmodules(beamformer,
+                                   std::move(capture_post_processor))),
       constants_(config.Get<ExperimentalAgc>().startup_min_volume,
                  config.Get<ExperimentalAgc>().clipped_level_min,
 #if defined(WEBRTC_ANDROID) || defined(WEBRTC_IOS)
@@ -371,6 +388,9 @@
     // TODO(peah): Move this creation to happen only when the level controller
     // is enabled.
     private_submodules_->level_controller.reset(new LevelController());
+
+    LOG(LS_INFO) << "Capture post processor activated: "
+                 << !!private_submodules_->capture_post_processor;
   }
 
   SetExtraOptions(config);
@@ -525,6 +545,7 @@
   InitializeResidualEchoDetector();
   InitializeEchoCanceller3();
   InitializeGainController2();
+  InitializePostProcessor();
 
   if (aec_dump_) {
     aec_dump_->WriteInitMessage(ToStreamsConfig(formats_.api_format));
@@ -1278,6 +1299,10 @@
     private_submodules_->level_controller->Process(capture_buffer);
   }
 
+  if (private_submodules_->capture_post_processor) {
+    private_submodules_->capture_post_processor->Process(capture_buffer);
+  }
+
   // The level estimator operates on the recombined data.
   public_submodules_->level_estimator->ProcessStream(capture_buffer);
 
@@ -1696,6 +1721,13 @@
   private_submodules_->residual_echo_detector->Initialize();
 }
 
+void AudioProcessingImpl::InitializePostProcessor() {
+  if (private_submodules_->capture_post_processor) {
+    private_submodules_->capture_post_processor->Initialize(
+        proc_sample_rate_hz(), num_proc_channels());
+  }
+}
+
 void AudioProcessingImpl::MaybeUpdateHistograms() {
   static const int kMinDiffDelayMs = 60;
 
diff --git a/modules/audio_processing/audio_processing_impl.h b/modules/audio_processing/audio_processing_impl.h
index 4ad9f7f..1e5695c 100644
--- a/modules/audio_processing/audio_processing_impl.h
+++ b/modules/audio_processing/audio_processing_impl.h
@@ -39,8 +39,10 @@
   // Methods forcing APM to run in a single-threaded manner.
   // Acquires both the render and capture locks.
   explicit AudioProcessingImpl(const webrtc::Config& config);
-  // AudioProcessingImpl takes ownership of beamformer.
+  // AudioProcessingImpl takes ownership of capture post processor and
+  // beamformer.
   AudioProcessingImpl(const webrtc::Config& config,
+                      std::unique_ptr<PostProcessing> capture_post_processor,
                       NonlinearBeamformer* beamformer);
   ~AudioProcessingImpl() override;
   int Initialize() override;
@@ -141,7 +143,7 @@
 
   class ApmSubmoduleStates {
    public:
-    ApmSubmoduleStates();
+    explicit ApmSubmoduleStates(bool capture_post_processor_enabled);
     // Updates the submodule state and returns true if it has changed.
     bool Update(bool low_cut_filter_enabled,
                 bool echo_canceller_enabled,
@@ -164,6 +166,7 @@
     bool RenderMultiBandProcessingActive() const;
 
    private:
+    const bool capture_post_processor_enabled_ = false;
     bool low_cut_filter_enabled_ = false;
     bool echo_canceller_enabled_ = false;
     bool mobile_echo_controller_enabled_ = false;
@@ -218,7 +221,8 @@
       RTC_EXCLUSIVE_LOCKS_REQUIRED(crit_render_, crit_capture_);
   void InitializeLowCutFilter() RTC_EXCLUSIVE_LOCKS_REQUIRED(crit_capture_);
   void InitializeEchoCanceller3() RTC_EXCLUSIVE_LOCKS_REQUIRED(crit_capture_);
-  void InitializeGainController2();
+  void InitializeGainController2() RTC_EXCLUSIVE_LOCKS_REQUIRED(crit_capture_);
+  void InitializePostProcessor() RTC_EXCLUSIVE_LOCKS_REQUIRED(crit_capture_);
 
   void EmptyQueuedRenderAudio();
   void AllocateRenderQueue()
diff --git a/modules/audio_processing/audio_processing_unittest.cc b/modules/audio_processing/audio_processing_unittest.cc
index b6e56c5..7cd2c95 100644
--- a/modules/audio_processing/audio_processing_unittest.cc
+++ b/modules/audio_processing/audio_processing_unittest.cc
@@ -24,6 +24,7 @@
 #include "modules/audio_processing/beamformer/mock_nonlinear_beamformer.h"
 #include "modules/audio_processing/common.h"
 #include "modules/audio_processing/include/audio_processing.h"
+#include "modules/audio_processing/include/mock_audio_processing.h"
 #include "modules/audio_processing/level_controller/level_controller_constants.h"
 #include "modules/audio_processing/test/protobuf_utils.h"
 #include "modules/audio_processing/test/test_utils.h"
@@ -1305,7 +1306,7 @@
   testing::NiceMock<MockNonlinearBeamformer>* beamformer =
       new testing::NiceMock<MockNonlinearBeamformer>(geometry, 1u);
   std::unique_ptr<AudioProcessing> apm(
-      AudioProcessing::Create(config, beamformer));
+      AudioProcessing::Create(config, nullptr, beamformer));
   EXPECT_EQ(kNoErr, apm->gain_control()->Enable(true));
   ChannelBuffer<float> src_buf(kSamplesPerChannel, kNumInputChannels);
   ChannelBuffer<float> dest_buf(kSamplesPerChannel, kNumOutputChannels);
@@ -2891,4 +2892,22 @@
               std::numeric_limits<float>::epsilon());
 }
 
+TEST(ApmConfiguration, EnablePostProcessing) {
+  // Verify that apm uses a capture post processing module if one is provided.
+  webrtc::Config webrtc_config;
+  auto mock_post_processor_ptr =
+      new testing::NiceMock<test::MockPostProcessing>();
+  auto mock_post_processor =
+      std::unique_ptr<PostProcessing>(mock_post_processor_ptr);
+  rtc::scoped_refptr<AudioProcessing> apm = AudioProcessing::Create(
+      webrtc_config, std::move(mock_post_processor), nullptr);
+
+  AudioFrame audio;
+  audio.num_channels_ = 1;
+  SetFrameSampleRate(&audio, AudioProcessing::NativeRate::kSampleRate16kHz);
+
+  EXPECT_CALL(*mock_post_processor_ptr, Process(testing::_)).Times(1);
+  std::cout << apm->ProcessStream(&audio) << std::endl;
+}
+
 }  // namespace webrtc
diff --git a/modules/audio_processing/include/audio_processing.h b/modules/audio_processing/include/audio_processing.h
index f8df3e9..6baa691 100644
--- a/modules/audio_processing/include/audio_processing.h
+++ b/modules/audio_processing/include/audio_processing.h
@@ -23,6 +23,7 @@
 #include "modules/audio_processing/beamformer/array_util.h"
 #include "modules/audio_processing/include/config.h"
 #include "rtc_base/arraysize.h"
+#include "rtc_base/deprecation.h"
 #include "rtc_base/platform_file.h"
 #include "rtc_base/refcount.h"
 #include "typedefs.h"  // NOLINT(build/include)
@@ -32,6 +33,7 @@
 struct AecCore;
 
 class AecDump;
+class AudioBuffer;
 class AudioFrame;
 
 class NonlinearBeamformer;
@@ -45,6 +47,7 @@
 class HighPassFilter;
 class LevelEstimator;
 class NoiseSuppression;
+class PostProcessing;
 class VoiceDetection;
 
 // Use to enable the extended filter mode in the AEC, along with robustness
@@ -359,9 +362,15 @@
   static AudioProcessing* Create();
   // Allows passing in an optional configuration at create-time.
   static AudioProcessing* Create(const webrtc::Config& config);
-  // Only for testing.
+  // Deprecated. Use the Create below, with nullptr PostProcessing.
+  RTC_DEPRECATED
   static AudioProcessing* Create(const webrtc::Config& config,
                                  NonlinearBeamformer* beamformer);
+  // Allows passing in optional user-defined processing modules.
+  static AudioProcessing* Create(
+      const webrtc::Config& config,
+      std::unique_ptr<PostProcessing> capture_post_processor,
+      NonlinearBeamformer* beamformer);
   ~AudioProcessing() override {}
 
   // Initializes internal states, while retaining all user settings. This
@@ -1087,6 +1096,19 @@
   virtual ~NoiseSuppression() {}
 };
 
+// Interface for a post processing submodule.
+class PostProcessing {
+ public:
+  // (Re-)Initializes the submodule.
+  virtual void Initialize(int sample_rate_hz, int num_channels) = 0;
+  // Processes the given capture or render signal.
+  virtual void Process(AudioBuffer* audio) = 0;
+  // Returns a string representation of the module state.
+  virtual std::string ToString() const = 0;
+
+  virtual ~PostProcessing() {}
+};
+
 // The voice activity detection (VAD) component analyzes the stream to
 // determine if voice is present. A facility is also provided to pass in an
 // external VAD decision.
diff --git a/modules/audio_processing/include/mock_audio_processing.h b/modules/audio_processing/include/mock_audio_processing.h
index 7a1d447..c7a0f51 100644
--- a/modules/audio_processing/include/mock_audio_processing.h
+++ b/modules/audio_processing/include/mock_audio_processing.h
@@ -104,6 +104,14 @@
   MOCK_METHOD0(NoiseEstimate, std::vector<float>());
 };
 
+class MockPostProcessing : public PostProcessing {
+ public:
+  virtual ~MockPostProcessing() {}
+  MOCK_METHOD2(Initialize, void(int sample_rate_hz, int num_channels));
+  MOCK_METHOD1(Process, void(AudioBuffer* audio));
+  MOCK_CONST_METHOD0(ToString, std::string());
+};
+
 class MockVoiceDetection : public VoiceDetection {
  public:
   virtual ~MockVoiceDetection() {}