Using the NS noise estimate for the IE
Review URL: https://codereview.webrtc.org/1672343002
Cr-Commit-Position: refs/heads/master@{#11559}
diff --git a/webrtc/modules/audio_processing/audio_processing_impl.cc b/webrtc/modules/audio_processing/audio_processing_impl.cc
index 34937f2..b3f38f4 100644
--- a/webrtc/modules/audio_processing/audio_processing_impl.cc
+++ b/webrtc/modules/audio_processing/audio_processing_impl.cc
@@ -771,12 +771,6 @@
ca->SplitIntoFrequencyBands();
}
- if (constants_.intelligibility_enabled) {
- public_submodules_->intelligibility_enhancer->AnalyzeCaptureAudio(
- ca->split_channels_f(kBand0To8kHz), capture_nonlocked_.split_rate,
- ca->num_channels());
- }
-
if (capture_nonlocked_.beamformer_enabled) {
private_submodules_->beamformer->ProcessChunk(*ca->split_data_f(),
ca->split_data_f());
@@ -793,6 +787,11 @@
ca->CopyLowPassToReference();
}
public_submodules_->noise_suppression->ProcessCaptureAudio(ca);
+ if (constants_.intelligibility_enabled) {
+ RTC_DCHECK(public_submodules_->noise_suppression->is_enabled());
+ public_submodules_->intelligibility_enhancer->SetCaptureNoiseEstimate(
+ public_submodules_->noise_suppression->NoiseEstimate());
+ }
RETURN_ON_ERR(
public_submodules_->echo_control_mobile->ProcessCaptureAudio(ca));
public_submodules_->voice_detection->ProcessCaptureAudio(ca);
diff --git a/webrtc/modules/audio_processing/intelligibility/intelligibility_enhancer.cc b/webrtc/modules/audio_processing/intelligibility/intelligibility_enhancer.cc
index fe964ab..c42a173 100644
--- a/webrtc/modules/audio_processing/intelligibility/intelligibility_enhancer.cc
+++ b/webrtc/modules/audio_processing/intelligibility/intelligibility_enhancer.cc
@@ -39,6 +39,26 @@
const float kLambdaBot = -1.0f; // Extreme values in bisection
const float kLambdaTop = -10e-18f; // search for lamda.
+// Returns dot product of vectors |a| and |b| with size |length|.
+float DotProduct(const float* a, const float* b, size_t length) {
+ float ret = 0.f;
+ for (size_t i = 0; i < length; ++i) {
+ ret = fmaf(a[i], b[i], ret);
+ }
+ return ret;
+}
+
+// Computes the power across ERB filters from the power spectral density |var|.
+// Stores it in |result|.
+void FilterVariance(const float* var,
+ const std::vector<std::vector<float>>& filter_bank,
+ float* result) {
+ for (size_t i = 0; i < filter_bank.size(); ++i) {
+ RTC_DCHECK_GT(filter_bank[i].size(), 0u);
+ result[i] = DotProduct(&filter_bank[i][0], var, filter_bank[i].size());
+ }
+}
+
} // namespace
using std::complex;
@@ -47,9 +67,8 @@
using VarianceType = intelligibility::VarianceArray::StepType;
IntelligibilityEnhancer::TransformCallback::TransformCallback(
- IntelligibilityEnhancer* parent,
- IntelligibilityEnhancer::AudioSource source)
- : parent_(parent), source_(source) {
+ IntelligibilityEnhancer* parent)
+ : parent_(parent) {
}
void IntelligibilityEnhancer::TransformCallback::ProcessAudioBlock(
@@ -60,7 +79,7 @@
complex<float>* const* out_block) {
RTC_DCHECK_EQ(parent_->freqs_, frames);
for (size_t i = 0; i < in_channels; ++i) {
- parent_->DispatchAudio(source_, in_block[i], out_block[i]);
+ parent_->ProcessClearBlock(in_block[i], out_block[i]);
}
}
@@ -85,27 +104,26 @@
config.var_type,
config.var_window_size,
config.var_decay_rate),
- noise_variance_(freqs_,
- config.var_type,
- config.var_window_size,
- config.var_decay_rate),
filtered_clear_var_(new float[bank_size_]),
filtered_noise_var_(new float[bank_size_]),
- filter_bank_(bank_size_),
center_freqs_(new float[bank_size_]),
+ render_filter_bank_(CreateErbBank(freqs_)),
rho_(new float[bank_size_]),
gains_eq_(new float[bank_size_]),
gain_applier_(freqs_, config.gain_change_limit),
temp_render_out_buffer_(chunk_length_, num_render_channels_),
- temp_capture_out_buffer_(chunk_length_, num_capture_channels_),
kbd_window_(new float[window_size_]),
- render_callback_(this, AudioSource::kRenderStream),
- capture_callback_(this, AudioSource::kCaptureStream),
+ render_callback_(this),
block_count_(0),
analysis_step_(0) {
RTC_DCHECK_LE(config.rho, 1.0f);
- CreateErbBank();
+ memset(filtered_clear_var_.get(),
+ 0,
+ bank_size_ * sizeof(filtered_clear_var_[0]));
+ memset(filtered_noise_var_.get(),
+ 0,
+ bank_size_ * sizeof(filtered_noise_var_[0]));
// Assumes all rho equal.
for (size_t i = 0; i < bank_size_; ++i) {
@@ -122,9 +140,20 @@
render_mangler_.reset(new LappedTransform(
num_render_channels_, num_render_channels_, chunk_length_,
kbd_window_.get(), window_size_, window_size_ / 2, &render_callback_));
- capture_mangler_.reset(new LappedTransform(
- num_capture_channels_, num_capture_channels_, chunk_length_,
- kbd_window_.get(), window_size_, window_size_ / 2, &capture_callback_));
+}
+
+void IntelligibilityEnhancer::SetCaptureNoiseEstimate(
+ std::vector<float> noise) {
+ if (capture_filter_bank_.size() != bank_size_ ||
+ capture_filter_bank_[0].size() != noise.size()) {
+ capture_filter_bank_ = CreateErbBank(noise.size());
+ }
+ if (noise.size() != noise_power_.size()) {
+ noise_power_.resize(noise.size());
+ }
+ for (size_t i = 0; i < noise.size(); ++i) {
+ noise_power_[i] = noise[i] * noise[i];
+ }
}
void IntelligibilityEnhancer::ProcessRenderAudio(float* const* audio,
@@ -145,29 +174,6 @@
}
}
-void IntelligibilityEnhancer::AnalyzeCaptureAudio(float* const* audio,
- int sample_rate_hz,
- size_t num_channels) {
- RTC_CHECK_EQ(sample_rate_hz_, sample_rate_hz);
- RTC_CHECK_EQ(num_capture_channels_, num_channels);
-
- capture_mangler_->ProcessChunk(audio, temp_capture_out_buffer_.channels());
-}
-
-void IntelligibilityEnhancer::DispatchAudio(
- IntelligibilityEnhancer::AudioSource source,
- const complex<float>* in_block,
- complex<float>* out_block) {
- switch (source) {
- case kRenderStream:
- ProcessClearBlock(in_block, out_block);
- break;
- case kCaptureStream:
- ProcessNoiseBlock(in_block, out_block);
- break;
- }
-}
-
void IntelligibilityEnhancer::ProcessClearBlock(const complex<float>* in_block,
complex<float>* out_block) {
if (block_count_ < 2) {
@@ -194,9 +200,12 @@
}
void IntelligibilityEnhancer::AnalyzeClearBlock(float power_target) {
- FilterVariance(clear_variance_.variance(), filtered_clear_var_.get());
- FilterVariance(noise_variance_.variance(), filtered_noise_var_.get());
-
+ FilterVariance(clear_variance_.variance(),
+ render_filter_bank_,
+ filtered_clear_var_.get());
+ FilterVariance(&noise_power_[0],
+ capture_filter_bank_,
+ filtered_noise_var_.get());
SolveForGainsGivenLambda(kLambdaTop, start_freq_, gains_eq_.get());
const float power_top =
DotProduct(gains_eq_.get(), filtered_clear_var_.get(), bank_size_);
@@ -242,16 +251,11 @@
for (size_t i = 0; i < freqs_; ++i) {
gains[i] = 0.0f;
for (size_t j = 0; j < bank_size_; ++j) {
- gains[i] = fmaf(filter_bank_[j][i], gains_eq_[j], gains[i]);
+ gains[i] = fmaf(render_filter_bank_[j][i], gains_eq_[j], gains[i]);
}
}
}
-void IntelligibilityEnhancer::ProcessNoiseBlock(const complex<float>* in_block,
- complex<float>* /*out_block*/) {
- noise_variance_.Step(in_block);
-}
-
size_t IntelligibilityEnhancer::GetBankSize(int sample_rate,
size_t erb_resolution) {
float freq_limit = sample_rate / 2000.0f;
@@ -260,7 +264,9 @@
return erb_scale * erb_resolution;
}
-void IntelligibilityEnhancer::CreateErbBank() {
+std::vector<std::vector<float>> IntelligibilityEnhancer::CreateErbBank(
+ size_t num_freqs) {
+ std::vector<std::vector<float>> filter_bank(bank_size_);
size_t lf = 1, rf = 4;
for (size_t i = 0; i < bank_size_; ++i) {
@@ -274,58 +280,60 @@
}
for (size_t i = 0; i < bank_size_; ++i) {
- filter_bank_[i].resize(freqs_);
+ filter_bank[i].resize(num_freqs);
}
for (size_t i = 1; i <= bank_size_; ++i) {
size_t lll, ll, rr, rrr;
static const size_t kOne = 1; // Avoids repeated static_cast<>s below.
lll = static_cast<size_t>(round(
- center_freqs_[max(kOne, i - lf) - 1] * freqs_ /
+ center_freqs_[max(kOne, i - lf) - 1] * num_freqs /
(0.5f * sample_rate_hz_)));
ll = static_cast<size_t>(round(
- center_freqs_[max(kOne, i) - 1] * freqs_ / (0.5f * sample_rate_hz_)));
- lll = min(freqs_, max(lll, kOne)) - 1;
- ll = min(freqs_, max(ll, kOne)) - 1;
+ center_freqs_[max(kOne, i) - 1] * num_freqs /
+ (0.5f * sample_rate_hz_)));
+ lll = min(num_freqs, max(lll, kOne)) - 1;
+ ll = min(num_freqs, max(ll, kOne)) - 1;
rrr = static_cast<size_t>(round(
- center_freqs_[min(bank_size_, i + rf) - 1] * freqs_ /
+ center_freqs_[min(bank_size_, i + rf) - 1] * num_freqs /
(0.5f * sample_rate_hz_)));
rr = static_cast<size_t>(round(
- center_freqs_[min(bank_size_, i + 1) - 1] * freqs_ /
+ center_freqs_[min(bank_size_, i + 1) - 1] * num_freqs /
(0.5f * sample_rate_hz_)));
- rrr = min(freqs_, max(rrr, kOne)) - 1;
- rr = min(freqs_, max(rr, kOne)) - 1;
+ rrr = min(num_freqs, max(rrr, kOne)) - 1;
+ rr = min(num_freqs, max(rr, kOne)) - 1;
float step, element;
step = 1.0f / (ll - lll);
element = 0.0f;
for (size_t j = lll; j <= ll; ++j) {
- filter_bank_[i - 1][j] = element;
+ filter_bank[i - 1][j] = element;
element += step;
}
step = 1.0f / (rrr - rr);
element = 1.0f;
for (size_t j = rr; j <= rrr; ++j) {
- filter_bank_[i - 1][j] = element;
+ filter_bank[i - 1][j] = element;
element -= step;
}
for (size_t j = ll; j <= rr; ++j) {
- filter_bank_[i - 1][j] = 1.0f;
+ filter_bank[i - 1][j] = 1.0f;
}
}
float sum;
- for (size_t i = 0; i < freqs_; ++i) {
+ for (size_t i = 0; i < num_freqs; ++i) {
sum = 0.0f;
for (size_t j = 0; j < bank_size_; ++j) {
- sum += filter_bank_[j][i];
+ sum += filter_bank[j][i];
}
for (size_t j = 0; j < bank_size_; ++j) {
- filter_bank_[j][i] /= sum;
+ filter_bank[j][i] /= sum;
}
}
+ return filter_bank;
}
void IntelligibilityEnhancer::SolveForGainsGivenLambda(float lambda,
@@ -356,24 +364,6 @@
}
}
-void IntelligibilityEnhancer::FilterVariance(const float* var, float* result) {
- RTC_DCHECK_GT(freqs_, 0u);
- for (size_t i = 0; i < bank_size_; ++i) {
- result[i] = DotProduct(&filter_bank_[i][0], var, freqs_);
- }
-}
-
-float IntelligibilityEnhancer::DotProduct(const float* a,
- const float* b,
- size_t length) {
- float ret = 0.0f;
-
- for (size_t i = 0; i < length; ++i) {
- ret = fmaf(a[i], b[i], ret);
- }
- return ret;
-}
-
bool IntelligibilityEnhancer::active() const {
return active_;
}
diff --git a/webrtc/modules/audio_processing/intelligibility/intelligibility_enhancer.h b/webrtc/modules/audio_processing/intelligibility/intelligibility_enhancer.h
index 1eb2234..fade144 100644
--- a/webrtc/modules/audio_processing/intelligibility/intelligibility_enhancer.h
+++ b/webrtc/modules/audio_processing/intelligibility/intelligibility_enhancer.h
@@ -60,10 +60,8 @@
explicit IntelligibilityEnhancer(const Config& config);
IntelligibilityEnhancer(); // Initialize with default config.
- // Reads and processes chunk of noise stream in time domain.
- void AnalyzeCaptureAudio(float* const* audio,
- int sample_rate_hz,
- size_t num_channels);
+ // Sets the capture noise magnitude spectrum estimate.
+ void SetCaptureNoiseEstimate(std::vector<float> noise);
// Reads chunk of speech in time domain and updates with modified signal.
void ProcessRenderAudio(float* const* audio,
@@ -72,15 +70,10 @@
bool active() const;
private:
- enum AudioSource {
- kRenderStream = 0, // Clear speech stream.
- kCaptureStream, // Noise stream.
- };
-
// Provides access point to the frequency domain.
class TransformCallback : public LappedTransform::Callback {
public:
- TransformCallback(IntelligibilityEnhancer* parent, AudioSource source);
+ TransformCallback(IntelligibilityEnhancer* parent);
// All in frequency domain, receives input |in_block|, applies
// intelligibility enhancement, and writes result to |out_block|.
@@ -92,17 +85,11 @@
private:
IntelligibilityEnhancer* parent_;
- AudioSource source_;
};
friend class TransformCallback;
FRIEND_TEST_ALL_PREFIXES(IntelligibilityEnhancerTest, TestErbCreation);
FRIEND_TEST_ALL_PREFIXES(IntelligibilityEnhancerTest, TestSolveForGains);
- // Sends streams to ProcessClearBlock or ProcessNoiseBlock based on source.
- void DispatchAudio(AudioSource source,
- const std::complex<float>* in_block,
- std::complex<float>* out_block);
-
// Updates variance computation and analysis with |in_block_|,
// and writes modified speech to |out_block|.
void ProcessClearBlock(const std::complex<float>* in_block,
@@ -117,27 +104,16 @@
// Transforms freq gains to ERB gains.
void UpdateErbGains();
- // Updates variance calculation for noise input with |in_block|.
- void ProcessNoiseBlock(const std::complex<float>* in_block,
- std::complex<float>* out_block);
-
// Returns number of ERB filters.
static size_t GetBankSize(int sample_rate, size_t erb_resolution);
// Initializes ERB filterbank.
- void CreateErbBank();
+ std::vector<std::vector<float>> CreateErbBank(size_t num_freqs);
// Analytically solves quadratic for optimal gains given |lambda|.
// Negative gains are set to 0. Stores the results in |sols|.
void SolveForGainsGivenLambda(float lambda, size_t start_freq, float* sols);
- // Computes variance across ERB filters from freq variance |var|.
- // Stores in |result|.
- void FilterVariance(const float* var, float* result);
-
- // Returns dot product of vectors specified by size |length| arrays |a|,|b|.
- static float DotProduct(const float* a, const float* b, size_t length);
-
const size_t freqs_; // Num frequencies in frequency domain.
const size_t window_size_; // Window size in samples; also the block size.
const size_t chunk_length_; // Chunk size in samples.
@@ -152,11 +128,12 @@
// TODO(ekm): Add logic for updating |active_|.
intelligibility::VarianceArray clear_variance_;
- intelligibility::VarianceArray noise_variance_;
+ std::vector<float> noise_power_;
rtc::scoped_ptr<float[]> filtered_clear_var_;
rtc::scoped_ptr<float[]> filtered_noise_var_;
- std::vector<std::vector<float>> filter_bank_;
rtc::scoped_ptr<float[]> center_freqs_;
+ std::vector<std::vector<float>> capture_filter_bank_;
+ std::vector<std::vector<float>> render_filter_bank_;
size_t start_freq_;
rtc::scoped_ptr<float[]> rho_; // Production and interpretation SNR.
// for each ERB band.
@@ -166,13 +143,10 @@
// Destination buffers used to reassemble blocked chunks before overwriting
// the original input array with modifications.
ChannelBuffer<float> temp_render_out_buffer_;
- ChannelBuffer<float> temp_capture_out_buffer_;
rtc::scoped_ptr<float[]> kbd_window_;
TransformCallback render_callback_;
- TransformCallback capture_callback_;
rtc::scoped_ptr<LappedTransform> render_mangler_;
- rtc::scoped_ptr<LappedTransform> capture_mangler_;
int block_count_;
int analysis_step_;
};
diff --git a/webrtc/modules/audio_processing/intelligibility/intelligibility_enhancer_unittest.cc b/webrtc/modules/audio_processing/intelligibility/intelligibility_enhancer_unittest.cc
index ce146de..436d174 100644
--- a/webrtc/modules/audio_processing/intelligibility/intelligibility_enhancer_unittest.cc
+++ b/webrtc/modules/audio_processing/intelligibility/intelligibility_enhancer_unittest.cc
@@ -99,7 +99,6 @@
float* clear_cursor = &clear_data_[0];
float* noise_cursor = &noise_data_[0];
for (int i = 0; i < kSamples; i += kFragmentSize) {
- enh_->AnalyzeCaptureAudio(&noise_cursor, kSampleRate, kNumChannels);
enh_->ProcessRenderAudio(&clear_cursor, kSampleRate, kNumChannels);
clear_cursor += kFragmentSize;
noise_cursor += kFragmentSize;
@@ -154,7 +153,7 @@
EXPECT_NEAR(kTestCenterFreqs[i], enh_->center_freqs_[i], kMaxTestError);
ASSERT_EQ(arraysize(kTestFilterBank[0]), enh_->freqs_);
for (size_t j = 0; j < enh_->freqs_; ++j) {
- EXPECT_NEAR(kTestFilterBank[i][j], enh_->filter_bank_[i][j],
+ EXPECT_NEAR(kTestFilterBank[i][j], enh_->render_filter_bank_[i][j],
kMaxTestError);
}
}
diff --git a/webrtc/modules/audio_processing/intelligibility/test/intelligibility_proc.cc b/webrtc/modules/audio_processing/intelligibility/test/intelligibility_proc.cc
index 4d2f5f4..e02e64e 100644
--- a/webrtc/modules/audio_processing/intelligibility/test/intelligibility_proc.cc
+++ b/webrtc/modules/audio_processing/intelligibility/test/intelligibility_proc.cc
@@ -23,10 +23,14 @@
#include "gflags/gflags.h"
#include "testing/gtest/include/gtest/gtest.h"
#include "webrtc/base/checks.h"
+#include "webrtc/base/criticalsection.h"
#include "webrtc/common_audio/real_fourier.h"
#include "webrtc/common_audio/wav_file.h"
+#include "webrtc/modules/audio_processing/audio_buffer.h"
+#include "webrtc/modules/audio_processing/include/audio_processing.h"
#include "webrtc/modules/audio_processing/intelligibility/intelligibility_enhancer.h"
#include "webrtc/modules/audio_processing/intelligibility/intelligibility_utils.h"
+#include "webrtc/modules/audio_processing/noise_suppression_impl.h"
#include "webrtc/system_wrappers/include/critical_section_wrapper.h"
#include "webrtc/test/testsupport/fileutils.h"
@@ -115,6 +119,17 @@
config.analysis_rate = FLAGS_ana_rate;
config.gain_change_limit = FLAGS_gain_limit;
IntelligibilityEnhancer enh(config);
+ rtc::CriticalSection crit;
+ NoiseSuppressionImpl ns(&crit);
+ ns.Initialize(kNumChannels, FLAGS_sample_rate);
+ ns.Enable(true);
+
+ AudioBuffer capture_audio(fragment_size,
+ kNumChannels,
+ fragment_size,
+ kNumChannels,
+ fragment_size);
+ StreamConfig stream_config(FLAGS_sample_rate, kNumChannels);
// Slice the input into smaller chunks, as the APM would do, and feed them
// through the enhancer.
@@ -122,7 +137,10 @@
float* noise_cursor = &noise_fpcm[0];
for (size_t i = 0; i < samples; i += fragment_size) {
- enh.AnalyzeCaptureAudio(&noise_cursor, FLAGS_sample_rate, kNumChannels);
+ capture_audio.CopyFrom(&noise_cursor, stream_config);
+ ns.AnalyzeCaptureAudio(&capture_audio);
+ ns.ProcessCaptureAudio(&capture_audio);
+ enh.SetCaptureNoiseEstimate(ns.NoiseEstimate());
enh.ProcessRenderAudio(&clear_cursor, FLAGS_sample_rate, kNumChannels);
clear_cursor += fragment_size;
noise_cursor += fragment_size;