From 8845f7e32bdea16fee2e00348055dda5639c41d1 Mon Sep 17 00:00:00 2001 From: Alessio Bazzica Date: Wed, 30 Sep 2020 13:19:22 +0200 Subject: [PATCH] AGC2 AdaptiveModeLevelEstimator min consecutive speech frames (3/3) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This is the last CL needed to add a new `AdaptiveModeLevelEstimator` feature that makes AGC2 more robus to VAD mistakes: the level estimator discards estimation updates when too few consecutive speech frames are observed. This CL adds a second state property to hold temporary updates and a counter for consecutive speech frames. When enough speech frames are observed, the reliable state is updated; otherwise, the temporary state is discarded. The default for `AdaptiveModeLevelEstimator::min_consecutive_speech_frames_` is 1, which means that the new feature is disabled. Tested: - Bit-exactness verified with audioproc_f - Not bit-exact if `min_consecutive_speech_frames_` set to 10 Bug: webrtc:7494 No-Try: True Change-Id: I0daa00e90c27c418c00baec39fb8eacd26eed858 Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/185125 Commit-Queue: Alessio Bazzica Reviewed-by: Per Ã…hgren Cr-Commit-Position: refs/heads/master@{#32250} --- .../agc2/adaptive_mode_level_estimator.cc | 85 ++++++++++++++----- .../agc2/adaptive_mode_level_estimator.h | 10 ++- .../adaptive_mode_level_estimator_unittest.cc | 64 ++++++++++++++ modules/audio_processing/agc2/agc2_common.h | 1 + 4 files changed, 139 insertions(+), 21 deletions(-) diff --git a/modules/audio_processing/agc2/adaptive_mode_level_estimator.cc b/modules/audio_processing/agc2/adaptive_mode_level_estimator.cc index 7517599e38..273b3fdb3f 100644 --- a/modules/audio_processing/agc2/adaptive_mode_level_estimator.cc +++ b/modules/audio_processing/agc2/adaptive_mode_level_estimator.cc @@ -67,6 +67,7 @@ AdaptiveModeLevelEstimator::AdaptiveModeLevelEstimator( : AdaptiveModeLevelEstimator( apm_data_dumper, AudioProcessing::Config::GainController2::LevelEstimator::kRms, + kDefaultAdjacentSpeechFramesThreshold, kDefaultUseSaturationProtector, kDefaultInitialSaturationMarginDb, kDefaultExtraSaturationMarginDb) {} @@ -78,6 +79,7 @@ AdaptiveModeLevelEstimator::AdaptiveModeLevelEstimator( float extra_saturation_margin_db) : AdaptiveModeLevelEstimator(apm_data_dumper, level_estimator, + kDefaultAdjacentSpeechFramesThreshold, use_saturation_protector, kDefaultInitialSaturationMarginDb, extra_saturation_margin_db) {} @@ -85,11 +87,13 @@ AdaptiveModeLevelEstimator::AdaptiveModeLevelEstimator( AdaptiveModeLevelEstimator::AdaptiveModeLevelEstimator( ApmDataDumper* apm_data_dumper, AudioProcessing::Config::GainController2::LevelEstimator level_estimator, + int adjacent_speech_frames_threshold, bool use_saturation_protector, float initial_saturation_margin_db, float extra_saturation_margin_db) : apm_data_dumper_(apm_data_dumper), level_estimator_type_(level_estimator), + adjacent_speech_frames_threshold_(adjacent_speech_frames_threshold), use_saturation_protector_(use_saturation_protector), initial_saturation_margin_db_(initial_saturation_margin_db), extra_saturation_margin_db_(extra_saturation_margin_db), @@ -98,6 +102,7 @@ AdaptiveModeLevelEstimator::AdaptiveModeLevelEstimator( initial_saturation_margin_db_, extra_saturation_margin_db_)) { RTC_DCHECK(apm_data_dumper_); + RTC_DCHECK_GE(adjacent_speech_frames_threshold_, 1); Reset(); } @@ -112,47 +117,83 @@ void AdaptiveModeLevelEstimator::Update( DumpDebugData(); if (vad_level.speech_probability < kVadConfidenceThreshold) { + // Not a speech frame. + if (adjacent_speech_frames_threshold_ > 1) { + // When two or more adjacent speech frames are required in order to update + // the state, we need to decide whether to discard or confirm the updates + // based on the speech sequence length. + if (num_adjacent_speech_frames_ >= adjacent_speech_frames_threshold_) { + // First non-speech frame after a long enough sequence of speech frames. + // Update the reliable state. + reliable_state_ = preliminary_state_; + } else if (num_adjacent_speech_frames_ > 0) { + // First non-speech frame after a too short sequence of speech frames. + // Reset to the last reliable state. + preliminary_state_ = reliable_state_; + } + } + num_adjacent_speech_frames_ = 0; return; } - // Update level estimate. - RTC_DCHECK_GE(state_.time_to_full_buffer_ms, 0); - const bool buffer_is_full = state_.time_to_full_buffer_ms == 0; + // Speech frame observed. + num_adjacent_speech_frames_++; + + // Update preliminary level estimate. + RTC_DCHECK_GE(preliminary_state_.time_to_full_buffer_ms, 0); + const bool buffer_is_full = preliminary_state_.time_to_full_buffer_ms == 0; if (!buffer_is_full) { - state_.time_to_full_buffer_ms -= kFrameDurationMs; + preliminary_state_.time_to_full_buffer_ms -= kFrameDurationMs; } // Weighted average of levels with speech probability as weight. RTC_DCHECK_GT(vad_level.speech_probability, 0.f); const float leak_factor = buffer_is_full ? kFullBufferLeakFactor : 1.f; - state_.level_dbfs.numerator = - state_.level_dbfs.numerator * leak_factor + + preliminary_state_.level_dbfs.numerator = + preliminary_state_.level_dbfs.numerator * leak_factor + GetLevel(vad_level, level_estimator_type_) * vad_level.speech_probability; - state_.level_dbfs.denominator = state_.level_dbfs.denominator * leak_factor + - vad_level.speech_probability; + preliminary_state_.level_dbfs.denominator = + preliminary_state_.level_dbfs.denominator * leak_factor + + vad_level.speech_probability; - const float level_dbfs = state_.level_dbfs.GetRatio(); + const float level_dbfs = preliminary_state_.level_dbfs.GetRatio(); if (use_saturation_protector_) { UpdateSaturationProtectorState(vad_level.peak_dbfs, level_dbfs, - state_.saturation_protector); + preliminary_state_.saturation_protector); } - // Cache level estimation. - level_dbfs_ = ComputeLevelEstimateDbfs(level_dbfs, use_saturation_protector_, - state_.saturation_protector.margin_db, - extra_saturation_margin_db_); + if (num_adjacent_speech_frames_ >= adjacent_speech_frames_threshold_) { + // `preliminary_state_` is now reliable. Update the last level estimation. + level_dbfs_ = ComputeLevelEstimateDbfs( + level_dbfs, use_saturation_protector_, + preliminary_state_.saturation_protector.margin_db, + extra_saturation_margin_db_); + } } bool AdaptiveModeLevelEstimator::IsConfident() const { - // Returns true if enough speech frames have been observed. - return state_.time_to_full_buffer_ms == 0; + if (adjacent_speech_frames_threshold_ == 1) { + // Ignore `reliable_state_` when a single frame is enough to update the + // level estimate (because it is not used). + return preliminary_state_.time_to_full_buffer_ms == 0; + } + // Once confident, it remains confident. + RTC_DCHECK(reliable_state_.time_to_full_buffer_ms != 0 || + preliminary_state_.time_to_full_buffer_ms == 0); + // During the first long enough speech sequence, `reliable_state_` must be + // ignored since `preliminary_state_` is used. + return reliable_state_.time_to_full_buffer_ms == 0 || + (num_adjacent_speech_frames_ >= adjacent_speech_frames_threshold_ && + preliminary_state_.time_to_full_buffer_ms == 0); } void AdaptiveModeLevelEstimator::Reset() { - ResetLevelEstimatorState(state_); + ResetLevelEstimatorState(preliminary_state_); + ResetLevelEstimatorState(reliable_state_); level_dbfs_ = ComputeLevelEstimateDbfs( kInitialSpeechLevelEstimateDbfs, use_saturation_protector_, initial_saturation_margin_db_, extra_saturation_margin_db_); + num_adjacent_speech_frames_ = 0; } void AdaptiveModeLevelEstimator::ResetLevelEstimatorState( @@ -166,8 +207,14 @@ void AdaptiveModeLevelEstimator::ResetLevelEstimatorState( void AdaptiveModeLevelEstimator::DumpDebugData() const { apm_data_dumper_->DumpRaw("agc2_adaptive_level_estimate_dbfs", level_dbfs_); - apm_data_dumper_->DumpRaw("agc2_adaptive_saturation_margin_db", - state_.saturation_protector.margin_db); + apm_data_dumper_->DumpRaw("agc2_adaptive_num_adjacent_speech_frames_", + num_adjacent_speech_frames_); + apm_data_dumper_->DumpRaw("agc2_adaptive_preliminary_level_estimate_num", + preliminary_state_.level_dbfs.numerator); + apm_data_dumper_->DumpRaw("agc2_adaptive_preliminary_level_estimate_den", + preliminary_state_.level_dbfs.denominator); + apm_data_dumper_->DumpRaw("agc2_adaptive_preliminary_saturation_margin_db", + preliminary_state_.saturation_protector.margin_db); } } // namespace webrtc diff --git a/modules/audio_processing/agc2/adaptive_mode_level_estimator.h b/modules/audio_processing/agc2/adaptive_mode_level_estimator.h index d5cf6f1087..99b9fc9c3c 100644 --- a/modules/audio_processing/agc2/adaptive_mode_level_estimator.h +++ b/modules/audio_processing/agc2/adaptive_mode_level_estimator.h @@ -12,6 +12,7 @@ #define MODULES_AUDIO_PROCESSING_AGC2_ADAPTIVE_MODE_LEVEL_ESTIMATOR_H_ #include +#include #include "modules/audio_processing/agc2/agc2_common.h" #include "modules/audio_processing/agc2/saturation_protector.h" @@ -38,6 +39,7 @@ class AdaptiveModeLevelEstimator { AdaptiveModeLevelEstimator( ApmDataDumper* apm_data_dumper, AudioProcessing::Config::GainController2::LevelEstimator level_estimator, + int adjacent_speech_frames_threshold, bool use_saturation_protector, float initial_saturation_margin_db, float extra_saturation_margin_db); @@ -63,10 +65,12 @@ class AdaptiveModeLevelEstimator { float denominator; float GetRatio() const; }; + // TODO(crbug.com/webrtc/7494): Remove if saturation protector always used. int time_to_full_buffer_ms; Ratio level_dbfs; SaturationProtectorState saturation_protector; }; + static_assert(std::is_trivially_copyable::value, ""); void ResetLevelEstimatorState(LevelEstimatorState& state) const; @@ -76,12 +80,14 @@ class AdaptiveModeLevelEstimator { const AudioProcessing::Config::GainController2::LevelEstimator level_estimator_type_; + const int adjacent_speech_frames_threshold_; const bool use_saturation_protector_; const float initial_saturation_margin_db_; const float extra_saturation_margin_db_; - // TODO(crbug.com/webrtc/7494): Add temporary state. - LevelEstimatorState state_; + LevelEstimatorState preliminary_state_; + LevelEstimatorState reliable_state_; float level_dbfs_; + int num_adjacent_speech_frames_; }; } // namespace webrtc diff --git a/modules/audio_processing/agc2/adaptive_mode_level_estimator_unittest.cc b/modules/audio_processing/agc2/adaptive_mode_level_estimator_unittest.cc index 01cc089315..0b5b2041e0 100644 --- a/modules/audio_processing/agc2/adaptive_mode_level_estimator_unittest.cc +++ b/modules/audio_processing/agc2/adaptive_mode_level_estimator_unittest.cc @@ -22,6 +22,16 @@ namespace { constexpr float kInitialSaturationMarginDb = 20.f; constexpr float kExtraSaturationMarginDb = 2.f; +static_assert(kInitialSpeechLevelEstimateDbfs < 0.f, ""); +constexpr float kVadLevelRms = kInitialSpeechLevelEstimateDbfs / 2.f; +constexpr float kVadLevelPeak = kInitialSpeechLevelEstimateDbfs / 3.f; + +constexpr VadLevelAnalyzer::Result kVadDataSpeech{/*speech_probability=*/1.f, + kVadLevelRms, kVadLevelPeak}; +constexpr VadLevelAnalyzer::Result kVadDataNonSpeech{ + /*speech_probability=*/kVadConfidenceThreshold / 2.f, kVadLevelRms, + kVadLevelPeak}; + constexpr float kMinSpeechProbability = 0.f; constexpr float kMaxSpeechProbability = 1.f; @@ -39,6 +49,7 @@ struct TestLevelEstimator { estimator(std::make_unique( &data_dumper, AudioProcessing::Config::GainController2::LevelEstimator::kRms, + /*min_consecutive_speech_frames=*/1, /*use_saturation_protector=*/true, kInitialSaturationMarginDb, kExtraSaturationMarginDb)) {} @@ -116,6 +127,7 @@ TEST(AutomaticGainController2AdaptiveModeLevelEstimator, TimeToAdapt) { // adapt. constexpr float kDifferentSpeechRmsDbfs = -10.f; // It should at most differ by 25% after one half 'window size' interval. + // TODO(crbug.com/webrtc/7494): Add constexpr for repeated expressions. const float kMaxDifferenceDb = 0.25f * std::abs(kDifferentSpeechRmsDbfs - kInitialSpeechRmsDbfs); RunOnConstantLevel( @@ -178,5 +190,57 @@ TEST(AutomaticGainController2AdaptiveModeLevelEstimator, kMaxDifferenceDb); } +struct TestConfig { + int min_consecutive_speech_frames; + bool use_saturation_protector; + float initial_saturation_margin_db; + float extra_saturation_margin_db; +}; + +class AdaptiveModeLevelEstimatorTest + : public ::testing::TestWithParam {}; + +TEST_P(AdaptiveModeLevelEstimatorTest, DoNotAdaptToShortSpeechSegments) { + const auto params = GetParam(); + ApmDataDumper apm_data_dumper(0); + AdaptiveModeLevelEstimator level_estimator( + &apm_data_dumper, + AudioProcessing::Config::GainController2::LevelEstimator::kRms, + params.min_consecutive_speech_frames, params.use_saturation_protector, + params.initial_saturation_margin_db, params.extra_saturation_margin_db); + const float initial_level = level_estimator.level_dbfs(); + ASSERT_LT(initial_level, kVadDataSpeech.rms_dbfs); + for (int i = 0; i < params.min_consecutive_speech_frames - 1; ++i) { + SCOPED_TRACE(i); + level_estimator.Update(kVadDataSpeech); + EXPECT_EQ(initial_level, level_estimator.level_dbfs()); + } + level_estimator.Update(kVadDataNonSpeech); + EXPECT_EQ(initial_level, level_estimator.level_dbfs()); +} + +TEST_P(AdaptiveModeLevelEstimatorTest, AdaptToEnoughSpeechSegments) { + const auto params = GetParam(); + ApmDataDumper apm_data_dumper(0); + AdaptiveModeLevelEstimator level_estimator( + &apm_data_dumper, + AudioProcessing::Config::GainController2::LevelEstimator::kRms, + params.min_consecutive_speech_frames, params.use_saturation_protector, + params.initial_saturation_margin_db, params.extra_saturation_margin_db); + const float initial_level = level_estimator.level_dbfs(); + ASSERT_LT(initial_level, kVadDataSpeech.rms_dbfs); + for (int i = 0; i < params.min_consecutive_speech_frames; ++i) { + level_estimator.Update(kVadDataSpeech); + } + EXPECT_LT(initial_level, level_estimator.level_dbfs()); +} + +INSTANTIATE_TEST_SUITE_P(AutomaticGainController2, + AdaptiveModeLevelEstimatorTest, + ::testing::Values(TestConfig{1, false, 0.f, 0.f}, + TestConfig{1, true, 0.f, 0.f}, + TestConfig{9, false, 0.f, 0.f}, + TestConfig{9, true, 0.f, 0.f})); + } // namespace } // namespace webrtc diff --git a/modules/audio_processing/agc2/agc2_common.h b/modules/audio_processing/agc2/agc2_common.h index c238b30881..4a99dd2f2e 100644 --- a/modules/audio_processing/agc2/agc2_common.h +++ b/modules/audio_processing/agc2/agc2_common.h @@ -51,6 +51,7 @@ constexpr float kInitialSpeechLevelEstimateDbfs = -30.f; // Robust VAD probability and speech decisions. constexpr float kDefaultSmoothedVadProbabilityAttack = 1.f; +constexpr int kDefaultAdjacentSpeechFramesThreshold = 1; // Saturation Protector settings. constexpr bool kDefaultUseSaturationProtector = true;