Using the NS noise estimate for the IE

Review URL: https://codereview.webrtc.org/1672343002

Cr-Commit-Position: refs/heads/master@{#11559}
diff --git a/webrtc/modules/audio_processing/audio_processing_impl.cc b/webrtc/modules/audio_processing/audio_processing_impl.cc
index 34937f2..b3f38f4 100644
--- a/webrtc/modules/audio_processing/audio_processing_impl.cc
+++ b/webrtc/modules/audio_processing/audio_processing_impl.cc
@@ -771,12 +771,6 @@
     ca->SplitIntoFrequencyBands();
   }
 
-  if (constants_.intelligibility_enabled) {
-    public_submodules_->intelligibility_enhancer->AnalyzeCaptureAudio(
-        ca->split_channels_f(kBand0To8kHz), capture_nonlocked_.split_rate,
-        ca->num_channels());
-  }
-
   if (capture_nonlocked_.beamformer_enabled) {
     private_submodules_->beamformer->ProcessChunk(*ca->split_data_f(),
                                                   ca->split_data_f());
@@ -793,6 +787,11 @@
     ca->CopyLowPassToReference();
   }
   public_submodules_->noise_suppression->ProcessCaptureAudio(ca);
+  if (constants_.intelligibility_enabled) {
+    RTC_DCHECK(public_submodules_->noise_suppression->is_enabled());
+    public_submodules_->intelligibility_enhancer->SetCaptureNoiseEstimate(
+        public_submodules_->noise_suppression->NoiseEstimate());
+  }
   RETURN_ON_ERR(
       public_submodules_->echo_control_mobile->ProcessCaptureAudio(ca));
   public_submodules_->voice_detection->ProcessCaptureAudio(ca);
diff --git a/webrtc/modules/audio_processing/intelligibility/intelligibility_enhancer.cc b/webrtc/modules/audio_processing/intelligibility/intelligibility_enhancer.cc
index fe964ab..c42a173 100644
--- a/webrtc/modules/audio_processing/intelligibility/intelligibility_enhancer.cc
+++ b/webrtc/modules/audio_processing/intelligibility/intelligibility_enhancer.cc
@@ -39,6 +39,26 @@
 const float kLambdaBot = -1.0f;      // Extreme values in bisection
 const float kLambdaTop = -10e-18f;  // search for lamda.
 
+// Returns dot product of vectors |a| and |b| with size |length|.
+float DotProduct(const float* a, const float* b, size_t length) {
+  float ret = 0.f;
+  for (size_t i = 0; i < length; ++i) {
+    ret = fmaf(a[i], b[i], ret);
+  }
+  return ret;
+}
+
+// Computes the power across ERB filters from the power spectral density |var|.
+// Stores it in |result|.
+void FilterVariance(const float* var,
+                    const std::vector<std::vector<float>>& filter_bank,
+                    float* result) {
+  for (size_t i = 0; i < filter_bank.size(); ++i) {
+    RTC_DCHECK_GT(filter_bank[i].size(), 0u);
+    result[i] = DotProduct(&filter_bank[i][0], var, filter_bank[i].size());
+  }
+}
+
 }  // namespace
 
 using std::complex;
@@ -47,9 +67,8 @@
 using VarianceType = intelligibility::VarianceArray::StepType;
 
 IntelligibilityEnhancer::TransformCallback::TransformCallback(
-    IntelligibilityEnhancer* parent,
-    IntelligibilityEnhancer::AudioSource source)
-    : parent_(parent), source_(source) {
+    IntelligibilityEnhancer* parent)
+    : parent_(parent) {
 }
 
 void IntelligibilityEnhancer::TransformCallback::ProcessAudioBlock(
@@ -60,7 +79,7 @@
     complex<float>* const* out_block) {
   RTC_DCHECK_EQ(parent_->freqs_, frames);
   for (size_t i = 0; i < in_channels; ++i) {
-    parent_->DispatchAudio(source_, in_block[i], out_block[i]);
+    parent_->ProcessClearBlock(in_block[i], out_block[i]);
   }
 }
 
