Add RNN-VAD to AGC2.

* Move 'VadWithLevel' to AGC2 where it belongs.
* Remove the vectors from VadWithLevel. They were there to make it work
  with modules/audio_processing/vad, which we don't need any longer.
* Remove the vector handling from AGC2. It was spread out across
  AdaptiveDigitalGainApplier, AdaptiveAGC and their unit tests.
* Hack the RNN VAD into VadWithLevel. The main issue is the resampling.


Bug: webrtc:9076
Change-Id: I13056c985d0ec41269735150caf4aaeb6ff9281e
Reviewed-on: https://webrtc-review.googlesource.com/77364
Reviewed-by: Sam Zackrisson <saza@webrtc.org>
Commit-Queue: Alex Loiko <aleloi@webrtc.org>
Cr-Commit-Position: refs/heads/master@{#23688}
This commit is contained in:
Alex Loiko
2018-06-20 14:14:18 +02:00
committed by Commit Bot
parent 87a9353cc9
commit db6af36979
13 changed files with 145 additions and 85 deletions

View File

@ -33,6 +33,7 @@ rtc_source_set("adaptive_digital") {
":common", ":common",
":gain_applier", ":gain_applier",
":noise_level_estimator", ":noise_level_estimator",
":rnn_vad_with_level",
"..:aec_core", "..:aec_core",
"..:apm_logging", "..:apm_logging",
"..:audio_frame_view", "..:audio_frame_view",
@ -41,9 +42,6 @@ rtc_source_set("adaptive_digital") {
"../../../rtc_base:checks", "../../../rtc_base:checks",
"../../../rtc_base:rtc_base_approved", "../../../rtc_base:rtc_base_approved",
"../../../rtc_base:safe_minmax", "../../../rtc_base:safe_minmax",
"../vad",
"../vad:vad_with_level",
"rnn_vad",
] ]
} }
@ -133,6 +131,20 @@ rtc_source_set("noise_level_estimator") {
configs += [ "..:apm_debug_dump" ] configs += [ "..:apm_debug_dump" ]
} }
rtc_source_set("rnn_vad_with_level") {
sources = [
"vad_with_level.cc",
"vad_with_level.h",
]
deps = [
"..:audio_frame_view",
"../../../api:array_view",
"../../../common_audio",
"../../../rtc_base:checks",
"rnn_vad:lib",
]
}
rtc_source_set("adaptive_digital_unittests") { rtc_source_set("adaptive_digital_unittests") {
testonly = true testonly = true
configs += [ "..:apm_debug_dump" ] configs += [ "..:apm_debug_dump" ]
@ -155,7 +167,6 @@ rtc_source_set("adaptive_digital_unittests") {
"../../../rtc_base:checks", "../../../rtc_base:checks",
"../../../rtc_base:rtc_base_approved", "../../../rtc_base:rtc_base_approved",
"../../../rtc_base:rtc_base_tests_utils", "../../../rtc_base:rtc_base_tests_utils",
"../vad:vad_with_level",
] ]
} }

View File

