blob: 3029e21619a981917f377afb78075bfd66def952 [file] [log] [blame]
ekm030249d2015-06-15 13:02:24 -07001/*
2 * Copyright (c) 2014 The WebRTC project authors. All Rights Reserved.
3 *
4 * Use of this source code is governed by a BSD-style license
5 * that can be found in the LICENSE file in the root of the source
6 * tree. An additional intellectual property rights grant can be found
7 * in the file PATENTS. All contributing project authors may
8 * be found in the AUTHORS file in the root of the source tree.
9 */
10
ekmdb4fecf2015-06-22 17:49:08 -070011//
12// Implements core class for intelligibility enhancer.
13//
14// Details of the model and algorithm can be found in the original paper:
15// http://ieeexplore.ieee.org/stamp/stamp.jsp?tp=&arnumber=6882788
16//
17
ekm030249d2015-06-15 13:02:24 -070018#include "webrtc/modules/audio_processing/intelligibility/intelligibility_enhancer.h"
19
20#include <cmath>
21#include <cstdlib>
22
23#include <algorithm>
ekmdb4fecf2015-06-22 17:49:08 -070024#include <numeric>
ekm030249d2015-06-15 13:02:24 -070025
26#include "webrtc/base/checks.h"
27#include "webrtc/common_audio/vad/include/webrtc_vad.h"
28#include "webrtc/common_audio/window_generator.h"
29
30using std::complex;
31using std::max;
32using std::min;
33
34namespace webrtc {
35
36const int IntelligibilityEnhancer::kErbResolution = 2;
37const int IntelligibilityEnhancer::kWindowSizeMs = 2;
ekmdb4fecf2015-06-22 17:49:08 -070038const int IntelligibilityEnhancer::kChunkSizeMs = 10; // Size provided by APM.
ekm030249d2015-06-15 13:02:24 -070039const int IntelligibilityEnhancer::kAnalyzeRate = 800;
40const int IntelligibilityEnhancer::kVarianceRate = 2;
41const float IntelligibilityEnhancer::kClipFreq = 200.0f;
42const float IntelligibilityEnhancer::kConfigRho = 0.02f;
43const float IntelligibilityEnhancer::kKbdAlpha = 1.5f;
ekmdb4fecf2015-06-22 17:49:08 -070044
45// To disable gain update smoothing, set gain limit to be VERY high.
46// TODO(ekmeyerson): Add option to disable gain smoothing altogether
47// to avoid the extra computation.
ekm030249d2015-06-15 13:02:24 -070048const float IntelligibilityEnhancer::kGainChangeLimit = 0.0125f;
49
50using VarianceType = intelligibility::VarianceArray::StepType;
51
52IntelligibilityEnhancer::TransformCallback::TransformCallback(
53 IntelligibilityEnhancer* parent,
54 IntelligibilityEnhancer::AudioSource source)
ekmdb4fecf2015-06-22 17:49:08 -070055 : parent_(parent), source_(source) {
56}
ekm030249d2015-06-15 13:02:24 -070057
58void IntelligibilityEnhancer::TransformCallback::ProcessAudioBlock(
59 const complex<float>* const* in_block,
ekmdb4fecf2015-06-22 17:49:08 -070060 int in_channels,
61 int frames,
62 int /* out_channels */,
ekm030249d2015-06-15 13:02:24 -070063 complex<float>* const* out_block) {
64 DCHECK_EQ(parent_->freqs_, frames);
65 for (int i = 0; i < in_channels; ++i) {
66 parent_->DispatchAudio(source_, in_block[i], out_block[i]);
67 }
68}
69
70IntelligibilityEnhancer::IntelligibilityEnhancer(int erb_resolution,
71 int sample_rate_hz,
72 int channels,
ekmdb4fecf2015-06-22 17:49:08 -070073 int cv_type,
74 float cv_alpha,
ekm030249d2015-06-15 13:02:24 -070075 int cv_win,
76 int analysis_rate,
77 int variance_rate,
78 float gain_limit)
ekmdb4fecf2015-06-22 17:49:08 -070079 : freqs_(RealFourier::ComplexLength(
80 RealFourier::FftOrder(sample_rate_hz * kWindowSizeMs / 1000))),
ekm030249d2015-06-15 13:02:24 -070081 window_size_(1 << RealFourier::FftOrder(freqs_)),
82 chunk_length_(sample_rate_hz * kChunkSizeMs / 1000),
83 bank_size_(GetBankSize(sample_rate_hz, erb_resolution)),
84 sample_rate_hz_(sample_rate_hz),
85 erb_resolution_(erb_resolution),
86 channels_(channels),
87 analysis_rate_(analysis_rate),
88 variance_rate_(variance_rate),
ekmdb4fecf2015-06-22 17:49:08 -070089 clear_variance_(freqs_,
90 static_cast<VarianceType>(cv_type),
91 cv_win,
ekm030249d2015-06-15 13:02:24 -070092 cv_alpha),
93 noise_variance_(freqs_, VarianceType::kStepInfinite, 475, 0.01f),
94 filtered_clear_var_(new float[bank_size_]),
95 filtered_noise_var_(new float[bank_size_]),
96 filter_bank_(nullptr),
97 center_freqs_(new float[bank_size_]),
98 rho_(new float[bank_size_]),
99 gains_eq_(new float[bank_size_]),
100 gain_applier_(freqs_, gain_limit),
101 temp_out_buffer_(nullptr),
ekmdb4fecf2015-06-22 17:49:08 -0700102 input_audio_(new float* [channels]),
ekm030249d2015-06-15 13:02:24 -0700103 kbd_window_(new float[window_size_]),
104 render_callback_(this, AudioSource::kRenderStream),
105 capture_callback_(this, AudioSource::kCaptureStream),
106 block_count_(0),
107 analysis_step_(0),
ekmdb4fecf2015-06-22 17:49:08 -0700108 vad_high_(WebRtcVad_Create()),
109 vad_low_(WebRtcVad_Create()),
ekm030249d2015-06-15 13:02:24 -0700110 vad_tmp_buffer_(new int16_t[chunk_length_]) {
111 DCHECK_LE(kConfigRho, 1.0f);
112
113 CreateErbBank();
114
ekm030249d2015-06-15 13:02:24 -0700115 WebRtcVad_Init(vad_high_);
ekmdb4fecf2015-06-22 17:49:08 -0700116 WebRtcVad_set_mode(vad_high_, 0); // High likelihood of speech.
ekm030249d2015-06-15 13:02:24 -0700117 WebRtcVad_Init(vad_low_);
ekmdb4fecf2015-06-22 17:49:08 -0700118 WebRtcVad_set_mode(vad_low_, 3); // Low likelihood of speech.
ekm030249d2015-06-15 13:02:24 -0700119
ekmdb4fecf2015-06-22 17:49:08 -0700120 temp_out_buffer_ = static_cast<float**>(
121 malloc(sizeof(*temp_out_buffer_) * channels_ +
122 sizeof(**temp_out_buffer_) * chunk_length_ * channels_));
ekm030249d2015-06-15 13:02:24 -0700123 for (int i = 0; i < channels_; ++i) {
ekmdb4fecf2015-06-22 17:49:08 -0700124 temp_out_buffer_[i] =
125 reinterpret_cast<float*>(temp_out_buffer_ + channels_) +
126 chunk_length_ * i;
ekm030249d2015-06-15 13:02:24 -0700127 }
128
ekmdb4fecf2015-06-22 17:49:08 -0700129 // Assumes all rho equal.
ekm030249d2015-06-15 13:02:24 -0700130 for (int i = 0; i < bank_size_; ++i) {
131 rho_[i] = kConfigRho * kConfigRho;
132 }
133
134 float freqs_khz = kClipFreq / 1000.0f;
ekmdb4fecf2015-06-22 17:49:08 -0700135 int erb_index = static_cast<int>(ceilf(
136 11.17f * logf((freqs_khz + 0.312f) / (freqs_khz + 14.6575f)) + 43.0f));
ekm030249d2015-06-15 13:02:24 -0700137 start_freq_ = max(1, erb_index * kErbResolution);
138
139 WindowGenerator::KaiserBesselDerived(kKbdAlpha, window_size_,
140 kbd_window_.get());
ekmdb4fecf2015-06-22 17:49:08 -0700141 render_mangler_.reset(new LappedTransform(
142 channels_, channels_, chunk_length_, kbd_window_.get(), window_size_,
143 window_size_ / 2, &render_callback_));
144 capture_mangler_.reset(new LappedTransform(
145 channels_, channels_, chunk_length_, kbd_window_.get(), window_size_,
146 window_size_ / 2, &capture_callback_));
ekm030249d2015-06-15 13:02:24 -0700147}
148
149IntelligibilityEnhancer::~IntelligibilityEnhancer() {
150 WebRtcVad_Free(vad_low_);
151 WebRtcVad_Free(vad_high_);
152 free(filter_bank_);
153}
154
155void IntelligibilityEnhancer::ProcessRenderAudio(float* const* audio) {
156 for (int i = 0; i < chunk_length_; ++i) {
157 vad_tmp_buffer_[i] = (int16_t)audio[0][i];
158 }
159 has_voice_low_ = WebRtcVad_Process(vad_low_, sample_rate_hz_,
160 vad_tmp_buffer_.get(), chunk_length_) == 1;
161
ekmdb4fecf2015-06-22 17:49:08 -0700162 // Process and enhance chunk of |audio|
ekm030249d2015-06-15 13:02:24 -0700163 render_mangler_->ProcessChunk(audio, temp_out_buffer_);
ekmdb4fecf2015-06-22 17:49:08 -0700164
ekm030249d2015-06-15 13:02:24 -0700165 for (int i = 0; i < channels_; ++i) {
166 memcpy(audio[i], temp_out_buffer_[i],
167 chunk_length_ * sizeof(**temp_out_buffer_));
168 }
169}
170
171void IntelligibilityEnhancer::ProcessCaptureAudio(float* const* audio) {
172 for (int i = 0; i < chunk_length_; ++i) {
173 vad_tmp_buffer_[i] = (int16_t)audio[0][i];
174 }
ekmdb4fecf2015-06-22 17:49:08 -0700175 // TODO(bercic): The VAD was always detecting voice in the noise stream,
176 // no matter what the aggressiveness, so it was temporarily disabled here.
ekm030249d2015-06-15 13:02:24 -0700177
ekmdb4fecf2015-06-22 17:49:08 -0700178 #if 0
179 if (WebRtcVad_Process(vad_high_, sample_rate_hz_, vad_tmp_buffer_.get(),
180 chunk_length_) == 1) {
181 printf("capture HAS speech\n");
182 return;
183 }
184 printf("capture NO speech\n");
185 #endif
186
ekm030249d2015-06-15 13:02:24 -0700187 capture_mangler_->ProcessChunk(audio, temp_out_buffer_);
188}
189
190void IntelligibilityEnhancer::DispatchAudio(
191 IntelligibilityEnhancer::AudioSource source,
ekmdb4fecf2015-06-22 17:49:08 -0700192 const complex<float>* in_block,
193 complex<float>* out_block) {
ekm030249d2015-06-15 13:02:24 -0700194 switch (source) {
195 case kRenderStream:
196 ProcessClearBlock(in_block, out_block);
197 break;
198 case kCaptureStream:
199 ProcessNoiseBlock(in_block, out_block);
200 break;
201 }
202}
203
204void IntelligibilityEnhancer::ProcessClearBlock(const complex<float>* in_block,
205 complex<float>* out_block) {
206 float power_target;
207
208 if (block_count_ < 2) {
209 memset(out_block, 0, freqs_ * sizeof(*out_block));
210 ++block_count_;
211 return;
212 }
213
ekmdb4fecf2015-06-22 17:49:08 -0700214 // For now, always assumes enhancement is necessary.
215 // TODO(ekmeyerson): Change to only enhance if necessary,
216 // based on experiments with different cutoffs.
ekm030249d2015-06-15 13:02:24 -0700217 if (has_voice_low_ || true) {
218 clear_variance_.Step(in_block, false);
219 power_target = std::accumulate(clear_variance_.variance(),
220 clear_variance_.variance() + freqs_, 0.0f);
221
222 if (block_count_ % analysis_rate_ == analysis_rate_ - 1) {
223 AnalyzeClearBlock(power_target);
224 ++analysis_step_;
225 if (analysis_step_ == variance_rate_) {
226 analysis_step_ = 0;
227 clear_variance_.Clear();
228 noise_variance_.Clear();
229 }
230 }
231 ++block_count_;
232 }
233
234 /* efidata(n,:) = sqrt(b(n)) * fidata(n,:) */
235 gain_applier_.Apply(in_block, out_block);
236}
237
238void IntelligibilityEnhancer::AnalyzeClearBlock(float power_target) {
239 FilterVariance(clear_variance_.variance(), filtered_clear_var_.get());
240 FilterVariance(noise_variance_.variance(), filtered_noise_var_.get());
241
ekmdb4fecf2015-06-22 17:49:08 -0700242 // Bisection search for optimal |lambda|
ekm030249d2015-06-15 13:02:24 -0700243
244 float lambda_bot = -1.0f, lambda_top = -10e-18f, lambda;
245 float power_bot, power_top, power;
ekmdb4fecf2015-06-22 17:49:08 -0700246 SolveForGainsGivenLambda(lambda_top, start_freq_, gains_eq_.get());
247 power_top =
248 DotProduct(gains_eq_.get(), filtered_clear_var_.get(), bank_size_);
249 SolveForGainsGivenLambda(lambda_bot, start_freq_, gains_eq_.get());
250 power_bot =
251 DotProduct(gains_eq_.get(), filtered_clear_var_.get(), bank_size_);
ekm030249d2015-06-15 13:02:24 -0700252 DCHECK(power_target >= power_bot && power_target <= power_top);
253
ekmdb4fecf2015-06-22 17:49:08 -0700254 float power_ratio = 2.0f; // Ratio of achieved power to target power.
255 const float kConvergeThresh = 0.001f; // TODO(ekmeyerson): Find best values
256 const int kMaxIters = 100; // for these, based on experiments.
ekm030249d2015-06-15 13:02:24 -0700257 int iters = 0;
ekmdb4fecf2015-06-22 17:49:08 -0700258 while (fabs(power_ratio - 1.0f) > kConvergeThresh && iters <= kMaxIters) {
ekm030249d2015-06-15 13:02:24 -0700259 lambda = lambda_bot + (lambda_top - lambda_bot) / 2.0f;
ekmdb4fecf2015-06-22 17:49:08 -0700260 SolveForGainsGivenLambda(lambda, start_freq_, gains_eq_.get());
ekm030249d2015-06-15 13:02:24 -0700261 power = DotProduct(gains_eq_.get(), filtered_clear_var_.get(), bank_size_);
262 if (power < power_target) {
263 lambda_bot = lambda;
264 } else {
265 lambda_top = lambda;
266 }
267 power_ratio = fabs(power / power_target);
268 ++iters;
269 }
270
ekmdb4fecf2015-06-22 17:49:08 -0700271 // (ERB gain) = filterbank' * (freq gain)
ekm030249d2015-06-15 13:02:24 -0700272 float* gains = gain_applier_.target();
273 for (int i = 0; i < freqs_; ++i) {
274 gains[i] = 0.0f;
275 for (int j = 0; j < bank_size_; ++j) {
276 gains[i] = fmaf(filter_bank_[j][i], gains_eq_[j], gains[i]);
277 }
278 }
279}
280
281void IntelligibilityEnhancer::ProcessNoiseBlock(const complex<float>* in_block,
282 complex<float>* /*out_block*/) {
283 noise_variance_.Step(in_block);
284}
285
286int IntelligibilityEnhancer::GetBankSize(int sample_rate, int erb_resolution) {
287 float freq_limit = sample_rate / 2000.0f;
ekmdb4fecf2015-06-22 17:49:08 -0700288 int erb_scale = ceilf(
289 11.17f * logf((freq_limit + 0.312f) / (freq_limit + 14.6575f)) + 43.0f);
ekm030249d2015-06-15 13:02:24 -0700290 return erb_scale * erb_resolution;
291}
292
293void IntelligibilityEnhancer::CreateErbBank() {
294 int lf = 1, rf = 4;
295
296 for (int i = 0; i < bank_size_; ++i) {
297 float abs_temp = fabsf((i + 1.0f) / static_cast<float>(erb_resolution_));
298 center_freqs_[i] = 676170.4f / (47.06538f - expf(0.08950404f * abs_temp));
299 center_freqs_[i] -= 14678.49f;
300 }
301 float last_center_freq = center_freqs_[bank_size_ - 1];
302 for (int i = 0; i < bank_size_; ++i) {
303 center_freqs_[i] *= 0.5f * sample_rate_hz_ / last_center_freq;
304 }
305
ekmdb4fecf2015-06-22 17:49:08 -0700306 filter_bank_ = static_cast<float**>(
307 malloc(sizeof(*filter_bank_) * bank_size_ +
308 sizeof(**filter_bank_) * freqs_ * bank_size_));
ekm030249d2015-06-15 13:02:24 -0700309 for (int i = 0; i < bank_size_; ++i) {
ekmdb4fecf2015-06-22 17:49:08 -0700310 filter_bank_[i] =
311 reinterpret_cast<float*>(filter_bank_ + bank_size_) + freqs_ * i;
ekm030249d2015-06-15 13:02:24 -0700312 }
313
314 for (int i = 1; i <= bank_size_; ++i) {
315 int lll, ll, rr, rrr;
316 lll = round(center_freqs_[max(1, i - lf) - 1] * freqs_ /
ekmdb4fecf2015-06-22 17:49:08 -0700317 (0.5f * sample_rate_hz_));
318 ll =
319 round(center_freqs_[max(1, i) - 1] * freqs_ / (0.5f * sample_rate_hz_));
ekm030249d2015-06-15 13:02:24 -0700320 lll = min(freqs_, max(lll, 1)) - 1;
ekmdb4fecf2015-06-22 17:49:08 -0700321 ll = min(freqs_, max(ll, 1)) - 1;
ekm030249d2015-06-15 13:02:24 -0700322
323 rrr = round(center_freqs_[min(bank_size_, i + rf) - 1] * freqs_ /
ekmdb4fecf2015-06-22 17:49:08 -0700324 (0.5f * sample_rate_hz_));
325 rr = round(center_freqs_[min(bank_size_, i + 1) - 1] * freqs_ /
326 (0.5f * sample_rate_hz_));
ekm030249d2015-06-15 13:02:24 -0700327 rrr = min(freqs_, max(rrr, 1)) - 1;
ekmdb4fecf2015-06-22 17:49:08 -0700328 rr = min(freqs_, max(rr, 1)) - 1;
ekm030249d2015-06-15 13:02:24 -0700329
330 float step, element;
331
332 step = 1.0f / (ll - lll);
333 element = 0.0f;
334 for (int j = lll; j <= ll; ++j) {
335 filter_bank_[i - 1][j] = element;
336 element += step;
337 }
338 step = 1.0f / (rrr - rr);
339 element = 1.0f;
340 for (int j = rr; j <= rrr; ++j) {
341 filter_bank_[i - 1][j] = element;
342 element -= step;
343 }
344 for (int j = ll; j <= rr; ++j) {
345 filter_bank_[i - 1][j] = 1.0f;
346 }
347 }
348
349 float sum;
350 for (int i = 0; i < freqs_; ++i) {
351 sum = 0.0f;
352 for (int j = 0; j < bank_size_; ++j) {
353 sum += filter_bank_[j][i];
354 }
355 for (int j = 0; j < bank_size_; ++j) {
356 filter_bank_[j][i] /= sum;
357 }
358 }
359}
360
ekmdb4fecf2015-06-22 17:49:08 -0700361void IntelligibilityEnhancer::SolveForGainsGivenLambda(float lambda,
362 int start_freq,
363 float* sols) {
ekm030249d2015-06-15 13:02:24 -0700364 bool quadratic = (kConfigRho < 1.0f);
365 const float* var_x0 = filtered_clear_var_.get();
366 const float* var_n0 = filtered_noise_var_.get();
367
368 for (int n = 0; n < start_freq; ++n) {
369 sols[n] = 1.0f;
370 }
ekmdb4fecf2015-06-22 17:49:08 -0700371
372 // Analytic solution for optimal gains. See paper for derivation.
ekm030249d2015-06-15 13:02:24 -0700373 for (int n = start_freq - 1; n < bank_size_; ++n) {
374 float alpha0, beta0, gamma0;
375 gamma0 = 0.5f * rho_[n] * var_x0[n] * var_n0[n] +
ekmdb4fecf2015-06-22 17:49:08 -0700376 lambda * var_x0[n] * var_n0[n] * var_n0[n];
ekm030249d2015-06-15 13:02:24 -0700377 beta0 = lambda * var_x0[n] * (2 - rho_[n]) * var_x0[n] * var_n0[n];
378 if (quadratic) {
379 alpha0 = lambda * var_x0[n] * (1 - rho_[n]) * var_x0[n] * var_x0[n];
ekmdb4fecf2015-06-22 17:49:08 -0700380 sols[n] =
381 (-beta0 - sqrtf(beta0 * beta0 - 4 * alpha0 * gamma0)) / (2 * alpha0);
ekm030249d2015-06-15 13:02:24 -0700382 } else {
383 sols[n] = -gamma0 / beta0;
384 }
385 sols[n] = fmax(0, sols[n]);
386 }
387}
388
389void IntelligibilityEnhancer::FilterVariance(const float* var, float* result) {
390 for (int i = 0; i < bank_size_; ++i) {
391 result[i] = DotProduct(filter_bank_[i], var, freqs_);
392 }
393}
394
ekmdb4fecf2015-06-22 17:49:08 -0700395float IntelligibilityEnhancer::DotProduct(const float* a,
396 const float* b,
397 int length) {
ekm030249d2015-06-15 13:02:24 -0700398 float ret = 0.0f;
399
400 for (int i = 0; i < length; ++i) {
401 ret = fmaf(a[i], b[i], ret);
402 }
403 return ret;
404}
405
406} // namespace webrtc