AGC2 AdaptiveModeLevelEstimator: cache last level estimate

`AdaptiveModeLevelEstimator::last_level_dbfs_` doesn't need to be optional.

Note: this CL breaks the chain of 3 CLs titled
"AGC2 AdaptiveModeLevelEstimator min consecutive speech frames".

Bug: webrtc:7494
Change-Id: Id5b409ca5cb5f11ed132c861b7995b9721e167bb
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/185809
Reviewed-by: Minyue Li <minyue@webrtc.org>
Commit-Queue: Alessio Bazzica <alessiob@webrtc.org>
Cr-Commit-Position: refs/heads/master@{#32237}
This commit is contained in:
Alessio Bazzica
2020-09-29 14:08:15 +02:00
committed by Commit Bot
parent c5152e893e
commit 307fab9e02
5 changed files with 61 additions and 45 deletions

View File

@ -53,7 +53,7 @@ void AdaptiveAgc::Process(AudioFrameView<float> float_frame,
speech_level_estimator_.Update(signal_with_levels.vad_result); speech_level_estimator_.Update(signal_with_levels.vad_result);
signal_with_levels.input_level_dbfs = speech_level_estimator_.GetLevelDbfs(); signal_with_levels.input_level_dbfs = speech_level_estimator_.level_dbfs();
signal_with_levels.input_noise_level_dbfs = signal_with_levels.input_noise_level_dbfs =
noise_level_estimator_.Analyze(float_frame); noise_level_estimator_.Analyze(float_frame);

View File

@ -16,6 +16,38 @@
#include "rtc_base/numerics/safe_minmax.h" #include "rtc_base/numerics/safe_minmax.h"
namespace webrtc { namespace webrtc {
namespace {
using LevelEstimatorType =
AudioProcessing::Config::GainController2::LevelEstimator;
// Combines a level estimation with the saturation protector margins.
float ComputeLevelEstimateDbfs(float level_estimate_dbfs,
bool use_saturation_protector,
float saturation_margin_db,
float extra_saturation_margin_db) {
return rtc::SafeClamp<float>(
level_estimate_dbfs +
(use_saturation_protector
? (saturation_margin_db + extra_saturation_margin_db)
: 0.f),
-90.f, 30.f);
}
// Returns the level of given type from `vad_level`.
float GetLevel(const VadLevelAnalyzer::Result& vad_level,
LevelEstimatorType type) {
switch (type) {
case LevelEstimatorType::kRms:
return vad_level.rms_dbfs;
break;
case LevelEstimatorType::kPeak:
return vad_level.peak_dbfs;
break;
}
}
} // namespace
float AdaptiveModeLevelEstimator::State::Ratio::GetRatio() const { float AdaptiveModeLevelEstimator::State::Ratio::GetRatio() const {
RTC_DCHECK_NE(denominator, 0.f); RTC_DCHECK_NE(denominator, 0.f);
@ -53,7 +85,10 @@ AdaptiveModeLevelEstimator::AdaptiveModeLevelEstimator(
use_saturation_protector_(use_saturation_protector), use_saturation_protector_(use_saturation_protector),
initial_saturation_margin_db_(initial_saturation_margin_db), initial_saturation_margin_db_(initial_saturation_margin_db),
extra_saturation_margin_db_(extra_saturation_margin_db), extra_saturation_margin_db_(extra_saturation_margin_db),
last_level_dbfs_(absl::nullopt) { level_dbfs_(ComputeLevelEstimateDbfs(kInitialSpeechLevelEstimateDbfs,
use_saturation_protector_,
initial_saturation_margin_db_,
extra_saturation_margin_db_)) {
Reset(); Reset();
} }
@ -78,49 +113,30 @@ void AdaptiveModeLevelEstimator::Update(
state_.time_to_full_buffer_ms -= kFrameDurationMs; state_.time_to_full_buffer_ms -= kFrameDurationMs;
} }
// Read level estimation.
float level_dbfs = 0.f;
using LevelEstimatorType =
AudioProcessing::Config::GainController2::LevelEstimator;
switch (level_estimator_type_) {
case LevelEstimatorType::kRms:
level_dbfs = vad_level.rms_dbfs;
break;
case LevelEstimatorType::kPeak:
level_dbfs = vad_level.peak_dbfs;
break;
}
// Update level estimation (average level weighted by speech probability). // Update level estimation (average level weighted by speech probability).
RTC_DCHECK_GT(vad_level.speech_probability, 0.f); RTC_DCHECK_GT(vad_level.speech_probability, 0.f);
const float leak_factor = buffer_is_full ? kFullBufferLeakFactor : 1.f; const float leak_factor = buffer_is_full ? kFullBufferLeakFactor : 1.f;
state_.level_dbfs.numerator = state_.level_dbfs.numerator * leak_factor + state_.level_dbfs.numerator =
level_dbfs * vad_level.speech_probability; state_.level_dbfs.numerator * leak_factor +
GetLevel(vad_level, level_estimator_type_) * vad_level.speech_probability;
state_.level_dbfs.denominator = state_.level_dbfs.denominator * leak_factor + state_.level_dbfs.denominator = state_.level_dbfs.denominator * leak_factor +
vad_level.speech_probability; vad_level.speech_probability;
// Cache level estimation. const float level_dbfs = state_.level_dbfs.GetRatio();
last_level_dbfs_ = state_.level_dbfs.GetRatio();
if (use_saturation_protector_) { if (use_saturation_protector_) {
UpdateSaturationProtectorState( UpdateSaturationProtectorState(vad_level.peak_dbfs, level_dbfs,
/*speech_peak_dbfs=*/vad_level.peak_dbfs, state_.saturation_protector);
/*speech_level_dbfs=*/last_level_dbfs_.value(),
state_.saturation_protector);
} }
// Cache level estimation.
level_dbfs_ = ComputeLevelEstimateDbfs(level_dbfs, use_saturation_protector_,
state_.saturation_protector.margin_db,
extra_saturation_margin_db_);
DebugDumpEstimate(); DebugDumpEstimate();
} }
float AdaptiveModeLevelEstimator::GetLevelDbfs() const {
float level_dbfs = last_level_dbfs_.value_or(kInitialSpeechLevelEstimateDbfs);
if (use_saturation_protector_) {
level_dbfs += state_.saturation_protector.margin_db;
level_dbfs += extra_saturation_margin_db_;
}
return rtc::SafeClamp<float>(level_dbfs, -90.f, 30.f);
}
bool AdaptiveModeLevelEstimator::IsConfident() const { bool AdaptiveModeLevelEstimator::IsConfident() const {
// Returns true if enough speech frames have been observed. // Returns true if enough speech frames have been observed.
return state_.time_to_full_buffer_ms == 0; return state_.time_to_full_buffer_ms == 0;
@ -128,7 +144,9 @@ bool AdaptiveModeLevelEstimator::IsConfident() const {
void AdaptiveModeLevelEstimator::Reset() { void AdaptiveModeLevelEstimator::Reset() {
ResetState(state_); ResetState(state_);
last_level_dbfs_ = absl::nullopt; level_dbfs_ = ComputeLevelEstimateDbfs(
kInitialSpeechLevelEstimateDbfs, use_saturation_protector_,
initial_saturation_margin_db_, extra_saturation_margin_db_);
} }
void AdaptiveModeLevelEstimator::ResetState(State& state) { void AdaptiveModeLevelEstimator::ResetState(State& state) {
@ -141,8 +159,7 @@ void AdaptiveModeLevelEstimator::ResetState(State& state) {
void AdaptiveModeLevelEstimator::DebugDumpEstimate() { void AdaptiveModeLevelEstimator::DebugDumpEstimate() {
if (apm_data_dumper_) { if (apm_data_dumper_) {
apm_data_dumper_->DumpRaw("agc2_adaptive_level_estimate_dbfs", apm_data_dumper_->DumpRaw("agc2_adaptive_level_estimate_dbfs", level_dbfs_);
GetLevelDbfs());
apm_data_dumper_->DumpRaw("agc2_adaptive_saturation_margin_db", apm_data_dumper_->DumpRaw("agc2_adaptive_saturation_margin_db",
state_.saturation_protector.margin_db); state_.saturation_protector.margin_db);
} }

View File

@ -13,7 +13,6 @@
#include <stddef.h> #include <stddef.h>
#include "absl/types/optional.h"
#include "modules/audio_processing/agc2/agc2_common.h" #include "modules/audio_processing/agc2/agc2_common.h"
#include "modules/audio_processing/agc2/saturation_protector.h" #include "modules/audio_processing/agc2/saturation_protector.h"
#include "modules/audio_processing/agc2/vad_with_level.h" #include "modules/audio_processing/agc2/vad_with_level.h"
@ -46,7 +45,7 @@ class AdaptiveModeLevelEstimator {
// Updates the level estimation. // Updates the level estimation.
void Update(const VadLevelAnalyzer::Result& vad_data); void Update(const VadLevelAnalyzer::Result& vad_data);
// Returns the estimated speech plus noise level. // Returns the estimated speech plus noise level.
float GetLevelDbfs() const; float level_dbfs() const { return level_dbfs_; }
// Returns true if the estimator is confident on its current estimate. // Returns true if the estimator is confident on its current estimate.
bool IsConfident() const; bool IsConfident() const;
@ -77,7 +76,7 @@ class AdaptiveModeLevelEstimator {
const float extra_saturation_margin_db_; const float extra_saturation_margin_db_;
// TODO(crbug.com/webrtc/7494): Add temporary state. // TODO(crbug.com/webrtc/7494): Add temporary state.
State state_; State state_;
absl::optional<float> last_level_dbfs_; float level_dbfs_;
}; };
} // namespace webrtc } // namespace webrtc

View File

@ -49,7 +49,7 @@ bool AdaptiveModeLevelEstimatorAgc::GetRmsErrorDb(int* error) {
return false; return false;
} }
*error = *error =
std::floor(target_level_dbfs() - level_estimator_.GetLevelDbfs() + 0.5f); std::floor(target_level_dbfs() - level_estimator_.level_dbfs() + 0.5f);
time_in_ms_since_last_estimate_ = 0; time_in_ms_since_last_estimate_ = 0;
return true; return true;
} }

View File

@ -53,7 +53,7 @@ TEST(AutomaticGainController2AdaptiveModeLevelEstimator,
VadLevelAnalyzer::Result vad_level{kMaxSpeechProbability, /*rms_dbfs=*/-20.f, VadLevelAnalyzer::Result vad_level{kMaxSpeechProbability, /*rms_dbfs=*/-20.f,
/*peak_dbfs=*/-10.f}; /*peak_dbfs=*/-10.f};
level_estimator.estimator->Update(vad_level); level_estimator.estimator->Update(vad_level);
static_cast<void>(level_estimator.estimator->GetLevelDbfs()); static_cast<void>(level_estimator.estimator->level_dbfs());
} }
TEST(AutomaticGainController2AdaptiveModeLevelEstimator, LevelShouldStabilize) { TEST(AutomaticGainController2AdaptiveModeLevelEstimator, LevelShouldStabilize) {
@ -68,7 +68,7 @@ TEST(AutomaticGainController2AdaptiveModeLevelEstimator, LevelShouldStabilize) {
*level_estimator.estimator); *level_estimator.estimator);
EXPECT_NEAR( EXPECT_NEAR(
level_estimator.estimator->GetLevelDbfs() - kExtraSaturationMarginDb, level_estimator.estimator->level_dbfs() - kExtraSaturationMarginDb,
kSpeechPeakDbfs, 0.1f); kSpeechPeakDbfs, 0.1f);
} }
@ -95,7 +95,7 @@ TEST(AutomaticGainController2AdaptiveModeLevelEstimator,
// Level should not have changed. // Level should not have changed.
EXPECT_NEAR( EXPECT_NEAR(
level_estimator.estimator->GetLevelDbfs() - kExtraSaturationMarginDb, level_estimator.estimator->level_dbfs() - kExtraSaturationMarginDb,
kSpeechRmsDbfs, 0.1f); kSpeechRmsDbfs, 0.1f);
} }
@ -126,7 +126,7 @@ TEST(AutomaticGainController2AdaptiveModeLevelEstimator, TimeToAdapt) {
/*peak_dbfs=*/kDifferentSpeechRmsDbfs}, /*peak_dbfs=*/kDifferentSpeechRmsDbfs},
*level_estimator.estimator); *level_estimator.estimator);
EXPECT_GT(std::abs(kDifferentSpeechRmsDbfs - EXPECT_GT(std::abs(kDifferentSpeechRmsDbfs -
level_estimator.estimator->GetLevelDbfs()), level_estimator.estimator->level_dbfs()),
kMaxDifferenceDb); kMaxDifferenceDb);
// Run for some more time. Afterwards, we should have adapted. // Run for some more time. Afterwards, we should have adapted.
@ -138,7 +138,7 @@ TEST(AutomaticGainController2AdaptiveModeLevelEstimator, TimeToAdapt) {
/*peak_dbfs=*/kDifferentSpeechRmsDbfs}, /*peak_dbfs=*/kDifferentSpeechRmsDbfs},
*level_estimator.estimator); *level_estimator.estimator);
EXPECT_NEAR( EXPECT_NEAR(
level_estimator.estimator->GetLevelDbfs() - kExtraSaturationMarginDb, level_estimator.estimator->level_dbfs() - kExtraSaturationMarginDb,
kDifferentSpeechRmsDbfs, kMaxDifferenceDb * 0.5f); kDifferentSpeechRmsDbfs, kMaxDifferenceDb * 0.5f);
} }
@ -173,7 +173,7 @@ TEST(AutomaticGainController2AdaptiveModeLevelEstimator,
const float kMaxDifferenceDb = const float kMaxDifferenceDb =
0.1f * std::abs(kDifferentSpeechRmsDbfs - kInitialSpeechRmsDbfs); 0.1f * std::abs(kDifferentSpeechRmsDbfs - kInitialSpeechRmsDbfs);
EXPECT_LT(std::abs(kDifferentSpeechRmsDbfs - EXPECT_LT(std::abs(kDifferentSpeechRmsDbfs -
(level_estimator.estimator->GetLevelDbfs() - (level_estimator.estimator->level_dbfs() -
kExtraSaturationMarginDb)), kExtraSaturationMarginDb)),
kMaxDifferenceDb); kMaxDifferenceDb);
} }