Make the echo detector injectable.

This adds a generic interface for an echo detector, and makes it possible to inject one into the audio processing module.

Bug: webrtc:8732
Change-Id: I30d97aeb829307b2ae9c4dbeb9a3e15ab7ec0912
Reviewed-on: https://webrtc-review.googlesource.com/38900
Commit-Queue: Ivo Creusen <ivoc@webrtc.org>
Reviewed-by: Per Ã…hgren <peah@webrtc.org>
Cr-Commit-Position: refs/heads/master@{#21588}
diff --git a/modules/audio_processing/audio_processing_impl.cc b/modules/audio_processing/audio_processing_impl.cc
index 4629ebf..accd03b 100644
--- a/modules/audio_processing/audio_processing_impl.cc
+++ b/modules/audio_processing/audio_processing_impl.cc
@@ -302,8 +302,10 @@
 struct AudioProcessingImpl::ApmPrivateSubmodules {
   ApmPrivateSubmodules(NonlinearBeamformer* beamformer,
                        std::unique_ptr<CustomProcessing> capture_post_processor,
-                       std::unique_ptr<CustomProcessing> render_pre_processor)
+                       std::unique_ptr<CustomProcessing> render_pre_processor,
+                       std::unique_ptr<EchoDetector> echo_detector)
       : beamformer(beamformer),
+        echo_detector(std::move(echo_detector)),
         capture_post_processor(std::move(capture_post_processor)),
         render_pre_processor(std::move(render_pre_processor)) {}
   // Accessed internally from capture or during initialization
@@ -312,7 +314,7 @@
   std::unique_ptr<GainController2> gain_controller2;
   std::unique_ptr<LowCutFilter> low_cut_filter;
   std::unique_ptr<LevelController> level_controller;
-  std::unique_ptr<ResidualEchoDetector> residual_echo_detector;
+  std::unique_ptr<EchoDetector> echo_detector;
   std::unique_ptr<EchoControl> echo_controller;
   std::unique_ptr<CustomProcessing> capture_post_processor;
   std::unique_ptr<CustomProcessing> render_pre_processor;
@@ -345,16 +347,27 @@
   return *this;
 }
 
+AudioProcessingBuilder& AudioProcessingBuilder::SetEchoDetector(
+    std::unique_ptr<EchoDetector> echo_detector) {
+  echo_detector_ = std::move(echo_detector);
+  return *this;
+}
+
 AudioProcessing* AudioProcessingBuilder::Create() {
   webrtc::Config config;
   return Create(config);
 }
 
 AudioProcessing* AudioProcessingBuilder::Create(const webrtc::Config& config) {
-  return AudioProcessing::Create(config, std::move(capture_post_processing_),
-                                 std::move(render_pre_processing_),
-                                 std::move(echo_control_factory_),
-                                 nonlinear_beamformer_.release());
+  AudioProcessingImpl* apm = new rtc::RefCountedObject<AudioProcessingImpl>(
+      config, std::move(capture_post_processing_),
+      std::move(render_pre_processing_), std::move(echo_control_factory_),
+      std::move(echo_detector_), nonlinear_beamformer_.release());
+  if (apm->Initialize() != AudioProcessing::kNoError) {
+    delete apm;
+    apm = nullptr;
+  }
+  return apm;
 }
 
 AudioProcessing* AudioProcessing::Create() {
@@ -388,7 +401,7 @@
     NonlinearBeamformer* beamformer) {
   AudioProcessingImpl* apm = new rtc::RefCountedObject<AudioProcessingImpl>(
       config, std::move(capture_post_processor),
-      std::move(render_pre_processor), std::move(echo_control_factory),
+      std::move(render_pre_processor), std::move(echo_control_factory), nullptr,
       beamformer);
   if (apm->Initialize() != kNoError) {
     delete apm;
@@ -399,13 +412,15 @@
 }
 
 AudioProcessingImpl::AudioProcessingImpl(const webrtc::Config& config)
-    : AudioProcessingImpl(config, nullptr, nullptr, nullptr, nullptr) {}
+    : AudioProcessingImpl(config, nullptr, nullptr, nullptr, nullptr, nullptr) {
+}
 
 AudioProcessingImpl::AudioProcessingImpl(
     const webrtc::Config& config,
     std::unique_ptr<CustomProcessing> capture_post_processor,
     std::unique_ptr<CustomProcessing> render_pre_processor,
     std::unique_ptr<EchoControlFactory> echo_control_factory,
+    std::unique_ptr<EchoDetector> echo_detector,
     NonlinearBeamformer* beamformer)
     : high_pass_filter_impl_(new HighPassFilterImpl(this)),
       echo_control_factory_(std::move(echo_control_factory)),
@@ -414,7 +429,8 @@
       private_submodules_(
           new ApmPrivateSubmodules(beamformer,
                                    std::move(capture_post_processor),
-                                   std::move(render_pre_processor))),
+                                   std::move(render_pre_processor),
+                                   std::move(echo_detector))),
       constants_(config.Get<ExperimentalAgc>().startup_min_volume,
                  config.Get<ExperimentalAgc>().clipped_level_min,
 #if defined(WEBRTC_ANDROID) || defined(WEBRTC_IOS)
@@ -454,8 +470,11 @@
     public_submodules_->gain_control_for_experimental_agc.reset(
         new GainControlForExperimentalAgc(
             public_submodules_->gain_control.get(), &crit_capture_));
-    private_submodules_->residual_echo_detector.reset(
-        new ResidualEchoDetector());
+
+    // If no echo detector is injected, use the ResidualEchoDetector.
+    if (!private_submodules_->echo_detector) {
+      private_submodules_->echo_detector.reset(new ResidualEchoDetector());
+    }
 
     // TODO(peah): Move this creation to happen only when the level controller
     // is enabled.
@@ -1121,7 +1140,8 @@
   }
 
   while (red_render_signal_queue_->Remove(&red_capture_queue_buffer_)) {
-    private_submodules_->residual_echo_detector->AnalyzeRenderAudio(
+    RTC_DCHECK(private_submodules_->echo_detector);
+    private_submodules_->echo_detector->AnalyzeRenderAudio(
         red_capture_queue_buffer_);
   }
 }
