HMM based transparent mode classifier

This change introduces a new Hidden Markov Model based classifier for
AEC3's 'transparent mode'. Transparent mode is used with
headsets/headphones where the speaker signal does not leak into the
microphone signal.

The current classifier suffers from two problems:
1. It sometimes takes a long time to enter transparent mode.
2. Sometimes transparent mode is left (and it once again takes a long
time to re-enter).

Both problems have a severe effect on AEC transparency.

The new classifier enters transparent mode quicker and is less likely
to exit transparent mode when there is no echo. This improves the
audio experience when using headset/headphones.

Another (minor) benefit of this change is that when transparent mode
is disabled no classifier is run (or even created) saving some memory
and CPU cycles.

Bug: webrtc:10232
Change-Id: I509af0e22b59463aeaead53c78c35be1e97fe8c3
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/184500
Commit-Queue: Gustaf Ullberg <gustaf@webrtc.org>
Reviewed-by: Per Åhgren <peah@webrtc.org>
Cr-Commit-Position: refs/heads/master@{#32182}
This commit is contained in:
Gustaf Ullberg
2020-09-24 09:21:49 +02:00
committed by Commit Bot
parent 936f1af3bb
commit afef7a74a7
6 changed files with 309 additions and 142 deletions

View File

@ -113,6 +113,8 @@ rtc_library("aec3") {
"suppression_filter.h",
"suppression_gain.cc",
"suppression_gain.h",
"transparent_mode.cc",
"transparent_mode.h",
]
defines = []

View File

@ -27,13 +27,6 @@
namespace webrtc {
namespace {
constexpr size_t kBlocksSinceConvergencedFilterInit = 10000;
constexpr size_t kBlocksSinceConsistentEstimateInit = 10000;
bool DeactivateTransparentMode() {
return field_trial::IsEnabled("WebRTC-Aec3TransparentModeKillSwitch");
}
bool DeactivateInitialStateResetAtEchoPathChange() {
return field_trial::IsEnabled(
"WebRTC-Aec3DeactivateInitialStateResetKillSwitch");
@ -134,7 +127,6 @@ AecState::AecState(const EchoCanceller3Config& config,
new ApmDataDumper(rtc::AtomicOps::Increment(&instance_count_))),
config_(config),
num_capture_channels_(num_capture_channels),
transparent_mode_activated_(!DeactivateTransparentMode()),
deactivate_initial_state_reset_at_echo_path_change_(
DeactivateInitialStateResetAtEchoPathChange()),
full_reset_at_echo_path_change_(FullResetAtEchoPathChange()),
@ -142,7 +134,7 @@ AecState::AecState(const EchoCanceller3Config& config,
SubtractorAnalyzerResetAtEchoPathChange()),
initial_state_(config_),
delay_state_(config_, num_capture_channels_),
transparent_state_(config_),
transparent_state_(TransparentMode::Create(config_)),
filter_quality_state_(config_, num_capture_channels_),
erl_estimator_(2 * kNumBlocksPerSecond),
erle_estimator_(2 * kNumBlocksPerSecond, config_, num_capture_channels_),
@ -164,7 +156,9 @@ void AecState::HandleEchoPathChange(
if (!deactivate_initial_state_reset_at_echo_path_change_) {
initial_state_.Reset();
}
transparent_state_.Reset();
if (transparent_state_) {
transparent_state_->Reset();
}
erle_estimator_.Reset(true);
erl_estimator_.Reset();
filter_quality_state_.Reset();
@ -277,13 +271,15 @@ void AecState::Update(
initial_state_.Update(active_render, SaturatedCapture());
// Detect whether the transparent mode should be activated.
transparent_state_.Update(delay_state_.MinDirectPathFilterDelay(),
any_filter_consistent, any_filter_converged,
all_filters_diverged, active_render,
SaturatedCapture());
if (transparent_state_) {
transparent_state_->Update(delay_state_.MinDirectPathFilterDelay(),
any_filter_consistent, any_filter_converged,
all_filters_diverged, active_render,
SaturatedCapture());
}
// Analyze the quality of the filter.
filter_quality_state_.Update(active_render, TransparentMode(),
filter_quality_state_.Update(active_render, TransparentModeActive(),
SaturatedCapture(), external_delay,
any_filter_converged);
@ -301,11 +297,12 @@ void AecState::Update(
erle_estimator_.Dump(data_dumper_);
reverb_model_estimator_.Dump(data_dumper_.get());
data_dumper_->DumpRaw("aec3_active_render", active_render);
data_dumper_->DumpRaw("aec3_erl", Erl());
data_dumper_->DumpRaw("aec3_erl_time_domain", ErlTimeDomain());
data_dumper_->DumpRaw("aec3_erle", Erle()[0]);
data_dumper_->DumpRaw("aec3_usable_linear_estimate", UsableLinearEstimate());
data_dumper_->DumpRaw("aec3_transparent_mode", TransparentMode());
data_dumper_->DumpRaw("aec3_transparent_mode", TransparentModeActive());
data_dumper_->DumpRaw("aec3_filter_delay",
filter_analyzer_.MinFilterDelayBlocks());
@ -387,92 +384,6 @@ void AecState::FilterDelay::Update(
filter_delays_blocks_.end());
}
AecState::TransparentMode::TransparentMode(const EchoCanceller3Config& config)
: bounded_erl_(config.ep_strength.bounded_erl),
linear_and_stable_echo_path_(
config.echo_removal_control.linear_and_stable_echo_path),
active_blocks_since_sane_filter_(kBlocksSinceConsistentEstimateInit),
non_converged_sequence_size_(kBlocksSinceConvergencedFilterInit) {}
void AecState::TransparentMode::Reset() {
non_converged_sequence_size_ = kBlocksSinceConvergencedFilterInit;
diverged_sequence_size_ = 0;
strong_not_saturated_render_blocks_ = 0;
if (linear_and_stable_echo_path_) {
recent_convergence_during_activity_ = false;
}
}
void AecState::TransparentMode::Update(int filter_delay_blocks,
bool any_filter_consistent,
bool any_filter_converged,
bool all_filters_diverged,
bool active_render,
bool saturated_capture) {
++capture_block_counter_;
strong_not_saturated_render_blocks_ +=
active_render && !saturated_capture ? 1 : 0;
if (any_filter_consistent && filter_delay_blocks < 5) {
sane_filter_observed_ = true;
active_blocks_since_sane_filter_ = 0;
} else if (active_render) {
++active_blocks_since_sane_filter_;
}
bool sane_filter_recently_seen;
if (!sane_filter_observed_) {
sane_filter_recently_seen =
capture_block_counter_ <= 5 * kNumBlocksPerSecond;
} else {
sane_filter_recently_seen =
active_blocks_since_sane_filter_ <= 30 * kNumBlocksPerSecond;
}
if (any_filter_converged) {
recent_convergence_during_activity_ = true;
active_non_converged_sequence_size_ = 0;
non_converged_sequence_size_ = 0;
++num_converged_blocks_;
} else {
if (++non_converged_sequence_size_ > 20 * kNumBlocksPerSecond) {
num_converged_blocks_ = 0;
}
if (active_render &&
++active_non_converged_sequence_size_ > 60 * kNumBlocksPerSecond) {
recent_convergence_during_activity_ = false;
}
}
if (!all_filters_diverged) {
diverged_sequence_size_ = 0;
} else if (++diverged_sequence_size_ >= 60) {
// TODO(peah): Change these lines to ensure proper triggering of usable
// filter.
non_converged_sequence_size_ = kBlocksSinceConvergencedFilterInit;
}
if (active_non_converged_sequence_size_ > 60 * kNumBlocksPerSecond) {
finite_erl_recently_detected_ = false;
}
if (num_converged_blocks_ > 50) {
finite_erl_recently_detected_ = true;
}
if (bounded_erl_) {
transparency_activated_ = false;
} else if (finite_erl_recently_detected_) {
transparency_activated_ = false;
} else if (sane_filter_recently_seen && recent_convergence_during_activity_) {
transparency_activated_ = false;
} else {
const bool filter_should_have_converged =
strong_not_saturated_render_blocks_ > 6 * kNumBlocksPerSecond;
transparency_activated_ = filter_should_have_converged;
}
}
AecState::FilteringQualityAnalyzer::FilteringQualityAnalyzer(
const EchoCanceller3Config& config,
size_t num_capture_channels)

View File

@ -31,6 +31,7 @@
#include "modules/audio_processing/aec3/reverb_model_estimator.h"
#include "modules/audio_processing/aec3/subtractor_output.h"
#include "modules/audio_processing/aec3/subtractor_output_analyzer.h"
#include "modules/audio_processing/aec3/transparent_mode.h"
namespace webrtc {
@ -107,8 +108,8 @@ class AecState {
}
// Returns whether the transparent mode is active
bool TransparentMode() const {
return transparent_mode_activated_ && transparent_state_.Active();
bool TransparentModeActive() const {
return transparent_state_ && transparent_state_->Active();
}
// Takes appropriate action at an echo path change.
@ -152,7 +153,6 @@ class AecState {
std::unique_ptr<ApmDataDumper> data_dumper_;
const EchoCanceller3Config config_;
const size_t num_capture_channels_;
const bool transparent_mode_activated_;
const bool deactivate_initial_state_reset_at_echo_path_change_;
const bool full_reset_at_echo_path_change_;
const bool subtractor_analyzer_reset_at_echo_path_change_;
@ -218,41 +218,8 @@ class AecState {
absl::optional<DelayEstimate> external_delay_;
} delay_state_;
// Class for detecting and toggling the transparent mode which causes the
// suppressor to apply no suppression.
class TransparentMode {
public:
explicit TransparentMode(const EchoCanceller3Config& config);
// Returns whether the transparent mode should be active.
bool Active() const { return transparency_activated_; }
// Resets the state of the detector.
void Reset();
// Updates the detection deciscion based on new data.
void Update(int filter_delay_blocks,
bool any_filter_consistent,
bool any_filter_converged,
bool all_filters_diverged,
bool active_render,
bool saturated_capture);
private:
const bool bounded_erl_;
const bool linear_and_stable_echo_path_;
size_t capture_block_counter_ = 0;
bool transparency_activated_ = false;
size_t active_blocks_since_sane_filter_;
bool sane_filter_observed_ = false;
bool finite_erl_recently_detected_ = false;
size_t non_converged_sequence_size_;
size_t diverged_sequence_size_ = 0;
size_t active_non_converged_sequence_size_ = 0;
size_t num_converged_blocks_ = 0;
bool recent_convergence_during_activity_ = false;
size_t strong_not_saturated_render_blocks_ = 0;
} transparent_state_;
// Classifier for toggling transparent mode when there is no echo.
std::unique_ptr<TransparentMode> transparent_state_;
// Class for analyzing how well the linear filter is, and can be expected to,
// perform on the current signals. The purpose of this is for using to

View File

@ -277,7 +277,7 @@ void ResidualEchoEstimator::Estimate(
NonLinearEstimate(echo_path_gain, X2, R2);
}
if (model_reverb_in_nonlinear_mode_ && !aec_state.TransparentMode()) {
if (model_reverb_in_nonlinear_mode_ && !aec_state.TransparentModeActive()) {
AddReverb(ReverbType::kNonLinear, aec_state, render_buffer, R2);
}
}
@ -395,7 +395,7 @@ float ResidualEchoEstimator::GetEchoPathGain(
const AecState& aec_state,
bool gain_for_early_reflections) const {
float gain_amplitude;
if (aec_state.TransparentMode()) {
if (aec_state.TransparentModeActive()) {
gain_amplitude = gain_for_early_reflections
? early_reflections_transparent_mode_gain_
: late_reflections_transparent_mode_gain_;

View File

@ -0,0 +1,241 @@
/*
* Copyright (c) 2020 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#include "modules/audio_processing/aec3/transparent_mode.h"
#include "rtc_base/checks.h"
#include "system_wrappers/include/field_trial.h"
namespace webrtc {
namespace {
constexpr size_t kBlocksSinceConvergencedFilterInit = 10000;
constexpr size_t kBlocksSinceConsistentEstimateInit = 10000;
bool DeactivateTransparentMode() {
return field_trial::IsEnabled("WebRTC-Aec3TransparentModeKillSwitch");
}
bool DeactivateTransparentModeHmm() {
return field_trial::IsEnabled("WebRTC-Aec3TransparentModeHmmKillSwitch");
}
} // namespace
// Classifier that toggles transparent mode which reduces echo suppression when
// headsets are used.
class TransparentModeImpl : public TransparentMode {
public:
bool Active() const override { return transparency_activated_; }
void Reset() override {
// Determines if transparent mode is used.
transparency_activated_ = false;
// The estimated probability of being transparent mode.
prob_transparent_state_ = 0.f;
}
void Update(int filter_delay_blocks,
bool any_filter_consistent,
bool any_filter_converged,
bool all_filters_diverged,
bool active_render,
bool saturated_capture) override {
// The classifier is implemented as a Hidden Markov Model (HMM) with two
// hidden states: "normal" and "transparent". The estimated probabilities of
// the two states are updated by observing filter convergence during active
// render. The filters are less likely to be reported as converged when
// there is no echo present in the microphone signal.
// The constants have been obtained by observing active_render and
// any_filter_converged under varying call scenarios. They have further been
// hand tuned to prefer normal state during uncertain regions (to avoid echo
// leaks).
// The model is only updated during active render.
if (!active_render)
return;
// Probability of switching from one state to the other.
constexpr float kSwitch = 0.000001f;
// Probability of observing converged filters in states "normal" and
// "transparent" during active render.
constexpr float kConvergedNormal = 0.03f;
constexpr float kConvergedTransparent = 0.005f;
// Probability of transitioning to transparent state from normal state and
// transparent state respectively.
constexpr float kA[2] = {kSwitch, 1.f - kSwitch};
// Probability of the two observations (converged filter or not converged
// filter) in normal state and transparent state respectively.
constexpr float kB[2][2] = {
{1.f - kConvergedNormal, kConvergedNormal},
{1.f - kConvergedTransparent, kConvergedTransparent}};
// Probability of the two states before the update.
const float prob_transparent = prob_transparent_state_;
const float prob_normal = 1.f - prob_transparent;
// Probability of transitioning to transparent state.
const float prob_transition_transparent =
prob_normal * kA[0] + prob_transparent * kA[1];
const float prob_transition_normal = 1.f - prob_transition_transparent;
// Observed output.
const int out = any_filter_converged;
// Joint probabilites of the observed output and respective states.
const float prob_joint_normal = prob_transition_normal * kB[0][out];
const float prob_joint_transparent =
prob_transition_transparent * kB[1][out];
// Conditional probability of transparent state and the observed output.
RTC_DCHECK_GT(prob_joint_normal + prob_joint_transparent, 0.f);
prob_transparent_state_ =
prob_joint_transparent / (prob_joint_normal + prob_joint_transparent);
// Transparent mode is only activated when its state probability is high.
// Dead zone between activation/deactivation thresholds to avoid switching
// back and forth.
if (prob_transparent_state_ > 0.95f) {
transparency_activated_ = true;
} else if (prob_transparent_state_ < 0.5f) {
transparency_activated_ = false;
}
}
private:
bool transparency_activated_ = false;
float prob_transparent_state_ = 0.f;
};
// Legacy classifier for toggling transparent mode.
class LegacyTransparentModeImpl : public TransparentMode {
public:
explicit LegacyTransparentModeImpl(const EchoCanceller3Config& config)
: bounded_erl_(config.ep_strength.bounded_erl),
linear_and_stable_echo_path_(
config.echo_removal_control.linear_and_stable_echo_path),
active_blocks_since_sane_filter_(kBlocksSinceConsistentEstimateInit),
non_converged_sequence_size_(kBlocksSinceConvergencedFilterInit) {}
bool Active() const override { return transparency_activated_; }
void Reset() override {
non_converged_sequence_size_ = kBlocksSinceConvergencedFilterInit;
diverged_sequence_size_ = 0;
strong_not_saturated_render_blocks_ = 0;
if (linear_and_stable_echo_path_) {
recent_convergence_during_activity_ = false;
}
}
void Update(int filter_delay_blocks,
bool any_filter_consistent,
bool any_filter_converged,
bool all_filters_diverged,
bool active_render,
bool saturated_capture) override {
++capture_block_counter_;
strong_not_saturated_render_blocks_ +=
active_render && !saturated_capture ? 1 : 0;
if (any_filter_consistent && filter_delay_blocks < 5) {
sane_filter_observed_ = true;
active_blocks_since_sane_filter_ = 0;
} else if (active_render) {
++active_blocks_since_sane_filter_;
}
bool sane_filter_recently_seen;
if (!sane_filter_observed_) {
sane_filter_recently_seen =
capture_block_counter_ <= 5 * kNumBlocksPerSecond;
} else {
sane_filter_recently_seen =
active_blocks_since_sane_filter_ <= 30 * kNumBlocksPerSecond;
}
if (any_filter_converged) {
recent_convergence_during_activity_ = true;
active_non_converged_sequence_size_ = 0;
non_converged_sequence_size_ = 0;
++num_converged_blocks_;
} else {
if (++non_converged_sequence_size_ > 20 * kNumBlocksPerSecond) {
num_converged_blocks_ = 0;
}
if (active_render &&
++active_non_converged_sequence_size_ > 60 * kNumBlocksPerSecond) {
recent_convergence_during_activity_ = false;
}
}
if (!all_filters_diverged) {
diverged_sequence_size_ = 0;
} else if (++diverged_sequence_size_ >= 60) {
// TODO(peah): Change these lines to ensure proper triggering of usable
// filter.
non_converged_sequence_size_ = kBlocksSinceConvergencedFilterInit;
}
if (active_non_converged_sequence_size_ > 60 * kNumBlocksPerSecond) {
finite_erl_recently_detected_ = false;
}
if (num_converged_blocks_ > 50) {
finite_erl_recently_detected_ = true;
}
if (bounded_erl_) {
transparency_activated_ = false;
} else if (finite_erl_recently_detected_) {
transparency_activated_ = false;
} else if (sane_filter_recently_seen &&
recent_convergence_during_activity_) {
transparency_activated_ = false;
} else {
const bool filter_should_have_converged =
strong_not_saturated_render_blocks_ > 6 * kNumBlocksPerSecond;
transparency_activated_ = filter_should_have_converged;
}
}
private:
const bool bounded_erl_;
const bool linear_and_stable_echo_path_;
size_t capture_block_counter_ = 0;
bool transparency_activated_ = false;
size_t active_blocks_since_sane_filter_;
bool sane_filter_observed_ = false;
bool finite_erl_recently_detected_ = false;
size_t non_converged_sequence_size_;
size_t diverged_sequence_size_ = 0;
size_t active_non_converged_sequence_size_ = 0;
size_t num_converged_blocks_ = 0;
bool recent_convergence_during_activity_ = false;
size_t strong_not_saturated_render_blocks_ = 0;
};
std::unique_ptr<TransparentMode> TransparentMode::Create(
const EchoCanceller3Config& config) {
if (DeactivateTransparentMode()) {
return nullptr;
}
if (DeactivateTransparentModeHmm()) {
return std::make_unique<LegacyTransparentModeImpl>(config);
}
return std::make_unique<TransparentModeImpl>();
}
} // namespace webrtc

View File

@ -0,0 +1,46 @@
/*
* Copyright (c) 2020 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#ifndef MODULES_AUDIO_PROCESSING_AEC3_TRANSPARENT_MODE_H_
#define MODULES_AUDIO_PROCESSING_AEC3_TRANSPARENT_MODE_H_
#include <memory>
#include "api/audio/echo_canceller3_config.h"
#include "modules/audio_processing/aec3/aec3_common.h"
namespace webrtc {
// Class for detecting and toggling the transparent mode which causes the
// suppressor to apply less suppression.
class TransparentMode {
public:
static std::unique_ptr<TransparentMode> Create(
const EchoCanceller3Config& config);
virtual ~TransparentMode() {}
// Returns whether the transparent mode should be active.
virtual bool Active() const = 0;
// Resets the state of the detector.
virtual void Reset() = 0;
// Updates the detection decision based on new data.
virtual void Update(int filter_delay_blocks,
bool any_filter_consistent,
bool any_filter_converged,
bool all_filters_diverged,
bool active_render,
bool saturated_capture) = 0;
};
} // namespace webrtc
#endif // MODULES_AUDIO_PROCESSING_AEC3_TRANSPARENT_MODE_H_