From db6af36979b6417b8026c9ee47cc6ee2776ce760 Mon Sep 17 00:00:00 2001 From: Alex Loiko Date: Wed, 20 Jun 2018 14:14:18 +0200 Subject: [PATCH] Add RNN-VAD to AGC2. * Move 'VadWithLevel' to AGC2 where it belongs. * Remove the vectors from VadWithLevel. They were there to make it work with modules/audio_processing/vad, which we don't need any longer. * Remove the vector handling from AGC2. It was spread out across AdaptiveDigitalGainApplier, AdaptiveAGC and their unit tests. * Hack the RNN VAD into VadWithLevel. The main issue is the resampling. Bug: webrtc:9076 Change-Id: I13056c985d0ec41269735150caf4aaeb6ff9281e Reviewed-on: https://webrtc-review.googlesource.com/77364 Reviewed-by: Sam Zackrisson Commit-Queue: Alex Loiko Cr-Commit-Position: refs/heads/master@{#23688} --- modules/audio_processing/agc2/BUILD.gn | 19 ++++-- modules/audio_processing/agc2/adaptive_agc.cc | 19 +++--- modules/audio_processing/agc2/adaptive_agc.h | 2 +- .../agc2/adaptive_digital_gain_applier.cc | 20 ++---- .../agc2/adaptive_digital_gain_applier.h | 11 ++- .../adaptive_digital_gain_applier_unittest.cc | 36 ++++------ .../agc2/adaptive_mode_level_estimator.h | 2 +- modules/audio_processing/agc2/agc2_common.h | 9 ++- .../agc2/saturation_protector.h | 2 +- .../agc2/saturation_protector_unittest.cc | 7 ++ .../audio_processing/agc2/vad_with_level.cc | 68 +++++++++++++++++++ .../{vad => agc2}/vad_with_level.h | 25 ++++--- modules/audio_processing/vad/BUILD.gn | 10 --- 13 files changed, 145 insertions(+), 85 deletions(-) create mode 100644 modules/audio_processing/agc2/vad_with_level.cc rename modules/audio_processing/{vad => agc2}/vad_with_level.h (60%) diff --git a/modules/audio_processing/agc2/BUILD.gn b/modules/audio_processing/agc2/BUILD.gn index e0ed2bb765..8501dd9d5f 100644 --- a/modules/audio_processing/agc2/BUILD.gn +++ b/modules/audio_processing/agc2/BUILD.gn @@ -33,6 +33,7 @@ rtc_source_set("adaptive_digital") { ":common", ":gain_applier", ":noise_level_estimator", + ":rnn_vad_with_level", "..:aec_core", "..:apm_logging", "..:audio_frame_view", @@ -41,9 +42,6 @@ rtc_source_set("adaptive_digital") { "../../../rtc_base:checks", "../../../rtc_base:rtc_base_approved", "../../../rtc_base:safe_minmax", - "../vad", - "../vad:vad_with_level", - "rnn_vad", ] } @@ -133,6 +131,20 @@ rtc_source_set("noise_level_estimator") { configs += [ "..:apm_debug_dump" ] } +rtc_source_set("rnn_vad_with_level") { + sources = [ + "vad_with_level.cc", + "vad_with_level.h", + ] + deps = [ + "..:audio_frame_view", + "../../../api:array_view", + "../../../common_audio", + "../../../rtc_base:checks", + "rnn_vad:lib", + ] +} + rtc_source_set("adaptive_digital_unittests") { testonly = true configs += [ "..:apm_debug_dump" ] @@ -155,7 +167,6 @@ rtc_source_set("adaptive_digital_unittests") { "../../../rtc_base:checks", "../../../rtc_base:rtc_base_approved", "../../../rtc_base:rtc_base_tests_utils", - "../vad:vad_with_level", ] } diff --git a/modules/audio_processing/agc2/adaptive_agc.cc b/modules/audio_processing/agc2/adaptive_agc.cc index 45e88531d8..7b242445d5 100644 --- a/modules/audio_processing/agc2/adaptive_agc.cc +++ b/modules/audio_processing/agc2/adaptive_agc.cc @@ -14,8 +14,8 @@ #include #include "common_audio/include/audio_util.h" +#include "modules/audio_processing/agc2/vad_with_level.h" #include "modules/audio_processing/logging/apm_data_dumper.h" -#include "modules/audio_processing/vad/voice_activity_detector.h" namespace webrtc { @@ -38,17 +38,14 @@ void AdaptiveAgc::Process(AudioFrameView float_frame) { // frames, and no estimates for other frames. We want to feed all to // the level estimator, but only care about the last level it // produces. - rtc::ArrayView vad_results = + const VadWithLevel::LevelAndProbability vad_result = vad_.AnalyzeFrame(float_frame); - for (const auto& vad_result : vad_results) { - apm_data_dumper_->DumpRaw("agc2_vad_probability", - vad_result.speech_probability); - apm_data_dumper_->DumpRaw("agc2_vad_rms_dbfs", vad_result.speech_rms_dbfs); + apm_data_dumper_->DumpRaw("agc2_vad_probability", + vad_result.speech_probability); + apm_data_dumper_->DumpRaw("agc2_vad_rms_dbfs", vad_result.speech_rms_dbfs); - apm_data_dumper_->DumpRaw("agc2_vad_peak_dbfs", - vad_result.speech_peak_dbfs); - speech_level_estimator_.UpdateEstimation(vad_result); - } + apm_data_dumper_->DumpRaw("agc2_vad_peak_dbfs", vad_result.speech_peak_dbfs); + speech_level_estimator_.UpdateEstimation(vad_result); const float speech_level_dbfs = speech_level_estimator_.LatestLevelEstimate(); @@ -57,7 +54,7 @@ void AdaptiveAgc::Process(AudioFrameView float_frame) { apm_data_dumper_->DumpRaw("agc2_noise_estimate_dbfs", noise_level_dbfs); // The gain applier applies the gain. - gain_applier_.Process(speech_level_dbfs, noise_level_dbfs, vad_results, + gain_applier_.Process(speech_level_dbfs, noise_level_dbfs, vad_result, float_frame); } diff --git a/modules/audio_processing/agc2/adaptive_agc.h b/modules/audio_processing/agc2/adaptive_agc.h index a91aa2ab86..dabe783f44 100644 --- a/modules/audio_processing/agc2/adaptive_agc.h +++ b/modules/audio_processing/agc2/adaptive_agc.h @@ -16,8 +16,8 @@ #include "modules/audio_processing/agc2/adaptive_digital_gain_applier.h" #include "modules/audio_processing/agc2/adaptive_mode_level_estimator.h" #include "modules/audio_processing/agc2/noise_level_estimator.h" +#include "modules/audio_processing/agc2/vad_with_level.h" #include "modules/audio_processing/include/audio_frame_view.h" -#include "modules/audio_processing/vad/vad_with_level.h" namespace webrtc { class ApmDataDumper; diff --git a/modules/audio_processing/agc2/adaptive_digital_gain_applier.cc b/modules/audio_processing/agc2/adaptive_digital_gain_applier.cc index 20b5a27282..f5b6b91adf 100644 --- a/modules/audio_processing/agc2/adaptive_digital_gain_applier.cc +++ b/modules/audio_processing/agc2/adaptive_digital_gain_applier.cc @@ -74,7 +74,7 @@ AdaptiveDigitalGainApplier::AdaptiveDigitalGainApplier( void AdaptiveDigitalGainApplier::Process( float input_level_dbfs, float input_noise_level_dbfs, - rtc::ArrayView vad_results, + const VadWithLevel::LevelAndProbability vad_result, AudioFrameView float_frame) { RTC_DCHECK_GE(input_level_dbfs, -150.f); RTC_DCHECK_LE(input_level_dbfs, 0.f); @@ -85,21 +85,9 @@ void AdaptiveDigitalGainApplier::Process( LimitGainByNoise(ComputeGainDb(input_level_dbfs), input_noise_level_dbfs, apm_data_dumper_); - // TODO(webrtc:7494): Remove this construct. Remove the vectors from - // VadWithData after we move to a VAD that outputs an estimate every - // kFrameDurationMs ms. - // - // Forbid increasing the gain when there is no speech. For some - // VADs, 'vad_results' has either many or 0 results. If there are 0 - // results, keep the old flag. If there are many results, and at - // least one is confident speech, we allow attenuation. - if (!vad_results.empty()) { - gain_increase_allowed_ = std::all_of( - vad_results.begin(), vad_results.end(), - [](const VadWithLevel::LevelAndProbability& vad_result) { - return vad_result.speech_probability > kVadConfidenceThreshold; - }); - } + // Forbid increasing the gain when there is no speech. + gain_increase_allowed_ = + vad_result.speech_probability > kVadConfidenceThreshold; const float gain_change_this_frame_db = ComputeGainChangeThisFrameDb( target_gain_db, last_gain_db_, gain_increase_allowed_); diff --git a/modules/audio_processing/agc2/adaptive_digital_gain_applier.h b/modules/audio_processing/agc2/adaptive_digital_gain_applier.h index b06c65b97d..31f87f16e5 100644 --- a/modules/audio_processing/agc2/adaptive_digital_gain_applier.h +++ b/modules/audio_processing/agc2/adaptive_digital_gain_applier.h @@ -13,8 +13,8 @@ #include "modules/audio_processing/agc2/agc2_common.h" #include "modules/audio_processing/agc2/gain_applier.h" +#include "modules/audio_processing/agc2/vad_with_level.h" #include "modules/audio_processing/include/audio_frame_view.h" -#include "modules/audio_processing/vad/vad_with_level.h" namespace webrtc { @@ -24,11 +24,10 @@ class AdaptiveDigitalGainApplier { public: explicit AdaptiveDigitalGainApplier(ApmDataDumper* apm_data_dumper); // Decide what gain to apply. - void Process( - float input_level_dbfs, - float input_noise_level_dbfs, - rtc::ArrayView vad_results, - AudioFrameView float_frame); + void Process(float input_level_dbfs, + float input_noise_level_dbfs, + const VadWithLevel::LevelAndProbability vad_result, + AudioFrameView float_frame); private: float last_gain_db_ = kInitialAdaptiveDigitalGainDb; diff --git a/modules/audio_processing/agc2/adaptive_digital_gain_applier_unittest.cc b/modules/audio_processing/agc2/adaptive_digital_gain_applier_unittest.cc index ebb040eb34..860da0034c 100644 --- a/modules/audio_processing/agc2/adaptive_digital_gain_applier_unittest.cc +++ b/modules/audio_processing/agc2/adaptive_digital_gain_applier_unittest.cc @@ -33,10 +33,8 @@ float RunOnConstantLevel(int num_iterations, for (int i = 0; i < num_iterations; ++i) { VectorFloatFrame fake_audio(1, 1, 1.f); - gain_applier->Process( - input_level_dbfs, kNoNoiseDbfs, - rtc::ArrayView(&vad_data, 1), - fake_audio.float_frame_view()); + gain_applier->Process(input_level_dbfs, kNoNoiseDbfs, vad_data, + fake_audio.float_frame_view()); gain_linear = fake_audio.float_frame_view().channel(0)[0]; } return gain_linear; @@ -54,10 +52,8 @@ TEST(AutomaticGainController2AdaptiveGainApplier, GainApplierShouldNotCrash) { // Make one call with reasonable audio level values and settings. VectorFloatFrame fake_audio(2, 480, 10000.f); - gain_applier.Process( - -5.0, kNoNoiseDbfs, - rtc::ArrayView(&kVadSpeech, 1), - fake_audio.float_frame_view()); + gain_applier.Process(-5.0, kNoNoiseDbfs, kVadSpeech, + fake_audio.float_frame_view()); } // Check that the output is -kHeadroom dBFS. @@ -107,10 +103,8 @@ TEST(AutomaticGainController2AdaptiveGainApplier, GainDoesNotChangeFast) { for (int i = 0; i < kNumFramesToAdapt; ++i) { SCOPED_TRACE(i); VectorFloatFrame fake_audio(1, 1, 1.f); - gain_applier.Process( - initial_level_dbfs, kNoNoiseDbfs, - rtc::ArrayView(&kVadSpeech, 1), - fake_audio.float_frame_view()); + gain_applier.Process(initial_level_dbfs, kNoNoiseDbfs, kVadSpeech, + fake_audio.float_frame_view()); float current_gain_linear = fake_audio.float_frame_view().channel(0)[0]; EXPECT_LE(std::abs(current_gain_linear - last_gain_linear), kMaxChangePerFrameLinear); @@ -121,10 +115,8 @@ TEST(AutomaticGainController2AdaptiveGainApplier, GainDoesNotChangeFast) { for (int i = 0; i < kNumFramesToAdapt; ++i) { SCOPED_TRACE(i); VectorFloatFrame fake_audio(1, 1, 1.f); - gain_applier.Process( - 0.f, kNoNoiseDbfs, - rtc::ArrayView(&kVadSpeech, 1), - fake_audio.float_frame_view()); + gain_applier.Process(0.f, kNoNoiseDbfs, kVadSpeech, + fake_audio.float_frame_view()); float current_gain_linear = fake_audio.float_frame_view().channel(0)[0]; EXPECT_LE(std::abs(current_gain_linear - last_gain_linear), kMaxChangePerFrameLinear); @@ -140,10 +132,8 @@ TEST(AutomaticGainController2AdaptiveGainApplier, GainIsRampedInAFrame) { constexpr int num_samples = 480; VectorFloatFrame fake_audio(1, num_samples, 1.f); - gain_applier.Process( - initial_level_dbfs, kNoNoiseDbfs, - rtc::ArrayView(&kVadSpeech, 1), - fake_audio.float_frame_view()); + gain_applier.Process(initial_level_dbfs, kNoNoiseDbfs, kVadSpeech, + fake_audio.float_frame_view()); float maximal_difference = 0.f; float current_value = 1.f * DbToRatio(kInitialAdaptiveDigitalGainDb); for (const auto& x : fake_audio.float_frame_view().channel(0)) { @@ -172,10 +162,8 @@ TEST(AutomaticGainController2AdaptiveGainApplier, NoiseLimitsGain) { for (int i = 0; i < num_initial_frames + num_frames; ++i) { VectorFloatFrame fake_audio(1, num_samples, 1.f); - gain_applier.Process( - initial_level_dbfs, kWithNoiseDbfs, - rtc::ArrayView(&kVadSpeech, 1), - fake_audio.float_frame_view()); + gain_applier.Process(initial_level_dbfs, kWithNoiseDbfs, kVadSpeech, + fake_audio.float_frame_view()); // Wait so that the adaptive gain applier has time to lower the gain. if (i > num_initial_frames) { diff --git a/modules/audio_processing/agc2/adaptive_mode_level_estimator.h b/modules/audio_processing/agc2/adaptive_mode_level_estimator.h index 9762f1fc55..186c59bf97 100644 --- a/modules/audio_processing/agc2/adaptive_mode_level_estimator.h +++ b/modules/audio_processing/agc2/adaptive_mode_level_estimator.h @@ -12,7 +12,7 @@ #define MODULES_AUDIO_PROCESSING_AGC2_ADAPTIVE_MODE_LEVEL_ESTIMATOR_H_ #include "modules/audio_processing/agc2/saturation_protector.h" -#include "modules/audio_processing/vad/vad_with_level.h" +#include "modules/audio_processing/agc2/vad_with_level.h" namespace webrtc { class ApmDataDumper; diff --git a/modules/audio_processing/agc2/agc2_common.h b/modules/audio_processing/agc2/agc2_common.h index 3ed88a36b8..7300653644 100644 --- a/modules/audio_processing/agc2/agc2_common.h +++ b/modules/audio_processing/agc2/agc2_common.h @@ -41,10 +41,10 @@ constexpr float kMaxNoiseLevelDbfs = -50.f; // Used in the Level Estimator for deciding when to update the speech // level estimate. Also used in the adaptive digital gain applier to // decide when to allow target gain reduction. -constexpr float kVadConfidenceThreshold = 0.9f; +constexpr float kVadConfidenceThreshold = 0.4f; // The amount of 'memory' of the Level Estimator. Decides leak factors. -constexpr size_t kFullBufferSizeMs = 1000; +constexpr size_t kFullBufferSizeMs = 1600; constexpr float kFullBufferLeakFactor = 1.f - 1.f / kFullBufferSizeMs; constexpr float kInitialSpeechLevelEstimateDbfs = -30.f; @@ -52,7 +52,10 @@ constexpr float kInitialSpeechLevelEstimateDbfs = -30.f; // Saturation Protector settings. constexpr float kInitialSaturationMarginDb = 17.f; -constexpr size_t kPeakEnveloperSuperFrameLengthMs = 500; +constexpr size_t kPeakEnveloperSuperFrameLengthMs = 400; +static_assert(kFullBufferSizeMs % kPeakEnveloperSuperFrameLengthMs == 0, + "Full buffer size should be a multiple of super frame length for " + "optimal Saturation Protector performance."); constexpr size_t kPeakEnveloperBufferSize = kFullBufferSizeMs / kPeakEnveloperSuperFrameLengthMs + 1; diff --git a/modules/audio_processing/agc2/saturation_protector.h b/modules/audio_processing/agc2/saturation_protector.h index d330c1514b..3a796fa65f 100644 --- a/modules/audio_processing/agc2/saturation_protector.h +++ b/modules/audio_processing/agc2/saturation_protector.h @@ -14,7 +14,7 @@ #include #include "modules/audio_processing/agc2/agc2_common.h" -#include "modules/audio_processing/vad/vad_with_level.h" +#include "modules/audio_processing/agc2/vad_with_level.h" namespace webrtc { diff --git a/modules/audio_processing/agc2/saturation_protector_unittest.cc b/modules/audio_processing/agc2/saturation_protector_unittest.cc index 88da2a235d..6013e13827 100644 --- a/modules/audio_processing/agc2/saturation_protector_unittest.cc +++ b/modules/audio_processing/agc2/saturation_protector_unittest.cc @@ -30,6 +30,7 @@ float RunOnConstantLevel(int num_iterations, max_difference = std::max(max_difference, std::abs(new_margin - last_margin)); last_margin = new_margin; + saturation_protector->DebugDumpEstimate(); } return max_difference; } @@ -127,6 +128,12 @@ TEST(AutomaticGainController2SaturationProtector, kLaterSpeechLevelDbfs, &saturation_protector), max_difference); + // The saturation protector expects that the RMS changes roughly + // 'kFullBufferSizeMs' after peaks change. This is to account for + // delay introduces by the level estimator. Therefore, the input + // above is 'normal' and 'expected', and shouldn't influence the + // margin by much. + const float total_difference = std::abs(saturation_protector.LastMargin() - kInitialSaturationMarginDb); diff --git a/modules/audio_processing/agc2/vad_with_level.cc b/modules/audio_processing/agc2/vad_with_level.cc new file mode 100644 index 0000000000..decfacd0c1 --- /dev/null +++ b/modules/audio_processing/agc2/vad_with_level.cc @@ -0,0 +1,68 @@ +/* + * Copyright (c) 2018 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/agc2/vad_with_level.h" + +#include + +#include "common_audio/include/audio_util.h" +#include "modules/audio_processing/agc2/rnn_vad/common.h" +#include "rtc_base/checks.h" + +namespace webrtc { + +namespace { +float ProcessForPeak(AudioFrameView frame) { + float current_max = 0; + for (const auto& x : frame.channel(0)) { + current_max = std::max(std::fabs(x), current_max); + } + return current_max; +} + +float ProcessForRms(AudioFrameView frame) { + float rms = 0; + for (const auto& x : frame.channel(0)) { + rms += x * x; + } + return sqrt(rms / frame.samples_per_channel()); +} +} // namespace + +VadWithLevel::VadWithLevel() = default; +VadWithLevel::~VadWithLevel() = default; + +VadWithLevel::LevelAndProbability VadWithLevel::AnalyzeFrame( + AudioFrameView frame) { + SetSampleRate(static_cast(frame.samples_per_channel() * 100)); + std::array work_frame; + // Feed the 1st channel to the resampler. + resampler_.Resample(frame.channel(0).data(), frame.samples_per_channel(), + work_frame.data(), rnn_vad::kFrameSize10ms24kHz); + + std::array feature_vector; + + const bool is_silence = features_extractor_.CheckSilenceComputeFeatures( + work_frame, feature_vector); + const float vad_probability = + rnn_vad_.ComputeVadProbability(feature_vector, is_silence); + return LevelAndProbability(vad_probability, + FloatS16ToDbfs(ProcessForRms(frame)), + FloatS16ToDbfs(ProcessForPeak(frame))); +} + +void VadWithLevel::SetSampleRate(int sample_rate_hz) { + // The source number of channels in 1, because we always use the 1st + // channel. + resampler_.InitializeIfNeeded(sample_rate_hz, rnn_vad::kSampleRate24kHz, + 1 /* num_channels */); +} + +} // namespace webrtc diff --git a/modules/audio_processing/vad/vad_with_level.h b/modules/audio_processing/agc2/vad_with_level.h similarity index 60% rename from modules/audio_processing/vad/vad_with_level.h rename to modules/audio_processing/agc2/vad_with_level.h index 9ad4d1701c..67a00ced6c 100644 --- a/modules/audio_processing/vad/vad_with_level.h +++ b/modules/audio_processing/agc2/vad_with_level.h @@ -8,10 +8,13 @@ * be found in the AUTHORS file in the root of the source tree. */ -#ifndef MODULES_AUDIO_PROCESSING_VAD_VAD_WITH_LEVEL_H_ -#define MODULES_AUDIO_PROCESSING_VAD_VAD_WITH_LEVEL_H_ +#ifndef MODULES_AUDIO_PROCESSING_AGC2_VAD_WITH_LEVEL_H_ +#define MODULES_AUDIO_PROCESSING_AGC2_VAD_WITH_LEVEL_H_ #include "api/array_view.h" +#include "common_audio/resampler/include/push_resampler.h" +#include "modules/audio_processing/agc2/rnn_vad/features_extraction.h" +#include "modules/audio_processing/agc2/rnn_vad/rnn.h" #include "modules/audio_processing/include/audio_frame_view.h" namespace webrtc { @@ -28,13 +31,19 @@ class VadWithLevel { float speech_peak_dbfs = 0; }; - // TODO(webrtc:7494): This is a stub. Add implementation. - rtc::ArrayView AnalyzeFrame( - AudioFrameView frame) { - return {nullptr, 0}; - } + VadWithLevel(); + ~VadWithLevel(); + + LevelAndProbability AnalyzeFrame(AudioFrameView frame); + + private: + void SetSampleRate(int sample_rate_hz); + + rnn_vad::RnnBasedVad rnn_vad_; + rnn_vad::FeaturesExtractor features_extractor_; + PushResampler resampler_; }; } // namespace webrtc -#endif // MODULES_AUDIO_PROCESSING_VAD_VAD_WITH_LEVEL_H_ +#endif // MODULES_AUDIO_PROCESSING_AGC2_VAD_WITH_LEVEL_H_ diff --git a/modules/audio_processing/vad/BUILD.gn b/modules/audio_processing/vad/BUILD.gn index 9a57789290..ae2a84d57c 100644 --- a/modules/audio_processing/vad/BUILD.gn +++ b/modules/audio_processing/vad/BUILD.gn @@ -44,16 +44,6 @@ rtc_static_library("vad") { ] } -rtc_source_set("vad_with_level") { - sources = [ - "vad_with_level.h", - ] - deps = [ - "..:audio_frame_view", - "../../../api:array_view", - ] -} - if (rtc_include_tests) { rtc_static_library("vad_unittests") { testonly = true