AEC3: Change adaptation speed of the matched filter after a delay is found

This change enables the use of two different adaptation speeds of the
matched filter of the delay estimator of AEC3.

One speed is used when no delay has been found, and one is used after a
reliable delay has been found. The purpose is to use a slower adaptation
speed to reduce the risk of divergence during double-talk without
slowing down the search for the initial delay.

The CL prepares for experimentation by adding field trials for
controlling the two adaptation speeds.

Bug: webrtc:12775
Change-Id: I817a1ab5ded0f78d20de45edcf04c708290173fc
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/219083
Commit-Queue: Gustaf Ullberg <gustaf@webrtc.org>
Reviewed-by: Per Ã…hgren <peah@webrtc.org>
Cr-Commit-Position: refs/heads/master@{#34055}
diff --git a/api/audio/echo_canceller3_config.h b/api/audio/echo_canceller3_config.h
index 8ffc3d9..d4a04cd 100644
--- a/api/audio/echo_canceller3_config.h
+++ b/api/audio/echo_canceller3_config.h
@@ -43,6 +43,7 @@
     size_t hysteresis_limit_blocks = 1;
     size_t fixed_capture_delay_samples = 0;
     float delay_estimate_smoothing = 0.7f;
+    float delay_estimate_smoothing_delay_found = 0.7f;
     float delay_candidate_detection_threshold = 0.2f;
     struct DelaySelectionThresholds {
       int initial;
diff --git a/api/audio/echo_canceller3_config_json.cc b/api/audio/echo_canceller3_config_json.cc
index 89256b3..39713a1 100644
--- a/api/audio/echo_canceller3_config_json.cc
+++ b/api/audio/echo_canceller3_config_json.cc
@@ -191,6 +191,8 @@
               &cfg.delay.fixed_capture_delay_samples);
     ReadParam(section, "delay_estimate_smoothing",
               &cfg.delay.delay_estimate_smoothing);
+    ReadParam(section, "delay_estimate_smoothing_delay_found",
+              &cfg.delay.delay_estimate_smoothing_delay_found);
     ReadParam(section, "delay_candidate_detection_threshold",
               &cfg.delay.delay_candidate_detection_threshold);
 
@@ -425,6 +427,8 @@
       << config.delay.fixed_capture_delay_samples << ",";
   ost << "\"delay_estimate_smoothing\": "
       << config.delay.delay_estimate_smoothing << ",";
+  ost << "\"delay_estimate_smoothing_delay_found\": "
+      << config.delay.delay_estimate_smoothing_delay_found << ",";
   ost << "\"delay_candidate_detection_threshold\": "
       << config.delay.delay_candidate_detection_threshold << ",";
 
diff --git a/modules/audio_processing/aec3/echo_canceller3.cc b/modules/audio_processing/aec3/echo_canceller3.cc
index 35a2cff..4e4632d 100644
--- a/modules/audio_processing/aec3/echo_canceller3.cc
+++ b/modules/audio_processing/aec3/echo_canceller3.cc
@@ -572,6 +572,12 @@
   RetrieveFieldTrialValue("WebRTC-Aec3SuppressorEpStrengthDefaultLenOverride",
                           -1.f, 1.f, &adjusted_cfg.ep_strength.default_len);
 
+  // Field trial-based overrides of individual delay estimator parameters.
+  RetrieveFieldTrialValue("WebRTC-Aec3DelayEstimateSmoothingOverride", 0.f, 1.f,
+                          &adjusted_cfg.delay.delay_estimate_smoothing);
+  RetrieveFieldTrialValue(
+      "WebRTC-Aec3DelayEstimateSmoothingDelayFoundOverride", 0.f, 1.f,
+      &adjusted_cfg.delay.delay_estimate_smoothing_delay_found);
   return adjusted_cfg;
 }
 
