Redesign of the render buffering in AEC3

This CL centralizes the render buffering in AEC3 so that all render
buffers are updated and synchronized/aligned with the render alignment
buffer.

Bug: webrtc:8597, chromium:790905
Change-Id: I8a94e5c1f27316b6100b420eec9652ea31c1a91d
Reviewed-on: https://webrtc-review.googlesource.com/25680
Commit-Queue: Per Åhgren <peah@webrtc.org>
Reviewed-by: Gustaf Ullberg <gustaf@webrtc.org>
Cr-Commit-Position: refs/heads/master@{#20989}
diff --git a/modules/audio_processing/aec3/render_delay_buffer.cc b/modules/audio_processing/aec3/render_delay_buffer.cc
index d2ead63..9640131 100644
--- a/modules/audio_processing/aec3/render_delay_buffer.cc
+++ b/modules/audio_processing/aec3/render_delay_buffer.cc
@@ -14,9 +14,12 @@
 #include <algorithm>
 
 #include "modules/audio_processing/aec3/aec3_common.h"
+#include "modules/audio_processing/aec3/aec3_fft.h"
 #include "modules/audio_processing/aec3/block_processor.h"
 #include "modules/audio_processing/aec3/decimator.h"
+#include "modules/audio_processing/aec3/fft_buffer.h"
 #include "modules/audio_processing/aec3/fft_data.h"
+#include "modules/audio_processing/aec3/matrix_buffer.h"
 #include "rtc_base/atomicops.h"
 #include "rtc_base/checks.h"
 #include "rtc_base/constructormagic.h"
@@ -25,191 +28,161 @@
 namespace webrtc {
 namespace {
 
-class ApiCallJitterBuffer {
- public:
-  explicit ApiCallJitterBuffer(size_t num_bands) {
-    buffer_.fill(std::vector<std::vector<float>>(
-        num_bands, std::vector<float>(kBlockSize, 0.f)));
-  }
-
-  ~ApiCallJitterBuffer() = default;
-
-  void Reset() {
-    size_ = 0;
-    last_insert_index_ = 0;
-  }
-
-  void Insert(const std::vector<std::vector<float>>& block) {
-    RTC_DCHECK_LT(size_, buffer_.size());
-    last_insert_index_ = (last_insert_index_ + 1) % buffer_.size();
-    RTC_DCHECK_EQ(buffer_[last_insert_index_].size(), block.size());
-    RTC_DCHECK_EQ(buffer_[last_insert_index_][0].size(), block[0].size());
-    for (size_t k = 0; k < block.size(); ++k) {
-      std::copy(block[k].begin(), block[k].end(),
-                buffer_[last_insert_index_][k].begin());
-    }
-    ++size_;
-  }
-
-  void Remove(std::vector<std::vector<float>>* block) {
-    RTC_DCHECK_LT(0, size_);
-    --size_;
-    const size_t extract_index =
-        (last_insert_index_ - size_ + buffer_.size()) % buffer_.size();
-    for (size_t k = 0; k < block->size(); ++k) {
-      std::copy(buffer_[extract_index][k].begin(),
-                buffer_[extract_index][k].end(), (*block)[k].begin());
-    }
-  }
-
-  size_t Size() const { return size_; }
-  bool Full() const { return size_ >= (buffer_.size()); }
-  bool Empty() const { return size_ == 0; }
-
- private:
-  std::array<std::vector<std::vector<float>>, kMaxApiCallsJitterBlocks> buffer_;
-  size_t size_ = 0;
-  int last_insert_index_ = 0;
-};
+constexpr int kBufferHeadroom = kAdaptiveFilterLength;
 
 class RenderDelayBufferImpl final : public RenderDelayBuffer {
  public:
-  RenderDelayBufferImpl(size_t num_bands,
-                        size_t down_sampling_factor,
-                        size_t downsampled_render_buffer_size,
-                        size_t render_delay_buffer_size);
+  RenderDelayBufferImpl(const EchoCanceller3Config& config, size_t num_bands);
   ~RenderDelayBufferImpl() override;
 
   void Reset() override;
-  bool Insert(const std::vector<std::vector<float>>& block) override;
-  bool UpdateBuffers() override;
+  BufferingEvent Insert(const std::vector<std::vector<float>>& block) override;
+  BufferingEvent PrepareCaptureCall() override;
   void SetDelay(size_t delay) override;
   size_t Delay() const override { return delay_; }
-
-  const RenderBuffer& GetRenderBuffer() const override { return fft_buffer_; }
+  size_t MaxDelay() const override {
+    return blocks_.buffer.size() - 1 - kBufferHeadroom;
+  }
+  size_t MaxApiJitter() const override { return max_api_jitter_; }
+  const RenderBuffer& GetRenderBuffer() const override {
+    return echo_remover_buffer_;
+  }
 
   const DownsampledRenderBuffer& GetDownsampledRenderBuffer() const override {
-    return downsampled_render_buffer_;
+    return low_rate_;
   }
 
  private:
   static int instance_count_;
   std::unique_ptr<ApmDataDumper> data_dumper_;
   const Aec3Optimization optimization_;
-  const size_t down_sampling_factor_;
-  const size_t sub_block_size_;
-  std::vector<std::vector<std::vector<float>>> buffer_;
-  size_t delay_ = 0;
-  size_t last_insert_index_ = 0;
-  RenderBuffer fft_buffer_;
-  DownsampledRenderBuffer downsampled_render_buffer_;
+  const size_t api_call_jitter_blocks_;
+  const size_t min_echo_path_delay_blocks_;
+  const int sub_block_size_;
+  MatrixBuffer blocks_;
+  VectorBuffer spectra_;
+  FftBuffer ffts_;
+  size_t delay_;
+  int max_api_jitter_ = 0;
+  int render_surplus_ = 0;
+  bool first_reset_occurred_ = false;
+  RenderBuffer echo_remover_buffer_;
+  DownsampledRenderBuffer low_rate_;
   Decimator render_decimator_;
-  ApiCallJitterBuffer api_call_jitter_buffer_;
   const std::vector<std::vector<float>> zero_block_;
+  const Aec3Fft fft_;
+  size_t capture_call_counter_ = 0;
+  std::vector<float> render_ds_;
+  int render_calls_in_a_row_ = 0;
+
+  void UpdateBuffersWithLatestBlock(size_t previous_write);
+  void IncreaseRead();
+  void IncreaseInsert();
+
   RTC_DISALLOW_IMPLICIT_CONSTRUCTORS(RenderDelayBufferImpl);
 };
 
 int RenderDelayBufferImpl::instance_count_ = 0;
 
-RenderDelayBufferImpl::RenderDelayBufferImpl(
-    size_t num_bands,
-    size_t down_sampling_factor,
-    size_t downsampled_render_buffer_size,
-    size_t render_delay_buffer_size)
+RenderDelayBufferImpl::RenderDelayBufferImpl(const EchoCanceller3Config& config,
+                                             size_t num_bands)
     : data_dumper_(
           new ApmDataDumper(rtc::AtomicOps::Increment(&instance_count_))),
       optimization_(DetectOptimization()),
-      down_sampling_factor_(down_sampling_factor),
-      sub_block_size_(down_sampling_factor_ > 0
-                          ? kBlockSize / down_sampling_factor
-                          : kBlockSize),
-      buffer_(
-          render_delay_buffer_size,
-          std::vector<std::vector<float>>(num_bands,
-                                          std::vector<float>(kBlockSize, 0.f))),
-      fft_buffer_(
-          optimization_,
-          num_bands,
-          std::max(kUnknownDelayRenderWindowSize, kAdaptiveFilterLength),
-          std::vector<size_t>(1, kAdaptiveFilterLength)),
-      downsampled_render_buffer_(downsampled_render_buffer_size),
-      render_decimator_(down_sampling_factor_),
-      api_call_jitter_buffer_(num_bands),
-      zero_block_(num_bands, std::vector<float>(kBlockSize, 0.f)) {
-  RTC_DCHECK_LT(buffer_.size(), downsampled_render_buffer_.buffer.size());
+      api_call_jitter_blocks_(config.delay.api_call_jitter_blocks),
+      min_echo_path_delay_blocks_(config.delay.min_echo_path_delay_blocks),
+      sub_block_size_(
+          static_cast<int>(config.delay.down_sampling_factor > 0
+                               ? kBlockSize / config.delay.down_sampling_factor
+                               : kBlockSize)),
+      blocks_(GetRenderDelayBufferSize(config.delay.down_sampling_factor,
+                                       config.delay.num_filters),
+              num_bands,
+              kBlockSize),
+      spectra_(blocks_.buffer.size(), kFftLengthBy2Plus1),
+      ffts_(blocks_.buffer.size()),
+      delay_(min_echo_path_delay_blocks_),
+      echo_remover_buffer_(kAdaptiveFilterLength, &blocks_, &spectra_, &ffts_),
+      low_rate_(GetDownSampledBufferSize(config.delay.down_sampling_factor,
+                                         config.delay.num_filters)),
+      render_decimator_(config.delay.down_sampling_factor),
+      zero_block_(num_bands, std::vector<float>(kBlockSize, 0.f)),
+      fft_(),
+      render_ds_(sub_block_size_, 0.f) {
+  RTC_DCHECK_EQ(blocks_.buffer.size(), ffts_.buffer.size());
+  RTC_DCHECK_EQ(spectra_.buffer.size(), ffts_.buffer.size());
+  Reset();
+  first_reset_occurred_ = false;
 }
 
 RenderDelayBufferImpl::~RenderDelayBufferImpl() = default;
 
 void RenderDelayBufferImpl::Reset() {
-  // Empty all data in the buffers.
-  delay_ = 0;
-  last_insert_index_ = 0;
-  downsampled_render_buffer_.position = 0;
-  std::fill(downsampled_render_buffer_.buffer.begin(),
-            downsampled_render_buffer_.buffer.end(), 0.f);
-  fft_buffer_.Clear();
-  api_call_jitter_buffer_.Reset();
-  for (auto& c : buffer_) {
-    for (auto& b : c) {
-      std::fill(b.begin(), b.end(), 0.f);
-    }
-  }
+  delay_ = min_echo_path_delay_blocks_;
+  const int offset1 = std::max<int>(
+      std::min(api_call_jitter_blocks_, min_echo_path_delay_blocks_), 1);
+  const int offset2 = static_cast<int>(delay_ + offset1);
+  const int offset3 = offset1 * sub_block_size_;
+  low_rate_.read = low_rate_.OffsetIndex(low_rate_.write, offset3);
+  blocks_.read = blocks_.OffsetIndex(blocks_.write, -offset2);
+  spectra_.read = spectra_.OffsetIndex(spectra_.write, offset2);
+  ffts_.read = ffts_.OffsetIndex(ffts_.write, offset2);
+  render_surplus_ = 0;
+  first_reset_occurred_ = true;
 }
 
-bool RenderDelayBufferImpl::Insert(
+RenderDelayBuffer::BufferingEvent RenderDelayBufferImpl::Insert(
     const std::vector<std::vector<float>>& block) {
-  RTC_DCHECK_EQ(block.size(), buffer_[0].size());
-  RTC_DCHECK_EQ(block[0].size(), buffer_[0][0].size());
+  RTC_DCHECK_EQ(block.size(), blocks_.buffer[0].size());
+  RTC_DCHECK_EQ(block[0].size(), blocks_.buffer[0][0].size());
+  BufferingEvent event = BufferingEvent::kNone;
 
-  if (api_call_jitter_buffer_.Full()) {
-    // Report buffer overrun and let the caller handle the overrun.
-    return false;
+  ++render_surplus_;
+  if (first_reset_occurred_) {
+    ++render_calls_in_a_row_;
+    max_api_jitter_ = std::max(max_api_jitter_, render_calls_in_a_row_);
   }
-  api_call_jitter_buffer_.Insert(block);
 
-  return true;
+  const size_t previous_write = blocks_.write;
+  IncreaseInsert();
+
+  if (low_rate_.read == low_rate_.write || blocks_.read == blocks_.write) {
+    // Render overrun due to more render data being inserted than read. Discard
+    // the oldest render data.
+    event = BufferingEvent::kRenderOverrun;
+    IncreaseRead();
+  }
+
+  for (size_t k = 0; k < block.size(); ++k) {
+    std::copy(block[k].begin(), block[k].end(),
+              blocks_.buffer[blocks_.write][k].begin());
+  }
+
+  UpdateBuffersWithLatestBlock(previous_write);
+  return event;
 }
 
-bool RenderDelayBufferImpl::UpdateBuffers() {
-  bool underrun = true;
-  // Update the buffers with a new block if such is available, otherwise insert
-  // a block of silence.
-  if (api_call_jitter_buffer_.Size() > 0) {
-    last_insert_index_ = (last_insert_index_ + 1) % buffer_.size();
-    api_call_jitter_buffer_.Remove(&buffer_[last_insert_index_]);
-    underrun = false;
-  }
+RenderDelayBuffer::BufferingEvent RenderDelayBufferImpl::PrepareCaptureCall() {
+  BufferingEvent event = BufferingEvent::kNone;
+  render_calls_in_a_row_ = 0;
 
-  downsampled_render_buffer_.position =
-      (downsampled_render_buffer_.position - sub_block_size_ +
-       downsampled_render_buffer_.buffer.size()) %
-      downsampled_render_buffer_.buffer.size();
-
-  rtc::ArrayView<const float> input(
-      underrun ? zero_block_[0].data() : buffer_[last_insert_index_][0].data(),
-      kBlockSize);
-  rtc::ArrayView<float> output(downsampled_render_buffer_.buffer.data() +
-                                   downsampled_render_buffer_.position,
-                               sub_block_size_);
-  data_dumper_->DumpWav("aec3_render_decimator_input", input.size(),
-                        input.data(), 16000, 1);
-  render_decimator_.Decimate(input, output);
-  data_dumper_->DumpWav("aec3_render_decimator_output", output.size(),
-                        output.data(), 16000 / down_sampling_factor_, 1);
-  for (size_t k = 0; k < output.size() / 2; ++k) {
-    float tmp = output[k];
-    output[k] = output[output.size() - 1 - k];
-    output[output.size() - 1 - k] = tmp;
-  }
-
-  if (underrun) {
-    fft_buffer_.Insert(zero_block_);
+  if (low_rate_.read == low_rate_.write || blocks_.read == blocks_.write) {
+    event = BufferingEvent::kRenderUnderrun;
   } else {
-    fft_buffer_.Insert(buffer_[(last_insert_index_ - delay_ + buffer_.size()) %
-                               buffer_.size()]);
+    IncreaseRead();
   }
-  return !underrun;
+  --render_surplus_;
+
+  echo_remover_buffer_.UpdateSpectralSum();
+
+  if (render_surplus_ >= static_cast<int>(api_call_jitter_blocks_)) {
+    event = BufferingEvent::kApiCallSkew;
+    RTC_LOG(LS_WARNING) << "Api call skew detected at " << capture_call_counter_
+                        << ".";
+  }
+
+  ++capture_call_counter_;
+  return event;
 }
 
 void RenderDelayBufferImpl::SetDelay(size_t delay) {
@@ -217,37 +190,51 @@
     return;
   }
 
-  // If there is a new delay set, clear the fft buffer.
-  fft_buffer_.Clear();
-
-  if ((buffer_.size() - 1) < delay) {
-    // If the desired delay is larger than the delay buffer, shorten the delay
-    // buffer size to achieve the desired alignment with the available buffer
-    // size.
-    downsampled_render_buffer_.position =
-        (downsampled_render_buffer_.position +
-         sub_block_size_ * (delay - (buffer_.size() - 1))) %
-        downsampled_render_buffer_.buffer.size();
-
-    last_insert_index_ =
-        (last_insert_index_ - (delay - (buffer_.size() - 1)) + buffer_.size()) %
-        buffer_.size();
-    delay_ = buffer_.size() - 1;
-  } else {
-    delay_ = delay;
+  const int delta_delay = static_cast<int>(delay_) - static_cast<int>(delay);
+  delay_ = delay;
+  if (delay_ > MaxDelay()) {
+    delay_ = std::min(MaxDelay(), delay);
+    RTC_NOTREACHED();
   }
+
+  // Recompute the read indices according to the set delay.
+  blocks_.UpdateReadIndex(delta_delay);
+  spectra_.UpdateReadIndex(-delta_delay);
+  ffts_.UpdateReadIndex(-delta_delay);
 }
 
+void RenderDelayBufferImpl::UpdateBuffersWithLatestBlock(
+    size_t previous_write) {
+  render_decimator_.Decimate(blocks_.buffer[blocks_.write][0], render_ds_);
+  std::copy(render_ds_.rbegin(), render_ds_.rend(),
+            low_rate_.buffer.begin() + low_rate_.write);
+
+  fft_.PaddedFft(blocks_.buffer[blocks_.write][0],
+                 blocks_.buffer[previous_write][0], &ffts_.buffer[ffts_.write]);
+
+  ffts_.buffer[ffts_.write].Spectrum(optimization_,
+                                     spectra_.buffer[spectra_.write]);
+};
+
+void RenderDelayBufferImpl::IncreaseRead() {
+  low_rate_.UpdateReadIndex(-sub_block_size_);
+  blocks_.IncReadIndex();
+  spectra_.DecReadIndex();
+  ffts_.DecReadIndex();
+};
+
+void RenderDelayBufferImpl::IncreaseInsert() {
+  low_rate_.UpdateWriteIndex(-sub_block_size_);
+  blocks_.IncWriteIndex();
+  spectra_.DecWriteIndex();
+  ffts_.DecWriteIndex();
+};
+
 }  // namespace
 
-RenderDelayBuffer* RenderDelayBuffer::Create(
-    size_t num_bands,
-    size_t down_sampling_factor,
-    size_t downsampled_render_buffer_size,
-    size_t render_delay_buffer_size) {
-  return new RenderDelayBufferImpl(num_bands, down_sampling_factor,
-                                   downsampled_render_buffer_size,
-                                   render_delay_buffer_size);
+RenderDelayBuffer* RenderDelayBuffer::Create(const EchoCanceller3Config& config,
+                                             size_t num_bands) {
+  return new RenderDelayBufferImpl(config, num_bands);
 }
 
 }  // namespace webrtc