@@ -85,27 +104,26 @@
                       config.var_type,
                       config.var_window_size,
                       config.var_decay_rate),
-      noise_variance_(freqs_,
-                      config.var_type,
-                      config.var_window_size,
-                      config.var_decay_rate),
       filtered_clear_var_(new float[bank_size_]),
       filtered_noise_var_(new float[bank_size_]),
-      filter_bank_(bank_size_),
       center_freqs_(new float[bank_size_]),
+      render_filter_bank_(CreateErbBank(freqs_)),
       rho_(new float[bank_size_]),
       gains_eq_(new float[bank_size_]),
       gain_applier_(freqs_, config.gain_change_limit),
       temp_render_out_buffer_(chunk_length_, num_render_channels_),
-      temp_capture_out_buffer_(chunk_length_, num_capture_channels_),
       kbd_window_(new float[window_size_]),
-      render_callback_(this, AudioSource::kRenderStream),
-      capture_callback_(this, AudioSource::kCaptureStream),
+      render_callback_(this),
       block_count_(0),
       analysis_step_(0) {
   RTC_DCHECK_LE(config.rho, 1.0f);
 
-  CreateErbBank();
+  memset(filtered_clear_var_.get(),
+         0,
+         bank_size_ * sizeof(filtered_clear_var_[0]));
+  memset(filtered_noise_var_.get(),
+         0,
+         bank_size_ * sizeof(filtered_noise_var_[0]));
 
   // Assumes all rho equal.
   for (size_t i = 0; i < bank_size_; ++i) {
@@ -122,9 +140,20 @@
   render_mangler_.reset(new LappedTransform(
       num_render_channels_, num_render_channels_, chunk_length_,
       kbd_window_.get(), window_size_, window_size_ / 2, &render_callback_));
-  capture_mangler_.reset(new LappedTransform(
-      num_capture_channels_, num_capture_channels_, chunk_length_,
-      kbd_window_.get(), window_size_, window_size_ / 2, &capture_callback_));
+}
+
+void IntelligibilityEnhancer::SetCaptureNoiseEstimate(
+    std::vector<float> noise) {
+  if (capture_filter_bank_.size() != bank_size_ ||
+      capture_filter_bank_[0].size() != noise.size()) {
+    capture_filter_bank_ = CreateErbBank(noise.size());
+  }
+  if (noise.size() != noise_power_.size()) {
+    noise_power_.resize(noise.size());
+  }
+  for (size_t i = 0; i < noise.size(); ++i) {
+    noise_power_[i] = noise[i] * noise[i];
+  }
 }
 
 void IntelligibilityEnhancer::ProcessRenderAudio(float* const* audio,
@@ -145,29 +174,6 @@
   }
 }
 