diff --git a/modules/audio_processing/aec3/echo_path_delay_estimator.cc b/modules/audio_processing/aec3/echo_path_delay_estimator.cc
index 2c987f9..8a78834 100644
--- a/modules/audio_processing/aec3/echo_path_delay_estimator.cc
+++ b/modules/audio_processing/aec3/echo_path_delay_estimator.cc
@@ -42,6 +42,7 @@
               ? config.render_levels.poor_excitation_render_limit_ds8
               : config.render_levels.poor_excitation_render_limit,
           config.delay.delay_estimate_smoothing,
+          config.delay.delay_estimate_smoothing_delay_found,
           config.delay.delay_candidate_detection_threshold),
       matched_filter_lag_aggregator_(data_dumper_,
                                      matched_filter_.GetMaxFilterLag(),
@@ -71,7 +72,8 @@
   data_dumper_->DumpWav("aec3_capture_decimator_output",
                         downsampled_capture.size(), downsampled_capture.data(),
                         16000 / down_sampling_factor_, 1);
-  matched_filter_.Update(render_buffer, downsampled_capture);
+  matched_filter_.Update(render_buffer, downsampled_capture,
+                         matched_filter_lag_aggregator_.ReliableDelayFound());
 
   absl::optional<DelayEstimate> aggregated_matched_filter_lag =
       matched_filter_lag_aggregator_.Aggregate(
diff --git a/modules/audio_processing/aec3/matched_filter.cc b/modules/audio_processing/aec3/matched_filter.cc
index 64b2d4e..1721e9c 100644
--- a/modules/audio_processing/aec3/matched_filter.cc
+++ b/modules/audio_processing/aec3/matched_filter.cc
@@ -307,7 +307,8 @@
                              int num_matched_filters,
                              size_t alignment_shift_sub_blocks,
                              float excitation_limit,
-                             float smoothing,
+                             float smoothing_fast,
+                             float smoothing_slow,
                              float matching_filter_threshold)
     : data_dumper_(data_dumper),
       optimization_(optimization),
@@ -319,7 +320,8 @@
       lag_estimates_(num_matched_filters),
       filters_offsets_(num_matched_filters, 0),
       excitation_limit_(excitation_limit),
-      smoothing_(smoothing),
+      smoothing_fast_(smoothing_fast),
+      smoothing_slow_(smoothing_slow),
       matching_filter_threshold_(matching_filter_threshold) {
   RTC_DCHECK(data_dumper);
   RTC_DCHECK_LT(0, window_size_sub_blocks);
@@ -340,10 +342,14 @@
 }
 
 void MatchedFilter::Update(const DownsampledRenderBuffer& render_buffer,
-                           rtc::ArrayView<const float> capture) {
+                           rtc::ArrayView<const float> capture,
+                           bool use_slow_smoothing) {
   RTC_DCHECK_EQ(sub_block_size_, capture.size());
   auto& y = capture;
 
+  const float smoothing =
+      use_slow_smoothing ? smoothing_slow_ : smoothing_fast_;
+
   const float x2_sum_threshold =
       filters_[0].size() * excitation_limit_ * excitation_limit_;
 
@@ -360,25 +366,25 @@
     switch (optimization_) {
 #if defined(WEBRTC_ARCH_X86_FAMILY)
       case Aec3Optimization::kSse2:
-        aec3::MatchedFilterCore_SSE2(x_start_index, x2_sum_threshold,
-                                     smoothing_, render_buffer.buffer, y,
-                                     filters_[n], &filters_updated, &error_sum);
+        aec3::MatchedFilterCore_SSE2(x_start_index, x2_sum_threshold, smoothing,
+                                     render_buffer.buffer, y, filters_[n],
+                                     &filters_updated, &error_sum);
         break;
       case Aec3Optimization::kAvx2:
-        aec3::MatchedFilterCore_AVX2(x_start_index, x2_sum_threshold,
-                                     smoothing_, render_buffer.buffer, y,
-                                     filters_[n], &filters_updated, &error_sum);
+        aec3::MatchedFilterCore_AVX2(x_start_index, x2_sum_threshold, smoothing,
+                                     render_buffer.buffer, y, filters_[n],
+                                     &filters_updated, &error_sum);
         break;
 #endif
 #if defined(WEBRTC_HAS_NEON)
       case Aec3Optimization::kNeon:
-        aec3::MatchedFilterCore_NEON(x_start_index, x2_sum_threshold,
-                                     smoothing_, render_buffer.buffer, y,
-                                     filters_[n], &filters_updated, &error_sum);
+        aec3::MatchedFilterCore_NEON(x_start_index, x2_sum_threshold, smoothing,
+                                     render_buffer.buffer, y, filters_[n],
+                                     &filters_updated, &error_sum);
         break;
 #endif
       default:
-        aec3::MatchedFilterCore(x_start_index, x2_sum_threshold, smoothing_,
+        aec3::MatchedFilterCore(x_start_index, x2_sum_threshold, smoothing,
                                 render_buffer.buffer, y, filters_[n],
                                 &filters_updated, &error_sum);
     }
diff --git a/modules/audio_processing/aec3/matched_filter.h b/modules/audio_processing/aec3/matched_filter.h
index fa44eb2..c6410ab 100644
--- a/modules/audio_processing/aec3/matched_filter.h
+++ b/modules/audio_processing/aec3/matched_filter.h
@@ -100,7 +100,8 @@
                 int num_matched_filters,
                 size_t alignment_shift_sub_blocks,
                 float excitation_limit,
-                float smoothing,
+                float smoothing_fast,
+                float smoothing_slow,
                 float matching_filter_threshold);
 
   MatchedFilter() = delete;
@@ -111,7 +112,8 @@
 
   // Updates the correlation with the values in the capture buffer.
   void Update(const DownsampledRenderBuffer& render_buffer,
-              rtc::ArrayView<const float> capture);
+              rtc::ArrayView<const float> capture,
+              bool use_slow_smoothing);
 
   // Resets the matched filter.
   void Reset();
@@ -140,7 +142,8 @@
   std::vector<LagEstimate> lag_estimates_;
   std::vector<size_t> filters_offsets_;
   const float excitation_limit_;
-  const float smoothing_;
+  const float smoothing_fast_;
+  const float smoothing_slow_;
   const float matching_filter_threshold_;
 };
 
