AGC2 AdaptiveModeLevelEstimator min consecutive speech frames (1/3)

This is the first CL needed to add a new `AdaptiveModeLevelEstimator`
feature that makes AGC2 more robus to VAD mistakes: the level estimator
discards estimation updates when too few consecutive speech frames are
observed.

In this CL, the state of the estimator is defined in a separate struct
so that in a follow-up CL a new member of that type can be added to
hold a temporary state (that can be either confirmed or discarded).

Tested: Bit-exactness verified with audioproc_f

Bug: webrtc:7494
Change-Id: Ic2ea5ed63c493b9f3a79f19e7f5eaecaa6808ace
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/184931
Commit-Queue: Alessio Bazzica <alessiob@webrtc.org>
Reviewed-by: Minyue Li <minyue@webrtc.org>
Reviewed-by: Ivo Creusen <ivoc@webrtc.org>
Cr-Commit-Position: refs/heads/master@{#32199}
This commit is contained in:
Alessio Bazzica
2020-09-25 14:34:26 +02:00
committed by Commit Bot
parent c1ece012cb
commit cd92b0be9a
5 changed files with 110 additions and 69 deletions

View File

@ -51,10 +51,9 @@ void AdaptiveAgc::Process(AudioFrameView<float> float_frame,
apm_data_dumper_->DumpRaw("agc2_vad_peak_dbfs",
signal_with_levels.vad_result.peak_dbfs);
speech_level_estimator_.UpdateEstimation(signal_with_levels.vad_result);
speech_level_estimator_.Update(signal_with_levels.vad_result);
signal_with_levels.input_level_dbfs =
speech_level_estimator_.LatestLevelEstimate();
signal_with_levels.input_level_dbfs = speech_level_estimator_.GetLevelDbfs();
signal_with_levels.input_noise_level_dbfs =
noise_level_estimator_.Analyze(float_frame);
@ -68,7 +67,7 @@ void AdaptiveAgc::Process(AudioFrameView<float> float_frame,
signal_with_levels.limiter_audio_level_dbfs);
signal_with_levels.estimate_is_confident =
speech_level_estimator_.LevelEstimationIsConfident();
speech_level_estimator_.IsConfident();
// The gain applier applies the gain.
gain_applier_.Process(signal_with_levels);

View File

@ -17,6 +17,11 @@
namespace webrtc {
float AdaptiveModeLevelEstimator::State::Ratio::GetRatio() const {
RTC_DCHECK_NE(denominator, 0.f);
return numerator / denominator;
}
AdaptiveModeLevelEstimator::AdaptiveModeLevelEstimator(
ApmDataDumper* apm_data_dumper)
: AdaptiveModeLevelEstimator(
@ -43,13 +48,16 @@ AdaptiveModeLevelEstimator::AdaptiveModeLevelEstimator(
bool use_saturation_protector,
float initial_saturation_margin_db,
float extra_saturation_margin_db)
: level_estimator_(level_estimator),
: apm_data_dumper_(apm_data_dumper),
saturation_protector_(apm_data_dumper, initial_saturation_margin_db),
level_estimator_type_(level_estimator),
use_saturation_protector_(use_saturation_protector),
extra_saturation_margin_db_(extra_saturation_margin_db),
saturation_protector_(apm_data_dumper, initial_saturation_margin_db),
apm_data_dumper_(apm_data_dumper) {}
last_level_dbfs_(absl::nullopt) {
Reset();
}
void AdaptiveModeLevelEstimator::UpdateEstimation(
void AdaptiveModeLevelEstimator::Update(
const VadLevelAnalyzer::Result& vad_level) {
RTC_DCHECK_GT(vad_level.rms_dbfs, -150.f);
RTC_DCHECK_LT(vad_level.rms_dbfs, 50.f);
@ -63,64 +71,80 @@ void AdaptiveModeLevelEstimator::UpdateEstimation(
return;
}
const bool buffer_is_full = buffer_size_ms_ >= kFullBufferSizeMs;
// Update the state.
RTC_DCHECK_GE(state_.time_to_full_buffer_ms, 0);
const bool buffer_is_full = state_.time_to_full_buffer_ms == 0;
if (!buffer_is_full) {
buffer_size_ms_ += kFrameDurationMs;
state_.time_to_full_buffer_ms -= kFrameDurationMs;
}
const float leak_factor = buffer_is_full ? kFullBufferLeakFactor : 1.f;
// Read speech level estimation.
float speech_level_dbfs = 0.f;
// Read level estimation.
float level_dbfs = 0.f;
using LevelEstimatorType =
AudioProcessing::Config::GainController2::LevelEstimator;
switch (level_estimator_) {
switch (level_estimator_type_) {
case LevelEstimatorType::kRms:
speech_level_dbfs = vad_level.rms_dbfs;
level_dbfs = vad_level.rms_dbfs;
break;
case LevelEstimatorType::kPeak:
speech_level_dbfs = vad_level.peak_dbfs;
level_dbfs = vad_level.peak_dbfs;
break;
}
// Update speech level estimation.
estimate_numerator_ = estimate_numerator_ * leak_factor +
speech_level_dbfs * vad_level.speech_probability;
estimate_denominator_ =
estimate_denominator_ * leak_factor + vad_level.speech_probability;
last_estimate_with_offset_dbfs_ = estimate_numerator_ / estimate_denominator_;
// Update level estimation (average level weighted by speech probability).
RTC_DCHECK_GT(vad_level.speech_probability, 0.f);
const float leak_factor = buffer_is_full ? kFullBufferLeakFactor : 1.f;
state_.level_dbfs.numerator = state_.level_dbfs.numerator * leak_factor +
level_dbfs * vad_level.speech_probability;
state_.level_dbfs.denominator = state_.level_dbfs.denominator * leak_factor +
vad_level.speech_probability;
// Cache level estimation.
last_level_dbfs_ = state_.level_dbfs.GetRatio();
// TODO(crbug.com/webrtc/7494): Update saturation protector state in `state`.
if (use_saturation_protector_) {
saturation_protector_.UpdateMargin(vad_level.peak_dbfs,
last_estimate_with_offset_dbfs_);
DebugDumpEstimate();
saturation_protector_.UpdateMargin(
/*speech_peak_dbfs=*/vad_level.peak_dbfs,
/*speech_level_dbfs=*/last_level_dbfs_.value());
}
DebugDumpEstimate();
}
float AdaptiveModeLevelEstimator::LatestLevelEstimate() const {
return rtc::SafeClamp<float>(
last_estimate_with_offset_dbfs_ +
(use_saturation_protector_ ? (saturation_protector_.margin_db() +
extra_saturation_margin_db_)
: 0.f),
-90.f, 30.f);
float AdaptiveModeLevelEstimator::GetLevelDbfs() const {
float level_dbfs = last_level_dbfs_.value_or(kInitialSpeechLevelEstimateDbfs);
if (use_saturation_protector_) {
level_dbfs += saturation_protector_.margin_db();
level_dbfs += extra_saturation_margin_db_;
}
return rtc::SafeClamp<float>(level_dbfs, -90.f, 30.f);
}
bool AdaptiveModeLevelEstimator::IsConfident() const {
// Returns true if enough speech frames have been observed.
return state_.time_to_full_buffer_ms == 0;
}
void AdaptiveModeLevelEstimator::Reset() {
buffer_size_ms_ = 0;
last_estimate_with_offset_dbfs_ = kInitialSpeechLevelEstimateDbfs;
estimate_numerator_ = 0.f;
estimate_denominator_ = 0.f;
saturation_protector_.Reset();
ResetState(state_);
last_level_dbfs_ = absl::nullopt;
}
void AdaptiveModeLevelEstimator::ResetState(State& state) {
state.time_to_full_buffer_ms = kFullBufferSizeMs;
state.level_dbfs.numerator = 0.f;
state.level_dbfs.denominator = 0.f;
// TODO(crbug.com/webrtc/7494): Reset saturation protector state in `state`.
}
void AdaptiveModeLevelEstimator::DebugDumpEstimate() {
if (apm_data_dumper_) {
apm_data_dumper_->DumpRaw("agc2_adaptive_level_estimate_with_offset_dbfs",
last_estimate_with_offset_dbfs_);
apm_data_dumper_->DumpRaw("agc2_adaptive_level_estimate_dbfs",
LatestLevelEstimate());
GetLevelDbfs());
}
saturation_protector_.DebugDumpEstimate();
}
} // namespace webrtc

View File

@ -13,6 +13,7 @@
#include <stddef.h>
#include "absl/types/optional.h"
#include "modules/audio_processing/agc2/agc2_common.h"
#include "modules/audio_processing/agc2/saturation_protector.h"
#include "modules/audio_processing/agc2/vad_with_level.h"
@ -21,6 +22,7 @@
namespace webrtc {
class ApmDataDumper;
// Level estimator for the digital adaptive gain controller.
class AdaptiveModeLevelEstimator {
public:
explicit AdaptiveModeLevelEstimator(ApmDataDumper* apm_data_dumper);
@ -40,26 +42,42 @@ class AdaptiveModeLevelEstimator {
bool use_saturation_protector,
float initial_saturation_margin_db,
float extra_saturation_margin_db);
void UpdateEstimation(const VadLevelAnalyzer::Result& vad_level);
float LatestLevelEstimate() const;
// Updates the level estimation.
void Update(const VadLevelAnalyzer::Result& vad_data);
// Returns the estimated speech plus noise level.
float GetLevelDbfs() const;
// Returns true if the estimator is confident on its current estimate.
bool IsConfident() const;
void Reset();
bool LevelEstimationIsConfident() const {
return buffer_size_ms_ >= kFullBufferSizeMs;
}
private:
// Part of the level estimator state used for check-pointing and restore ops.
struct State {
struct Ratio {
float numerator;
float denominator;
float GetRatio() const;
};
int time_to_full_buffer_ms;
Ratio level_dbfs;
// TODO(crbug.com/webrtc/7494): Add saturation protector state.
};
void ResetState(State& state);
void DebugDumpEstimate();
ApmDataDumper* const apm_data_dumper_;
SaturationProtector saturation_protector_;
const AudioProcessing::Config::GainController2::LevelEstimator
level_estimator_;
level_estimator_type_;
const bool use_saturation_protector_;
const float extra_saturation_margin_db_;
size_t buffer_size_ms_ = 0;
float last_estimate_with_offset_dbfs_ = kInitialSpeechLevelEstimateDbfs;
float estimate_numerator_ = 0.f;
float estimate_denominator_ = 0.f;
SaturationProtector saturation_protector_;
ApmDataDumper* const apm_data_dumper_;
// TODO(crbug.com/webrtc/7494): Add temporary state.
State state_;
absl::optional<float> last_level_dbfs_;
};
} // namespace webrtc

View File

@ -38,7 +38,7 @@ void AdaptiveModeLevelEstimatorAgc::Process(const int16_t* audio,
if (latest_voice_probability_ > kVadConfidenceThreshold) {
time_in_ms_since_last_estimate_ += kFrameDurationMs;
}
level_estimator_.UpdateEstimation(vad_prob);
level_estimator_.Update(vad_prob);
}
// Retrieves the difference between the target RMS level and the current
@ -48,8 +48,8 @@ bool AdaptiveModeLevelEstimatorAgc::GetRmsErrorDb(int* error) {
if (time_in_ms_since_last_estimate_ <= kTimeUntilConfidentMs) {
return false;
}
*error = std::floor(target_level_dbfs() -
level_estimator_.LatestLevelEstimate() + 0.5f);
*error =
std::floor(target_level_dbfs() - level_estimator_.GetLevelDbfs() + 0.5f);
time_in_ms_since_last_estimate_ = 0;
return true;
}

View File

@ -29,7 +29,7 @@ void RunOnConstantLevel(int num_iterations,
const VadLevelAnalyzer::Result& vad_level,
AdaptiveModeLevelEstimator& level_estimator) {
for (int i = 0; i < num_iterations; ++i) {
level_estimator.UpdateEstimation(vad_level);
level_estimator.Update(vad_level);
}
}
@ -54,8 +54,8 @@ TEST(AutomaticGainController2AdaptiveModeLevelEstimator,
VadLevelAnalyzer::Result vad_level{kMaxSpeechProbability, /*rms_dbfs=*/-20.f,
/*peak_dbfs=*/-10.f};
level_estimator.estimator->UpdateEstimation(vad_level);
static_cast<void>(level_estimator.estimator->LatestLevelEstimate());
level_estimator.estimator->Update(vad_level);
static_cast<void>(level_estimator.estimator->GetLevelDbfs());
}
TEST(AutomaticGainController2AdaptiveModeLevelEstimator, LevelShouldStabilize) {
@ -69,9 +69,9 @@ TEST(AutomaticGainController2AdaptiveModeLevelEstimator, LevelShouldStabilize) {
kSpeechPeakDbfs},
*level_estimator.estimator);
EXPECT_NEAR(level_estimator.estimator->LatestLevelEstimate() -
kExtraSaturationMarginDb,
kSpeechPeakDbfs, 0.1f);
EXPECT_NEAR(
level_estimator.estimator->GetLevelDbfs() - kExtraSaturationMarginDb,
kSpeechPeakDbfs, 0.1f);
}
TEST(AutomaticGainController2AdaptiveModeLevelEstimator,
@ -96,9 +96,9 @@ TEST(AutomaticGainController2AdaptiveModeLevelEstimator,
*level_estimator.estimator);
// Level should not have changed.
EXPECT_NEAR(level_estimator.estimator->LatestLevelEstimate() -
kExtraSaturationMarginDb,
kSpeechRmsDbfs, 0.1f);
EXPECT_NEAR(
level_estimator.estimator->GetLevelDbfs() - kExtraSaturationMarginDb,
kSpeechRmsDbfs, 0.1f);
}
TEST(AutomaticGainController2AdaptiveModeLevelEstimator, TimeToAdapt) {
@ -128,7 +128,7 @@ TEST(AutomaticGainController2AdaptiveModeLevelEstimator, TimeToAdapt) {
/*peak_dbfs=*/kDifferentSpeechRmsDbfs},
*level_estimator.estimator);
EXPECT_GT(std::abs(kDifferentSpeechRmsDbfs -
level_estimator.estimator->LatestLevelEstimate()),
level_estimator.estimator->GetLevelDbfs()),
kMaxDifferenceDb);
// Run for some more time. Afterwards, we should have adapted.
@ -139,9 +139,9 @@ TEST(AutomaticGainController2AdaptiveModeLevelEstimator, TimeToAdapt) {
/*rms_dbfs=*/kDifferentSpeechRmsDbfs - kInitialSaturationMarginDb,
/*peak_dbfs=*/kDifferentSpeechRmsDbfs},
*level_estimator.estimator);
EXPECT_NEAR(level_estimator.estimator->LatestLevelEstimate() -
kExtraSaturationMarginDb,
kDifferentSpeechRmsDbfs, kMaxDifferenceDb * 0.5f);
EXPECT_NEAR(
level_estimator.estimator->GetLevelDbfs() - kExtraSaturationMarginDb,
kDifferentSpeechRmsDbfs, kMaxDifferenceDb * 0.5f);
}
TEST(AutomaticGainController2AdaptiveModeLevelEstimator,
@ -175,7 +175,7 @@ TEST(AutomaticGainController2AdaptiveModeLevelEstimator,
const float kMaxDifferenceDb =
0.1f * std::abs(kDifferentSpeechRmsDbfs - kInitialSpeechRmsDbfs);
EXPECT_LT(std::abs(kDifferentSpeechRmsDbfs -
(level_estimator.estimator->LatestLevelEstimate() -
(level_estimator.estimator->GetLevelDbfs() -
kExtraSaturationMarginDb)),
kMaxDifferenceDb);
}