@@ -1337,7 +1357,8 @@
   }
 
   if (config_.residual_echo_detector.enabled) {
-    private_submodules_->residual_echo_detector->AnalyzeCaptureAudio(
+    RTC_DCHECK(private_submodules_->echo_detector);
+    private_submodules_->echo_detector->AnalyzeCaptureAudio(
         rtc::ArrayView<const float>(capture_buffer->channels_f()[0],
                                     capture_buffer->num_frames()));
   }
@@ -1664,11 +1685,11 @@
   }
   {
     rtc::CritScope cs_capture(&crit_capture_);
-    stats.residual_echo_likelihood =
-        private_submodules_->residual_echo_detector->echo_likelihood();
+    RTC_DCHECK(private_submodules_->echo_detector);
+    auto ed_metrics = private_submodules_->echo_detector->GetMetrics();
+    stats.residual_echo_likelihood = ed_metrics.echo_likelihood;
     stats.residual_echo_likelihood_recent_max =
-        private_submodules_->residual_echo_detector
-            ->echo_likelihood_recent_max();
+        ed_metrics.echo_likelihood_recent_max;
   }
   public_submodules_->echo_cancellation->GetDelayMetrics(
       &stats.delay_median, &stats.delay_standard_deviation,
@@ -1705,11 +1726,11 @@
     }
     if (config_.residual_echo_detector.enabled) {
       rtc::CritScope cs_capture(&crit_capture_);
-      stats.residual_echo_likelihood = rtc::Optional<double>(
-          private_submodules_->residual_echo_detector->echo_likelihood());
+      RTC_DCHECK(private_submodules_->echo_detector);
+      auto ed_metrics = private_submodules_->echo_detector->GetMetrics();
+      stats.residual_echo_likelihood = ed_metrics.echo_likelihood;
       stats.residual_echo_likelihood_recent_max =
-          rtc::Optional<double>(private_submodules_->residual_echo_detector
-                                    ->echo_likelihood_recent_max());
+          ed_metrics.echo_likelihood_recent_max;
     }
     int delay_median, delay_std;
     float fraction_poor_delays;
@@ -1854,7 +1875,9 @@
 }
 
 void AudioProcessingImpl::InitializeResidualEchoDetector() {
-  private_submodules_->residual_echo_detector->Initialize();
+  RTC_DCHECK(private_submodules_->echo_detector);
+  private_submodules_->echo_detector->Initialize(proc_sample_rate_hz(),
+                                                 num_proc_channels());
 }
 
 void AudioProcessingImpl::InitializePostProcessor() {
diff --git a/modules/audio_processing/audio_processing_impl.h b/modules/audio_processing/audio_processing_impl.h
index c05d238..8ece029 100644
--- a/modules/audio_processing/audio_processing_impl.h
+++ b/modules/audio_processing/audio_processing_impl.h
@@ -45,6 +45,7 @@
                       std::unique_ptr<CustomProcessing> capture_post_processor,
                       std::unique_ptr<CustomProcessing> render_pre_processor,
                       std::unique_ptr<EchoControlFactory> echo_control_factory,
+                      std::unique_ptr<EchoDetector> echo_detector,
                       NonlinearBeamformer* beamformer);
   ~AudioProcessingImpl() override;
   int Initialize() override;
