AGC2: VAD moved into GainController2

Bit exactness verified with audioproc_f on a collection of AEC dumps
and Wav files (42 recordings in total).

Bug: webrtc:7494
Change-Id: Id9849c4463791f5a203afe31efc163efb4d4458e
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/234583
Reviewed-by: Hanna Silen <silen@webrtc.org>
Commit-Queue: Alessio Bazzica <alessiob@webrtc.org>
Cr-Commit-Position: refs/heads/main@{#35248}
This commit is contained in:
Alessio Bazzica
2021-10-15 13:57:56 +02:00
committed by WebRTC LUCI CQ
parent 5544364161
commit b4d4ae2c23
9 changed files with 94 additions and 69 deletions

View File

@ -132,8 +132,10 @@ rtc_library("gain_controller2") {
"../../rtc_base:stringutils", "../../rtc_base:stringutils",
"../../system_wrappers:field_trial", "../../system_wrappers:field_trial",
"agc2:adaptive_digital", "agc2:adaptive_digital",
"agc2:cpu_features",
"agc2:fixed_digital", "agc2:fixed_digital",
"agc2:gain_applier", "agc2:gain_applier",
"agc2:vad_wrapper",
] ]
} }

View File

@ -52,7 +52,6 @@ rtc_library("adaptive_digital") {
"../../../rtc_base:rtc_base_approved", "../../../rtc_base:rtc_base_approved",
"../../../rtc_base:safe_compare", "../../../rtc_base:safe_compare",
"../../../rtc_base:safe_minmax", "../../../rtc_base:safe_minmax",
"../../../system_wrappers:field_trial",
"../../../system_wrappers:metrics", "../../../system_wrappers:metrics",
] ]
@ -150,7 +149,11 @@ rtc_library("vad_wrapper") {
"vad_wrapper.cc", "vad_wrapper.cc",
"vad_wrapper.h", "vad_wrapper.h",
] ]
visibility = [ "./*" ]
visibility = [
"..:gain_controller2",
"./*",
]
defines = [] defines = []
if (rtc_build_with_neon && current_cpu != "arm64") { if (rtc_build_with_neon && current_cpu != "arm64") {

View File

@ -11,31 +11,14 @@
#include "modules/audio_processing/agc2/adaptive_agc.h" #include "modules/audio_processing/agc2/adaptive_agc.h"
#include "common_audio/include/audio_util.h" #include "common_audio/include/audio_util.h"
#include "modules/audio_processing/agc2/cpu_features.h"
#include "modules/audio_processing/agc2/vad_wrapper.h" #include "modules/audio_processing/agc2/vad_wrapper.h"
#include "modules/audio_processing/logging/apm_data_dumper.h" #include "modules/audio_processing/logging/apm_data_dumper.h"
#include "rtc_base/checks.h" #include "rtc_base/checks.h"
#include "rtc_base/logging.h" #include "rtc_base/logging.h"
#include "system_wrappers/include/field_trial.h"
namespace webrtc { namespace webrtc {
namespace { namespace {
// Detects the available CPU features and applies any kill-switches.
AvailableCpuFeatures GetAllowedCpuFeatures() {
AvailableCpuFeatures features = GetAvailableCpuFeatures();
if (field_trial::IsEnabled("WebRTC-Agc2SimdSse2KillSwitch")) {
features.sse2 = false;
}
if (field_trial::IsEnabled("WebRTC-Agc2SimdAvx2KillSwitch")) {
features.avx2 = false;
}
if (field_trial::IsEnabled("WebRTC-Agc2SimdNeonKillSwitch")) {
features.neon = false;
}
return features;
}
// Peak and RMS audio levels in dBFS. // Peak and RMS audio levels in dBFS.
struct AudioLevels { struct AudioLevels {
float peak_dbfs; float peak_dbfs;
@ -60,7 +43,6 @@ AdaptiveAgc::AdaptiveAgc(
ApmDataDumper* apm_data_dumper, ApmDataDumper* apm_data_dumper,
const AudioProcessing::Config::GainController2::AdaptiveDigital& config) const AudioProcessing::Config::GainController2::AdaptiveDigital& config)
: speech_level_estimator_(apm_data_dumper, config), : speech_level_estimator_(apm_data_dumper, config),
vad_(config.vad_reset_period_ms, GetAllowedCpuFeatures()),
gain_controller_(apm_data_dumper, config), gain_controller_(apm_data_dumper, config),
apm_data_dumper_(apm_data_dumper), apm_data_dumper_(apm_data_dumper),
noise_level_estimator_(CreateNoiseFloorEstimator(apm_data_dumper)), noise_level_estimator_(CreateNoiseFloorEstimator(apm_data_dumper)),
@ -77,18 +59,18 @@ AdaptiveAgc::~AdaptiveAgc() = default;
void AdaptiveAgc::Initialize(int sample_rate_hz, int num_channels) { void AdaptiveAgc::Initialize(int sample_rate_hz, int num_channels) {
gain_controller_.Initialize(sample_rate_hz, num_channels); gain_controller_.Initialize(sample_rate_hz, num_channels);
vad_.Initialize(sample_rate_hz);
} }
void AdaptiveAgc::Process(AudioFrameView<float> frame, float limiter_envelope) { void AdaptiveAgc::Process(AudioFrameView<float> frame,
float speech_probability,
float limiter_envelope) {
AudioLevels levels = ComputeAudioLevels(frame); AudioLevels levels = ComputeAudioLevels(frame);
apm_data_dumper_->DumpRaw("agc2_input_rms_dbfs", levels.rms_dbfs);
apm_data_dumper_->DumpRaw("agc2_input_peak_dbfs", levels.peak_dbfs);
AdaptiveDigitalGainApplier::FrameInfo info; AdaptiveDigitalGainApplier::FrameInfo info;
info.speech_probability = vad_.Analyze(frame); info.speech_probability = speech_probability;
apm_data_dumper_->DumpRaw("agc2_speech_probability", info.speech_probability);
apm_data_dumper_->DumpRaw("agc2_input_rms_dbfs", levels.rms_dbfs);
apm_data_dumper_->DumpRaw("agc2_input_peak_dbfs", levels.peak_dbfs);
speech_level_estimator_.Update(levels.rms_dbfs, levels.peak_dbfs, speech_level_estimator_.Update(levels.rms_dbfs, levels.peak_dbfs,
info.speech_probability); info.speech_probability);

View File

@ -17,7 +17,6 @@
#include "modules/audio_processing/agc2/adaptive_mode_level_estimator.h" #include "modules/audio_processing/agc2/adaptive_mode_level_estimator.h"
#include "modules/audio_processing/agc2/noise_level_estimator.h" #include "modules/audio_processing/agc2/noise_level_estimator.h"
#include "modules/audio_processing/agc2/saturation_protector.h" #include "modules/audio_processing/agc2/saturation_protector.h"
#include "modules/audio_processing/agc2/vad_wrapper.h"
#include "modules/audio_processing/include/audio_frame_view.h" #include "modules/audio_processing/include/audio_frame_view.h"
#include "modules/audio_processing/include/audio_processing.h" #include "modules/audio_processing/include/audio_processing.h"
@ -38,16 +37,17 @@ class AdaptiveAgc {
// TODO(crbug.com/webrtc/7494): Add `SetLimiterEnvelope()`. // TODO(crbug.com/webrtc/7494): Add `SetLimiterEnvelope()`.
// Analyzes `frame` and applies a digital adaptive gain to it. Takes into // Analyzes `frame` and applies a digital adaptive gain to it. Takes into
// account the envelope measured by the limiter. // account the speech probability and the envelope measured by the limiter.
// TODO(crbug.com/webrtc/7494): Remove `limiter_envelope`. // TODO(crbug.com/webrtc/7494): Remove `limiter_envelope`.
void Process(AudioFrameView<float> frame, float limiter_envelope); void Process(AudioFrameView<float> frame,
float speech_probability,
float limiter_envelope);
// Handles a gain change applied to the input signal (e.g., analog gain). // Handles a gain change applied to the input signal (e.g., analog gain).
void HandleInputGainChange(); void HandleInputGainChange();
private: private:
AdaptiveModeLevelEstimator speech_level_estimator_; AdaptiveModeLevelEstimator speech_level_estimator_;
VoiceActivityDetectorWrapper vad_;
AdaptiveDigitalGainApplier gain_controller_; AdaptiveDigitalGainApplier gain_controller_;
ApmDataDumper* const apm_data_dumper_; ApmDataDumper* const apm_data_dumper_;
std::unique_ptr<NoiseLevelEstimator> noise_level_estimator_; std::unique_ptr<NoiseLevelEstimator> noise_level_estimator_;

View File

@ -54,24 +54,25 @@ class MonoVadImpl : public VoiceActivityDetectorWrapper::MonoVad {
VoiceActivityDetectorWrapper::VoiceActivityDetectorWrapper( VoiceActivityDetectorWrapper::VoiceActivityDetectorWrapper(
int vad_reset_period_ms, int vad_reset_period_ms,
const AvailableCpuFeatures& cpu_features) const AvailableCpuFeatures& cpu_features,
: VoiceActivityDetectorWrapper( int sample_rate_hz)
vad_reset_period_ms, : VoiceActivityDetectorWrapper(vad_reset_period_ms,
std::make_unique<MonoVadImpl>(cpu_features)) {} std::make_unique<MonoVadImpl>(cpu_features),
sample_rate_hz) {}
VoiceActivityDetectorWrapper::VoiceActivityDetectorWrapper( VoiceActivityDetectorWrapper::VoiceActivityDetectorWrapper(
int vad_reset_period_ms, int vad_reset_period_ms,
std::unique_ptr<MonoVad> vad) std::unique_ptr<MonoVad> vad,
int sample_rate_hz)
: vad_reset_period_frames_( : vad_reset_period_frames_(
rtc::CheckedDivExact(vad_reset_period_ms, kFrameDurationMs)), rtc::CheckedDivExact(vad_reset_period_ms, kFrameDurationMs)),
initialized_(false),
frame_size_(0),
time_to_vad_reset_(vad_reset_period_frames_), time_to_vad_reset_(vad_reset_period_frames_),
vad_(std::move(vad)) { vad_(std::move(vad)) {
RTC_DCHECK(vad_); RTC_DCHECK(vad_);
RTC_DCHECK_GT(vad_reset_period_frames_, 1); RTC_DCHECK_GT(vad_reset_period_frames_, 1);
resampled_buffer_.resize( resampled_buffer_.resize(
rtc::CheckedDivExact(vad_->SampleRateHz(), kNumFramesPerSecond)); rtc::CheckedDivExact(vad_->SampleRateHz(), kNumFramesPerSecond));
Initialize(sample_rate_hz);
} }
VoiceActivityDetectorWrapper::~VoiceActivityDetectorWrapper() = default; VoiceActivityDetectorWrapper::~VoiceActivityDetectorWrapper() = default;
@ -85,11 +86,9 @@ void VoiceActivityDetectorWrapper::Initialize(int sample_rate_hz) {
constexpr int kStatusOk = 0; constexpr int kStatusOk = 0;
RTC_DCHECK_EQ(status, kStatusOk); RTC_DCHECK_EQ(status, kStatusOk);
vad_->Reset(); vad_->Reset();
initialized_ = true;
} }
float VoiceActivityDetectorWrapper::Analyze(AudioFrameView<const float> frame) { float VoiceActivityDetectorWrapper::Analyze(AudioFrameView<const float> frame) {
RTC_DCHECK(initialized_);
// Periodically reset the VAD. // Periodically reset the VAD.
time_to_vad_reset_--; time_to_vad_reset_--;
if (time_to_vad_reset_ <= 0) { if (time_to_vad_reset_ <= 0) {

View File

@ -43,20 +43,20 @@ class VoiceActivityDetectorWrapper {
// Ctor. `vad_reset_period_ms` indicates the period in milliseconds to call // Ctor. `vad_reset_period_ms` indicates the period in milliseconds to call
// `MonoVad::Reset()`; it must be equal to or greater than the duration of two // `MonoVad::Reset()`; it must be equal to or greater than the duration of two
// frames. Uses `cpu_features` to instantiate the default VAD. // frames. Uses `cpu_features` to instantiate the default VAD.
// TODO(bugs.webrtc.org/7494): Pass sample rate.
VoiceActivityDetectorWrapper(int vad_reset_period_ms, VoiceActivityDetectorWrapper(int vad_reset_period_ms,
const AvailableCpuFeatures& cpu_features); const AvailableCpuFeatures& cpu_features,
int sample_rate_hz);
// Ctor. Uses a custom `vad`. // Ctor. Uses a custom `vad`.
VoiceActivityDetectorWrapper(int vad_reset_period_ms, VoiceActivityDetectorWrapper(int vad_reset_period_ms,
std::unique_ptr<MonoVad> vad); std::unique_ptr<MonoVad> vad,
int sample_rate_hz);
VoiceActivityDetectorWrapper(const VoiceActivityDetectorWrapper&) = delete; VoiceActivityDetectorWrapper(const VoiceActivityDetectorWrapper&) = delete;
VoiceActivityDetectorWrapper& operator=(const VoiceActivityDetectorWrapper&) = VoiceActivityDetectorWrapper& operator=(const VoiceActivityDetectorWrapper&) =
delete; delete;
~VoiceActivityDetectorWrapper(); ~VoiceActivityDetectorWrapper();
// TODO(bugs.webrtc.org/7494): Call initialize in the ctor. // Initializes the VAD wrapper.
// Initializes the VAD wrapper. Must be called before `Analyze()`.
void Initialize(int sample_rate_hz); void Initialize(int sample_rate_hz);
// Analyzes the first channel of `frame` and returns the speech probability. // Analyzes the first channel of `frame` and returns the speech probability.
@ -66,8 +66,6 @@ class VoiceActivityDetectorWrapper {
private: private:
const int vad_reset_period_frames_; const int vad_reset_period_frames_;
// TODO(bugs.webrtc.org/7494): Remove `initialized_`.
bool initialized_;
int frame_size_; int frame_size_;
int time_to_vad_reset_; int time_to_vad_reset_;
PushResampler<float> resampler_; PushResampler<float> resampler_;

View File

@ -31,6 +31,8 @@ using ::testing::Return;
using ::testing::ReturnRoundRobin; using ::testing::ReturnRoundRobin;
using ::testing::Truly; using ::testing::Truly;
constexpr int kNumFramesPerSecond = 100;
constexpr int kNoVadPeriodicReset = constexpr int kNoVadPeriodicReset =
kFrameDurationMs * (std::numeric_limits<int>::max() / kFrameDurationMs); kFrameDurationMs * (std::numeric_limits<int>::max() / kFrameDurationMs);
@ -52,8 +54,7 @@ TEST(GainController2VoiceActivityDetectorWrapper, CtorAndInitReadSampleRate) {
.WillRepeatedly(Return(kSampleRate8kHz)); .WillRepeatedly(Return(kSampleRate8kHz));
EXPECT_CALL(*vad, Reset).Times(AnyNumber()); EXPECT_CALL(*vad, Reset).Times(AnyNumber());
auto vad_wrapper = std::make_unique<VoiceActivityDetectorWrapper>( auto vad_wrapper = std::make_unique<VoiceActivityDetectorWrapper>(
kNoVadPeriodicReset, std::move(vad)); kNoVadPeriodicReset, std::move(vad), kSampleRate8kHz);
vad_wrapper->Initialize(kSampleRate8kHz);
} }
// Creates a `VoiceActivityDetectorWrapper` injecting a mock VAD that // Creates a `VoiceActivityDetectorWrapper` injecting a mock VAD that
@ -61,27 +62,29 @@ TEST(GainController2VoiceActivityDetectorWrapper, CtorAndInitReadSampleRate) {
// restarts from the beginning when after the last element is returned. // restarts from the beginning when after the last element is returned.
std::unique_ptr<VoiceActivityDetectorWrapper> CreateMockVadWrapper( std::unique_ptr<VoiceActivityDetectorWrapper> CreateMockVadWrapper(
int vad_reset_period_ms, int vad_reset_period_ms,
int sample_rate_hz,
const std::vector<float>& speech_probabilities, const std::vector<float>& speech_probabilities,
int expected_vad_reset_calls) { int expected_vad_reset_calls) {
auto vad = std::make_unique<MockVad>(); auto vad = std::make_unique<MockVad>();
EXPECT_CALL(*vad, SampleRateHz) EXPECT_CALL(*vad, SampleRateHz)
.Times(AnyNumber()) .Times(AnyNumber())
.WillRepeatedly(Return(kSampleRate8kHz)); .WillRepeatedly(Return(sample_rate_hz));
if (expected_vad_reset_calls >= 0) { if (expected_vad_reset_calls >= 0) {
EXPECT_CALL(*vad, Reset).Times(expected_vad_reset_calls); EXPECT_CALL(*vad, Reset).Times(expected_vad_reset_calls);
} }
EXPECT_CALL(*vad, Analyze) EXPECT_CALL(*vad, Analyze)
.Times(AnyNumber()) .Times(AnyNumber())
.WillRepeatedly(ReturnRoundRobin(speech_probabilities)); .WillRepeatedly(ReturnRoundRobin(speech_probabilities));
return std::make_unique<VoiceActivityDetectorWrapper>(vad_reset_period_ms, return std::make_unique<VoiceActivityDetectorWrapper>(
std::move(vad)); vad_reset_period_ms, std::move(vad), kSampleRate8kHz);
} }
// 10 ms mono frame. // 10 ms mono frame.
struct FrameWithView { struct FrameWithView {
// Ctor. Initializes the frame samples with `value`. // Ctor. Initializes the frame samples with `value`.
explicit FrameWithView(int sample_rate_hz) explicit FrameWithView(int sample_rate_hz)
: samples(rtc::CheckedDivExact(sample_rate_hz, 100), 0.0f), : samples(rtc::CheckedDivExact(sample_rate_hz, kNumFramesPerSecond),
0.0f),
channel0(samples.data()), channel0(samples.data()),
view(&channel0, /*num_channels=*/1, samples.size()) {} view(&channel0, /*num_channels=*/1, samples.size()) {}
std::vector<float> samples; std::vector<float> samples;
@ -94,10 +97,9 @@ TEST(GainController2VoiceActivityDetectorWrapper, CheckSpeechProbabilities) {
const std::vector<float> speech_probabilities{0.709f, 0.484f, 0.882f, 0.167f, const std::vector<float> speech_probabilities{0.709f, 0.484f, 0.882f, 0.167f,
0.44f, 0.525f, 0.858f, 0.314f, 0.44f, 0.525f, 0.858f, 0.314f,
0.653f, 0.965f, 0.413f, 0.0f}; 0.653f, 0.965f, 0.413f, 0.0f};
auto vad_wrapper = auto vad_wrapper = CreateMockVadWrapper(kNoVadPeriodicReset, kSampleRate8kHz,
CreateMockVadWrapper(kNoVadPeriodicReset, speech_probabilities, speech_probabilities,
/*expected_vad_reset_calls=*/1); /*expected_vad_reset_calls=*/1);
vad_wrapper->Initialize(kSampleRate8kHz);
FrameWithView frame(kSampleRate8kHz); FrameWithView frame(kSampleRate8kHz);
for (int i = 0; rtc::SafeLt(i, speech_probabilities.size()); ++i) { for (int i = 0; rtc::SafeLt(i, speech_probabilities.size()); ++i) {
SCOPED_TRACE(i); SCOPED_TRACE(i);
@ -108,10 +110,9 @@ TEST(GainController2VoiceActivityDetectorWrapper, CheckSpeechProbabilities) {
// Checks that the VAD is not periodically reset. // Checks that the VAD is not periodically reset.
TEST(GainController2VoiceActivityDetectorWrapper, VadNoPeriodicReset) { TEST(GainController2VoiceActivityDetectorWrapper, VadNoPeriodicReset) {
constexpr int kNumFrames = 19; constexpr int kNumFrames = 19;
auto vad_wrapper = auto vad_wrapper = CreateMockVadWrapper(kNoVadPeriodicReset, kSampleRate8kHz,
CreateMockVadWrapper(kNoVadPeriodicReset, /*speech_probabilities=*/{1.0f}, /*speech_probabilities=*/{1.0f},
/*expected_vad_reset_calls=*/1); /*expected_vad_reset_calls=*/1);
vad_wrapper->Initialize(kSampleRate8kHz);
FrameWithView frame(kSampleRate8kHz); FrameWithView frame(kSampleRate8kHz);
for (int i = 0; i < kNumFrames; ++i) { for (int i = 0; i < kNumFrames; ++i) {
vad_wrapper->Analyze(frame.view); vad_wrapper->Analyze(frame.view);
@ -129,10 +130,10 @@ class VadPeriodResetParametrization
TEST_P(VadPeriodResetParametrization, VadPeriodicReset) { TEST_P(VadPeriodResetParametrization, VadPeriodicReset) {
auto vad_wrapper = CreateMockVadWrapper( auto vad_wrapper = CreateMockVadWrapper(
/*vad_reset_period_ms=*/vad_reset_period_frames() * kFrameDurationMs, /*vad_reset_period_ms=*/vad_reset_period_frames() * kFrameDurationMs,
kSampleRate8kHz,
/*speech_probabilities=*/{1.0f}, /*speech_probabilities=*/{1.0f},
/*expected_vad_reset_calls=*/1 + /*expected_vad_reset_calls=*/1 +
num_frames() / vad_reset_period_frames()); num_frames() / vad_reset_period_frames());
vad_wrapper->Initialize(kSampleRate8kHz);
FrameWithView frame(kSampleRate8kHz); FrameWithView frame(kSampleRate8kHz);
for (int i = 0; i < num_frames(); ++i) { for (int i = 0; i < num_frames(); ++i) {
vad_wrapper->Analyze(frame.view); vad_wrapper->Analyze(frame.view);
@ -161,13 +162,12 @@ TEST_P(VadResamplingParametrization, CheckResampledFrameSize) {
.WillRepeatedly(Return(vad_sample_rate_hz())); .WillRepeatedly(Return(vad_sample_rate_hz()));
EXPECT_CALL(*vad, Reset).Times(1); EXPECT_CALL(*vad, Reset).Times(1);
EXPECT_CALL(*vad, Analyze(Truly([this](rtc::ArrayView<const float> frame) { EXPECT_CALL(*vad, Analyze(Truly([this](rtc::ArrayView<const float> frame) {
return rtc::SafeEq(frame.size(), return rtc::SafeEq(frame.size(), rtc::CheckedDivExact(vad_sample_rate_hz(),
rtc::CheckedDivExact(vad_sample_rate_hz(), 100)); kNumFramesPerSecond));
}))).Times(1); }))).Times(1);
auto vad_wrapper = std::make_unique<VoiceActivityDetectorWrapper>( auto vad_wrapper = std::make_unique<VoiceActivityDetectorWrapper>(
kNoVadPeriodicReset, std::move(vad)); kNoVadPeriodicReset, std::move(vad), input_sample_rate_hz());
FrameWithView frame(input_sample_rate_hz()); FrameWithView frame(input_sample_rate_hz());
vad_wrapper->Initialize(input_sample_rate_hz());
vad_wrapper->Analyze(frame.view); vad_wrapper->Analyze(frame.view);
} }

View File

@ -14,6 +14,7 @@
#include <utility> #include <utility>
#include "common_audio/include/audio_util.h" #include "common_audio/include/audio_util.h"
#include "modules/audio_processing/agc2/cpu_features.h"
#include "modules/audio_processing/audio_buffer.h" #include "modules/audio_processing/audio_buffer.h"
#include "modules/audio_processing/include/audio_frame_view.h" #include "modules/audio_processing/include/audio_frame_view.h"
#include "modules/audio_processing/logging/apm_data_dumper.h" #include "modules/audio_processing/logging/apm_data_dumper.h"
@ -21,6 +22,7 @@
#include "rtc_base/checks.h" #include "rtc_base/checks.h"
#include "rtc_base/logging.h" #include "rtc_base/logging.h"
#include "rtc_base/strings/string_builder.h" #include "rtc_base/strings/string_builder.h"
#include "system_wrappers/include/field_trial.h"
namespace webrtc { namespace webrtc {
namespace { namespace {
@ -33,6 +35,21 @@ constexpr int kFrameLengthMs = 10;
constexpr int kLogLimiterStatsPeriodNumFrames = constexpr int kLogLimiterStatsPeriodNumFrames =
kLogLimiterStatsPeriodMs / kFrameLengthMs; kLogLimiterStatsPeriodMs / kFrameLengthMs;
// Detects the available CPU features and applies any kill-switches.
AvailableCpuFeatures GetAllowedCpuFeatures() {
AvailableCpuFeatures features = GetAvailableCpuFeatures();
if (field_trial::IsEnabled("WebRTC-Agc2SimdSse2KillSwitch")) {
features.sse2 = false;
}
if (field_trial::IsEnabled("WebRTC-Agc2SimdAvx2KillSwitch")) {
features.avx2 = false;
}
if (field_trial::IsEnabled("WebRTC-Agc2SimdNeonKillSwitch")) {
features.neon = false;
}
return features;
}
// Creates an adaptive digital gain controller if enabled. // Creates an adaptive digital gain controller if enabled.
std::unique_ptr<AdaptiveAgc> CreateAdaptiveDigitalController( std::unique_ptr<AdaptiveAgc> CreateAdaptiveDigitalController(
const Agc2Config::AdaptiveDigital& config, const Agc2Config::AdaptiveDigital& config,
@ -40,7 +57,8 @@ std::unique_ptr<AdaptiveAgc> CreateAdaptiveDigitalController(
int num_channels, int num_channels,
ApmDataDumper* data_dumper) { ApmDataDumper* data_dumper) {
if (config.enabled) { if (config.enabled) {
// TODO(bugs.webrtc.org/7494): Also init with sample rate and num channels. // TODO(bugs.webrtc.org/7494): Also init with sample rate and num
// channels.
auto controller = std::make_unique<AdaptiveAgc>(data_dumper, config); auto controller = std::make_unique<AdaptiveAgc>(data_dumper, config);
// TODO(bugs.webrtc.org/7494): Remove once passed to the ctor. // TODO(bugs.webrtc.org/7494): Remove once passed to the ctor.
controller->Initialize(sample_rate_hz, num_channels); controller->Initialize(sample_rate_hz, num_channels);
@ -56,7 +74,8 @@ int GainController2::instance_count_ = 0;
GainController2::GainController2(const Agc2Config& config, GainController2::GainController2(const Agc2Config& config,
int sample_rate_hz, int sample_rate_hz,
int num_channels) int num_channels)
: data_dumper_(rtc::AtomicOps::Increment(&instance_count_)), : cpu_features_(GetAllowedCpuFeatures()),
data_dumper_(rtc::AtomicOps::Increment(&instance_count_)),
fixed_gain_applier_(/*hard_clip_samples=*/false, fixed_gain_applier_(/*hard_clip_samples=*/false,
/*initial_gain_factor=*/0.0f), /*initial_gain_factor=*/0.0f),
adaptive_digital_controller_( adaptive_digital_controller_(
@ -71,6 +90,14 @@ GainController2::GainController2(const Agc2Config& config,
data_dumper_.InitiateNewSetOfRecordings(); data_dumper_.InitiateNewSetOfRecordings();
// TODO(bugs.webrtc.org/7494): Set gain when `fixed_gain_applier_` is init'd. // TODO(bugs.webrtc.org/7494): Set gain when `fixed_gain_applier_` is init'd.
fixed_gain_applier_.SetGainFactor(DbToRatio(config.fixed_digital.gain_db)); fixed_gain_applier_.SetGainFactor(DbToRatio(config.fixed_digital.gain_db));
const bool use_vad = config.adaptive_digital.enabled;
if (use_vad) {
// TODO(bugs.webrtc.org/7494): Move `vad_reset_period_ms` from adaptive
// digital to gain controller 2 config.
vad_ = std::make_unique<VoiceActivityDetectorWrapper>(
config.adaptive_digital.vad_reset_period_ms, cpu_features_,
sample_rate_hz);
}
} }
GainController2::~GainController2() = default; GainController2::~GainController2() = default;
@ -82,6 +109,9 @@ void GainController2::Initialize(int sample_rate_hz, int num_channels) {
sample_rate_hz == AudioProcessing::kSampleRate48kHz); sample_rate_hz == AudioProcessing::kSampleRate48kHz);
// TODO(bugs.webrtc.org/7494): Initialize `fixed_gain_applier_`. // TODO(bugs.webrtc.org/7494): Initialize `fixed_gain_applier_`.
limiter_.SetSampleRate(sample_rate_hz); limiter_.SetSampleRate(sample_rate_hz);
if (vad_) {
vad_->Initialize(sample_rate_hz);
}
if (adaptive_digital_controller_) { if (adaptive_digital_controller_) {
adaptive_digital_controller_->Initialize(sample_rate_hz, num_channels); adaptive_digital_controller_->Initialize(sample_rate_hz, num_channels);
} }
@ -104,10 +134,17 @@ void GainController2::Process(AudioBuffer* audio) {
data_dumper_.DumpRaw("agc2_notified_analog_level", analog_level_); data_dumper_.DumpRaw("agc2_notified_analog_level", analog_level_);
AudioFrameView<float> float_frame(audio->channels(), audio->num_channels(), AudioFrameView<float> float_frame(audio->channels(), audio->num_channels(),
audio->num_frames()); audio->num_frames());
absl::optional<float> speech_probability;
// TODO(bugs.webrtc.org/7494): Apply fixed digital gain after VAD.
fixed_gain_applier_.ApplyGain(float_frame); fixed_gain_applier_.ApplyGain(float_frame);
if (vad_) {
speech_probability = vad_->Analyze(float_frame);
data_dumper_.DumpRaw("agc2_speech_probability", speech_probability.value());
}
if (adaptive_digital_controller_) { if (adaptive_digital_controller_) {
adaptive_digital_controller_->Process(float_frame, RTC_DCHECK(speech_probability.has_value());
limiter_.LastAudioLevel()); adaptive_digital_controller_->Process(
float_frame, speech_probability.value(), limiter_.LastAudioLevel());
} }
limiter_.Process(float_frame); limiter_.Process(float_frame);

View File

@ -15,8 +15,10 @@
#include <string> #include <string>
#include "modules/audio_processing/agc2/adaptive_agc.h" #include "modules/audio_processing/agc2/adaptive_agc.h"
#include "modules/audio_processing/agc2/cpu_features.h"
#include "modules/audio_processing/agc2/gain_applier.h" #include "modules/audio_processing/agc2/gain_applier.h"
#include "modules/audio_processing/agc2/limiter.h" #include "modules/audio_processing/agc2/limiter.h"
#include "modules/audio_processing/agc2/vad_wrapper.h"
#include "modules/audio_processing/include/audio_processing.h" #include "modules/audio_processing/include/audio_processing.h"
#include "modules/audio_processing/logging/apm_data_dumper.h" #include "modules/audio_processing/logging/apm_data_dumper.h"
@ -51,8 +53,10 @@ class GainController2 {
private: private:
static int instance_count_; static int instance_count_;
const AvailableCpuFeatures cpu_features_;
ApmDataDumper data_dumper_; ApmDataDumper data_dumper_;
GainApplier fixed_gain_applier_; GainApplier fixed_gain_applier_;
std::unique_ptr<VoiceActivityDetectorWrapper> vad_;
std::unique_ptr<AdaptiveAgc> adaptive_digital_controller_; std::unique_ptr<AdaptiveAgc> adaptive_digital_controller_;
Limiter limiter_; Limiter limiter_;
int calls_since_last_limiter_log_; int calls_since_last_limiter_log_;