AGC2 AdaptiveModeLevelEstimator min consecutive speech frames (3/3)

This is the last CL needed to add a new `AdaptiveModeLevelEstimator`
feature that makes AGC2 more robus to VAD mistakes: the level estimator
discards estimation updates when too few consecutive speech frames are
observed.

This CL adds a second state property to hold temporary updates and a
counter for consecutive speech frames. When enough speech frames are
observed, the reliable state is updated; otherwise, the temporary state
is discarded.

The default for `AdaptiveModeLevelEstimator::min_consecutive_speech_frames_`
is 1, which means that the new feature is disabled.

Tested:
- Bit-exactness verified with audioproc_f
- Not bit-exact if `min_consecutive_speech_frames_` set to 10

Bug: webrtc:7494
No-Try: True
Change-Id: I0daa00e90c27c418c00baec39fb8eacd26eed858
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/185125
Commit-Queue: Alessio Bazzica <alessiob@webrtc.org>
Reviewed-by: Per Åhgren <peah@webrtc.org>
Cr-Commit-Position: refs/heads/master@{#32250}
This commit is contained in:
Alessio Bazzica
2020-09-30 13:19:22 +02:00
committed by Commit Bot
parent 77607f1e83
commit 8845f7e32b
4 changed files with 139 additions and 21 deletions

View File

@ -67,6 +67,7 @@ AdaptiveModeLevelEstimator::AdaptiveModeLevelEstimator(
: AdaptiveModeLevelEstimator(
apm_data_dumper,
AudioProcessing::Config::GainController2::LevelEstimator::kRms,
kDefaultAdjacentSpeechFramesThreshold,
kDefaultUseSaturationProtector,
kDefaultInitialSaturationMarginDb,
kDefaultExtraSaturationMarginDb) {}
@ -78,6 +79,7 @@ AdaptiveModeLevelEstimator::AdaptiveModeLevelEstimator(
float extra_saturation_margin_db)
: AdaptiveModeLevelEstimator(apm_data_dumper,
level_estimator,
kDefaultAdjacentSpeechFramesThreshold,
use_saturation_protector,
kDefaultInitialSaturationMarginDb,
extra_saturation_margin_db) {}
@ -85,11 +87,13 @@ AdaptiveModeLevelEstimator::AdaptiveModeLevelEstimator(
AdaptiveModeLevelEstimator::AdaptiveModeLevelEstimator(
ApmDataDumper* apm_data_dumper,
AudioProcessing::Config::GainController2::LevelEstimator level_estimator,
int adjacent_speech_frames_threshold,
bool use_saturation_protector,
float initial_saturation_margin_db,
float extra_saturation_margin_db)
: apm_data_dumper_(apm_data_dumper),
level_estimator_type_(level_estimator),
adjacent_speech_frames_threshold_(adjacent_speech_frames_threshold),
use_saturation_protector_(use_saturation_protector),
initial_saturation_margin_db_(initial_saturation_margin_db),
extra_saturation_margin_db_(extra_saturation_margin_db),
@ -98,6 +102,7 @@ AdaptiveModeLevelEstimator::AdaptiveModeLevelEstimator(
initial_saturation_margin_db_,
extra_saturation_margin_db_)) {
RTC_DCHECK(apm_data_dumper_);
RTC_DCHECK_GE(adjacent_speech_frames_threshold_, 1);
Reset();
}
@ -112,47 +117,83 @@ void AdaptiveModeLevelEstimator::Update(
DumpDebugData();
if (vad_level.speech_probability < kVadConfidenceThreshold) {
// Not a speech frame.
if (adjacent_speech_frames_threshold_ > 1) {
// When two or more adjacent speech frames are required in order to update
// the state, we need to decide whether to discard or confirm the updates
// based on the speech sequence length.
if (num_adjacent_speech_frames_ >= adjacent_speech_frames_threshold_) {
// First non-speech frame after a long enough sequence of speech frames.
// Update the reliable state.
reliable_state_ = preliminary_state_;
} else if (num_adjacent_speech_frames_ > 0) {
// First non-speech frame after a too short sequence of speech frames.
// Reset to the last reliable state.
preliminary_state_ = reliable_state_;
}
}
num_adjacent_speech_frames_ = 0;
return;
}
// Update level estimate.
RTC_DCHECK_GE(state_.time_to_full_buffer_ms, 0);
const bool buffer_is_full = state_.time_to_full_buffer_ms == 0;
// Speech frame observed.
num_adjacent_speech_frames_++;
// Update preliminary level estimate.
RTC_DCHECK_GE(preliminary_state_.time_to_full_buffer_ms, 0);
const bool buffer_is_full = preliminary_state_.time_to_full_buffer_ms == 0;
if (!buffer_is_full) {
state_.time_to_full_buffer_ms -= kFrameDurationMs;
preliminary_state_.time_to_full_buffer_ms -= kFrameDurationMs;
}
// Weighted average of levels with speech probability as weight.
RTC_DCHECK_GT(vad_level.speech_probability, 0.f);
const float leak_factor = buffer_is_full ? kFullBufferLeakFactor : 1.f;
state_.level_dbfs.numerator =
state_.level_dbfs.numerator * leak_factor +
preliminary_state_.level_dbfs.numerator =
preliminary_state_.level_dbfs.numerator * leak_factor +
GetLevel(vad_level, level_estimator_type_) * vad_level.speech_probability;
state_.level_dbfs.denominator = state_.level_dbfs.denominator * leak_factor +
vad_level.speech_probability;
preliminary_state_.level_dbfs.denominator =
preliminary_state_.level_dbfs.denominator * leak_factor +
vad_level.speech_probability;
const float level_dbfs = state_.level_dbfs.GetRatio();
const float level_dbfs = preliminary_state_.level_dbfs.GetRatio();
if (use_saturation_protector_) {
UpdateSaturationProtectorState(vad_level.peak_dbfs, level_dbfs,
state_.saturation_protector);
preliminary_state_.saturation_protector);
}
// Cache level estimation.
level_dbfs_ = ComputeLevelEstimateDbfs(level_dbfs, use_saturation_protector_,
state_.saturation_protector.margin_db,
extra_saturation_margin_db_);
if (num_adjacent_speech_frames_ >= adjacent_speech_frames_threshold_) {
// `preliminary_state_` is now reliable. Update the last level estimation.
level_dbfs_ = ComputeLevelEstimateDbfs(
level_dbfs, use_saturation_protector_,
preliminary_state_.saturation_protector.margin_db,
extra_saturation_margin_db_);
}
}
bool AdaptiveModeLevelEstimator::IsConfident() const {
// Returns true if enough speech frames have been observed.
return state_.time_to_full_buffer_ms == 0;
if (adjacent_speech_frames_threshold_ == 1) {
// Ignore `reliable_state_` when a single frame is enough to update the
// level estimate (because it is not used).
return preliminary_state_.time_to_full_buffer_ms == 0;
}
// Once confident, it remains confident.
RTC_DCHECK(reliable_state_.time_to_full_buffer_ms != 0 ||
preliminary_state_.time_to_full_buffer_ms == 0);
// During the first long enough speech sequence, `reliable_state_` must be
// ignored since `preliminary_state_` is used.
return reliable_state_.time_to_full_buffer_ms == 0 ||
(num_adjacent_speech_frames_ >= adjacent_speech_frames_threshold_ &&
preliminary_state_.time_to_full_buffer_ms == 0);
}
void AdaptiveModeLevelEstimator::Reset() {
ResetLevelEstimatorState(state_);
ResetLevelEstimatorState(preliminary_state_);
ResetLevelEstimatorState(reliable_state_);
level_dbfs_ = ComputeLevelEstimateDbfs(
kInitialSpeechLevelEstimateDbfs, use_saturation_protector_,
initial_saturation_margin_db_, extra_saturation_margin_db_);
num_adjacent_speech_frames_ = 0;
}
void AdaptiveModeLevelEstimator::ResetLevelEstimatorState(
@ -166,8 +207,14 @@ void AdaptiveModeLevelEstimator::ResetLevelEstimatorState(
void AdaptiveModeLevelEstimator::DumpDebugData() const {
apm_data_dumper_->DumpRaw("agc2_adaptive_level_estimate_dbfs", level_dbfs_);
apm_data_dumper_->DumpRaw("agc2_adaptive_saturation_margin_db",
state_.saturation_protector.margin_db);
apm_data_dumper_->DumpRaw("agc2_adaptive_num_adjacent_speech_frames_",
num_adjacent_speech_frames_);
apm_data_dumper_->DumpRaw("agc2_adaptive_preliminary_level_estimate_num",
preliminary_state_.level_dbfs.numerator);
apm_data_dumper_->DumpRaw("agc2_adaptive_preliminary_level_estimate_den",
preliminary_state_.level_dbfs.denominator);
apm_data_dumper_->DumpRaw("agc2_adaptive_preliminary_saturation_margin_db",
preliminary_state_.saturation_protector.margin_db);
}
} // namespace webrtc

View File

@ -12,6 +12,7 @@
#define MODULES_AUDIO_PROCESSING_AGC2_ADAPTIVE_MODE_LEVEL_ESTIMATOR_H_
#include <stddef.h>
#include <type_traits>
#include "modules/audio_processing/agc2/agc2_common.h"
#include "modules/audio_processing/agc2/saturation_protector.h"
@ -38,6 +39,7 @@ class AdaptiveModeLevelEstimator {
AdaptiveModeLevelEstimator(
ApmDataDumper* apm_data_dumper,
AudioProcessing::Config::GainController2::LevelEstimator level_estimator,
int adjacent_speech_frames_threshold,
bool use_saturation_protector,
float initial_saturation_margin_db,
float extra_saturation_margin_db);
@ -63,10 +65,12 @@ class AdaptiveModeLevelEstimator {
float denominator;
float GetRatio() const;
};
// TODO(crbug.com/webrtc/7494): Remove if saturation protector always used.
int time_to_full_buffer_ms;
Ratio level_dbfs;
SaturationProtectorState saturation_protector;
};
static_assert(std::is_trivially_copyable<LevelEstimatorState>::value, "");
void ResetLevelEstimatorState(LevelEstimatorState& state) const;
@ -76,12 +80,14 @@ class AdaptiveModeLevelEstimator {
const AudioProcessing::Config::GainController2::LevelEstimator
level_estimator_type_;
const int adjacent_speech_frames_threshold_;
const bool use_saturation_protector_;
const float initial_saturation_margin_db_;
const float extra_saturation_margin_db_;
// TODO(crbug.com/webrtc/7494): Add temporary state.
LevelEstimatorState state_;
LevelEstimatorState preliminary_state_;
LevelEstimatorState reliable_state_;
float level_dbfs_;
int num_adjacent_speech_frames_;
};
} // namespace webrtc

View File

@ -22,6 +22,16 @@ namespace {
constexpr float kInitialSaturationMarginDb = 20.f;
constexpr float kExtraSaturationMarginDb = 2.f;
static_assert(kInitialSpeechLevelEstimateDbfs < 0.f, "");
constexpr float kVadLevelRms = kInitialSpeechLevelEstimateDbfs / 2.f;
constexpr float kVadLevelPeak = kInitialSpeechLevelEstimateDbfs / 3.f;
constexpr VadLevelAnalyzer::Result kVadDataSpeech{/*speech_probability=*/1.f,
kVadLevelRms, kVadLevelPeak};
constexpr VadLevelAnalyzer::Result kVadDataNonSpeech{
/*speech_probability=*/kVadConfidenceThreshold / 2.f, kVadLevelRms,
kVadLevelPeak};
constexpr float kMinSpeechProbability = 0.f;
constexpr float kMaxSpeechProbability = 1.f;
@ -39,6 +49,7 @@ struct TestLevelEstimator {
estimator(std::make_unique<AdaptiveModeLevelEstimator>(
&data_dumper,
AudioProcessing::Config::GainController2::LevelEstimator::kRms,
/*min_consecutive_speech_frames=*/1,
/*use_saturation_protector=*/true,
kInitialSaturationMarginDb,
kExtraSaturationMarginDb)) {}
@ -116,6 +127,7 @@ TEST(AutomaticGainController2AdaptiveModeLevelEstimator, TimeToAdapt) {
// adapt.
constexpr float kDifferentSpeechRmsDbfs = -10.f;
// It should at most differ by 25% after one half 'window size' interval.
// TODO(crbug.com/webrtc/7494): Add constexpr for repeated expressions.
const float kMaxDifferenceDb =
0.25f * std::abs(kDifferentSpeechRmsDbfs - kInitialSpeechRmsDbfs);
RunOnConstantLevel(
@ -178,5 +190,57 @@ TEST(AutomaticGainController2AdaptiveModeLevelEstimator,
kMaxDifferenceDb);
}
struct TestConfig {
int min_consecutive_speech_frames;
bool use_saturation_protector;
float initial_saturation_margin_db;
float extra_saturation_margin_db;
};
class AdaptiveModeLevelEstimatorTest
: public ::testing::TestWithParam<TestConfig> {};
TEST_P(AdaptiveModeLevelEstimatorTest, DoNotAdaptToShortSpeechSegments) {
const auto params = GetParam();
ApmDataDumper apm_data_dumper(0);
AdaptiveModeLevelEstimator level_estimator(
&apm_data_dumper,
AudioProcessing::Config::GainController2::LevelEstimator::kRms,
params.min_consecutive_speech_frames, params.use_saturation_protector,
params.initial_saturation_margin_db, params.extra_saturation_margin_db);
const float initial_level = level_estimator.level_dbfs();
ASSERT_LT(initial_level, kVadDataSpeech.rms_dbfs);
for (int i = 0; i < params.min_consecutive_speech_frames - 1; ++i) {
SCOPED_TRACE(i);
level_estimator.Update(kVadDataSpeech);
EXPECT_EQ(initial_level, level_estimator.level_dbfs());
}
level_estimator.Update(kVadDataNonSpeech);
EXPECT_EQ(initial_level, level_estimator.level_dbfs());
}
TEST_P(AdaptiveModeLevelEstimatorTest, AdaptToEnoughSpeechSegments) {
const auto params = GetParam();
ApmDataDumper apm_data_dumper(0);
AdaptiveModeLevelEstimator level_estimator(
&apm_data_dumper,
AudioProcessing::Config::GainController2::LevelEstimator::kRms,
params.min_consecutive_speech_frames, params.use_saturation_protector,
params.initial_saturation_margin_db, params.extra_saturation_margin_db);
const float initial_level = level_estimator.level_dbfs();
ASSERT_LT(initial_level, kVadDataSpeech.rms_dbfs);
for (int i = 0; i < params.min_consecutive_speech_frames; ++i) {
level_estimator.Update(kVadDataSpeech);
}
EXPECT_LT(initial_level, level_estimator.level_dbfs());
}
INSTANTIATE_TEST_SUITE_P(AutomaticGainController2,
AdaptiveModeLevelEstimatorTest,
::testing::Values(TestConfig{1, false, 0.f, 0.f},
TestConfig{1, true, 0.f, 0.f},
TestConfig{9, false, 0.f, 0.f},
TestConfig{9, true, 0.f, 0.f}));
} // namespace
} // namespace webrtc

View File

@ -51,6 +51,7 @@ constexpr float kInitialSpeechLevelEstimateDbfs = -30.f;
// Robust VAD probability and speech decisions.
constexpr float kDefaultSmoothedVadProbabilityAttack = 1.f;
constexpr int kDefaultAdjacentSpeechFramesThreshold = 1;
// Saturation Protector settings.
constexpr bool kDefaultUseSaturationProtector = true;