AGC2 VadWithLevel(::LevelAndProbability) renamed + injectable VAD

Refactoring CL to improve names and allow to inject a VAD into
`VadLevelAnalyzer` (new name for `VadWithLevel`).

The injectable VAD is needed to inject a mock VAD and write better
unit tests as new features are going to be added to the class.

Bug: webrtc:7494
Change-Id: Ic0cea1e86a19a82533bd40fa04c061be3c44f068
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/185180
Commit-Queue: Alessio Bazzica <alessiob@webrtc.org>
Reviewed-by: Minyue Li <minyue@webrtc.org>
Cr-Commit-Position: refs/heads/master@{#32195}
This commit is contained in:
Alessio Bazzica
2020-09-25 13:24:36 +02:00
committed by Commit Bot
parent 77baeee99e
commit 530781d03f
11 changed files with 158 additions and 123 deletions

View File

@ -47,9 +47,9 @@ void AdaptiveAgc::Process(AudioFrameView<float> float_frame,
apm_data_dumper_->DumpRaw("agc2_vad_probability",
signal_with_levels.vad_result.speech_probability);
apm_data_dumper_->DumpRaw("agc2_vad_rms_dbfs",
signal_with_levels.vad_result.speech_rms_dbfs);
signal_with_levels.vad_result.rms_dbfs);
apm_data_dumper_->DumpRaw("agc2_vad_peak_dbfs",
signal_with_levels.vad_result.speech_peak_dbfs);
signal_with_levels.vad_result.peak_dbfs);
speech_level_estimator_.UpdateEstimation(signal_with_levels.vad_result);

View File

