blob: d2d5917387a501e2bb91d36d4399778817494a1b [file] [log] [blame]
peahca4cac72016-06-29 15:26:12 -07001/*
2 * Copyright (c) 2016 The WebRTC project authors. All Rights Reserved.
3 *
4 * Use of this source code is governed by a BSD-style license
5 * that can be found in the LICENSE file in the root of the source
6 * tree. An additional intellectual property rights grant can be found
7 * in the file PATENTS. All contributing project authors may
8 * be found in the AUTHORS file in the root of the source tree.
9 */
10
Mirko Bonadei92ea95e2017-09-15 06:47:31 +020011#include "modules/audio_processing/level_controller/signal_classifier.h"
peahca4cac72016-06-29 15:26:12 -070012
13#include <algorithm>
14#include <numeric>
15#include <vector>
16
Mirko Bonadei92ea95e2017-09-15 06:47:31 +020017#include "api/array_view.h"
18#include "modules/audio_processing/audio_buffer.h"
19#include "modules/audio_processing/level_controller/down_sampler.h"
20#include "modules/audio_processing/level_controller/noise_spectrum_estimator.h"
21#include "modules/audio_processing/logging/apm_data_dumper.h"
22#include "rtc_base/constructormagic.h"
peahca4cac72016-06-29 15:26:12 -070023
24namespace webrtc {
25namespace {
26
27void RemoveDcLevel(rtc::ArrayView<float> x) {
kwibergaf476c72016-11-28 15:21:39 -080028 RTC_DCHECK_LT(0, x.size());
peahca4cac72016-06-29 15:26:12 -070029 float mean = std::accumulate(x.data(), x.data() + x.size(), 0.f);
30 mean /= x.size();
31
32 for (float& v : x) {
33 v -= mean;
34 }
35}
36
peah81b92912016-10-06 06:46:20 -070037void PowerSpectrum(const OouraFft* ooura_fft,
38 rtc::ArrayView<const float> x,
peahca4cac72016-06-29 15:26:12 -070039 rtc::ArrayView<float> spectrum) {
kwibergaf476c72016-11-28 15:21:39 -080040 RTC_DCHECK_EQ(65, spectrum.size());
41 RTC_DCHECK_EQ(128, x.size());
peahca4cac72016-06-29 15:26:12 -070042 float X[128];
43 std::copy(x.data(), x.data() + x.size(), X);
peah81b92912016-10-06 06:46:20 -070044 ooura_fft->Fft(X);
peahca4cac72016-06-29 15:26:12 -070045
46 float* X_p = X;
47 RTC_DCHECK_EQ(X_p, &X[0]);
48 spectrum[0] = (*X_p) * (*X_p);
49 ++X_p;
50 RTC_DCHECK_EQ(X_p, &X[1]);
51 spectrum[64] = (*X_p) * (*X_p);
52 for (int k = 1; k < 64; ++k) {
53 ++X_p;
54 RTC_DCHECK_EQ(X_p, &X[2 * k]);
55 spectrum[k] = (*X_p) * (*X_p);
56 ++X_p;
57 RTC_DCHECK_EQ(X_p, &X[2 * k + 1]);
58 spectrum[k] += (*X_p) * (*X_p);
59 }
60}
61
62webrtc::SignalClassifier::SignalType ClassifySignal(
63 rtc::ArrayView<const float> signal_spectrum,
64 rtc::ArrayView<const float> noise_spectrum,
65 ApmDataDumper* data_dumper) {
66 int num_stationary_bands = 0;
67 int num_highly_nonstationary_bands = 0;
68
69 // Detect stationary and highly nonstationary bands.
70 for (size_t k = 1; k < 40; k++) {
71 if (signal_spectrum[k] < 3 * noise_spectrum[k] &&
72 signal_spectrum[k] * 3 > noise_spectrum[k]) {
73 ++num_stationary_bands;
74 } else if (signal_spectrum[k] > 9 * noise_spectrum[k]) {
75 ++num_highly_nonstationary_bands;
76 }
77 }
78
79 data_dumper->DumpRaw("lc_num_stationary_bands", 1, &num_stationary_bands);
80 data_dumper->DumpRaw("lc_num_highly_nonstationary_bands", 1,
81 &num_highly_nonstationary_bands);
82
83 // Use the detected number of bands to classify the overall signal
84 // stationarity.
85 if (num_stationary_bands > 15) {
86 return SignalClassifier::SignalType::kStationary;
87 } else if (num_highly_nonstationary_bands > 15) {
88 return SignalClassifier::SignalType::kHighlyNonStationary;
89 } else {
90 return SignalClassifier::SignalType::kNonStationary;
91 }
92}
93
94} // namespace
95
kwiberg83ffe452016-08-29 14:46:07 -070096SignalClassifier::FrameExtender::FrameExtender(size_t frame_size,
97 size_t extended_frame_size)
98 : x_old_(extended_frame_size - frame_size, 0.f) {}
99
100SignalClassifier::FrameExtender::~FrameExtender() = default;
101
peahca4cac72016-06-29 15:26:12 -0700102void SignalClassifier::FrameExtender::ExtendFrame(
103 rtc::ArrayView<const float> x,
104 rtc::ArrayView<float> x_extended) {
105 RTC_DCHECK_EQ(x_old_.size() + x.size(), x_extended.size());
106 std::copy(x_old_.data(), x_old_.data() + x_old_.size(), x_extended.data());
107 std::copy(x.data(), x.data() + x.size(), x_extended.data() + x_old_.size());
108 std::copy(x_extended.data() + x_extended.size() - x_old_.size(),
109 x_extended.data() + x_extended.size(), x_old_.data());
110}
111
112SignalClassifier::SignalClassifier(ApmDataDumper* data_dumper)
113 : data_dumper_(data_dumper),
114 down_sampler_(data_dumper_),
115 noise_spectrum_estimator_(data_dumper_) {
116 Initialize(AudioProcessing::kSampleRate48kHz);
117}
118SignalClassifier::~SignalClassifier() {}
119
120void SignalClassifier::Initialize(int sample_rate_hz) {
peahca4cac72016-06-29 15:26:12 -0700121 down_sampler_.Initialize(sample_rate_hz);
122 noise_spectrum_estimator_.Initialize();
123 frame_extender_.reset(new FrameExtender(80, 128));
124 sample_rate_hz_ = sample_rate_hz;
125 initialization_frames_left_ = 2;
126 consistent_classification_counter_ = 3;
127 last_signal_type_ = SignalClassifier::SignalType::kNonStationary;
128}
129
130void SignalClassifier::Analyze(const AudioBuffer& audio,
131 SignalType* signal_type) {
kwiberg352444f2016-11-28 15:58:53 -0800132 RTC_DCHECK_EQ(audio.num_frames(), sample_rate_hz_ / 100);
peahca4cac72016-06-29 15:26:12 -0700133
134 // Compute the signal power spectrum.
135 float downsampled_frame[80];
136 down_sampler_.DownSample(rtc::ArrayView<const float>(
137 audio.channels_const_f()[0], audio.num_frames()),
138 downsampled_frame);
139 float extended_frame[128];
140 frame_extender_->ExtendFrame(downsampled_frame, extended_frame);
141 RemoveDcLevel(extended_frame);
142 float signal_spectrum[65];
peah81b92912016-10-06 06:46:20 -0700143 PowerSpectrum(&ooura_fft_, extended_frame, signal_spectrum);
peahca4cac72016-06-29 15:26:12 -0700144
145 // Classify the signal based on the estimate of the noise spectrum and the
146 // signal spectrum estimate.
147 *signal_type = ClassifySignal(signal_spectrum,
148 noise_spectrum_estimator_.GetNoiseSpectrum(),
149 data_dumper_);
150
151 // Update the noise spectrum based on the signal spectrum.
152 noise_spectrum_estimator_.Update(signal_spectrum,
153 initialization_frames_left_ > 0);
154
155 // Update the number of frames until a reliable signal spectrum is achieved.
156 initialization_frames_left_ = std::max(0, initialization_frames_left_ - 1);
157
158 if (last_signal_type_ == *signal_type) {
159 consistent_classification_counter_ =
160 std::max(0, consistent_classification_counter_ - 1);
161 } else {
162 last_signal_type_ = *signal_type;
163 consistent_classification_counter_ = 3;
164 }
165
166 if (consistent_classification_counter_ > 0) {
167 *signal_type = SignalClassifier::SignalType::kNonStationary;
168 }
169}
170
171} // namespace webrtc