-void IntelligibilityEnhancer::AnalyzeCaptureAudio(float* const* audio,
-                                                  int sample_rate_hz,
-                                                  size_t num_channels) {
-  RTC_CHECK_EQ(sample_rate_hz_, sample_rate_hz);
-  RTC_CHECK_EQ(num_capture_channels_, num_channels);
-
-  capture_mangler_->ProcessChunk(audio, temp_capture_out_buffer_.channels());
-}
-
-void IntelligibilityEnhancer::DispatchAudio(
-    IntelligibilityEnhancer::AudioSource source,
-    const complex<float>* in_block,
-    complex<float>* out_block) {
-  switch (source) {
-    case kRenderStream:
-      ProcessClearBlock(in_block, out_block);
-      break;
-    case kCaptureStream:
-      ProcessNoiseBlock(in_block, out_block);
-      break;
-  }
-}
-
 void IntelligibilityEnhancer::ProcessClearBlock(const complex<float>* in_block,
                                                 complex<float>* out_block) {
   if (block_count_ < 2) {
@@ -194,9 +200,12 @@
 }
 
 void IntelligibilityEnhancer::AnalyzeClearBlock(float power_target) {
-  FilterVariance(clear_variance_.variance(), filtered_clear_var_.get());
-  FilterVariance(noise_variance_.variance(), filtered_noise_var_.get());
-
+  FilterVariance(clear_variance_.variance(),
+                 render_filter_bank_,
+                 filtered_clear_var_.get());
+  FilterVariance(&noise_power_[0],
+                 capture_filter_bank_,
+                 filtered_noise_var_.get());
   SolveForGainsGivenLambda(kLambdaTop, start_freq_, gains_eq_.get());
   const float power_top =
       DotProduct(gains_eq_.get(), filtered_clear_var_.get(), bank_size_);
@@ -242,16 +251,11 @@
   for (size_t i = 0; i < freqs_; ++i) {
     gains[i] = 0.0f;
     for (size_t j = 0; j < bank_size_; ++j) {
-      gains[i] = fmaf(filter_bank_[j][i], gains_eq_[j], gains[i]);
+      gains[i] = fmaf(render_filter_bank_[j][i], gains_eq_[j], gains[i]);
     }
   }
 }
 
-void IntelligibilityEnhancer::ProcessNoiseBlock(const complex<float>* in_block,
-                                                complex<float>* /*out_block*/) {
-  noise_variance_.Step(in_block);
-}
-
 size_t IntelligibilityEnhancer::GetBankSize(int sample_rate,
                                             size_t erb_resolution) {
   float freq_limit = sample_rate / 2000.0f;
@@ -260,7 +264,9 @@
   return erb_scale * erb_resolution;
 }
 
-void IntelligibilityEnhancer::CreateErbBank() {
+std::vector<std::vector<float>> IntelligibilityEnhancer::CreateErbBank(
+    size_t num_freqs) {
+  std::vector<std::vector<float>> filter_bank(bank_size_);
   size_t lf = 1, rf = 4;
 
   for (size_t i = 0; i < bank_size_; ++i) {
@@ -274,58 +280,60 @@
   }
 
   for (size_t i = 0; i < bank_size_; ++i) {
-    filter_bank_[i].resize(freqs_);
+    filter_bank[i].resize(num_freqs);
   }
 
   for (size_t i = 1; i <= bank_size_; ++i) {
     size_t lll, ll, rr, rrr;
     static const size_t kOne = 1;  // Avoids repeated static_cast<>s below.
     lll = static_cast<size_t>(round(
-        center_freqs_[max(kOne, i - lf) - 1] * freqs_ /
+        center_freqs_[max(kOne, i - lf) - 1] * num_freqs /
             (0.5f * sample_rate_hz_)));
     ll = static_cast<size_t>(round(
-        center_freqs_[max(kOne, i) - 1] * freqs_ / (0.5f * sample_rate_hz_)));
-    lll = min(freqs_, max(lll, kOne)) - 1;
-    ll = min(freqs_, max(ll, kOne)) - 1;
+        center_freqs_[max(kOne, i) - 1] * num_freqs /
+            (0.5f * sample_rate_hz_)));
+    lll = min(num_freqs, max(lll, kOne)) - 1;
+    ll = min(num_freqs, max(ll, kOne)) - 1;
 
     rrr = static_cast<size_t>(round(
-        center_freqs_[min(bank_size_, i + rf) - 1] * freqs_ /
+        center_freqs_[min(bank_size_, i + rf) - 1] * num_freqs /
             (0.5f * sample_rate_hz_)));
     rr = static_cast<size_t>(round(
-        center_freqs_[min(bank_size_, i + 1) - 1] * freqs_ /
+        center_freqs_[min(bank_size_, i + 1) - 1] * num_freqs /
             (0.5f * sample_rate_hz_)));
-    rrr = min(freqs_, max(rrr, kOne)) - 1;
-    rr = min(freqs_, max(rr, kOne)) - 1;
+    rrr = min(num_freqs, max(rrr, kOne)) - 1;
+    rr = min(num_freqs, max(rr, kOne)) - 1;
 
     float step, element;
 
     step = 1.0f / (ll - lll);
     element = 0.0f;
     for (size_t j = lll; j <= ll; ++j) {
-      filter_bank_[i - 1][j] = element;
+      filter_bank[i - 1][j] = element;
       element += step;
     }
     step = 1.0f / (rrr - rr);
     element = 1.0f;
     for (size_t j = rr; j <= rrr; ++j) {
-      filter_bank_[i - 1][j] = element;
+      filter_bank[i - 1][j] = element;
       element -= step;
     }
     for (size_t j = ll; j <= rr; ++j) {
-      filter_bank_[i - 1][j] = 1.0f;
+      filter_bank[i - 1][j] = 1.0f;
     }
   }
 
   float sum;
-  for (size_t i = 0; i < freqs_; ++i) {
+  for (size_t i = 0; i < num_freqs; ++i) {
     sum = 0.0f;
     for (size_t j = 0; j < bank_size_; ++j) {
-      sum += filter_bank_[j][i];
+      sum += filter_bank[j][i];
     }
     for (size_t j = 0; j < bank_size_; ++j) {
-      filter_bank_[j][i] /= sum;
+      filter_bank[j][i] /= sum;
     }
   }
+  return filter_bank;
 }
 
 void IntelligibilityEnhancer::SolveForGainsGivenLambda(float lambda,
@@ -356,24 +364,6 @@
   }
 }
 