diff --git a/modules/audio_processing/include/audio_processing.h b/modules/audio_processing/include/audio_processing.h
index 60bf0c7..8951b8c 100644
--- a/modules/audio_processing/include/audio_processing.h
+++ b/modules/audio_processing/include/audio_processing.h
@@ -49,6 +49,7 @@
 class EchoCancellation;
 class EchoControlMobile;
 class EchoControlFactory;
+class EchoDetector;
 class GainControl;
 class HighPassFilter;
 class LevelEstimator;
@@ -665,6 +666,9 @@
   // The AudioProcessingBuilder takes ownership of the nonlinear beamformer.
   AudioProcessingBuilder& SetNonlinearBeamformer(
       std::unique_ptr<NonlinearBeamformer> nonlinear_beamformer);
+  // The AudioProcessingBuilder takes ownership of the echo_detector.
+  AudioProcessingBuilder& SetEchoDetector(
+      std::unique_ptr<EchoDetector> echo_detector);
   // This creates an APM instance using the previously set components. Calling
   // the Create function resets the AudioProcessingBuilder to its initial state.
   AudioProcessing* Create();
@@ -675,6 +679,7 @@
   std::unique_ptr<CustomProcessing> capture_post_processing_;
   std::unique_ptr<CustomProcessing> render_pre_processing_;
   std::unique_ptr<NonlinearBeamformer> nonlinear_beamformer_;
+  std::unique_ptr<EchoDetector> echo_detector_;
   RTC_DISALLOW_COPY_AND_ASSIGN(AudioProcessingBuilder);
 };
 
@@ -1147,6 +1152,34 @@
   virtual ~CustomProcessing() {}
 };
 
+// Interface for an echo detector submodule.
+class EchoDetector {
+ public:
+  // (Re-)Initializes the submodule.
+  virtual void Initialize(int sample_rate_hz, int num_channels) = 0;
+
+  // Analysis (not changing) of the render signal.
+  virtual void AnalyzeRenderAudio(rtc::ArrayView<const float> render_audio) = 0;
+
+  // Analysis (not changing) of the capture signal.
+  virtual void AnalyzeCaptureAudio(
+      rtc::ArrayView<const float> capture_audio) = 0;
+
+  // Pack an AudioBuffer into a vector<float>.
+  static void PackRenderAudioBuffer(AudioBuffer* audio,
+                                    std::vector<float>* packed_buffer);
+
+  struct Metrics {
+    double echo_likelihood;
+    double echo_likelihood_recent_max;
+  };
+
+  // Collect current metrics from the echo detector.
+  virtual Metrics GetMetrics() const = 0;
+
+  virtual ~EchoDetector() {}
+};
+
 // The voice activity detection (VAD) component analyzes the stream to
 // determine if voice is present. A facility is also provided to pass in an
 // external VAD decision.
diff --git a/modules/audio_processing/residual_echo_detector.cc b/modules/audio_processing/residual_echo_detector.cc
index b35c155..ef325a0 100644
--- a/modules/audio_processing/residual_echo_detector.cc
+++ b/modules/audio_processing/residual_echo_detector.cc
@@ -177,7 +177,8 @@
                               : 0;
 }
 
