diff --git a/webrtc/modules/audio_processing/audio_processing_impl.cc b/webrtc/modules/audio_processing/audio_processing_impl.cc index 34937f2095..b3f38f408e 100644 --- a/webrtc/modules/audio_processing/audio_processing_impl.cc +++ b/webrtc/modules/audio_processing/audio_processing_impl.cc @@ -771,12 +771,6 @@ int AudioProcessingImpl::ProcessStreamLocked() { ca->SplitIntoFrequencyBands(); } - if (constants_.intelligibility_enabled) { - public_submodules_->intelligibility_enhancer->AnalyzeCaptureAudio( - ca->split_channels_f(kBand0To8kHz), capture_nonlocked_.split_rate, - ca->num_channels()); - } - if (capture_nonlocked_.beamformer_enabled) { private_submodules_->beamformer->ProcessChunk(*ca->split_data_f(), ca->split_data_f()); @@ -793,6 +787,11 @@ int AudioProcessingImpl::ProcessStreamLocked() { ca->CopyLowPassToReference(); } public_submodules_->noise_suppression->ProcessCaptureAudio(ca); + if (constants_.intelligibility_enabled) { + RTC_DCHECK(public_submodules_->noise_suppression->is_enabled()); + public_submodules_->intelligibility_enhancer->SetCaptureNoiseEstimate( + public_submodules_->noise_suppression->NoiseEstimate()); + } RETURN_ON_ERR( public_submodules_->echo_control_mobile->ProcessCaptureAudio(ca)); public_submodules_->voice_detection->ProcessCaptureAudio(ca); diff --git a/webrtc/modules/audio_processing/intelligibility/intelligibility_enhancer.cc b/webrtc/modules/audio_processing/intelligibility/intelligibility_enhancer.cc index fe964aba8c..c42a1731b4 100644 --- a/webrtc/modules/audio_processing/intelligibility/intelligibility_enhancer.cc +++ b/webrtc/modules/audio_processing/intelligibility/intelligibility_enhancer.cc @@ -39,6 +39,26 @@ const float kKbdAlpha = 1.5f; const float kLambdaBot = -1.0f; // Extreme values in bisection const float kLambdaTop = -10e-18f; // search for lamda. +// Returns dot product of vectors |a| and |b| with size |length|. +float DotProduct(const float* a, const float* b, size_t length) { + float ret = 0.f; + for (size_t i = 0; i < length; ++i) { + ret = fmaf(a[i], b[i], ret); + } + return ret; +} + +// Computes the power across ERB filters from the power spectral density |var|. +// Stores it in |result|. +void FilterVariance(const float* var, + const std::vector>& filter_bank, + float* result) { + for (size_t i = 0; i < filter_bank.size(); ++i) { + RTC_DCHECK_GT(filter_bank[i].size(), 0u); + result[i] = DotProduct(&filter_bank[i][0], var, filter_bank[i].size()); + } +} + } // namespace using std::complex; @@ -47,9 +67,8 @@ using std::min; using VarianceType = intelligibility::VarianceArray::StepType; IntelligibilityEnhancer::TransformCallback::TransformCallback( - IntelligibilityEnhancer* parent, - IntelligibilityEnhancer::AudioSource source) - : parent_(parent), source_(source) { + IntelligibilityEnhancer* parent) + : parent_(parent) { } void IntelligibilityEnhancer::TransformCallback::ProcessAudioBlock( @@ -60,7 +79,7 @@ void IntelligibilityEnhancer::TransformCallback::ProcessAudioBlock( complex* const* out_block) { RTC_DCHECK_EQ(parent_->freqs_, frames); for (size_t i = 0; i < in_channels; ++i) { - parent_->DispatchAudio(source_, in_block[i], out_block[i]); + parent_->ProcessClearBlock(in_block[i], out_block[i]); } } @@ -85,27 +104,26 @@ IntelligibilityEnhancer::IntelligibilityEnhancer(const Config& config) config.var_type, config.var_window_size, config.var_decay_rate), - noise_variance_(freqs_, - config.var_type, - config.var_window_size, - config.var_decay_rate), filtered_clear_var_(new float[bank_size_]), filtered_noise_var_(new float[bank_size_]), - filter_bank_(bank_size_), center_freqs_(new float[bank_size_]), + render_filter_bank_(CreateErbBank(freqs_)), rho_(new float[bank_size_]), gains_eq_(new float[bank_size_]), gain_applier_(freqs_, config.gain_change_limit), temp_render_out_buffer_(chunk_length_, num_render_channels_), - temp_capture_out_buffer_(chunk_length_, num_capture_channels_), kbd_window_(new float[window_size_]), - render_callback_(this, AudioSource::kRenderStream), - capture_callback_(this, AudioSource::kCaptureStream), + render_callback_(this), block_count_(0), analysis_step_(0) { RTC_DCHECK_LE(config.rho, 1.0f); - CreateErbBank(); + memset(filtered_clear_var_.get(), + 0, + bank_size_ * sizeof(filtered_clear_var_[0])); + memset(filtered_noise_var_.get(), + 0, + bank_size_ * sizeof(filtered_noise_var_[0])); // Assumes all rho equal. for (size_t i = 0; i < bank_size_; ++i) { @@ -122,9 +140,20 @@ IntelligibilityEnhancer::IntelligibilityEnhancer(const Config& config) render_mangler_.reset(new LappedTransform( num_render_channels_, num_render_channels_, chunk_length_, kbd_window_.get(), window_size_, window_size_ / 2, &render_callback_)); - capture_mangler_.reset(new LappedTransform( - num_capture_channels_, num_capture_channels_, chunk_length_, - kbd_window_.get(), window_size_, window_size_ / 2, &capture_callback_)); +} + +void IntelligibilityEnhancer::SetCaptureNoiseEstimate( + std::vector noise) { + if (capture_filter_bank_.size() != bank_size_ || + capture_filter_bank_[0].size() != noise.size()) { + capture_filter_bank_ = CreateErbBank(noise.size()); + } + if (noise.size() != noise_power_.size()) { + noise_power_.resize(noise.size()); + } + for (size_t i = 0; i < noise.size(); ++i) { + noise_power_[i] = noise[i] * noise[i]; + } } void IntelligibilityEnhancer::ProcessRenderAudio(float* const* audio, @@ -145,29 +174,6 @@ void IntelligibilityEnhancer::ProcessRenderAudio(float* const* audio, } } -void IntelligibilityEnhancer::AnalyzeCaptureAudio(float* const* audio, - int sample_rate_hz, - size_t num_channels) { - RTC_CHECK_EQ(sample_rate_hz_, sample_rate_hz); - RTC_CHECK_EQ(num_capture_channels_, num_channels); - - capture_mangler_->ProcessChunk(audio, temp_capture_out_buffer_.channels()); -} - -void IntelligibilityEnhancer::DispatchAudio( - IntelligibilityEnhancer::AudioSource source, - const complex* in_block, - complex* out_block) { - switch (source) { - case kRenderStream: - ProcessClearBlock(in_block, out_block); - break; - case kCaptureStream: - ProcessNoiseBlock(in_block, out_block); - break; - } -} - void IntelligibilityEnhancer::ProcessClearBlock(const complex* in_block, complex* out_block) { if (block_count_ < 2) { @@ -194,9 +200,12 @@ void IntelligibilityEnhancer::ProcessClearBlock(const complex* in_block, } void IntelligibilityEnhancer::AnalyzeClearBlock(float power_target) { - FilterVariance(clear_variance_.variance(), filtered_clear_var_.get()); - FilterVariance(noise_variance_.variance(), filtered_noise_var_.get()); - + FilterVariance(clear_variance_.variance(), + render_filter_bank_, + filtered_clear_var_.get()); + FilterVariance(&noise_power_[0], + capture_filter_bank_, + filtered_noise_var_.get()); SolveForGainsGivenLambda(kLambdaTop, start_freq_, gains_eq_.get()); const float power_top = DotProduct(gains_eq_.get(), filtered_clear_var_.get(), bank_size_); @@ -242,16 +251,11 @@ void IntelligibilityEnhancer::UpdateErbGains() { for (size_t i = 0; i < freqs_; ++i) { gains[i] = 0.0f; for (size_t j = 0; j < bank_size_; ++j) { - gains[i] = fmaf(filter_bank_[j][i], gains_eq_[j], gains[i]); + gains[i] = fmaf(render_filter_bank_[j][i], gains_eq_[j], gains[i]); } } } -void IntelligibilityEnhancer::ProcessNoiseBlock(const complex* in_block, - complex* /*out_block*/) { - noise_variance_.Step(in_block); -} - size_t IntelligibilityEnhancer::GetBankSize(int sample_rate, size_t erb_resolution) { float freq_limit = sample_rate / 2000.0f; @@ -260,7 +264,9 @@ size_t IntelligibilityEnhancer::GetBankSize(int sample_rate, return erb_scale * erb_resolution; } -void IntelligibilityEnhancer::CreateErbBank() { +std::vector> IntelligibilityEnhancer::CreateErbBank( + size_t num_freqs) { + std::vector> filter_bank(bank_size_); size_t lf = 1, rf = 4; for (size_t i = 0; i < bank_size_; ++i) { @@ -274,58 +280,60 @@ void IntelligibilityEnhancer::CreateErbBank() { } for (size_t i = 0; i < bank_size_; ++i) { - filter_bank_[i].resize(freqs_); + filter_bank[i].resize(num_freqs); } for (size_t i = 1; i <= bank_size_; ++i) { size_t lll, ll, rr, rrr; static const size_t kOne = 1; // Avoids repeated static_cast<>s below. lll = static_cast(round( - center_freqs_[max(kOne, i - lf) - 1] * freqs_ / + center_freqs_[max(kOne, i - lf) - 1] * num_freqs / (0.5f * sample_rate_hz_))); ll = static_cast(round( - center_freqs_[max(kOne, i) - 1] * freqs_ / (0.5f * sample_rate_hz_))); - lll = min(freqs_, max(lll, kOne)) - 1; - ll = min(freqs_, max(ll, kOne)) - 1; + center_freqs_[max(kOne, i) - 1] * num_freqs / + (0.5f * sample_rate_hz_))); + lll = min(num_freqs, max(lll, kOne)) - 1; + ll = min(num_freqs, max(ll, kOne)) - 1; rrr = static_cast(round( - center_freqs_[min(bank_size_, i + rf) - 1] * freqs_ / + center_freqs_[min(bank_size_, i + rf) - 1] * num_freqs / (0.5f * sample_rate_hz_))); rr = static_cast(round( - center_freqs_[min(bank_size_, i + 1) - 1] * freqs_ / + center_freqs_[min(bank_size_, i + 1) - 1] * num_freqs / (0.5f * sample_rate_hz_))); - rrr = min(freqs_, max(rrr, kOne)) - 1; - rr = min(freqs_, max(rr, kOne)) - 1; + rrr = min(num_freqs, max(rrr, kOne)) - 1; + rr = min(num_freqs, max(rr, kOne)) - 1; float step, element; step = 1.0f / (ll - lll); element = 0.0f; for (size_t j = lll; j <= ll; ++j) { - filter_bank_[i - 1][j] = element; + filter_bank[i - 1][j] = element; element += step; } step = 1.0f / (rrr - rr); element = 1.0f; for (size_t j = rr; j <= rrr; ++j) { - filter_bank_[i - 1][j] = element; + filter_bank[i - 1][j] = element; element -= step; } for (size_t j = ll; j <= rr; ++j) { - filter_bank_[i - 1][j] = 1.0f; + filter_bank[i - 1][j] = 1.0f; } } float sum; - for (size_t i = 0; i < freqs_; ++i) { + for (size_t i = 0; i < num_freqs; ++i) { sum = 0.0f; for (size_t j = 0; j < bank_size_; ++j) { - sum += filter_bank_[j][i]; + sum += filter_bank[j][i]; } for (size_t j = 0; j < bank_size_; ++j) { - filter_bank_[j][i] /= sum; + filter_bank[j][i] /= sum; } } + return filter_bank; } void IntelligibilityEnhancer::SolveForGainsGivenLambda(float lambda, @@ -356,24 +364,6 @@ void IntelligibilityEnhancer::SolveForGainsGivenLambda(float lambda, } } -void IntelligibilityEnhancer::FilterVariance(const float* var, float* result) { - RTC_DCHECK_GT(freqs_, 0u); - for (size_t i = 0; i < bank_size_; ++i) { - result[i] = DotProduct(&filter_bank_[i][0], var, freqs_); - } -} - -float IntelligibilityEnhancer::DotProduct(const float* a, - const float* b, - size_t length) { - float ret = 0.0f; - - for (size_t i = 0; i < length; ++i) { - ret = fmaf(a[i], b[i], ret); - } - return ret; -} - bool IntelligibilityEnhancer::active() const { return active_; } diff --git a/webrtc/modules/audio_processing/intelligibility/intelligibility_enhancer.h b/webrtc/modules/audio_processing/intelligibility/intelligibility_enhancer.h index 1eb22342ad..fade1449cc 100644 --- a/webrtc/modules/audio_processing/intelligibility/intelligibility_enhancer.h +++ b/webrtc/modules/audio_processing/intelligibility/intelligibility_enhancer.h @@ -60,10 +60,8 @@ class IntelligibilityEnhancer { explicit IntelligibilityEnhancer(const Config& config); IntelligibilityEnhancer(); // Initialize with default config. - // Reads and processes chunk of noise stream in time domain. - void AnalyzeCaptureAudio(float* const* audio, - int sample_rate_hz, - size_t num_channels); + // Sets the capture noise magnitude spectrum estimate. + void SetCaptureNoiseEstimate(std::vector noise); // Reads chunk of speech in time domain and updates with modified signal. void ProcessRenderAudio(float* const* audio, @@ -72,15 +70,10 @@ class IntelligibilityEnhancer { bool active() const; private: - enum AudioSource { - kRenderStream = 0, // Clear speech stream. - kCaptureStream, // Noise stream. - }; - // Provides access point to the frequency domain. class TransformCallback : public LappedTransform::Callback { public: - TransformCallback(IntelligibilityEnhancer* parent, AudioSource source); + TransformCallback(IntelligibilityEnhancer* parent); // All in frequency domain, receives input |in_block|, applies // intelligibility enhancement, and writes result to |out_block|. @@ -92,17 +85,11 @@ class IntelligibilityEnhancer { private: IntelligibilityEnhancer* parent_; - AudioSource source_; }; friend class TransformCallback; FRIEND_TEST_ALL_PREFIXES(IntelligibilityEnhancerTest, TestErbCreation); FRIEND_TEST_ALL_PREFIXES(IntelligibilityEnhancerTest, TestSolveForGains); - // Sends streams to ProcessClearBlock or ProcessNoiseBlock based on source. - void DispatchAudio(AudioSource source, - const std::complex* in_block, - std::complex* out_block); - // Updates variance computation and analysis with |in_block_|, // and writes modified speech to |out_block|. void ProcessClearBlock(const std::complex* in_block, @@ -117,27 +104,16 @@ class IntelligibilityEnhancer { // Transforms freq gains to ERB gains. void UpdateErbGains(); - // Updates variance calculation for noise input with |in_block|. - void ProcessNoiseBlock(const std::complex* in_block, - std::complex* out_block); - // Returns number of ERB filters. static size_t GetBankSize(int sample_rate, size_t erb_resolution); // Initializes ERB filterbank. - void CreateErbBank(); + std::vector> CreateErbBank(size_t num_freqs); // Analytically solves quadratic for optimal gains given |lambda|. // Negative gains are set to 0. Stores the results in |sols|. void SolveForGainsGivenLambda(float lambda, size_t start_freq, float* sols); - // Computes variance across ERB filters from freq variance |var|. - // Stores in |result|. - void FilterVariance(const float* var, float* result); - - // Returns dot product of vectors specified by size |length| arrays |a|,|b|. - static float DotProduct(const float* a, const float* b, size_t length); - const size_t freqs_; // Num frequencies in frequency domain. const size_t window_size_; // Window size in samples; also the block size. const size_t chunk_length_; // Chunk size in samples. @@ -152,11 +128,12 @@ class IntelligibilityEnhancer { // TODO(ekm): Add logic for updating |active_|. intelligibility::VarianceArray clear_variance_; - intelligibility::VarianceArray noise_variance_; + std::vector noise_power_; rtc::scoped_ptr filtered_clear_var_; rtc::scoped_ptr filtered_noise_var_; - std::vector> filter_bank_; rtc::scoped_ptr center_freqs_; + std::vector> capture_filter_bank_; + std::vector> render_filter_bank_; size_t start_freq_; rtc::scoped_ptr rho_; // Production and interpretation SNR. // for each ERB band. @@ -166,13 +143,10 @@ class IntelligibilityEnhancer { // Destination buffers used to reassemble blocked chunks before overwriting // the original input array with modifications. ChannelBuffer temp_render_out_buffer_; - ChannelBuffer temp_capture_out_buffer_; rtc::scoped_ptr kbd_window_; TransformCallback render_callback_; - TransformCallback capture_callback_; rtc::scoped_ptr render_mangler_; - rtc::scoped_ptr capture_mangler_; int block_count_; int analysis_step_; }; diff --git a/webrtc/modules/audio_processing/intelligibility/intelligibility_enhancer_unittest.cc b/webrtc/modules/audio_processing/intelligibility/intelligibility_enhancer_unittest.cc index ce146deaf5..436d174775 100644 --- a/webrtc/modules/audio_processing/intelligibility/intelligibility_enhancer_unittest.cc +++ b/webrtc/modules/audio_processing/intelligibility/intelligibility_enhancer_unittest.cc @@ -99,7 +99,6 @@ class IntelligibilityEnhancerTest : public ::testing::Test { float* clear_cursor = &clear_data_[0]; float* noise_cursor = &noise_data_[0]; for (int i = 0; i < kSamples; i += kFragmentSize) { - enh_->AnalyzeCaptureAudio(&noise_cursor, kSampleRate, kNumChannels); enh_->ProcessRenderAudio(&clear_cursor, kSampleRate, kNumChannels); clear_cursor += kFragmentSize; noise_cursor += kFragmentSize; @@ -154,7 +153,7 @@ TEST_F(IntelligibilityEnhancerTest, TestErbCreation) { EXPECT_NEAR(kTestCenterFreqs[i], enh_->center_freqs_[i], kMaxTestError); ASSERT_EQ(arraysize(kTestFilterBank[0]), enh_->freqs_); for (size_t j = 0; j < enh_->freqs_; ++j) { - EXPECT_NEAR(kTestFilterBank[i][j], enh_->filter_bank_[i][j], + EXPECT_NEAR(kTestFilterBank[i][j], enh_->render_filter_bank_[i][j], kMaxTestError); } } diff --git a/webrtc/modules/audio_processing/intelligibility/test/intelligibility_proc.cc b/webrtc/modules/audio_processing/intelligibility/test/intelligibility_proc.cc index 4d2f5f4c5d..e02e64e709 100644 --- a/webrtc/modules/audio_processing/intelligibility/test/intelligibility_proc.cc +++ b/webrtc/modules/audio_processing/intelligibility/test/intelligibility_proc.cc @@ -23,10 +23,14 @@ #include "gflags/gflags.h" #include "testing/gtest/include/gtest/gtest.h" #include "webrtc/base/checks.h" +#include "webrtc/base/criticalsection.h" #include "webrtc/common_audio/real_fourier.h" #include "webrtc/common_audio/wav_file.h" +#include "webrtc/modules/audio_processing/audio_buffer.h" +#include "webrtc/modules/audio_processing/include/audio_processing.h" #include "webrtc/modules/audio_processing/intelligibility/intelligibility_enhancer.h" #include "webrtc/modules/audio_processing/intelligibility/intelligibility_utils.h" +#include "webrtc/modules/audio_processing/noise_suppression_impl.h" #include "webrtc/system_wrappers/include/critical_section_wrapper.h" #include "webrtc/test/testsupport/fileutils.h" @@ -115,6 +119,17 @@ void void_main(int argc, char* argv[]) { config.analysis_rate = FLAGS_ana_rate; config.gain_change_limit = FLAGS_gain_limit; IntelligibilityEnhancer enh(config); + rtc::CriticalSection crit; + NoiseSuppressionImpl ns(&crit); + ns.Initialize(kNumChannels, FLAGS_sample_rate); + ns.Enable(true); + + AudioBuffer capture_audio(fragment_size, + kNumChannels, + fragment_size, + kNumChannels, + fragment_size); + StreamConfig stream_config(FLAGS_sample_rate, kNumChannels); // Slice the input into smaller chunks, as the APM would do, and feed them // through the enhancer. @@ -122,7 +137,10 @@ void void_main(int argc, char* argv[]) { float* noise_cursor = &noise_fpcm[0]; for (size_t i = 0; i < samples; i += fragment_size) { - enh.AnalyzeCaptureAudio(&noise_cursor, FLAGS_sample_rate, kNumChannels); + capture_audio.CopyFrom(&noise_cursor, stream_config); + ns.AnalyzeCaptureAudio(&capture_audio); + ns.ProcessCaptureAudio(&capture_audio); + enh.SetCaptureNoiseEstimate(ns.NoiseEstimate()); enh.ProcessRenderAudio(&clear_cursor, FLAGS_sample_rate, kNumChannels); clear_cursor += fragment_size; noise_cursor += fragment_size;