AGC2 VAD probability: instant decay / slow attack
Feature added to gain robustness to occasional VAD speech probability spikes. In such a case, the attack process reduces the chance that the smoothed values are greater than the speech threshold. Bug: webrtc:7494 Change-Id: I6babe5afe30ea3dea021181a19d86bb74b33a98c Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/185046 Commit-Queue: Alessio Bazzica <alessiob@webrtc.org> Reviewed-by: Gustaf Ullberg <gustaf@webrtc.org> Cr-Commit-Position: refs/heads/master@{#32198}
This commit is contained in:

committed by
Commit Bot

parent
f2969fa868
commit
c1ece012cb
@ -95,6 +95,7 @@ rtc_library("common") {
|
|||||||
"../../../rtc_base:rtc_base_approved",
|
"../../../rtc_base:rtc_base_approved",
|
||||||
"../../../system_wrappers:field_trial",
|
"../../../system_wrappers:field_trial",
|
||||||
]
|
]
|
||||||
|
absl_deps = [ "//third_party/abseil-cpp/absl/types:optional" ]
|
||||||
}
|
}
|
||||||
|
|
||||||
rtc_library("fixed_digital") {
|
rtc_library("fixed_digital") {
|
||||||
@ -168,6 +169,7 @@ rtc_library("rnn_vad_with_level") {
|
|||||||
"vad_with_level.h",
|
"vad_with_level.h",
|
||||||
]
|
]
|
||||||
deps = [
|
deps = [
|
||||||
|
":common",
|
||||||
"..:audio_frame_view",
|
"..:audio_frame_view",
|
||||||
"../../../api:array_view",
|
"../../../api:array_view",
|
||||||
"../../../common_audio",
|
"../../../common_audio",
|
||||||
@ -265,9 +267,12 @@ rtc_library("rnn_vad_with_level_unittests") {
|
|||||||
testonly = true
|
testonly = true
|
||||||
sources = [ "vad_with_level_unittest.cc" ]
|
sources = [ "vad_with_level_unittest.cc" ]
|
||||||
deps = [
|
deps = [
|
||||||
|
":common",
|
||||||
":rnn_vad_with_level",
|
":rnn_vad_with_level",
|
||||||
"..:audio_frame_view",
|
"..:audio_frame_view",
|
||||||
"../../../rtc_base:gunit_helpers",
|
"../../../rtc_base:gunit_helpers",
|
||||||
|
"../../../rtc_base:safe_compare",
|
||||||
|
"../../../test:test_support",
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -49,6 +49,9 @@ constexpr float kFullBufferLeakFactor = 1.f - 1.f / kFullBufferSizeMs;
|
|||||||
|
|
||||||
constexpr float kInitialSpeechLevelEstimateDbfs = -30.f;
|
constexpr float kInitialSpeechLevelEstimateDbfs = -30.f;
|
||||||
|
|
||||||
|
// Robust VAD probability and speech decisions.
|
||||||
|
constexpr float kDefaultSmoothedVadProbabilityAttack = 1.f;
|
||||||
|
|
||||||
// Saturation Protector settings.
|
// Saturation Protector settings.
|
||||||
float GetInitialSaturationMarginDb();
|
float GetInitialSaturationMarginDb();
|
||||||
float GetExtraSaturationMarginOffsetDb();
|
float GetExtraSaturationMarginOffsetDb();
|
||||||
|
@ -17,6 +17,7 @@
|
|||||||
#include "api/array_view.h"
|
#include "api/array_view.h"
|
||||||
#include "common_audio/include/audio_util.h"
|
#include "common_audio/include/audio_util.h"
|
||||||
#include "common_audio/resampler/include/push_resampler.h"
|
#include "common_audio/resampler/include/push_resampler.h"
|
||||||
|
#include "modules/audio_processing/agc2/agc2_common.h"
|
||||||
#include "modules/audio_processing/agc2/rnn_vad/common.h"
|
#include "modules/audio_processing/agc2/rnn_vad/common.h"
|
||||||
#include "modules/audio_processing/agc2/rnn_vad/features_extraction.h"
|
#include "modules/audio_processing/agc2/rnn_vad/features_extraction.h"
|
||||||
#include "modules/audio_processing/agc2/rnn_vad/rnn.h"
|
#include "modules/audio_processing/agc2/rnn_vad/rnn.h"
|
||||||
@ -61,12 +62,32 @@ class Vad : public VoiceActivityDetector {
|
|||||||
rnn_vad::RnnBasedVad rnn_vad_;
|
rnn_vad::RnnBasedVad rnn_vad_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Returns an updated version of `p_old` by using instant decay and the given
|
||||||
|
// `attack` on a new VAD probability value `p_new`.
|
||||||
|
float SmoothedVadProbability(float p_old, float p_new, float attack) {
|
||||||
|
RTC_DCHECK_GT(attack, 0.f);
|
||||||
|
RTC_DCHECK_LE(attack, 1.f);
|
||||||
|
if (p_new < p_old || attack == 1.f) {
|
||||||
|
// Instant decay (or no smoothing).
|
||||||
|
return p_new;
|
||||||
|
} else {
|
||||||
|
// Attack phase.
|
||||||
|
return attack * p_new + (1.f - attack) * p_old;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
VadLevelAnalyzer::VadLevelAnalyzer() : vad_(std::make_unique<Vad>()) {}
|
VadLevelAnalyzer::VadLevelAnalyzer()
|
||||||
|
: VadLevelAnalyzer(kDefaultSmoothedVadProbabilityAttack,
|
||||||
|
std::make_unique<Vad>()) {}
|
||||||
|
|
||||||
VadLevelAnalyzer::VadLevelAnalyzer(std::unique_ptr<VoiceActivityDetector> vad)
|
VadLevelAnalyzer::VadLevelAnalyzer(float vad_probability_attack)
|
||||||
: vad_(std::move(vad)) {
|
: VadLevelAnalyzer(vad_probability_attack, std::make_unique<Vad>()) {}
|
||||||
|
|
||||||
|
VadLevelAnalyzer::VadLevelAnalyzer(float vad_probability_attack,
|
||||||
|
std::unique_ptr<VoiceActivityDetector> vad)
|
||||||
|
: vad_(std::move(vad)), vad_probability_attack_(vad_probability_attack) {
|
||||||
RTC_DCHECK(vad_);
|
RTC_DCHECK(vad_);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -74,13 +95,18 @@ VadLevelAnalyzer::~VadLevelAnalyzer() = default;
|
|||||||
|
|
||||||
VadLevelAnalyzer::Result VadLevelAnalyzer::AnalyzeFrame(
|
VadLevelAnalyzer::Result VadLevelAnalyzer::AnalyzeFrame(
|
||||||
AudioFrameView<const float> frame) {
|
AudioFrameView<const float> frame) {
|
||||||
|
// Compute levels.
|
||||||
float peak = 0.f;
|
float peak = 0.f;
|
||||||
float rms = 0.f;
|
float rms = 0.f;
|
||||||
for (const auto& x : frame.channel(0)) {
|
for (const auto& x : frame.channel(0)) {
|
||||||
peak = std::max(std::fabs(x), peak);
|
peak = std::max(std::fabs(x), peak);
|
||||||
rms += x * x;
|
rms += x * x;
|
||||||
}
|
}
|
||||||
return {vad_->ComputeProbability(frame),
|
// Compute smoothed speech probability.
|
||||||
|
vad_probability_ = SmoothedVadProbability(
|
||||||
|
/*p_old=*/vad_probability_, /*p_new=*/vad_->ComputeProbability(frame),
|
||||||
|
vad_probability_attack_);
|
||||||
|
return {vad_probability_,
|
||||||
FloatS16ToDbfs(std::sqrt(rms / frame.samples_per_channel())),
|
FloatS16ToDbfs(std::sqrt(rms / frame.samples_per_channel())),
|
||||||
FloatS16ToDbfs(peak)};
|
FloatS16ToDbfs(peak)};
|
||||||
}
|
}
|
||||||
|
@ -36,8 +36,10 @@ class VadLevelAnalyzer {
|
|||||||
|
|
||||||
// Ctor. Uses the default VAD.
|
// Ctor. Uses the default VAD.
|
||||||
VadLevelAnalyzer();
|
VadLevelAnalyzer();
|
||||||
|
explicit VadLevelAnalyzer(float vad_probability_attack);
|
||||||
// Ctor. Uses a custom `vad`.
|
// Ctor. Uses a custom `vad`.
|
||||||
explicit VadLevelAnalyzer(std::unique_ptr<VoiceActivityDetector> vad);
|
VadLevelAnalyzer(float vad_probability_attack,
|
||||||
|
std::unique_ptr<VoiceActivityDetector> vad);
|
||||||
VadLevelAnalyzer(const VadLevelAnalyzer&) = delete;
|
VadLevelAnalyzer(const VadLevelAnalyzer&) = delete;
|
||||||
VadLevelAnalyzer& operator=(const VadLevelAnalyzer&) = delete;
|
VadLevelAnalyzer& operator=(const VadLevelAnalyzer&) = delete;
|
||||||
~VadLevelAnalyzer();
|
~VadLevelAnalyzer();
|
||||||
@ -47,6 +49,8 @@ class VadLevelAnalyzer {
|
|||||||
|
|
||||||
private:
|
private:
|
||||||
std::unique_ptr<VoiceActivityDetector> vad_;
|
std::unique_ptr<VoiceActivityDetector> vad_;
|
||||||
|
const float vad_probability_attack_;
|
||||||
|
float vad_probability_ = 0.f;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace webrtc
|
} // namespace webrtc
|
||||||
|
@ -10,30 +10,121 @@
|
|||||||
|
|
||||||
#include "modules/audio_processing/agc2/vad_with_level.h"
|
#include "modules/audio_processing/agc2/vad_with_level.h"
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "modules/audio_processing/agc2/agc2_common.h"
|
||||||
|
#include "modules/audio_processing/include/audio_frame_view.h"
|
||||||
#include "rtc_base/gunit.h"
|
#include "rtc_base/gunit.h"
|
||||||
|
#include "rtc_base/numerics/safe_compare.h"
|
||||||
|
#include "test/gmock.h"
|
||||||
|
|
||||||
namespace webrtc {
|
namespace webrtc {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
TEST(AutomaticGainController2VadWithLevelEstimator,
|
using ::testing::AnyNumber;
|
||||||
PeakLevelGreaterThanRmsLevel) {
|
using ::testing::ReturnRoundRobin;
|
||||||
constexpr size_t kSampleRateHz = 8000;
|
|
||||||
|
|
||||||
// 10 ms input frame, constant except for one peak value.
|
constexpr float kInstantAttack = 1.f;
|
||||||
// Handcrafted so that the average is lower than the peak value.
|
constexpr float kSlowAttack = 0.1f;
|
||||||
std::array<float, kSampleRateHz / 100> frame;
|
|
||||||
frame.fill(1000.f);
|
constexpr int kSampleRateHz = 8000;
|
||||||
frame[10] = 2000.f;
|
|
||||||
float* const channel0 = frame.data();
|
class MockVad : public VadLevelAnalyzer::VoiceActivityDetector {
|
||||||
AudioFrameView<float> frame_view(&channel0, 1, frame.size());
|
public:
|
||||||
|
MOCK_METHOD(float,
|
||||||
|
ComputeProbability,
|
||||||
|
(AudioFrameView<const float> frame),
|
||||||
|
(override));
|
||||||
|
};
|
||||||
|
|
||||||
|
// Creates a `VadLevelAnalyzer` injecting a mock VAD which repeatedly returns
|
||||||
|
// the next value from `speech_probabilities` until it reaches the end and will
|
||||||
|
// restart from the beginning.
|
||||||
|
std::unique_ptr<VadLevelAnalyzer> CreateVadLevelAnalyzerWithMockVad(
|
||||||
|
float vad_probability_attack,
|
||||||
|
const std::vector<float>& speech_probabilities) {
|
||||||
|
auto vad = std::make_unique<MockVad>();
|
||||||
|
EXPECT_CALL(*vad, ComputeProbability)
|
||||||
|
.Times(AnyNumber())
|
||||||
|
.WillRepeatedly(ReturnRoundRobin(speech_probabilities));
|
||||||
|
return std::make_unique<VadLevelAnalyzer>(vad_probability_attack,
|
||||||
|
std::move(vad));
|
||||||
|
}
|
||||||
|
|
||||||
|
// 10 ms mono frame.
|
||||||
|
struct FrameWithView {
|
||||||
|
// Ctor. Initializes the frame samples with `value`.
|
||||||
|
FrameWithView(float value = 0.f)
|
||||||
|
: channel0(samples.data()),
|
||||||
|
view(&channel0, /*num_channels=*/1, samples.size()) {
|
||||||
|
samples.fill(value);
|
||||||
|
}
|
||||||
|
std::array<float, kSampleRateHz / 100> samples;
|
||||||
|
const float* const channel0;
|
||||||
|
const AudioFrameView<const float> view;
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST(AutomaticGainController2VadLevelAnalyzer, PeakLevelGreaterThanRmsLevel) {
|
||||||
|
// Handcrafted frame so that the average is lower than the peak value.
|
||||||
|
FrameWithView frame(1000.f); // Constant frame.
|
||||||
|
frame.samples[10] = 2000.f; // Except for one peak value.
|
||||||
|
|
||||||
// Compute audio frame levels (the VAD result is ignored).
|
// Compute audio frame levels (the VAD result is ignored).
|
||||||
VadLevelAnalyzer analyzer;
|
VadLevelAnalyzer analyzer;
|
||||||
auto levels_and_vad_prob = analyzer.AnalyzeFrame(frame_view);
|
auto levels_and_vad_prob = analyzer.AnalyzeFrame(frame.view);
|
||||||
|
|
||||||
// Compare peak and RMS levels.
|
// Compare peak and RMS levels.
|
||||||
EXPECT_LT(levels_and_vad_prob.rms_dbfs, levels_and_vad_prob.peak_dbfs);
|
EXPECT_LT(levels_and_vad_prob.rms_dbfs, levels_and_vad_prob.peak_dbfs);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Checks that the unprocessed and the smoothed speech probabilities match when
|
||||||
|
// instant attack is used.
|
||||||
|
TEST(AutomaticGainController2VadLevelAnalyzer, NoSpeechProbabilitySmoothing) {
|
||||||
|
const std::vector<float> speech_probabilities{0.709f, 0.484f, 0.882f, 0.167f,
|
||||||
|
0.44f, 0.525f, 0.858f, 0.314f,
|
||||||
|
0.653f, 0.965f, 0.413f, 0.f};
|
||||||
|
auto analyzer =
|
||||||
|
CreateVadLevelAnalyzerWithMockVad(kInstantAttack, speech_probabilities);
|
||||||
|
FrameWithView frame;
|
||||||
|
for (int i = 0; rtc::SafeLt(i, speech_probabilities.size()); ++i) {
|
||||||
|
SCOPED_TRACE(i);
|
||||||
|
EXPECT_EQ(speech_probabilities[i],
|
||||||
|
analyzer->AnalyzeFrame(frame.view).speech_probability);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Checks that the smoothed speech probability does not instantly converge to
|
||||||
|
// the unprocessed one when slow attack is used.
|
||||||
|
TEST(AutomaticGainController2VadLevelAnalyzer,
|
||||||
|
SlowAttackSpeechProbabilitySmoothing) {
|
||||||
|
const std::vector<float> speech_probabilities{0.f, 0.f, 1.f, 1.f, 1.f, 1.f};
|
||||||
|
auto analyzer =
|
||||||
|
CreateVadLevelAnalyzerWithMockVad(kSlowAttack, speech_probabilities);
|
||||||
|
FrameWithView frame;
|
||||||
|
float prev_probability = 0.f;
|
||||||
|
for (int i = 0; rtc::SafeLt(i, speech_probabilities.size()); ++i) {
|
||||||
|
SCOPED_TRACE(i);
|
||||||
|
const float smoothed_probability =
|
||||||
|
analyzer->AnalyzeFrame(frame.view).speech_probability;
|
||||||
|
EXPECT_LT(smoothed_probability, 1.f); // Not enough time to reach 1.
|
||||||
|
EXPECT_LE(prev_probability, smoothed_probability); // Converge towards 1.
|
||||||
|
prev_probability = smoothed_probability;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Checks that the smoothed speech probability instantly decays to the
|
||||||
|
// unprocessed one when slow attack is used.
|
||||||
|
TEST(AutomaticGainController2VadLevelAnalyzer, SpeechProbabilityInstantDecay) {
|
||||||
|
const std::vector<float> speech_probabilities{1.f, 1.f, 1.f, 1.f, 1.f, 0.f};
|
||||||
|
auto analyzer =
|
||||||
|
CreateVadLevelAnalyzerWithMockVad(kSlowAttack, speech_probabilities);
|
||||||
|
FrameWithView frame;
|
||||||
|
for (int i = 0; rtc::SafeLt(i, speech_probabilities.size() - 1); ++i) {
|
||||||
|
analyzer->AnalyzeFrame(frame.view);
|
||||||
|
}
|
||||||
|
EXPECT_EQ(0.f, analyzer->AnalyzeFrame(frame.view).speech_probability);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace webrtc
|
} // namespace webrtc
|
||||||
|
Reference in New Issue
Block a user