diff --git a/modules/audio_processing/aec3/matched_filter_lag_aggregator.h b/modules/audio_processing/aec3/matched_filter_lag_aggregator.h
index d48011e..612bd5d 100644
--- a/modules/audio_processing/aec3/matched_filter_lag_aggregator.h
+++ b/modules/audio_processing/aec3/matched_filter_lag_aggregator.h
@@ -45,6 +45,9 @@
   absl::optional<DelayEstimate> Aggregate(
       rtc::ArrayView<const MatchedFilter::LagEstimate> lag_estimates);
 
+  // Returns whether a reliable delay estimate has been found.
+  bool ReliableDelayFound() const { return significant_candidate_found_; }
+
  private:
   ApmDataDumper* const data_dumper_;
   std::vector<int> histogram_;
diff --git a/modules/audio_processing/aec3/matched_filter_unittest.cc b/modules/audio_processing/aec3/matched_filter_unittest.cc
index 137275f..37b51fa 100644
--- a/modules/audio_processing/aec3/matched_filter_unittest.cc
+++ b/modules/audio_processing/aec3/matched_filter_unittest.cc
@@ -206,6 +206,7 @@
                            kWindowSizeSubBlocks, kNumMatchedFilters,
                            kAlignmentShiftSubBlocks, 150,
                            config.delay.delay_estimate_smoothing,
+                           config.delay.delay_estimate_smoothing_delay_found,
                            config.delay.delay_candidate_detection_threshold);
 
       std::unique_ptr<RenderDelayBuffer> render_delay_buffer(
@@ -231,7 +232,7 @@
             downsampled_capture_data.data(), sub_block_size);
         capture_decimator.Decimate(capture[0], downsampled_capture);
         filter.Update(render_delay_buffer->GetDownsampledRenderBuffer(),
-                      downsampled_capture);
+                      downsampled_capture, false);
       }
 
       // Obtain the lag estimates.
