blob: 8778c494265797a4ff88b620e7804b944af53f5d [file] [log] [blame]
Alex Loiko4ed47d02018-04-04 15:05:57 +02001/*
2 * Copyright (c) 2016 The WebRTC project authors. All Rights Reserved.
3 *
4 * Use of this source code is governed by a BSD-style license
5 * that can be found in the LICENSE file in the root of the source
6 * tree. An additional intellectual property rights grant can be found
7 * in the file PATENTS. All contributing project authors may
8 * be found in the AUTHORS file in the root of the source tree.
9 */
10
11#include "modules/audio_processing/agc2/signal_classifier.h"
12
13#include <algorithm>
14#include <numeric>
15#include <vector>
16
17#include "api/array_view.h"
18#include "modules/audio_processing/agc2/down_sampler.h"
19#include "modules/audio_processing/agc2/noise_spectrum_estimator.h"
20#include "modules/audio_processing/logging/apm_data_dumper.h"
Yves Gerey988cc082018-10-23 12:03:01 +020021#include "rtc_base/checks.h"
Alex Loiko4ed47d02018-04-04 15:05:57 +020022
23namespace webrtc {
24namespace {
25
26void RemoveDcLevel(rtc::ArrayView<float> x) {
27 RTC_DCHECK_LT(0, x.size());
28 float mean = std::accumulate(x.data(), x.data() + x.size(), 0.f);
29 mean /= x.size();
30
31 for (float& v : x) {
32 v -= mean;
33 }
34}
35
36void PowerSpectrum(const OouraFft* ooura_fft,
37 rtc::ArrayView<const float> x,
38 rtc::ArrayView<float> spectrum) {
39 RTC_DCHECK_EQ(65, spectrum.size());
40 RTC_DCHECK_EQ(128, x.size());
41 float X[128];
42 std::copy(x.data(), x.data() + x.size(), X);
43 ooura_fft->Fft(X);
44
45 float* X_p = X;
46 RTC_DCHECK_EQ(X_p, &X[0]);
47 spectrum[0] = (*X_p) * (*X_p);
48 ++X_p;
49 RTC_DCHECK_EQ(X_p, &X[1]);
50 spectrum[64] = (*X_p) * (*X_p);
51 for (int k = 1; k < 64; ++k) {
52 ++X_p;
53 RTC_DCHECK_EQ(X_p, &X[2 * k]);
54 spectrum[k] = (*X_p) * (*X_p);
55 ++X_p;
56 RTC_DCHECK_EQ(X_p, &X[2 * k + 1]);
57 spectrum[k] += (*X_p) * (*X_p);
58 }
59}
60
61webrtc::SignalClassifier::SignalType ClassifySignal(
62 rtc::ArrayView<const float> signal_spectrum,
63 rtc::ArrayView<const float> noise_spectrum,
64 ApmDataDumper* data_dumper) {
65 int num_stationary_bands = 0;
66 int num_highly_nonstationary_bands = 0;
67
68 // Detect stationary and highly nonstationary bands.
69 for (size_t k = 1; k < 40; k++) {
70 if (signal_spectrum[k] < 3 * noise_spectrum[k] &&
71 signal_spectrum[k] * 3 > noise_spectrum[k]) {
72 ++num_stationary_bands;
73 } else if (signal_spectrum[k] > 9 * noise_spectrum[k]) {
74 ++num_highly_nonstationary_bands;
75 }
76 }
77
78 data_dumper->DumpRaw("lc_num_stationary_bands", 1, &num_stationary_bands);
79 data_dumper->DumpRaw("lc_num_highly_nonstationary_bands", 1,
80 &num_highly_nonstationary_bands);
81
82 // Use the detected number of bands to classify the overall signal
83 // stationarity.
84 if (num_stationary_bands > 15) {
85 return SignalClassifier::SignalType::kStationary;
86 } else {
87 return SignalClassifier::SignalType::kNonStationary;
88 }
89}
90
91} // namespace
92
93SignalClassifier::FrameExtender::FrameExtender(size_t frame_size,
94 size_t extended_frame_size)
95 : x_old_(extended_frame_size - frame_size, 0.f) {}
96
97SignalClassifier::FrameExtender::~FrameExtender() = default;
98
99void SignalClassifier::FrameExtender::ExtendFrame(
100 rtc::ArrayView<const float> x,
101 rtc::ArrayView<float> x_extended) {
102 RTC_DCHECK_EQ(x_old_.size() + x.size(), x_extended.size());
103 std::copy(x_old_.data(), x_old_.data() + x_old_.size(), x_extended.data());
104 std::copy(x.data(), x.data() + x.size(), x_extended.data() + x_old_.size());
105 std::copy(x_extended.data() + x_extended.size() - x_old_.size(),
106 x_extended.data() + x_extended.size(), x_old_.data());
107}
108
109SignalClassifier::SignalClassifier(ApmDataDumper* data_dumper)
110 : data_dumper_(data_dumper),
111 down_sampler_(data_dumper_),
112 noise_spectrum_estimator_(data_dumper_) {
113 Initialize(48000);
114}
115SignalClassifier::~SignalClassifier() {}
116
117void SignalClassifier::Initialize(int sample_rate_hz) {
118 down_sampler_.Initialize(sample_rate_hz);
119 noise_spectrum_estimator_.Initialize();
120 frame_extender_.reset(new FrameExtender(80, 128));
121 sample_rate_hz_ = sample_rate_hz;
122 initialization_frames_left_ = 2;
123 consistent_classification_counter_ = 3;
124 last_signal_type_ = SignalClassifier::SignalType::kNonStationary;
125}
126
127SignalClassifier::SignalType SignalClassifier::Analyze(
128 rtc::ArrayView<const float> signal) {
129 RTC_DCHECK_EQ(signal.size(), sample_rate_hz_ / 100);
130
131 // Compute the signal power spectrum.
132 float downsampled_frame[80];
133 down_sampler_.DownSample(signal, downsampled_frame);
134 float extended_frame[128];
135 frame_extender_->ExtendFrame(downsampled_frame, extended_frame);
136 RemoveDcLevel(extended_frame);
137 float signal_spectrum[65];
138 PowerSpectrum(&ooura_fft_, extended_frame, signal_spectrum);
139
140 // Classify the signal based on the estimate of the noise spectrum and the
141 // signal spectrum estimate.
142 const SignalType signal_type = ClassifySignal(
143 signal_spectrum, noise_spectrum_estimator_.GetNoiseSpectrum(),
144 data_dumper_);
145
146 // Update the noise spectrum based on the signal spectrum.
147 noise_spectrum_estimator_.Update(signal_spectrum,
148 initialization_frames_left_ > 0);
149
150 // Update the number of frames until a reliable signal spectrum is achieved.
151 initialization_frames_left_ = std::max(0, initialization_frames_left_ - 1);
152
153 if (last_signal_type_ == signal_type) {
154 consistent_classification_counter_ =
155 std::max(0, consistent_classification_counter_ - 1);
156 } else {
157 last_signal_type_ = signal_type;
158 consistent_classification_counter_ = 3;
159 }
160
161 if (consistent_classification_counter_ > 0) {
162 return SignalClassifier::SignalType::kNonStationary;
163 }
164 return signal_type;
165}
166
167} // namespace webrtc