@ -33,7 +33,7 @@ class AdaptiveAgc {
private:
AdaptiveModeLevelEstimator speech_level_estimator_;
VadWithLevel vad_;
VadLevelAnalyzer vad_;
AdaptiveDigitalGainApplier gain_applier_;
ApmDataDumper* const apm_data_dumper_;
NoiseLevelEstimator noise_level_estimator_;

View File

@ -26,7 +26,7 @@ struct SignalWithLevels {
float input_level_dbfs = -1.f;
float input_noise_level_dbfs = -1.f;
VadWithLevel::LevelAndProbability vad_result;
VadLevelAnalyzer::Result vad_result;
float limiter_audio_level_dbfs = -1.f;
bool estimate_is_confident = false;
AudioFrameView<float> float_frame;

View File

@ -23,11 +23,13 @@ namespace {
// Constants used in place of estimated noise levels.
constexpr float kNoNoiseDbfs = -90.f;
constexpr float kWithNoiseDbfs = -20.f;
constexpr VadWithLevel::LevelAndProbability kVadSpeech(1.f, -20.f, 0.f);
static_assert(std::is_trivially_destructible<VadLevelAnalyzer::Result>::value,
"");
constexpr VadLevelAnalyzer::Result kVadSpeech{1.f, -20.f, 0.f};
// Runs gain applier and returns the applied gain in linear scale.
float RunOnConstantLevel(int num_iterations,
VadWithLevel::LevelAndProbability vad_data,
VadLevelAnalyzer::Result vad_level,
float input_level_dbfs,
AdaptiveDigitalGainApplier* gain_applier) {
float gain_linear = 0.f;
@ -37,7 +39,7 @@ float RunOnConstantLevel(int num_iterations,
SignalWithLevels signal_with_levels(fake_audio.float_frame_view());
signal_with_levels.input_level_dbfs = input_level_dbfs;
signal_with_levels.input_noise_level_dbfs = kNoNoiseDbfs;
signal_with_levels.vad_result = vad_data;
signal_with_levels.vad_result = vad_level;
signal_with_levels.limiter_audio_level_dbfs = -2.f;
signal_with_levels.estimate_is_confident = true;
gain_applier->Process(signal_with_levels);
@ -61,9 +63,6 @@ SignalWithLevels TestSignalWithLevel(AudioFrameView<float> float_frame) {
} // namespace
TEST(AutomaticGainController2AdaptiveGainApplier, GainApplierShouldNotCrash) {
static_assert(
std::is_trivially_destructible<VadWithLevel::LevelAndProbability>::value,
"");
ApmDataDumper apm_data_dumper(0);
AdaptiveDigitalGainApplier gain_applier(&apm_data_dumper);

View File

@ -50,15 +50,15 @@ AdaptiveModeLevelEstimator::AdaptiveModeLevelEstimator(
apm_data_dumper_(apm_data_dumper) {}
void AdaptiveModeLevelEstimator::UpdateEstimation(
const VadWithLevel::LevelAndProbability& vad_data) {
RTC_DCHECK_GT(vad_data.speech_rms_dbfs, -150.f);
RTC_DCHECK_LT(vad_data.speech_rms_dbfs, 50.f);
RTC_DCHECK_GT(vad_data.speech_peak_dbfs, -150.f);
RTC_DCHECK_LT(vad_data.speech_peak_dbfs, 50.f);
RTC_DCHECK_GE(vad_data.speech_probability, 0.f);
RTC_DCHECK_LE(vad_data.speech_probability, 1.f);
const VadLevelAnalyzer::Result& vad_level) {
RTC_DCHECK_GT(vad_level.rms_dbfs, -150.f);
RTC_DCHECK_LT(vad_level.rms_dbfs, 50.f);
RTC_DCHECK_GT(vad_level.peak_dbfs, -150.f);
RTC_DCHECK_LT(vad_level.peak_dbfs, 50.f);
RTC_DCHECK_GE(vad_level.speech_probability, 0.f);
RTC_DCHECK_LE(vad_level.speech_probability, 1.f);
if (vad_data.speech_probability < kVadConfidenceThreshold) {
if (vad_level.speech_probability < kVadConfidenceThreshold) {
DebugDumpEstimate();
return;
}
@ -76,22 +76,22 @@ void AdaptiveModeLevelEstimator::UpdateEstimation(
AudioProcessing::Config::GainController2::LevelEstimator;
switch (level_estimator_) {
case LevelEstimatorType::kRms:
speech_level_dbfs = vad_data.speech_rms_dbfs;
speech_level_dbfs = vad_level.rms_dbfs;
break;
case LevelEstimatorType::kPeak:
speech_level_dbfs = vad_data.speech_peak_dbfs;
speech_level_dbfs = vad_level.peak_dbfs;
break;
}
// Update speech level estimation.
estimate_numerator_ = estimate_numerator_ * leak_factor +
speech_level_dbfs * vad_data.speech_probability;
speech_level_dbfs * vad_level.speech_probability;
estimate_denominator_ =
estimate_denominator_ * leak_factor + vad_data.speech_probability;
estimate_denominator_ * leak_factor + vad_level.speech_probability;
last_estimate_with_offset_dbfs_ = estimate_numerator_ / estimate_denominator_;
if (use_saturation_protector_) {
saturation_protector_.UpdateMargin(vad_data.speech_peak_dbfs,
saturation_protector_.UpdateMargin(vad_level.peak_dbfs,
last_estimate_with_offset_dbfs_);
DebugDumpEstimate();
}

View File

@ -40,7 +40,7 @@ class AdaptiveModeLevelEstimator {
bool use_saturation_protector,
float initial_saturation_margin_db,
float extra_saturation_margin_db);
void UpdateEstimation(const VadWithLevel::LevelAndProbability& vad_data);
void UpdateEstimation(const VadLevelAnalyzer::Result& vad_level);
float LatestLevelEstimate() const;
void Reset();
bool LevelEstimationIsConfident() const {

View File

@ -43,7 +43,7 @@ class AdaptiveModeLevelEstimatorAgc : public Agc {
static constexpr int kDefaultAgc2LevelHeadroomDbfs = -1;
int32_t time_in_ms_since_last_estimate_ = 0;
AdaptiveModeLevelEstimator level_estimator_;
VadWithLevel agc2_vad_;
VadLevelAnalyzer agc2_vad_;
float latest_voice_probability_ = 0.f;
};
} // namespace webrtc

View File

@ -22,11 +22,14 @@ namespace {
constexpr float kInitialSaturationMarginDb = 20.f;
constexpr float kExtraSaturationMarginDb = 2.f;
constexpr float kMinSpeechProbability = 0.f;
constexpr float kMaxSpeechProbability = 1.f;
void RunOnConstantLevel(int num_iterations,
VadWithLevel::LevelAndProbability vad_data,
const VadLevelAnalyzer::Result& vad_level,
AdaptiveModeLevelEstimator& level_estimator) {
for (int i = 0; i < num_iterations; ++i) {
level_estimator.UpdateEstimation(vad_data); // By copy
level_estimator.UpdateEstimation(vad_level);
}
}
@ -49,8 +52,9 @@ TEST(AutomaticGainController2AdaptiveModeLevelEstimator,
EstimatorShouldNotCrash) {
TestLevelEstimator level_estimator;
VadWithLevel::LevelAndProbability vad_data(1.f, -20.f, -10.f);
level_estimator.estimator->UpdateEstimation(vad_data);
VadLevelAnalyzer::Result vad_level{kMaxSpeechProbability, /*rms_dbfs=*/-20.f,
/*peak_dbfs=*/-10.f};
level_estimator.estimator->UpdateEstimation(vad_level);
static_cast<void>(level_estimator.estimator->LatestLevelEstimate());
}
@ -58,11 +62,12 @@ TEST(AutomaticGainController2AdaptiveModeLevelEstimator, LevelShouldStabilize) {
TestLevelEstimator level_estimator;
constexpr float kSpeechPeakDbfs = -15.f;
RunOnConstantLevel(
100,
VadWithLevel::LevelAndProbability(
1.f, kSpeechPeakDbfs - kInitialSaturationMarginDb, kSpeechPeakDbfs),
*level_estimator.estimator);
RunOnConstantLevel(100,
VadLevelAnalyzer::Result{kMaxSpeechProbability,
/*rms_dbfs=*/kSpeechPeakDbfs -
kInitialSaturationMarginDb,
kSpeechPeakDbfs},
*level_estimator.estimator);
EXPECT_NEAR(level_estimator.estimator->LatestLevelEstimate() -
kExtraSaturationMarginDb,
@ -75,17 +80,20 @@ TEST(AutomaticGainController2AdaptiveModeLevelEstimator,
// Run for one second of fake audio.
constexpr float kSpeechRmsDbfs = -25.f;
RunOnConstantLevel(
100,
VadWithLevel::LevelAndProbability(
1.f, kSpeechRmsDbfs - kInitialSaturationMarginDb, kSpeechRmsDbfs),
*level_estimator.estimator);
RunOnConstantLevel(100,
VadLevelAnalyzer::Result{kMaxSpeechProbability,
/*rms_dbfs=*/kSpeechRmsDbfs -
kInitialSaturationMarginDb,
/*peak_dbfs=*/kSpeechRmsDbfs},
*level_estimator.estimator);
// Run for one more second, but mark as not speech.
constexpr float kNoiseRmsDbfs = 0.f;
RunOnConstantLevel(
100, VadWithLevel::LevelAndProbability(0.f, kNoiseRmsDbfs, kNoiseRmsDbfs),
*level_estimator.estimator);
RunOnConstantLevel(100,
VadLevelAnalyzer::Result{kMinSpeechProbability,
/*rms_dbfs=*/kNoiseRmsDbfs,
/*peak_dbfs=*/kNoiseRmsDbfs},
*level_estimator.estimator);
// Level should not have changed.
EXPECT_NEAR(level_estimator.estimator->LatestLevelEstimate() -
@ -100,9 +108,10 @@ TEST(AutomaticGainController2AdaptiveModeLevelEstimator, TimeToAdapt) {
constexpr float kInitialSpeechRmsDbfs = -30.f;
RunOnConstantLevel(
kFullBufferSizeMs / kFrameDurationMs,
VadWithLevel::LevelAndProbability(
1.f, kInitialSpeechRmsDbfs - kInitialSaturationMarginDb,
kInitialSpeechRmsDbfs),
VadLevelAnalyzer::Result{
kMaxSpeechProbability,
/*rms_dbfs=*/kInitialSpeechRmsDbfs - kInitialSaturationMarginDb,
/*peak_dbfs=*/kInitialSpeechRmsDbfs},
*level_estimator.estimator);
// Run for one half 'window size' interval. This should not be enough to
@ -110,12 +119,13 @@ TEST(AutomaticGainController2AdaptiveModeLevelEstimator, TimeToAdapt) {
constexpr float kDifferentSpeechRmsDbfs = -10.f;
// It should at most differ by 25% after one half 'window size' interval.
const float kMaxDifferenceDb =
0.25 * std::abs(kDifferentSpeechRmsDbfs - kInitialSpeechRmsDbfs);
0.25f * std::abs(kDifferentSpeechRmsDbfs - kInitialSpeechRmsDbfs);
RunOnConstantLevel(
static_cast<int>(kFullBufferSizeMs / kFrameDurationMs / 2),
VadWithLevel::LevelAndProbability(
1.f, kDifferentSpeechRmsDbfs - kInitialSaturationMarginDb,
kDifferentSpeechRmsDbfs),
VadLevelAnalyzer::Result{
kMaxSpeechProbability,
/*rms_dbfs=*/kDifferentSpeechRmsDbfs - kInitialSaturationMarginDb,
/*peak_dbfs=*/kDifferentSpeechRmsDbfs},
*level_estimator.estimator);
EXPECT_GT(std::abs(kDifferentSpeechRmsDbfs -
level_estimator.estimator->LatestLevelEstimate()),
@ -124,9 +134,10 @@ TEST(AutomaticGainController2AdaptiveModeLevelEstimator, TimeToAdapt) {
// Run for some more time. Afterwards, we should have adapted.
RunOnConstantLevel(
static_cast<int>(3 * kFullBufferSizeMs / kFrameDurationMs),
VadWithLevel::LevelAndProbability(
1.f, kDifferentSpeechRmsDbfs - kInitialSaturationMarginDb,
kDifferentSpeechRmsDbfs),
VadLevelAnalyzer::Result{
kMaxSpeechProbability,
/*rms_dbfs=*/kDifferentSpeechRmsDbfs - kInitialSaturationMarginDb,
/*peak_dbfs=*/kDifferentSpeechRmsDbfs},
*level_estimator.estimator);
EXPECT_NEAR(level_estimator.estimator->LatestLevelEstimate() -
kExtraSaturationMarginDb,
@ -142,9 +153,10 @@ TEST(AutomaticGainController2AdaptiveModeLevelEstimator,
constexpr float kInitialSpeechRmsDbfs = -30.f;
RunOnConstantLevel(
kFullBufferSizeMs / kFrameDurationMs,
VadWithLevel::LevelAndProbability(
1.f, kInitialSpeechRmsDbfs - kInitialSaturationMarginDb,
kInitialSpeechRmsDbfs),
VadLevelAnalyzer::Result{
kMaxSpeechProbability,
/*rms_dbfs=*/kInitialSpeechRmsDbfs - kInitialSaturationMarginDb,
/*peak_dbfs=*/kInitialSpeechRmsDbfs},
*level_estimator.estimator);
constexpr float kDifferentSpeechRmsDbfs = -10.f;
@ -153,9 +165,10 @@ TEST(AutomaticGainController2AdaptiveModeLevelEstimator,
RunOnConstantLevel(
kFullBufferSizeMs / kFrameDurationMs / 2,
VadWithLevel::LevelAndProbability(
1.f, kDifferentSpeechRmsDbfs - kInitialSaturationMarginDb,
kDifferentSpeechRmsDbfs),
VadLevelAnalyzer::Result{
kMaxSpeechProbability,
/*rms_dbfs=*/kDifferentSpeechRmsDbfs - kInitialSaturationMarginDb,
/*peak_dbfs=*/kDifferentSpeechRmsDbfs},
*level_estimator.estimator);
// The level should be close to 'kDifferentSpeechRmsDbfs'.

View File

@ -16,55 +16,73 @@
#include "api/array_view.h"
#include "common_audio/include/audio_util.h"
#include "common_audio/resampler/include/push_resampler.h"
#include "modules/audio_processing/agc2/rnn_vad/common.h"
#include "modules/audio_processing/agc2/rnn_vad/features_extraction.h"
#include "modules/audio_processing/agc2/rnn_vad/rnn.h"
#include "rtc_base/checks.h"
namespace webrtc {
namespace {
float ProcessForPeak(AudioFrameView<const float> frame) {
float current_max = 0;
for (const auto& x : frame.channel(0)) {
current_max = std::max(std::fabs(x), current_max);
}
return current_max;
}
float ProcessForRms(AudioFrameView<const float> frame) {
float rms = 0;
for (const auto& x : frame.channel(0)) {
rms += x * x;
using VoiceActivityDetector = VadLevelAnalyzer::VoiceActivityDetector;
// Default VAD that combines a resampler and the RNN VAD.
// Computes the speech probability on the first channel.
class Vad : public VoiceActivityDetector {
public:
Vad() = default;
Vad(const Vad&) = delete;
Vad& operator=(const Vad&) = delete;
~Vad() = default;
float ComputeProbability(AudioFrameView<const float> frame) override {
// The source number of channels is 1, because we always use the 1st
// channel.
resampler_.InitializeIfNeeded(
/*sample_rate_hz=*/static_cast<int>(frame.samples_per_channel() * 100),
rnn_vad::kSampleRate24kHz,
/*num_channels=*/1);
std::array<float, rnn_vad::kFrameSize10ms24kHz> work_frame;
// Feed the 1st channel to the resampler.
resampler_.Resample(frame.channel(0).data(), frame.samples_per_channel(),
work_frame.data(), rnn_vad::kFrameSize10ms24kHz);
std::array<float, rnn_vad::kFeatureVectorSize> feature_vector;
const bool is_silence = features_extractor_.CheckSilenceComputeFeatures(
work_frame, feature_vector);
return rnn_vad_.ComputeVadProbability(feature_vector, is_silence);
}
return std::sqrt(rms / frame.samples_per_channel());
}
private:
PushResampler<float> resampler_;
rnn_vad::FeaturesExtractor features_extractor_;
rnn_vad::RnnBasedVad rnn_vad_;
};
} // namespace
VadWithLevel::VadWithLevel() = default;
VadWithLevel::~VadWithLevel() = default;
VadLevelAnalyzer::VadLevelAnalyzer() : vad_(std::make_unique<Vad>()) {}
VadWithLevel::LevelAndProbability VadWithLevel::AnalyzeFrame(
AudioFrameView<const float> frame) {
SetSampleRate(static_cast<int>(frame.samples_per_channel() * 100));
std::array<float, rnn_vad::kFrameSize10ms24kHz> work_frame;
// Feed the 1st channel to the resampler.
resampler_.Resample(frame.channel(0).data(), frame.samples_per_channel(),
work_frame.data(), rnn_vad::kFrameSize10ms24kHz);
std::array<float, rnn_vad::kFeatureVectorSize> feature_vector;
const bool is_silence = features_extractor_.CheckSilenceComputeFeatures(
work_frame, feature_vector);
const float vad_probability =
rnn_vad_.ComputeVadProbability(feature_vector, is_silence);
return LevelAndProbability(vad_probability,
FloatS16ToDbfs(ProcessForRms(frame)),
FloatS16ToDbfs(ProcessForPeak(frame)));
VadLevelAnalyzer::VadLevelAnalyzer(std::unique_ptr<VoiceActivityDetector> vad)
: vad_(std::move(vad)) {
RTC_DCHECK(vad_);
}
void VadWithLevel::SetSampleRate(int sample_rate_hz) {
// The source number of channels in 1, because we always use the 1st
// channel.
resampler_.InitializeIfNeeded(sample_rate_hz, rnn_vad::kSampleRate24kHz,
1 /* num_channels */);
VadLevelAnalyzer::~VadLevelAnalyzer() = default;
VadLevelAnalyzer::Result VadLevelAnalyzer::AnalyzeFrame(
AudioFrameView<const float> frame) {
float peak = 0.f;
float rms = 0.f;
for (const auto& x : frame.channel(0)) {
peak = std::max(std::fabs(x), peak);
rms += x * x;
}
return {vad_->ComputeProbability(frame),
FloatS16ToDbfs(std::sqrt(rms / frame.samples_per_channel())),
FloatS16ToDbfs(peak)};
}
} // namespace webrtc

View File

@ -11,36 +11,42 @@
#ifndef MODULES_AUDIO_PROCESSING_AGC2_VAD_WITH_LEVEL_H_
#define MODULES_AUDIO_PROCESSING_AGC2_VAD_WITH_LEVEL_H_
#include "common_audio/resampler/include/push_resampler.h"
#include "modules/audio_processing/agc2/rnn_vad/features_extraction.h"
#include "modules/audio_processing/agc2/rnn_vad/rnn.h"
#include <memory>
#include "modules/audio_processing/include/audio_frame_view.h"
namespace webrtc {
class VadWithLevel {
// Class to analyze voice activity and audio levels.
class VadLevelAnalyzer {
public:
struct LevelAndProbability {
constexpr LevelAndProbability(float prob, float rms, float peak)
: speech_probability(prob),
speech_rms_dbfs(rms),
speech_peak_dbfs(peak) {}
LevelAndProbability() = default;
float speech_probability = 0;
float speech_rms_dbfs = 0; // Root mean square in decibels to full-scale.
float speech_peak_dbfs = 0;
struct Result {
float speech_probability; // Range: [0, 1].
float rms_dbfs; // Root mean square power (dBFS).
float peak_dbfs; // Peak power (dBFS).
};
VadWithLevel();
~VadWithLevel();
// Voice Activity Detector (VAD) interface.
class VoiceActivityDetector {
public:
virtual ~VoiceActivityDetector() = default;
// Analyzes an audio frame and returns the speech probability.
virtual float ComputeProbability(AudioFrameView<const float> frame) = 0;
};
LevelAndProbability AnalyzeFrame(AudioFrameView<const float> frame);
// Ctor. Uses the default VAD.
VadLevelAnalyzer();
// Ctor. Uses a custom `vad`.
explicit VadLevelAnalyzer(std::unique_ptr<VoiceActivityDetector> vad);
VadLevelAnalyzer(const VadLevelAnalyzer&) = delete;
VadLevelAnalyzer& operator=(const VadLevelAnalyzer&) = delete;
~VadLevelAnalyzer();
// Computes the speech probability and the level for `frame`.
Result AnalyzeFrame(AudioFrameView<const float> frame);
private:
void SetSampleRate(int sample_rate_hz);
rnn_vad::RnnBasedVad rnn_vad_;
rnn_vad::FeaturesExtractor features_extractor_;
PushResampler<float> resampler_;
std::unique_ptr<VoiceActivityDetector> vad_;
};
} // namespace webrtc

View File

@ -13,7 +13,7 @@
#include "rtc_base/gunit.h"
namespace webrtc {
namespace test {
namespace {
TEST(AutomaticGainController2VadWithLevelEstimator,
PeakLevelGreaterThanRmsLevel) {
@ -28,13 +28,12 @@ TEST(AutomaticGainController2VadWithLevelEstimator,
AudioFrameView<float> frame_view(&channel0, 1, frame.size());
// Compute audio frame levels (the VAD result is ignored).
VadWithLevel vad_with_level;
auto levels_and_vad_prob = vad_with_level.AnalyzeFrame(frame_view);
VadLevelAnalyzer analyzer;
auto levels_and_vad_prob = analyzer.AnalyzeFrame(frame_view);
// Compare peak and RMS levels.
EXPECT_LT(levels_and_vad_prob.speech_rms_dbfs,
levels_and_vad_prob.speech_peak_dbfs);
EXPECT_LT(levels_and_vad_prob.rms_dbfs, levels_and_vad_prob.peak_dbfs);
}
} // namespace test
} // namespace
} // namespace webrtc