@@ -318,6 +319,7 @@
                          kWindowSizeSubBlocks, kNumMatchedFilters,
                          kAlignmentShiftSubBlocks, 150,
                          config.delay.delay_estimate_smoothing,
+                         config.delay.delay_estimate_smoothing_delay_found,
                          config.delay.delay_candidate_detection_threshold);
 
     // Analyze the correlation between render and capture.
@@ -325,7 +327,8 @@
       RandomizeSampleVector(&random_generator, render[0][0]);
       RandomizeSampleVector(&random_generator, capture);
       render_delay_buffer->Insert(render);
-      filter.Update(render_delay_buffer->GetDownsampledRenderBuffer(), capture);
+      filter.Update(render_delay_buffer->GetDownsampledRenderBuffer(), capture,
+                    false);
     }
 
     // Obtain the lag estimates.
@@ -361,6 +364,7 @@
                          kWindowSizeSubBlocks, kNumMatchedFilters,
                          kAlignmentShiftSubBlocks, 150,
                          config.delay.delay_estimate_smoothing,
+                         config.delay.delay_estimate_smoothing_delay_found,
                          config.delay.delay_candidate_detection_threshold);
     std::unique_ptr<RenderDelayBuffer> render_delay_buffer(
         RenderDelayBuffer::Create(EchoCanceller3Config(), kSampleRateHz,
@@ -379,7 +383,7 @@
                                                 sub_block_size);
       capture_decimator.Decimate(capture[0], downsampled_capture);
       filter.Update(render_delay_buffer->GetDownsampledRenderBuffer(),
-                    downsampled_capture);
+                    downsampled_capture, false);
     }
 
     // Obtain the lag estimates.
@@ -407,6 +411,7 @@
       MatchedFilter filter(&data_dumper, DetectOptimization(), sub_block_size,
                            32, num_matched_filters, 1, 150,
                            config.delay.delay_estimate_smoothing,
+                           config.delay.delay_estimate_smoothing_delay_found,
                            config.delay.delay_candidate_detection_threshold);
       EXPECT_EQ(num_matched_filters, filter.GetLagEstimates().size());
     }
@@ -421,6 +426,7 @@
   EchoCanceller3Config config;
   EXPECT_DEATH(MatchedFilter(&data_dumper, DetectOptimization(), 16, 0, 1, 1,
                              150, config.delay.delay_estimate_smoothing,
+                             config.delay.delay_estimate_smoothing_delay_found,
                              config.delay.delay_candidate_detection_threshold),
                "");
 }
@@ -430,6 +436,7 @@
   EchoCanceller3Config config;
   EXPECT_DEATH(MatchedFilter(nullptr, DetectOptimization(), 16, 1, 1, 1, 150,
                              config.delay.delay_estimate_smoothing,
+                             config.delay.delay_estimate_smoothing_delay_found,
                              config.delay.delay_candidate_detection_threshold),
                "");
 }
@@ -441,6 +448,7 @@
   EchoCanceller3Config config;
   EXPECT_DEATH(MatchedFilter(&data_dumper, DetectOptimization(), 15, 1, 1, 1,
                              150, config.delay.delay_estimate_smoothing,
+                             config.delay.delay_estimate_smoothing_delay_found,
                              config.delay.delay_candidate_detection_threshold),
                "");
 }
@@ -453,6 +461,7 @@
   EchoCanceller3Config config;
   EXPECT_DEATH(MatchedFilter(&data_dumper, DetectOptimization(), 12, 1, 1, 1,
                              150, config.delay.delay_estimate_smoothing,
+                             config.delay.delay_estimate_smoothing_delay_found,
                              config.delay.delay_candidate_detection_threshold),
                "");
 }