-void ResidualEchoDetector::Initialize() {
+void ResidualEchoDetector::Initialize(int /*sample_rate_hz*/,
+                                      int /*num_channels*/) {
   render_buffer_.Clear();
   std::fill(render_power_.begin(), render_power_.end(), 0.f);
   std::fill(render_power_mean_.begin(), render_power_mean_.end(), 0.f);
@@ -193,12 +194,17 @@
   reliability_ = 0.f;
 }
 
-void ResidualEchoDetector::PackRenderAudioBuffer(
-    AudioBuffer* audio,
-    std::vector<float>* packed_buffer) {
+void EchoDetector::PackRenderAudioBuffer(AudioBuffer* audio,
+                                         std::vector<float>* packed_buffer) {
   packed_buffer->clear();
   packed_buffer->insert(packed_buffer->end(), audio->channels_f()[0],
                         audio->channels_f()[0] + audio->num_frames());
 }
 
+EchoDetector::Metrics ResidualEchoDetector::GetMetrics() const {
+  EchoDetector::Metrics metrics;
+  metrics.echo_likelihood = echo_likelihood_;
+  metrics.echo_likelihood_recent_max = recent_likelihood_max_.max();
+  return metrics;
+}
 }  // namespace webrtc
diff --git a/modules/audio_processing/residual_echo_detector.h b/modules/audio_processing/residual_echo_detector.h
index de1b989..e8ae552 100644
--- a/modules/audio_processing/residual_echo_detector.h
+++ b/modules/audio_processing/residual_echo_detector.h
@@ -18,39 +18,32 @@
 #include "modules/audio_processing/echo_detector/mean_variance_estimator.h"
 #include "modules/audio_processing/echo_detector/moving_max.h"
 #include "modules/audio_processing/echo_detector/normalized_covariance_estimator.h"
+#include "modules/audio_processing/include/audio_processing.h"
 
 namespace webrtc {
 
 class ApmDataDumper;
 class AudioBuffer;
-class EchoDetector;
 
-class ResidualEchoDetector {
+class ResidualEchoDetector : public EchoDetector {
  public:
   ResidualEchoDetector();
-  ~ResidualEchoDetector();
+  ~ResidualEchoDetector() override;
 
   // This function should be called while holding the render lock.
-  void AnalyzeRenderAudio(rtc::ArrayView<const float> render_audio);
+  void AnalyzeRenderAudio(rtc::ArrayView<const float> render_audio) override;
 
   // This function should be called while holding the capture lock.
-  void AnalyzeCaptureAudio(rtc::ArrayView<const float> capture_audio);
+  void AnalyzeCaptureAudio(rtc::ArrayView<const float> capture_audio) override;
 
   // This function should be called while holding the capture lock.
-  void Initialize();
+  void Initialize(int sample_rate_hz, int num_channels) override;
 
   // This function is for testing purposes only.
   void SetReliabilityForTest(float value) { reliability_ = value; }
 
-  static void PackRenderAudioBuffer(AudioBuffer* audio,
-                                    std::vector<float>* packed_buffer);
-
   // This function should be called while holding the capture lock.
-  float echo_likelihood() const { return echo_likelihood_; }
-
-  float echo_likelihood_recent_max() const {
-    return recent_likelihood_max_.max();
-  }
+  EchoDetector::Metrics GetMetrics() const override;
 
  private:
   static int instance_count_;
diff --git a/modules/audio_processing/residual_echo_detector_unittest.cc b/modules/audio_processing/residual_echo_detector_unittest.cc
index baf83ba..7bfa0d2 100644
--- a/modules/audio_processing/residual_echo_detector_unittest.cc
+++ b/modules/audio_processing/residual_echo_detector_unittest.cc
@@ -37,7 +37,8 @@
     }
   }
   // We expect to detect echo with near certain likelihood.
-  EXPECT_NEAR(1.f, echo_detector.echo_likelihood(), 0.01f);
+  auto ed_metrics = echo_detector.GetMetrics();
+  EXPECT_NEAR(1.f, ed_metrics.echo_likelihood, 0.01f);
 }
 
 TEST(ResidualEchoDetectorTests, NoEcho) {
@@ -57,7 +58,8 @@
     echo_detector.AnalyzeCaptureAudio(zeros);
   }
   // We expect to not detect any echo.
-  EXPECT_NEAR(0.f, echo_detector.echo_likelihood(), 0.01f);
+  auto ed_metrics = echo_detector.GetMetrics();
+  EXPECT_NEAR(0.f, ed_metrics.echo_likelihood, 0.01f);
 }
 
 TEST(ResidualEchoDetectorTests, EchoWithRenderClockDrift) {
@@ -92,7 +94,8 @@
   // A growing buffer can be caused by jitter or clock drift and it's not
   // possible to make this decision right away. For this reason we only expect
   // an echo likelihood of 75% in this test.
-  EXPECT_GT(echo_detector.echo_likelihood(), 0.75f);
+  auto ed_metrics = echo_detector.GetMetrics();
+  EXPECT_GT(ed_metrics.echo_likelihood, 0.75f);
 }
 
 TEST(ResidualEchoDetectorTests, EchoWithCaptureClockDrift) {
@@ -122,7 +125,8 @@
     }
   }
   // We expect to detect echo with near certain likelihood.
-  EXPECT_NEAR(1.f, echo_detector.echo_likelihood(), 0.01f);
+  auto ed_metrics = echo_detector.GetMetrics();
+  EXPECT_NEAR(1.f, ed_metrics.echo_likelihood, 0.01f);
 }
 
 }  // namespace webrtc