blob: 1e766875caedc519004077e4a2ebfc1f993c9262 [file] [log] [blame]
ekm030249d2015-06-15 13:02:24 -07001/*
2 * Copyright (c) 2014 The WebRTC project authors. All Rights Reserved.
3 *
4 * Use of this source code is governed by a BSD-style license
5 * that can be found in the LICENSE file in the root of the source
6 * tree. An additional intellectual property rights grant can be found
7 * in the file PATENTS. All contributing project authors may
8 * be found in the AUTHORS file in the root of the source tree.
9 */
10
ekmdb4fecf2015-06-22 17:49:08 -070011//
12// Implements core class for intelligibility enhancer.
13//
14// Details of the model and algorithm can be found in the original paper:
15// http://ieeexplore.ieee.org/stamp/stamp.jsp?tp=&arnumber=6882788
16//
17
ekm030249d2015-06-15 13:02:24 -070018#include "webrtc/modules/audio_processing/intelligibility/intelligibility_enhancer.h"
19
ekm35b72fb2015-07-10 14:11:52 -070020#include <math.h>
21#include <stdlib.h>
ekm030249d2015-06-15 13:02:24 -070022
23#include <algorithm>
ekmdb4fecf2015-06-22 17:49:08 -070024#include <numeric>
ekm030249d2015-06-15 13:02:24 -070025
26#include "webrtc/base/checks.h"
27#include "webrtc/common_audio/vad/include/webrtc_vad.h"
28#include "webrtc/common_audio/window_generator.h"
29
ekm35b72fb2015-07-10 14:11:52 -070030namespace webrtc {
31
32namespace {
33
34const int kErbResolution = 2;
35const int kWindowSizeMs = 2;
36const int kChunkSizeMs = 10; // Size provided by APM.
37const float kClipFreq = 200.0f;
38const float kConfigRho = 0.02f; // Default production and interpretation SNR.
39const float kKbdAlpha = 1.5f;
40const float kLambdaBot = -1.0f; // Extreme values in bisection
41const float kLambdaTop = -10e-18f; // search for lamda.
42
43} // namespace
44
ekm030249d2015-06-15 13:02:24 -070045using std::complex;
46using std::max;
47using std::min;
ekm030249d2015-06-15 13:02:24 -070048using VarianceType = intelligibility::VarianceArray::StepType;
49
50IntelligibilityEnhancer::TransformCallback::TransformCallback(
51 IntelligibilityEnhancer* parent,
52 IntelligibilityEnhancer::AudioSource source)
ekmdb4fecf2015-06-22 17:49:08 -070053 : parent_(parent), source_(source) {
54}
ekm030249d2015-06-15 13:02:24 -070055
56void IntelligibilityEnhancer::TransformCallback::ProcessAudioBlock(
57 const complex<float>* const* in_block,
ekmdb4fecf2015-06-22 17:49:08 -070058 int in_channels,
59 int frames,
60 int /* out_channels */,
ekm030249d2015-06-15 13:02:24 -070061 complex<float>* const* out_block) {
62 DCHECK_EQ(parent_->freqs_, frames);
63 for (int i = 0; i < in_channels; ++i) {
64 parent_->DispatchAudio(source_, in_block[i], out_block[i]);
65 }
66}
67
68IntelligibilityEnhancer::IntelligibilityEnhancer(int erb_resolution,
69 int sample_rate_hz,
70 int channels,
ekmdb4fecf2015-06-22 17:49:08 -070071 int cv_type,
72 float cv_alpha,
ekm030249d2015-06-15 13:02:24 -070073 int cv_win,
74 int analysis_rate,
75 int variance_rate,
76 float gain_limit)
ekmdb4fecf2015-06-22 17:49:08 -070077 : freqs_(RealFourier::ComplexLength(
78 RealFourier::FftOrder(sample_rate_hz * kWindowSizeMs / 1000))),
ekm030249d2015-06-15 13:02:24 -070079 window_size_(1 << RealFourier::FftOrder(freqs_)),
80 chunk_length_(sample_rate_hz * kChunkSizeMs / 1000),
81 bank_size_(GetBankSize(sample_rate_hz, erb_resolution)),
82 sample_rate_hz_(sample_rate_hz),
83 erb_resolution_(erb_resolution),
84 channels_(channels),
85 analysis_rate_(analysis_rate),
86 variance_rate_(variance_rate),
ekmdb4fecf2015-06-22 17:49:08 -070087 clear_variance_(freqs_,
88 static_cast<VarianceType>(cv_type),
89 cv_win,
ekm030249d2015-06-15 13:02:24 -070090 cv_alpha),
91 noise_variance_(freqs_, VarianceType::kStepInfinite, 475, 0.01f),
92 filtered_clear_var_(new float[bank_size_]),
93 filtered_noise_var_(new float[bank_size_]),
ekm35b72fb2015-07-10 14:11:52 -070094 filter_bank_(bank_size_),
ekm030249d2015-06-15 13:02:24 -070095 center_freqs_(new float[bank_size_]),
96 rho_(new float[bank_size_]),
97 gains_eq_(new float[bank_size_]),
98 gain_applier_(freqs_, gain_limit),
99 temp_out_buffer_(nullptr),
ekmdb4fecf2015-06-22 17:49:08 -0700100 input_audio_(new float* [channels]),
ekm030249d2015-06-15 13:02:24 -0700101 kbd_window_(new float[window_size_]),
102 render_callback_(this, AudioSource::kRenderStream),
103 capture_callback_(this, AudioSource::kCaptureStream),
104 block_count_(0),
105 analysis_step_(0),
ekmdb4fecf2015-06-22 17:49:08 -0700106 vad_high_(WebRtcVad_Create()),
107 vad_low_(WebRtcVad_Create()),
ekm030249d2015-06-15 13:02:24 -0700108 vad_tmp_buffer_(new int16_t[chunk_length_]) {
109 DCHECK_LE(kConfigRho, 1.0f);
110
111 CreateErbBank();
112
ekm030249d2015-06-15 13:02:24 -0700113 WebRtcVad_Init(vad_high_);
ekmdb4fecf2015-06-22 17:49:08 -0700114 WebRtcVad_set_mode(vad_high_, 0); // High likelihood of speech.
ekm030249d2015-06-15 13:02:24 -0700115 WebRtcVad_Init(vad_low_);
ekmdb4fecf2015-06-22 17:49:08 -0700116 WebRtcVad_set_mode(vad_low_, 3); // Low likelihood of speech.
ekm030249d2015-06-15 13:02:24 -0700117
ekmdb4fecf2015-06-22 17:49:08 -0700118 temp_out_buffer_ = static_cast<float**>(
119 malloc(sizeof(*temp_out_buffer_) * channels_ +
120 sizeof(**temp_out_buffer_) * chunk_length_ * channels_));
ekm030249d2015-06-15 13:02:24 -0700121 for (int i = 0; i < channels_; ++i) {
ekmdb4fecf2015-06-22 17:49:08 -0700122 temp_out_buffer_[i] =
123 reinterpret_cast<float*>(temp_out_buffer_ + channels_) +
124 chunk_length_ * i;
ekm030249d2015-06-15 13:02:24 -0700125 }
126
ekmdb4fecf2015-06-22 17:49:08 -0700127 // Assumes all rho equal.
ekm030249d2015-06-15 13:02:24 -0700128 for (int i = 0; i < bank_size_; ++i) {
129 rho_[i] = kConfigRho * kConfigRho;
130 }
131
132 float freqs_khz = kClipFreq / 1000.0f;
ekmdb4fecf2015-06-22 17:49:08 -0700133 int erb_index = static_cast<int>(ceilf(
134 11.17f * logf((freqs_khz + 0.312f) / (freqs_khz + 14.6575f)) + 43.0f));
ekm030249d2015-06-15 13:02:24 -0700135 start_freq_ = max(1, erb_index * kErbResolution);
136
137 WindowGenerator::KaiserBesselDerived(kKbdAlpha, window_size_,
138 kbd_window_.get());
ekmdb4fecf2015-06-22 17:49:08 -0700139 render_mangler_.reset(new LappedTransform(
140 channels_, channels_, chunk_length_, kbd_window_.get(), window_size_,
141 window_size_ / 2, &render_callback_));
142 capture_mangler_.reset(new LappedTransform(
143 channels_, channels_, chunk_length_, kbd_window_.get(), window_size_,
144 window_size_ / 2, &capture_callback_));
ekm030249d2015-06-15 13:02:24 -0700145}
146
147IntelligibilityEnhancer::~IntelligibilityEnhancer() {
148 WebRtcVad_Free(vad_low_);
149 WebRtcVad_Free(vad_high_);
ekm35b72fb2015-07-10 14:11:52 -0700150 free(temp_out_buffer_);
ekm030249d2015-06-15 13:02:24 -0700151}
152
153void IntelligibilityEnhancer::ProcessRenderAudio(float* const* audio) {
154 for (int i = 0; i < chunk_length_; ++i) {
155 vad_tmp_buffer_[i] = (int16_t)audio[0][i];
156 }
157 has_voice_low_ = WebRtcVad_Process(vad_low_, sample_rate_hz_,
158 vad_tmp_buffer_.get(), chunk_length_) == 1;
159
ekmdb4fecf2015-06-22 17:49:08 -0700160 // Process and enhance chunk of |audio|
ekm030249d2015-06-15 13:02:24 -0700161 render_mangler_->ProcessChunk(audio, temp_out_buffer_);
ekmdb4fecf2015-06-22 17:49:08 -0700162
ekm030249d2015-06-15 13:02:24 -0700163 for (int i = 0; i < channels_; ++i) {
164 memcpy(audio[i], temp_out_buffer_[i],
165 chunk_length_ * sizeof(**temp_out_buffer_));
166 }
167}
168
169void IntelligibilityEnhancer::ProcessCaptureAudio(float* const* audio) {
170 for (int i = 0; i < chunk_length_; ++i) {
171 vad_tmp_buffer_[i] = (int16_t)audio[0][i];
172 }
ekmdb4fecf2015-06-22 17:49:08 -0700173 // TODO(bercic): The VAD was always detecting voice in the noise stream,
174 // no matter what the aggressiveness, so it was temporarily disabled here.
ekm030249d2015-06-15 13:02:24 -0700175
ekmdb4fecf2015-06-22 17:49:08 -0700176 #if 0
177 if (WebRtcVad_Process(vad_high_, sample_rate_hz_, vad_tmp_buffer_.get(),
178 chunk_length_) == 1) {
179 printf("capture HAS speech\n");
180 return;
181 }
182 printf("capture NO speech\n");
183 #endif
184
ekm030249d2015-06-15 13:02:24 -0700185 capture_mangler_->ProcessChunk(audio, temp_out_buffer_);
186}
187
188void IntelligibilityEnhancer::DispatchAudio(
189 IntelligibilityEnhancer::AudioSource source,
ekmdb4fecf2015-06-22 17:49:08 -0700190 const complex<float>* in_block,
191 complex<float>* out_block) {
ekm030249d2015-06-15 13:02:24 -0700192 switch (source) {
193 case kRenderStream:
194 ProcessClearBlock(in_block, out_block);
195 break;
196 case kCaptureStream:
197 ProcessNoiseBlock(in_block, out_block);
198 break;
199 }
200}
201
202void IntelligibilityEnhancer::ProcessClearBlock(const complex<float>* in_block,
203 complex<float>* out_block) {
ekm030249d2015-06-15 13:02:24 -0700204 if (block_count_ < 2) {
205 memset(out_block, 0, freqs_ * sizeof(*out_block));
206 ++block_count_;
207 return;
208 }
209
ekmdb4fecf2015-06-22 17:49:08 -0700210 // For now, always assumes enhancement is necessary.
211 // TODO(ekmeyerson): Change to only enhance if necessary,
212 // based on experiments with different cutoffs.
ekm030249d2015-06-15 13:02:24 -0700213 if (has_voice_low_ || true) {
214 clear_variance_.Step(in_block, false);
ekm35b72fb2015-07-10 14:11:52 -0700215 const float power_target = std::accumulate(
216 clear_variance_.variance(), clear_variance_.variance() + freqs_, 0.0f);
ekm030249d2015-06-15 13:02:24 -0700217
218 if (block_count_ % analysis_rate_ == analysis_rate_ - 1) {
219 AnalyzeClearBlock(power_target);
220 ++analysis_step_;
221 if (analysis_step_ == variance_rate_) {
222 analysis_step_ = 0;
223 clear_variance_.Clear();
224 noise_variance_.Clear();
225 }
226 }
227 ++block_count_;
228 }
229
230 /* efidata(n,:) = sqrt(b(n)) * fidata(n,:) */
231 gain_applier_.Apply(in_block, out_block);
232}
233
234void IntelligibilityEnhancer::AnalyzeClearBlock(float power_target) {
235 FilterVariance(clear_variance_.variance(), filtered_clear_var_.get());
236 FilterVariance(noise_variance_.variance(), filtered_noise_var_.get());
237
ekm35b72fb2015-07-10 14:11:52 -0700238 SolveForGainsGivenLambda(kLambdaTop, start_freq_, gains_eq_.get());
239 const float power_top =
ekmdb4fecf2015-06-22 17:49:08 -0700240 DotProduct(gains_eq_.get(), filtered_clear_var_.get(), bank_size_);
ekm35b72fb2015-07-10 14:11:52 -0700241 SolveForGainsGivenLambda(kLambdaBot, start_freq_, gains_eq_.get());
242 const float power_bot =
ekmdb4fecf2015-06-22 17:49:08 -0700243 DotProduct(gains_eq_.get(), filtered_clear_var_.get(), bank_size_);
ekm35b72fb2015-07-10 14:11:52 -0700244 if (power_target >= power_bot && power_target <= power_top) {
245 SolveForLambda(power_target, power_bot, power_top);
246 UpdateErbGains();
247 } // Else experiencing variance underflow, so do nothing.
248}
ekm030249d2015-06-15 13:02:24 -0700249
ekm35b72fb2015-07-10 14:11:52 -0700250void IntelligibilityEnhancer::SolveForLambda(float power_target,
251 float power_bot,
252 float power_top) {
ekmdb4fecf2015-06-22 17:49:08 -0700253 const float kConvergeThresh = 0.001f; // TODO(ekmeyerson): Find best values
254 const int kMaxIters = 100; // for these, based on experiments.
ekm35b72fb2015-07-10 14:11:52 -0700255
256 const float reciprocal_power_target = 1.f / power_target;
257 float lambda_bot = kLambdaBot;
258 float lambda_top = kLambdaTop;
259 float power_ratio = 2.0f; // Ratio of achieved power to target power.
ekm030249d2015-06-15 13:02:24 -0700260 int iters = 0;
ekm35b72fb2015-07-10 14:11:52 -0700261 while (std::fabs(power_ratio - 1.0f) > kConvergeThresh &&
262 iters <= kMaxIters) {
263 const float lambda = lambda_bot + (lambda_top - lambda_bot) / 2.0f;
ekmdb4fecf2015-06-22 17:49:08 -0700264 SolveForGainsGivenLambda(lambda, start_freq_, gains_eq_.get());
ekm35b72fb2015-07-10 14:11:52 -0700265 const float power =
266 DotProduct(gains_eq_.get(), filtered_clear_var_.get(), bank_size_);
ekm030249d2015-06-15 13:02:24 -0700267 if (power < power_target) {
268 lambda_bot = lambda;
269 } else {
270 lambda_top = lambda;
271 }
ekm35b72fb2015-07-10 14:11:52 -0700272 power_ratio = std::fabs(power * reciprocal_power_target);
ekm030249d2015-06-15 13:02:24 -0700273 ++iters;
274 }
ekm35b72fb2015-07-10 14:11:52 -0700275}
ekm030249d2015-06-15 13:02:24 -0700276
ekm35b72fb2015-07-10 14:11:52 -0700277void IntelligibilityEnhancer::UpdateErbGains() {
ekmdb4fecf2015-06-22 17:49:08 -0700278 // (ERB gain) = filterbank' * (freq gain)
ekm030249d2015-06-15 13:02:24 -0700279 float* gains = gain_applier_.target();
280 for (int i = 0; i < freqs_; ++i) {
281 gains[i] = 0.0f;
282 for (int j = 0; j < bank_size_; ++j) {
283 gains[i] = fmaf(filter_bank_[j][i], gains_eq_[j], gains[i]);
284 }
285 }
286}
287
288void IntelligibilityEnhancer::ProcessNoiseBlock(const complex<float>* in_block,
289 complex<float>* /*out_block*/) {
290 noise_variance_.Step(in_block);
291}
292
293int IntelligibilityEnhancer::GetBankSize(int sample_rate, int erb_resolution) {
294 float freq_limit = sample_rate / 2000.0f;
ekmdb4fecf2015-06-22 17:49:08 -0700295 int erb_scale = ceilf(
296 11.17f * logf((freq_limit + 0.312f) / (freq_limit + 14.6575f)) + 43.0f);
ekm030249d2015-06-15 13:02:24 -0700297 return erb_scale * erb_resolution;
298}
299
300void IntelligibilityEnhancer::CreateErbBank() {
301 int lf = 1, rf = 4;
302
303 for (int i = 0; i < bank_size_; ++i) {
304 float abs_temp = fabsf((i + 1.0f) / static_cast<float>(erb_resolution_));
305 center_freqs_[i] = 676170.4f / (47.06538f - expf(0.08950404f * abs_temp));
306 center_freqs_[i] -= 14678.49f;
307 }
308 float last_center_freq = center_freqs_[bank_size_ - 1];
309 for (int i = 0; i < bank_size_; ++i) {
310 center_freqs_[i] *= 0.5f * sample_rate_hz_ / last_center_freq;
311 }
312
ekm030249d2015-06-15 13:02:24 -0700313 for (int i = 0; i < bank_size_; ++i) {
ekm35b72fb2015-07-10 14:11:52 -0700314 filter_bank_[i].resize(freqs_);
ekm030249d2015-06-15 13:02:24 -0700315 }
316
317 for (int i = 1; i <= bank_size_; ++i) {
318 int lll, ll, rr, rrr;
319 lll = round(center_freqs_[max(1, i - lf) - 1] * freqs_ /
ekmdb4fecf2015-06-22 17:49:08 -0700320 (0.5f * sample_rate_hz_));
321 ll =
322 round(center_freqs_[max(1, i) - 1] * freqs_ / (0.5f * sample_rate_hz_));
ekm030249d2015-06-15 13:02:24 -0700323 lll = min(freqs_, max(lll, 1)) - 1;
ekmdb4fecf2015-06-22 17:49:08 -0700324 ll = min(freqs_, max(ll, 1)) - 1;
ekm030249d2015-06-15 13:02:24 -0700325
326 rrr = round(center_freqs_[min(bank_size_, i + rf) - 1] * freqs_ /
ekmdb4fecf2015-06-22 17:49:08 -0700327 (0.5f * sample_rate_hz_));
328 rr = round(center_freqs_[min(bank_size_, i + 1) - 1] * freqs_ /
329 (0.5f * sample_rate_hz_));
ekm030249d2015-06-15 13:02:24 -0700330 rrr = min(freqs_, max(rrr, 1)) - 1;
ekmdb4fecf2015-06-22 17:49:08 -0700331 rr = min(freqs_, max(rr, 1)) - 1;
ekm030249d2015-06-15 13:02:24 -0700332
333 float step, element;
334
335 step = 1.0f / (ll - lll);
336 element = 0.0f;
337 for (int j = lll; j <= ll; ++j) {
338 filter_bank_[i - 1][j] = element;
339 element += step;
340 }
341 step = 1.0f / (rrr - rr);
342 element = 1.0f;
343 for (int j = rr; j <= rrr; ++j) {
344 filter_bank_[i - 1][j] = element;
345 element -= step;
346 }
347 for (int j = ll; j <= rr; ++j) {
348 filter_bank_[i - 1][j] = 1.0f;
349 }
350 }
351
352 float sum;
353 for (int i = 0; i < freqs_; ++i) {
354 sum = 0.0f;
355 for (int j = 0; j < bank_size_; ++j) {
356 sum += filter_bank_[j][i];
357 }
358 for (int j = 0; j < bank_size_; ++j) {
359 filter_bank_[j][i] /= sum;
360 }
361 }
362}
363
ekmdb4fecf2015-06-22 17:49:08 -0700364void IntelligibilityEnhancer::SolveForGainsGivenLambda(float lambda,
365 int start_freq,
366 float* sols) {
ekm030249d2015-06-15 13:02:24 -0700367 bool quadratic = (kConfigRho < 1.0f);
368 const float* var_x0 = filtered_clear_var_.get();
369 const float* var_n0 = filtered_noise_var_.get();
370
371 for (int n = 0; n < start_freq; ++n) {
372 sols[n] = 1.0f;
373 }
ekmdb4fecf2015-06-22 17:49:08 -0700374
375 // Analytic solution for optimal gains. See paper for derivation.
ekm030249d2015-06-15 13:02:24 -0700376 for (int n = start_freq - 1; n < bank_size_; ++n) {
377 float alpha0, beta0, gamma0;
378 gamma0 = 0.5f * rho_[n] * var_x0[n] * var_n0[n] +
ekmdb4fecf2015-06-22 17:49:08 -0700379 lambda * var_x0[n] * var_n0[n] * var_n0[n];
ekm030249d2015-06-15 13:02:24 -0700380 beta0 = lambda * var_x0[n] * (2 - rho_[n]) * var_x0[n] * var_n0[n];
381 if (quadratic) {
382 alpha0 = lambda * var_x0[n] * (1 - rho_[n]) * var_x0[n] * var_x0[n];
ekmdb4fecf2015-06-22 17:49:08 -0700383 sols[n] =
384 (-beta0 - sqrtf(beta0 * beta0 - 4 * alpha0 * gamma0)) / (2 * alpha0);
ekm030249d2015-06-15 13:02:24 -0700385 } else {
386 sols[n] = -gamma0 / beta0;
387 }
388 sols[n] = fmax(0, sols[n]);
389 }
390}
391
392void IntelligibilityEnhancer::FilterVariance(const float* var, float* result) {
393 for (int i = 0; i < bank_size_; ++i) {
ekm35b72fb2015-07-10 14:11:52 -0700394 result[i] = DotProduct(filter_bank_[i].data(), var, freqs_);
ekm030249d2015-06-15 13:02:24 -0700395 }
396}
397
ekmdb4fecf2015-06-22 17:49:08 -0700398float IntelligibilityEnhancer::DotProduct(const float* a,
399 const float* b,
400 int length) {
ekm030249d2015-06-15 13:02:24 -0700401 float ret = 0.0f;
402
403 for (int i = 0; i < length; ++i) {
404 ret = fmaf(a[i], b[i], ret);
405 }
406 return ret;
407}
408
409} // namespace webrtc