@ -14,8 +14,8 @@
#include <numeric> #include <numeric>
#include "common_audio/include/audio_util.h" #include "common_audio/include/audio_util.h"
#include "modules/audio_processing/agc2/vad_with_level.h"
#include "modules/audio_processing/logging/apm_data_dumper.h" #include "modules/audio_processing/logging/apm_data_dumper.h"
#include "modules/audio_processing/vad/voice_activity_detector.h"
namespace webrtc { namespace webrtc {
@ -38,17 +38,14 @@ void AdaptiveAgc::Process(AudioFrameView<float> float_frame) {
// frames, and no estimates for other frames. We want to feed all to // frames, and no estimates for other frames. We want to feed all to
// the level estimator, but only care about the last level it // the level estimator, but only care about the last level it
// produces. // produces.
rtc::ArrayView<const VadWithLevel::LevelAndProbability> vad_results = const VadWithLevel::LevelAndProbability vad_result =
vad_.AnalyzeFrame(float_frame); vad_.AnalyzeFrame(float_frame);
for (const auto& vad_result : vad_results) { apm_data_dumper_->DumpRaw("agc2_vad_probability",
apm_data_dumper_->DumpRaw("agc2_vad_probability", vad_result.speech_probability);
vad_result.speech_probability); apm_data_dumper_->DumpRaw("agc2_vad_rms_dbfs", vad_result.speech_rms_dbfs);
apm_data_dumper_->DumpRaw("agc2_vad_rms_dbfs", vad_result.speech_rms_dbfs);
apm_data_dumper_->DumpRaw("agc2_vad_peak_dbfs", apm_data_dumper_->DumpRaw("agc2_vad_peak_dbfs", vad_result.speech_peak_dbfs);
vad_result.speech_peak_dbfs); speech_level_estimator_.UpdateEstimation(vad_result);
speech_level_estimator_.UpdateEstimation(vad_result);
}
const float speech_level_dbfs = speech_level_estimator_.LatestLevelEstimate(); const float speech_level_dbfs = speech_level_estimator_.LatestLevelEstimate();
@ -57,7 +54,7 @@ void AdaptiveAgc::Process(AudioFrameView<float> float_frame) {
apm_data_dumper_->DumpRaw("agc2_noise_estimate_dbfs", noise_level_dbfs); apm_data_dumper_->DumpRaw("agc2_noise_estimate_dbfs", noise_level_dbfs);
// The gain applier applies the gain. // The gain applier applies the gain.
gain_applier_.Process(speech_level_dbfs, noise_level_dbfs, vad_results, gain_applier_.Process(speech_level_dbfs, noise_level_dbfs, vad_result,
float_frame); float_frame);
} }

View File

@ -16,8 +16,8 @@
#include "modules/audio_processing/agc2/adaptive_digital_gain_applier.h" #include "modules/audio_processing/agc2/adaptive_digital_gain_applier.h"
#include "modules/audio_processing/agc2/adaptive_mode_level_estimator.h" #include "modules/audio_processing/agc2/adaptive_mode_level_estimator.h"
#include "modules/audio_processing/agc2/noise_level_estimator.h" #include "modules/audio_processing/agc2/noise_level_estimator.h"
#include "modules/audio_processing/agc2/vad_with_level.h"
#include "modules/audio_processing/include/audio_frame_view.h" #include "modules/audio_processing/include/audio_frame_view.h"
#include "modules/audio_processing/vad/vad_with_level.h"
namespace webrtc { namespace webrtc {
class ApmDataDumper; class ApmDataDumper;

View File

@ -74,7 +74,7 @@ AdaptiveDigitalGainApplier::AdaptiveDigitalGainApplier(
void AdaptiveDigitalGainApplier::Process( void AdaptiveDigitalGainApplier::Process(
float input_level_dbfs, float input_level_dbfs,
float input_noise_level_dbfs, float input_noise_level_dbfs,
rtc::ArrayView<const VadWithLevel::LevelAndProbability> vad_results, const VadWithLevel::LevelAndProbability vad_result,
AudioFrameView<float> float_frame) { AudioFrameView<float> float_frame) {
RTC_DCHECK_GE(input_level_dbfs, -150.f); RTC_DCHECK_GE(input_level_dbfs, -150.f);
RTC_DCHECK_LE(input_level_dbfs, 0.f); RTC_DCHECK_LE(input_level_dbfs, 0.f);
@ -85,21 +85,9 @@ void AdaptiveDigitalGainApplier::Process(
LimitGainByNoise(ComputeGainDb(input_level_dbfs), input_noise_level_dbfs, LimitGainByNoise(ComputeGainDb(input_level_dbfs), input_noise_level_dbfs,
apm_data_dumper_); apm_data_dumper_);
// TODO(webrtc:7494): Remove this construct. Remove the vectors from // Forbid increasing the gain when there is no speech.
// VadWithData after we move to a VAD that outputs an estimate every gain_increase_allowed_ =
// kFrameDurationMs ms. vad_result.speech_probability > kVadConfidenceThreshold;
//
// Forbid increasing the gain when there is no speech. For some
// VADs, 'vad_results' has either many or 0 results. If there are 0
// results, keep the old flag. If there are many results, and at
// least one is confident speech, we allow attenuation.
if (!vad_results.empty()) {
gain_increase_allowed_ = std::all_of(
vad_results.begin(), vad_results.end(),
[](const VadWithLevel::LevelAndProbability& vad_result) {
return vad_result.speech_probability > kVadConfidenceThreshold;
});
}
const float gain_change_this_frame_db = ComputeGainChangeThisFrameDb( const float gain_change_this_frame_db = ComputeGainChangeThisFrameDb(
target_gain_db, last_gain_db_, gain_increase_allowed_); target_gain_db, last_gain_db_, gain_increase_allowed_);

View File

@ -13,8 +13,8 @@
#include "modules/audio_processing/agc2/agc2_common.h" #include "modules/audio_processing/agc2/agc2_common.h"
#include "modules/audio_processing/agc2/gain_applier.h" #include "modules/audio_processing/agc2/gain_applier.h"
#include "modules/audio_processing/agc2/vad_with_level.h"
#include "modules/audio_processing/include/audio_frame_view.h" #include "modules/audio_processing/include/audio_frame_view.h"
#include "modules/audio_processing/vad/vad_with_level.h"
namespace webrtc { namespace webrtc {
@ -24,11 +24,10 @@ class AdaptiveDigitalGainApplier {
public: public:
explicit AdaptiveDigitalGainApplier(ApmDataDumper* apm_data_dumper); explicit AdaptiveDigitalGainApplier(ApmDataDumper* apm_data_dumper);
// Decide what gain to apply. // Decide what gain to apply.
void Process( void Process(float input_level_dbfs,
float input_level_dbfs, float input_noise_level_dbfs,
float input_noise_level_dbfs, const VadWithLevel::LevelAndProbability vad_result,
rtc::ArrayView<const VadWithLevel::LevelAndProbability> vad_results, AudioFrameView<float> float_frame);
AudioFrameView<float> float_frame);
private: private:
float last_gain_db_ = kInitialAdaptiveDigitalGainDb; float last_gain_db_ = kInitialAdaptiveDigitalGainDb;

View File

@ -33,10 +33,8 @@ float RunOnConstantLevel(int num_iterations,
for (int i = 0; i < num_iterations; ++i) { for (int i = 0; i < num_iterations; ++i) {
VectorFloatFrame fake_audio(1, 1, 1.f); VectorFloatFrame fake_audio(1, 1, 1.f);
gain_applier->Process( gain_applier->Process(input_level_dbfs, kNoNoiseDbfs, vad_data,
input_level_dbfs, kNoNoiseDbfs, fake_audio.float_frame_view());
rtc::ArrayView<const VadWithLevel::LevelAndProbability>(&vad_data, 1),
fake_audio.float_frame_view());
gain_linear = fake_audio.float_frame_view().channel(0)[0]; gain_linear = fake_audio.float_frame_view().channel(0)[0];
} }
return gain_linear; return gain_linear;
@ -54,10 +52,8 @@ TEST(AutomaticGainController2AdaptiveGainApplier, GainApplierShouldNotCrash) {
// Make one call with reasonable audio level values and settings. // Make one call with reasonable audio level values and settings.
VectorFloatFrame fake_audio(2, 480, 10000.f); VectorFloatFrame fake_audio(2, 480, 10000.f);
gain_applier.Process( gain_applier.Process(-5.0, kNoNoiseDbfs, kVadSpeech,
-5.0, kNoNoiseDbfs, fake_audio.float_frame_view());
rtc::ArrayView<const VadWithLevel::LevelAndProbability>(&kVadSpeech, 1),
fake_audio.float_frame_view());
} }
// Check that the output is -kHeadroom dBFS. // Check that the output is -kHeadroom dBFS.
@ -107,10 +103,8 @@ TEST(AutomaticGainController2AdaptiveGainApplier, GainDoesNotChangeFast) {
for (int i = 0; i < kNumFramesToAdapt; ++i) { for (int i = 0; i < kNumFramesToAdapt; ++i) {
SCOPED_TRACE(i); SCOPED_TRACE(i);
VectorFloatFrame fake_audio(1, 1, 1.f); VectorFloatFrame fake_audio(1, 1, 1.f);
gain_applier.Process( gain_applier.Process(initial_level_dbfs, kNoNoiseDbfs, kVadSpeech,
initial_level_dbfs, kNoNoiseDbfs, fake_audio.float_frame_view());
rtc::ArrayView<const VadWithLevel::LevelAndProbability>(&kVadSpeech, 1),
fake_audio.float_frame_view());
float current_gain_linear = fake_audio.float_frame_view().channel(0)[0]; float current_gain_linear = fake_audio.float_frame_view().channel(0)[0];
EXPECT_LE(std::abs(current_gain_linear - last_gain_linear), EXPECT_LE(std::abs(current_gain_linear - last_gain_linear),
kMaxChangePerFrameLinear); kMaxChangePerFrameLinear);
@ -121,10 +115,8 @@ TEST(AutomaticGainController2AdaptiveGainApplier, GainDoesNotChangeFast) {
for (int i = 0; i < kNumFramesToAdapt; ++i) { for (int i = 0; i < kNumFramesToAdapt; ++i) {
SCOPED_TRACE(i); SCOPED_TRACE(i);
VectorFloatFrame fake_audio(1, 1, 1.f); VectorFloatFrame fake_audio(1, 1, 1.f);
gain_applier.Process( gain_applier.Process(0.f, kNoNoiseDbfs, kVadSpeech,
0.f, kNoNoiseDbfs, fake_audio.float_frame_view());
rtc::ArrayView<const VadWithLevel::LevelAndProbability>(&kVadSpeech, 1),
fake_audio.float_frame_view());
float current_gain_linear = fake_audio.float_frame_view().channel(0)[0]; float current_gain_linear = fake_audio.float_frame_view().channel(0)[0];
EXPECT_LE(std::abs(current_gain_linear - last_gain_linear), EXPECT_LE(std::abs(current_gain_linear - last_gain_linear),
kMaxChangePerFrameLinear); kMaxChangePerFrameLinear);
@ -140,10 +132,8 @@ TEST(AutomaticGainController2AdaptiveGainApplier, GainIsRampedInAFrame) {
constexpr int num_samples = 480; constexpr int num_samples = 480;
VectorFloatFrame fake_audio(1, num_samples, 1.f); VectorFloatFrame fake_audio(1, num_samples, 1.f);
gain_applier.Process( gain_applier.Process(initial_level_dbfs, kNoNoiseDbfs, kVadSpeech,
initial_level_dbfs, kNoNoiseDbfs, fake_audio.float_frame_view());
rtc::ArrayView<const VadWithLevel::LevelAndProbability>(&kVadSpeech, 1),
fake_audio.float_frame_view());
float maximal_difference = 0.f; float maximal_difference = 0.f;
float current_value = 1.f * DbToRatio(kInitialAdaptiveDigitalGainDb); float current_value = 1.f * DbToRatio(kInitialAdaptiveDigitalGainDb);
for (const auto& x : fake_audio.float_frame_view().channel(0)) { for (const auto& x : fake_audio.float_frame_view().channel(0)) {
@ -172,10 +162,8 @@ TEST(AutomaticGainController2AdaptiveGainApplier, NoiseLimitsGain) {
for (int i = 0; i < num_initial_frames + num_frames; ++i) { for (int i = 0; i < num_initial_frames + num_frames; ++i) {
VectorFloatFrame fake_audio(1, num_samples, 1.f); VectorFloatFrame fake_audio(1, num_samples, 1.f);
gain_applier.Process( gain_applier.Process(initial_level_dbfs, kWithNoiseDbfs, kVadSpeech,
initial_level_dbfs, kWithNoiseDbfs, fake_audio.float_frame_view());
rtc::ArrayView<const VadWithLevel::LevelAndProbability>(&kVadSpeech, 1),
fake_audio.float_frame_view());
// Wait so that the adaptive gain applier has time to lower the gain. // Wait so that the adaptive gain applier has time to lower the gain.
if (i > num_initial_frames) { if (i > num_initial_frames) {

View File

@ -12,7 +12,7 @@
#define MODULES_AUDIO_PROCESSING_AGC2_ADAPTIVE_MODE_LEVEL_ESTIMATOR_H_ #define MODULES_AUDIO_PROCESSING_AGC2_ADAPTIVE_MODE_LEVEL_ESTIMATOR_H_
#include "modules/audio_processing/agc2/saturation_protector.h" #include "modules/audio_processing/agc2/saturation_protector.h"
#include "modules/audio_processing/vad/vad_with_level.h" #include "modules/audio_processing/agc2/vad_with_level.h"
namespace webrtc { namespace webrtc {
class ApmDataDumper; class ApmDataDumper;

View File

@ -41,10 +41,10 @@ constexpr float kMaxNoiseLevelDbfs = -50.f;
// Used in the Level Estimator for deciding when to update the speech // Used in the Level Estimator for deciding when to update the speech
// level estimate. Also used in the adaptive digital gain applier to // level estimate. Also used in the adaptive digital gain applier to
// decide when to allow target gain reduction. // decide when to allow target gain reduction.
constexpr float kVadConfidenceThreshold = 0.9f; constexpr float kVadConfidenceThreshold = 0.4f;
// The amount of 'memory' of the Level Estimator. Decides leak factors. // The amount of 'memory' of the Level Estimator. Decides leak factors.
constexpr size_t kFullBufferSizeMs = 1000; constexpr size_t kFullBufferSizeMs = 1600;
constexpr float kFullBufferLeakFactor = 1.f - 1.f / kFullBufferSizeMs; constexpr float kFullBufferLeakFactor = 1.f - 1.f / kFullBufferSizeMs;
constexpr float kInitialSpeechLevelEstimateDbfs = -30.f; constexpr float kInitialSpeechLevelEstimateDbfs = -30.f;
@ -52,7 +52,10 @@ constexpr float kInitialSpeechLevelEstimateDbfs = -30.f;
// Saturation Protector settings. // Saturation Protector settings.
constexpr float kInitialSaturationMarginDb = 17.f; constexpr float kInitialSaturationMarginDb = 17.f;
constexpr size_t kPeakEnveloperSuperFrameLengthMs = 500; constexpr size_t kPeakEnveloperSuperFrameLengthMs = 400;
static_assert(kFullBufferSizeMs % kPeakEnveloperSuperFrameLengthMs == 0,
"Full buffer size should be a multiple of super frame length for "
"optimal Saturation Protector performance.");
constexpr size_t kPeakEnveloperBufferSize = constexpr size_t kPeakEnveloperBufferSize =
kFullBufferSizeMs / kPeakEnveloperSuperFrameLengthMs + 1; kFullBufferSizeMs / kPeakEnveloperSuperFrameLengthMs + 1;

View File

@ -14,7 +14,7 @@
#include <array> #include <array>
#include "modules/audio_processing/agc2/agc2_common.h" #include "modules/audio_processing/agc2/agc2_common.h"
#include "modules/audio_processing/vad/vad_with_level.h" #include "modules/audio_processing/agc2/vad_with_level.h"
namespace webrtc { namespace webrtc {

View File

@ -30,6 +30,7 @@ float RunOnConstantLevel(int num_iterations,
max_difference = max_difference =
std::max(max_difference, std::abs(new_margin - last_margin)); std::max(max_difference, std::abs(new_margin - last_margin));
last_margin = new_margin; last_margin = new_margin;
saturation_protector->DebugDumpEstimate();
} }
return max_difference; return max_difference;
} }
@ -127,6 +128,12 @@ TEST(AutomaticGainController2SaturationProtector,
kLaterSpeechLevelDbfs, &saturation_protector), kLaterSpeechLevelDbfs, &saturation_protector),
max_difference); max_difference);
// The saturation protector expects that the RMS changes roughly
// 'kFullBufferSizeMs' after peaks change. This is to account for
// delay introduces by the level estimator. Therefore, the input
// above is 'normal' and 'expected', and shouldn't influence the
// margin by much.
const float total_difference = const float total_difference =
std::abs(saturation_protector.LastMargin() - kInitialSaturationMarginDb); std::abs(saturation_protector.LastMargin() - kInitialSaturationMarginDb);

View File

@ -0,0 +1,68 @@
/*
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#include "modules/audio_processing/agc2/vad_with_level.h"
#include <algorithm>
#include "common_audio/include/audio_util.h"
#include "modules/audio_processing/agc2/rnn_vad/common.h"
#include "rtc_base/checks.h"
namespace webrtc {
namespace {
float ProcessForPeak(AudioFrameView<const float> frame) {
float current_max = 0;
for (const auto& x : frame.channel(0)) {
current_max = std::max(std::fabs(x), current_max);
}
return current_max;
}
float ProcessForRms(AudioFrameView<const float> frame) {
float rms = 0;
for (const auto& x : frame.channel(0)) {
rms += x * x;
}
return sqrt(rms / frame.samples_per_channel());
}
} // namespace
VadWithLevel::VadWithLevel() = default;
VadWithLevel::~VadWithLevel() = default;
VadWithLevel::LevelAndProbability VadWithLevel::AnalyzeFrame(
AudioFrameView<const float> frame) {
SetSampleRate(static_cast<int>(frame.samples_per_channel() * 100));
std::array<float, rnn_vad::kFrameSize10ms24kHz> work_frame;
// Feed the 1st channel to the resampler.
resampler_.Resample(frame.channel(0).data(), frame.samples_per_channel(),
work_frame.data(), rnn_vad::kFrameSize10ms24kHz);
std::array<float, rnn_vad::kFeatureVectorSize> feature_vector;
const bool is_silence = features_extractor_.CheckSilenceComputeFeatures(
work_frame, feature_vector);
const float vad_probability =
rnn_vad_.ComputeVadProbability(feature_vector, is_silence);
return LevelAndProbability(vad_probability,
FloatS16ToDbfs(ProcessForRms(frame)),
FloatS16ToDbfs(ProcessForPeak(frame)));
}
void VadWithLevel::SetSampleRate(int sample_rate_hz) {
// The source number of channels in 1, because we always use the 1st
// channel.
resampler_.InitializeIfNeeded(sample_rate_hz, rnn_vad::kSampleRate24kHz,
1 /* num_channels */);
}
} // namespace webrtc

View File

@ -8,10 +8,13 @@
* be found in the AUTHORS file in the root of the source tree. * be found in the AUTHORS file in the root of the source tree.
*/ */
#ifndef MODULES_AUDIO_PROCESSING_VAD_VAD_WITH_LEVEL_H_ #ifndef MODULES_AUDIO_PROCESSING_AGC2_VAD_WITH_LEVEL_H_
#define MODULES_AUDIO_PROCESSING_VAD_VAD_WITH_LEVEL_H_ #define MODULES_AUDIO_PROCESSING_AGC2_VAD_WITH_LEVEL_H_
#include "api/array_view.h" #include "api/array_view.h"
#include "common_audio/resampler/include/push_resampler.h"
#include "modules/audio_processing/agc2/rnn_vad/features_extraction.h"
#include "modules/audio_processing/agc2/rnn_vad/rnn.h"
#include "modules/audio_processing/include/audio_frame_view.h" #include "modules/audio_processing/include/audio_frame_view.h"
namespace webrtc { namespace webrtc {
@ -28,13 +31,19 @@ class VadWithLevel {
float speech_peak_dbfs = 0; float speech_peak_dbfs = 0;
}; };
// TODO(webrtc:7494): This is a stub. Add implementation. VadWithLevel();
rtc::ArrayView<const LevelAndProbability> AnalyzeFrame( ~VadWithLevel();
AudioFrameView<const float> frame) {
return {nullptr, 0}; LevelAndProbability AnalyzeFrame(AudioFrameView<const float> frame);
}
private:
void SetSampleRate(int sample_rate_hz);
rnn_vad::RnnBasedVad rnn_vad_;
rnn_vad::FeaturesExtractor features_extractor_;
PushResampler<float> resampler_;
}; };
} // namespace webrtc } // namespace webrtc
#endif // MODULES_AUDIO_PROCESSING_VAD_VAD_WITH_LEVEL_H_ #endif // MODULES_AUDIO_PROCESSING_AGC2_VAD_WITH_LEVEL_H_

View File

@ -44,16 +44,6 @@ rtc_static_library("vad") {
] ]
} }
rtc_source_set("vad_with_level") {
sources = [
"vad_with_level.h",
]
deps = [
"..:audio_frame_view",
"../../../api:array_view",
]
}
if (rtc_include_tests) { if (rtc_include_tests) {
rtc_static_library("vad_unittests") { rtc_static_library("vad_unittests") {
testonly = true testonly = true