AEC3: Improved the accuracy of the adaptive filter

This CL adds a functionality that jump-starts the
AEC3 shadow filter whenever it performs consistently
worse than the main filter.
The jump-start is done such that the shadow filter
is re-initialized using the main filter coefficients.

The effects of this is a significantly more accurate
main linear filter which leads to less echo leakage
and better transparency

Bug: webrtc:9565, chromium:867873
Change-Id: Ie0b23cd536adc7ce96fc3ed2a7db112aec7437f1
Reviewed-on: https://webrtc-review.googlesource.com/90413
Reviewed-by: Sam Zackrisson <saza@webrtc.org>
Commit-Queue: Per Åhgren <peah@webrtc.org>
Cr-Commit-Position: refs/heads/master@{#24117}
This commit is contained in:
Per Åhgren
2018-07-26 15:32:24 +02:00
committed by Commit Bot
parent 5e55881b94
commit e4db6a1518
13 changed files with 136 additions and 42 deletions

View File

@ -46,7 +46,7 @@ struct EchoCanceller3Config {
float noise_gate;
};
MainConfiguration main = {13, 0.0005f, 0.01f, 0.001f, 20075344.f};
MainConfiguration main = {13, 0.00005f, 0.01f, 0.1f, 20075344.f};
ShadowConfiguration shadow = {13, 0.7f, 20075344.f};
MainConfiguration main_initial = {12, 0.005f, 0.5f, 0.001f, 20075344.f};

View File

@ -96,6 +96,7 @@ rtc_static_library("aec3") {
"stationarity_estimator.h",
"subtractor.cc",
"subtractor.h",
"subtractor_output.cc",
"subtractor_output.h",
"subtractor_output_analyzer.cc",
"subtractor_output_analyzer.h",

View File

@ -632,4 +632,14 @@ void AdaptiveFirFilter::ScaleFilter(float factor) {
}
}
// Set the filter coefficients.
void AdaptiveFirFilter::SetFilter(const std::vector<FftData>& H) {
RTC_DCHECK_EQ(H_.size(), H.size());
const size_t num_partitions = std::min(H_.size(), H.size());
for (size_t k = 0; k < num_partitions; ++k) {
std::copy(H[k].re.begin(), H[k].re.end(), H_[k].re.begin());
std::copy(H[k].im.begin(), H[k].im.end(), H_[k].im.begin());
}
}
} // namespace webrtc

View File