-void IntelligibilityEnhancer::FilterVariance(const float* var, float* result) {
-  RTC_DCHECK_GT(freqs_, 0u);
-  for (size_t i = 0; i < bank_size_; ++i) {
-    result[i] = DotProduct(&filter_bank_[i][0], var, freqs_);
-  }
-}
-
-float IntelligibilityEnhancer::DotProduct(const float* a,
-                                          const float* b,
-                                          size_t length) {
-  float ret = 0.0f;
-
-  for (size_t i = 0; i < length; ++i) {
-    ret = fmaf(a[i], b[i], ret);
-  }
-  return ret;
-}
-
 bool IntelligibilityEnhancer::active() const {
   return active_;
 }
diff --git a/webrtc/modules/audio_processing/intelligibility/intelligibility_enhancer.h b/webrtc/modules/audio_processing/intelligibility/intelligibility_enhancer.h
index 1eb2234..fade144 100644
--- a/webrtc/modules/audio_processing/intelligibility/intelligibility_enhancer.h
+++ b/webrtc/modules/audio_processing/intelligibility/intelligibility_enhancer.h
@@ -60,10 +60,8 @@
   explicit IntelligibilityEnhancer(const Config& config);
   IntelligibilityEnhancer();  // Initialize with default config.
 
-  // Reads and processes chunk of noise stream in time domain.
-  void AnalyzeCaptureAudio(float* const* audio,
-                           int sample_rate_hz,
-                           size_t num_channels);
+  // Sets the capture noise magnitude spectrum estimate.
+  void SetCaptureNoiseEstimate(std::vector<float> noise);
 
   // Reads chunk of speech in time domain and updates with modified signal.
   void ProcessRenderAudio(float* const* audio,
@@ -72,15 +70,10 @@
   bool active() const;
 
  private:
-  enum AudioSource {
-    kRenderStream = 0,  // Clear speech stream.
-    kCaptureStream,  // Noise stream.
-  };
-
   // Provides access point to the frequency domain.
   class TransformCallback : public LappedTransform::Callback {
    public:
-    TransformCallback(IntelligibilityEnhancer* parent, AudioSource source);
+    TransformCallback(IntelligibilityEnhancer* parent);
 
     // All in frequency domain, receives input |in_block|, applies
     // intelligibility enhancement, and writes result to |out_block|.
@@ -92,17 +85,11 @@
 
    private:
     IntelligibilityEnhancer* parent_;
-    AudioSource source_;
   };
   friend class TransformCallback;
   FRIEND_TEST_ALL_PREFIXES(IntelligibilityEnhancerTest, TestErbCreation);
   FRIEND_TEST_ALL_PREFIXES(IntelligibilityEnhancerTest, TestSolveForGains);
 
-  // Sends streams to ProcessClearBlock or ProcessNoiseBlock based on source.
-  void DispatchAudio(AudioSource source,
-                     const std::complex<float>* in_block,
-                     std::complex<float>* out_block);
-
   // Updates variance computation and analysis with |in_block_|,
   // and writes modified speech to |out_block|.
   void ProcessClearBlock(const std::complex<float>* in_block,
@@ -117,27 +104,16 @@
   // Transforms freq gains to ERB gains.
   void UpdateErbGains();
 
-  // Updates variance calculation for noise input with |in_block|.
-  void ProcessNoiseBlock(const std::complex<float>* in_block,
-                         std::complex<float>* out_block);
-
   // Returns number of ERB filters.
   static size_t GetBankSize(int sample_rate, size_t erb_resolution);
 
   // Initializes ERB filterbank.
-  void CreateErbBank();
+  std::vector<std::vector<float>> CreateErbBank(size_t num_freqs);
 
   // Analytically solves quadratic for optimal gains given |lambda|.
   // Negative gains are set to 0. Stores the results in |sols|.
   void SolveForGainsGivenLambda(float lambda, size_t start_freq, float* sols);
 
-  // Computes variance across ERB filters from freq variance |var|.
-  // Stores in |result|.
-  void FilterVariance(const float* var, float* result);
-
-  // Returns dot product of vectors specified by size |length| arrays |a|,|b|.
-  static float DotProduct(const float* a, const float* b, size_t length);
-
   const size_t freqs_;         // Num frequencies in frequency domain.
   const size_t window_size_;   // Window size in samples; also the block size.
   const size_t chunk_length_;  // Chunk size in samples.
@@ -152,11 +128,12 @@
                                // TODO(ekm): Add logic for updating |active_|.
 
   intelligibility::VarianceArray clear_variance_;
-  intelligibility::VarianceArray noise_variance_;
+  std::vector<float> noise_power_;
   rtc::scoped_ptr<float[]> filtered_clear_var_;
   rtc::scoped_ptr<float[]> filtered_noise_var_;
-  std::vector<std::vector<float>> filter_bank_;
   rtc::scoped_ptr<float[]> center_freqs_;
+  std::vector<std::vector<float>> capture_filter_bank_;
+  std::vector<std::vector<float>> render_filter_bank_;
   size_t start_freq_;
   rtc::scoped_ptr<float[]> rho_;  // Production and interpretation SNR.
                                   // for each ERB band.
@@ -166,13 +143,10 @@
   // Destination buffers used to reassemble blocked chunks before overwriting
   // the original input array with modifications.
   ChannelBuffer<float> temp_render_out_buffer_;
-  ChannelBuffer<float> temp_capture_out_buffer_;
 
   rtc::scoped_ptr<float[]> kbd_window_;
   TransformCallback render_callback_;
-  TransformCallback capture_callback_;
   rtc::scoped_ptr<LappedTransform> render_mangler_;
-  rtc::scoped_ptr<LappedTransform> capture_mangler_;
   int block_count_;
   int analysis_step_;
 };
diff --git a/webrtc/modules/audio_processing/intelligibility/intelligibility_enhancer_unittest.cc b/webrtc/modules/audio_processing/intelligibility/intelligibility_enhancer_unittest.cc
index ce146de..436d174 100644
--- a/webrtc/modules/audio_processing/intelligibility/intelligibility_enhancer_unittest.cc
+++ b/webrtc/modules/audio_processing/intelligibility/intelligibility_enhancer_unittest.cc
@@ -99,7 +99,6 @@
     float* clear_cursor = &clear_data_[0];
     float* noise_cursor = &noise_data_[0];
     for (int i = 0; i < kSamples; i += kFragmentSize) {
-      enh_->AnalyzeCaptureAudio(&noise_cursor, kSampleRate, kNumChannels);
       enh_->ProcessRenderAudio(&clear_cursor, kSampleRate, kNumChannels);
       clear_cursor += kFragmentSize;
       noise_cursor += kFragmentSize;
@@ -154,7 +153,7 @@
     EXPECT_NEAR(kTestCenterFreqs[i], enh_->center_freqs_[i], kMaxTestError);
     ASSERT_EQ(arraysize(kTestFilterBank[0]), enh_->freqs_);
     for (size_t j = 0; j < enh_->freqs_; ++j) {
-      EXPECT_NEAR(kTestFilterBank[i][j], enh_->filter_bank_[i][j],
+      EXPECT_NEAR(kTestFilterBank[i][j], enh_->render_filter_bank_[i][j],
                   kMaxTestError);
     }
   }
diff --git a/webrtc/modules/audio_processing/intelligibility/test/intelligibility_proc.cc b/webrtc/modules/audio_processing/intelligibility/test/intelligibility_proc.cc
index 4d2f5f4..e02e64e 100644
--- a/webrtc/modules/audio_processing/intelligibility/test/intelligibility_proc.cc
+++ b/webrtc/modules/audio_processing/intelligibility/test/intelligibility_proc.cc
@@ -23,10 +23,14 @@
 #include "gflags/gflags.h"
 #include "testing/gtest/include/gtest/gtest.h"
 #include "webrtc/base/checks.h"
+#include "webrtc/base/criticalsection.h"
 #include "webrtc/common_audio/real_fourier.h"
 #include "webrtc/common_audio/wav_file.h"
+#include "webrtc/modules/audio_processing/audio_buffer.h"
+#include "webrtc/modules/audio_processing/include/audio_processing.h"
 #include "webrtc/modules/audio_processing/intelligibility/intelligibility_enhancer.h"
 #include "webrtc/modules/audio_processing/intelligibility/intelligibility_utils.h"
+#include "webrtc/modules/audio_processing/noise_suppression_impl.h"
 #include "webrtc/system_wrappers/include/critical_section_wrapper.h"
 #include "webrtc/test/testsupport/fileutils.h"
 
@@ -115,6 +119,17 @@
   config.analysis_rate = FLAGS_ana_rate;
   config.gain_change_limit = FLAGS_gain_limit;
   IntelligibilityEnhancer enh(config);
+  rtc::CriticalSection crit;
+  NoiseSuppressionImpl ns(&crit);
+  ns.Initialize(kNumChannels, FLAGS_sample_rate);
+  ns.Enable(true);
+
+  AudioBuffer capture_audio(fragment_size,
+                            kNumChannels,
+                            fragment_size,
+                            kNumChannels,
+                            fragment_size);
+  StreamConfig stream_config(FLAGS_sample_rate, kNumChannels);
 
   // Slice the input into smaller chunks, as the APM would do, and feed them
   // through the enhancer.
@@ -122,7 +137,10 @@
   float* noise_cursor = &noise_fpcm[0];
 
   for (size_t i = 0; i < samples; i += fragment_size) {
-    enh.AnalyzeCaptureAudio(&noise_cursor, FLAGS_sample_rate, kNumChannels);
+    capture_audio.CopyFrom(&noise_cursor, stream_config);
+    ns.AnalyzeCaptureAudio(&capture_audio);
+    ns.ProcessCaptureAudio(&capture_audio);
+    enh.SetCaptureNoiseEstimate(ns.NoiseEstimate());
     enh.ProcessRenderAudio(&clear_cursor, FLAGS_sample_rate, kNumChannels);
     clear_cursor += fragment_size;
     noise_cursor += fragment_size;