@ -147,6 +147,12 @@ class AdaptiveFirFilter {
// Scale the filter impulse response and spectrum by a factor.
void ScaleFilter(float factor);
// Set the filter coefficients.
void SetFilter(const std::vector<FftData>& H);
// Gets the filter coefficients.
const std::vector<FftData>& GetFilter() const { return H_; }
private:
// Constrain the filter partitions in a cyclic manner.
void Constrain();

View File

@ -123,7 +123,7 @@ void AecState::Update(
const SubtractorOutput& subtractor_output,
rtc::ArrayView<const float> y) {
// Analyze the filter output.
subtractor_output_analyzer_.Update(y, subtractor_output);
subtractor_output_analyzer_.Update(subtractor_output);
const bool converged_filter = subtractor_output_analyzer_.ConvergedFilter();
const bool diverged_filter = subtractor_output_analyzer_.DivergedFilter();

View File

@ -52,6 +52,7 @@ TEST(AecState, NormalUsage) {
GetTimeDomainLength(config.filter.main.length_blocks), 0.f);
// Verify that linear AEC usability is false when the filter is diverged.
output.UpdatePowers(y);
state.Update(delay_estimate, diverged_filter_frequency_response,
impulse_response, *render_delay_buffer->GetRenderBuffer(),
E2_main, Y2, output, y);
@ -61,6 +62,7 @@ TEST(AecState, NormalUsage) {
std::fill(x[0].begin(), x[0].end(), 101.f);
for (int k = 0; k < 3000; ++k) {
render_delay_buffer->Insert(x);
output.UpdatePowers(y);
state.Update(delay_estimate, converged_filter_frequency_response,
impulse_response, *render_delay_buffer->GetRenderBuffer(),
E2_main, Y2, output, y);
@ -69,6 +71,7 @@ TEST(AecState, NormalUsage) {
// Verify that linear AEC usability becomes false after an echo path change is
// reported
output.UpdatePowers(y);
state.HandleEchoPathChange(EchoPathVariability(
false, EchoPathVariability::DelayAdjustment::kBufferReadjustment, false));
state.Update(delay_estimate, converged_filter_frequency_response,
@ -79,6 +82,7 @@ TEST(AecState, NormalUsage) {
// Verify that the active render detection works as intended.
std::fill(x[0].begin(), x[0].end(), 101.f);
render_delay_buffer->Insert(x);
output.UpdatePowers(y);
state.HandleEchoPathChange(EchoPathVariability(
true, EchoPathVariability::DelayAdjustment::kNewDetectedDelay, false));
state.Update(delay_estimate, converged_filter_frequency_response,
@ -88,6 +92,7 @@ TEST(AecState, NormalUsage) {
for (int k = 0; k < 1000; ++k) {
render_delay_buffer->Insert(x);
output.UpdatePowers(y);
state.Update(delay_estimate, converged_filter_frequency_response,
impulse_response, *render_delay_buffer->GetRenderBuffer(),
E2_main, Y2, output, y);
@ -111,6 +116,7 @@ TEST(AecState, NormalUsage) {
Y2.fill(10.f * 10000.f * 10000.f);
for (size_t k = 0; k < 1000; ++k) {
output.UpdatePowers(y);
state.Update(delay_estimate, converged_filter_frequency_response,
impulse_response, *render_delay_buffer->GetRenderBuffer(),
E2_main, Y2, output, y);
@ -128,6 +134,7 @@ TEST(AecState, NormalUsage) {
E2_main.fill(1.f * 10000.f * 10000.f);
Y2.fill(10.f * E2_main[0]);
for (size_t k = 0; k < 1000; ++k) {
output.UpdatePowers(y);
state.Update(delay_estimate, converged_filter_frequency_response,
impulse_response, *render_delay_buffer->GetRenderBuffer(),
E2_main, Y2, output, y);
@ -149,6 +156,7 @@ TEST(AecState, NormalUsage) {
E2_main.fill(1.f * 10000.f * 10000.f);
Y2.fill(5.f * E2_main[0]);
for (size_t k = 0; k < 1000; ++k) {
output.UpdatePowers(y);
state.Update(delay_estimate, converged_filter_frequency_response,
impulse_response, *render_delay_buffer->GetRenderBuffer(),
E2_main, Y2, output, y);
@ -203,6 +211,7 @@ TEST(AecState, ConvergedFilterDelay) {
impulse_response[k * kBlockSize + 1] = 1.f;
state.HandleEchoPathChange(echo_path_variability);
output.UpdatePowers(y);
state.Update(delay_estimate, frequency_response, impulse_response,
*render_delay_buffer->GetRenderBuffer(), E2_main, Y2, output,
y);

View File

@ -50,6 +50,10 @@ bool EnableSlowFilterAdaptation() {
return !field_trial::IsEnabled("WebRTC-Aec3SlowFilterAdaptationKillSwitch");
}
bool EnableShadowFilterJumpstart() {
return !field_trial::IsEnabled("WebRTC-Aec3ShadowFilterJumpstartKillSwitch");
}
// Method for adjusting config parameter dependencies..
EchoCanceller3Config AdjustConfig(const EchoCanceller3Config& config) {
EchoCanceller3Config adjusted_cfg = config;
@ -103,12 +107,25 @@ EchoCanceller3Config AdjustConfig(const EchoCanceller3Config& config) {
}
if (!EnableSlowFilterAdaptation()) {
adjusted_cfg.filter.main.leakage_converged = 0.005f;
adjusted_cfg.filter.main.leakage_diverged = 0.1f;
if (!EnableShadowFilterJumpstart()) {
adjusted_cfg.filter.main.leakage_converged = 0.005f;
adjusted_cfg.filter.main.leakage_diverged = 0.1f;
}
adjusted_cfg.filter.main_initial.leakage_converged = 0.05f;
adjusted_cfg.filter.main_initial.leakage_diverged = 5.f;
}
if (!EnableShadowFilterJumpstart()) {
if (EnableSlowFilterAdaptation()) {
adjusted_cfg.filter.main.leakage_converged = 0.0005f;
adjusted_cfg.filter.main.leakage_diverged = 0.01f;
} else {
adjusted_cfg.filter.main.leakage_converged = 0.005f;
adjusted_cfg.filter.main.leakage_diverged = 0.1f;
}
adjusted_cfg.filter.main.error_floor = 0.001f;
}
return adjusted_cfg;
}

View File

@ -36,6 +36,10 @@ bool EnableMisadjustmentEstimator() {
return !field_trial::IsEnabled("WebRTC-Aec3MisadjustmentEstimatorKillSwitch");
}
bool EnableShadowFilterJumpstart() {
return !field_trial::IsEnabled("WebRTC-Aec3ShadowFilterJumpstartKillSwitch");
}
void PredictionError(const Aec3Fft& fft,
const FftData& S,
rtc::ArrayView<const float> y,
@ -95,6 +99,7 @@ Subtractor::Subtractor(const EchoCanceller3Config& config,
adaptation_during_saturation_(EnableAdaptationDuringSaturation()),
enable_misadjustment_estimator_(EnableMisadjustmentEstimator()),
enable_agc_gain_change_response_(EnableAgcGainChangeResponse()),
enable_shadow_filter_jumpstart_(EnableShadowFilterJumpstart()),
main_filter_(config_.filter.main.length_blocks,
config_.filter.main_initial.length_blocks,
config.filter.config_change_duration_blocks,
@ -180,10 +185,13 @@ void Subtractor::Process(const RenderBuffer& render_buffer,
PredictionError(fft_, S, y, &e_shadow, nullptr, adaptation_during_saturation_,
&shadow_saturation);
// Compute the signal powers in the subtractor output.
output->UpdatePowers(y);
// Adjust the filter if needed.
bool main_filter_adjusted = false;
if (enable_misadjustment_estimator_) {
filter_misadjustment_estimator_.Update(e_main, y);
filter_misadjustment_estimator_.Update(*output);
if (filter_misadjustment_estimator_.IsAdjustmentNeeded()) {
float scale = filter_misadjustment_estimator_.GetMisadjustment();
main_filter_.ScaleFilter(scale);
@ -216,13 +224,22 @@ void Subtractor::Process(const RenderBuffer& render_buffer,
data_dumper_->DumpRaw("aec3_subtractor_G_main", G.im);
// Update the shadow filter.
if (shadow_filter_.SizePartitions() != main_filter_.SizePartitions()) {
render_buffer.SpectralSum(shadow_filter_.SizePartitions(), &X2);
poor_shadow_filter_counter_ =
output->e2_main < output->e2_shadow ? poor_shadow_filter_counter_ + 1 : 0;
if (poor_shadow_filter_counter_ < 10 || !enable_shadow_filter_jumpstart_) {
if (shadow_filter_.SizePartitions() != main_filter_.SizePartitions()) {
render_buffer.SpectralSum(shadow_filter_.SizePartitions(), &X2);
}
G_shadow_.Compute(X2, render_signal_analyzer, E_shadow,
shadow_filter_.SizePartitions(),
aec_state.SaturatedCapture() || shadow_saturation, &G);
shadow_filter_.Adapt(render_buffer, G);
} else {
G.re.fill(0.f);
G.im.fill(0.f);
poor_shadow_filter_counter_ = 0;
shadow_filter_.SetFilter(main_filter_.GetFilter());
}
G_shadow_.Compute(X2, render_signal_analyzer, E_shadow,
shadow_filter_.SizePartitions(),
aec_state.SaturatedCapture() || shadow_saturation, &G);
shadow_filter_.Adapt(render_buffer, G);
data_dumper_->DumpRaw("aec3_subtractor_G_shadow", G.re);
data_dumper_->DumpRaw("aec3_subtractor_G_shadow", G.im);
@ -241,19 +258,15 @@ void Subtractor::Process(const RenderBuffer& render_buffer,
}
void Subtractor::FilterMisadjustmentEstimator::Update(
rtc::ArrayView<const float> e,
rtc::ArrayView<const float> y) {
const auto sum_of_squares = [](float a, float b) { return a + b * b; };
const float y2 = std::accumulate(y.begin(), y.end(), 0.f, sum_of_squares);
const float e2 = std::accumulate(e.begin(), e.end(), 0.f, sum_of_squares);
e2_acum_ += e2;
y2_acum_ += y2;
const SubtractorOutput& output) {
e2_acum_ += output.e2_main;
y2_acum_ += output.y2;
if (++n_blocks_acum_ == n_blocks_) {
if (y2_acum_ > n_blocks_ * 200.f * 200.f * kBlockSize) {
float update = (e2_acum_ / y2_acum_);
if (e2_acum_ > n_blocks_ * 7500.f * 7500.f * kBlockSize) {
overhang_ = 4; // Duration equal to blockSizeMs * n_blocks_ * 4
// Duration equal to blockSizeMs * n_blocks_ * 4.
overhang_ = 4;
} else {
overhang_ = std::max(overhang_ - 1, 0);
}

View File

@ -74,7 +74,7 @@ class Subtractor {
FilterMisadjustmentEstimator() = default;
~FilterMisadjustmentEstimator() = default;
// Update the misadjustment estimator.
void Update(rtc::ArrayView<const float> e, rtc::ArrayView<const float> y);
void Update(const SubtractorOutput& output);
// GetMisadjustment() Returns a recommended scale for the filter so the
// prediction error energy gets closer to the energy that is seen at the
// microphone input.
@ -107,11 +107,13 @@ class Subtractor {
const bool adaptation_during_saturation_;
const bool enable_misadjustment_estimator_;
const bool enable_agc_gain_change_response_;
const bool enable_shadow_filter_jumpstart_;
AdaptiveFirFilter main_filter_;
AdaptiveFirFilter shadow_filter_;
MainFilterUpdateGain G_main_;
ShadowFilterUpdateGain G_shadow_;
FilterMisadjustmentEstimator filter_misadjustment_estimator_;
size_t poor_shadow_filter_counter_ = 0;
RTC_DISALLOW_IMPLICIT_CONSTRUCTORS(Subtractor);
};

View File

@ -0,0 +1,41 @@
/*
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#include "modules/audio_processing/aec3/subtractor_output.h"
#include <numeric>
namespace webrtc {
SubtractorOutput::SubtractorOutput() = default;
SubtractorOutput::~SubtractorOutput() = default;
void SubtractorOutput::Reset() {
s_main.fill(0.f);
e_main.fill(0.f);
e_shadow.fill(0.f);
E_main.re.fill(0.f);
E_main.im.fill(0.f);
E2_main.fill(0.f);
E2_shadow.fill(0.f);
e2_main = 0.f;
e2_shadow = 0.f;
y2 = 0.f;
}
void SubtractorOutput::UpdatePowers(rtc::ArrayView<const float> y) {
const auto sum_of_squares = [](float a, float b) { return a + b * b; };
y2 = std::accumulate(y.begin(), y.end(), 0.f, sum_of_squares);
e2_main = std::accumulate(e_main.begin(), e_main.end(), 0.f, sum_of_squares);
e2_shadow =
std::accumulate(e_shadow.begin(), e_shadow.end(), 0.f, sum_of_squares);
}
} // namespace webrtc

View File

@ -13,6 +13,7 @@
#include <array>
#include "api/array_view.h"
#include "modules/audio_processing/aec3/aec3_common.h"
#include "modules/audio_processing/aec3/fft_data.h"
@ -20,22 +21,24 @@ namespace webrtc {
// Stores the values being returned from the echo subtractor.
struct SubtractorOutput {
SubtractorOutput();
~SubtractorOutput();
std::array<float, kBlockSize> s_main;
std::array<float, kBlockSize> e_main;
std::array<float, kBlockSize> e_shadow;
FftData E_main;
std::array<float, kFftLengthBy2Plus1> E2_main;
std::array<float, kFftLengthBy2Plus1> E2_shadow;
float e2_main = 0.f;
float e2_shadow = 0.f;
float y2 = 0.f;
void Reset() {
s_main.fill(0.f);
e_main.fill(0.f);
e_shadow.fill(0.f);
E_main.re.fill(0.f);
E_main.im.fill(0.f);
E2_main.fill(0.f);
E2_shadow.fill(0.f);
}
// Reset the struct content.
void Reset();
// Updates the powers of the signals.
void UpdatePowers(rtc::ArrayView<const float> y);
};
} // namespace webrtc

View File

@ -16,17 +16,10 @@
namespace webrtc {
void SubtractorOutputAnalyzer::Update(
rtc::ArrayView<const float> y,
const SubtractorOutput& subtractor_output) {
const auto& e_main = subtractor_output.e_main;
const auto& e_shadow = subtractor_output.e_shadow;
const auto sum_of_squares = [](float a, float b) { return a + b * b; };
const float y2 = std::accumulate(y.begin(), y.end(), 0.f, sum_of_squares);
const float e2_main =
std::accumulate(e_main.begin(), e_main.end(), 0.f, sum_of_squares);
const float e2_shadow =
std::accumulate(e_shadow.begin(), e_shadow.end(), 0.f, sum_of_squares);
const float y2 = subtractor_output.y2;
const float e2_main = subtractor_output.e2_main;
const float e2_shadow = subtractor_output.e2_shadow;
constexpr float kConvergenceThreshold = 50 * 50 * kBlockSize;
main_filter_converged_ = e2_main < 0.5f * y2 && y2 > kConvergenceThreshold;

View File

@ -23,8 +23,7 @@ class SubtractorOutputAnalyzer {
~SubtractorOutputAnalyzer() = default;
// Analyses the subtractor output.
void Update(rtc::ArrayView<const float> y,
const SubtractorOutput& subtractor_output);
void Update(const SubtractorOutput& subtractor_output);
bool ConvergedFilter() const {
return main_filter_converged_ || shadow_filter_converged_;