diff --git a/modules/audio_processing/aec3/BUILD.gn b/modules/audio_processing/aec3/BUILD.gn index 4b8142945c..6baf3a1ddb 100644 --- a/modules/audio_processing/aec3/BUILD.gn +++ b/modules/audio_processing/aec3/BUILD.gn @@ -14,6 +14,8 @@ rtc_static_library("aec3") { sources = [ "adaptive_fir_filter.cc", "adaptive_fir_filter.h", + "adaptive_fir_filter_erl.cc", + "adaptive_fir_filter_erl.h", "aec3_common.cc", "aec3_common.h", "aec3_fft.cc", @@ -185,6 +187,7 @@ if (rtc_include_tests) { if (rtc_enable_protobuf) { sources += [ + "adaptive_fir_filter_erl_unittest.cc", "adaptive_fir_filter_unittest.cc", "aec3_fft_unittest.cc", "aec_state_unittest.cc", diff --git a/modules/audio_processing/aec3/adaptive_fir_filter.cc b/modules/audio_processing/aec3/adaptive_fir_filter.cc index 024b605527..00fa884aeb 100644 --- a/modules/audio_processing/aec3/adaptive_fir_filter.cc +++ b/modules/audio_processing/aec3/adaptive_fir_filter.cc @@ -82,55 +82,6 @@ void UpdateFrequencyResponse_SSE2( } #endif -// Computes and stores the echo return loss estimate of the filter, which is the -// sum of the partition frequency responses. -void UpdateErlEstimator( - const std::vector>& H2, - std::array* erl) { - erl->fill(0.f); - for (auto& H2_j : H2) { - std::transform(H2_j.begin(), H2_j.end(), erl->begin(), erl->begin(), - std::plus()); - } -} - -#if defined(WEBRTC_HAS_NEON) -// Computes and stores the echo return loss estimate of the filter, which is the -// sum of the partition frequency responses. -void UpdateErlEstimator_NEON( - const std::vector>& H2, - std::array* erl) { - erl->fill(0.f); - for (auto& H2_j : H2) { - for (size_t k = 0; k < kFftLengthBy2; k += 4) { - const float32x4_t H2_j_k = vld1q_f32(&H2_j[k]); - float32x4_t erl_k = vld1q_f32(&(*erl)[k]); - erl_k = vaddq_f32(erl_k, H2_j_k); - vst1q_f32(&(*erl)[k], erl_k); - } - (*erl)[kFftLengthBy2] += H2_j[kFftLengthBy2]; - } -} -#endif - -#if defined(WEBRTC_ARCH_X86_FAMILY) -// Computes and stores the echo return loss estimate of the filter, which is the -// sum of the partition frequency responses. -void UpdateErlEstimator_SSE2( - const std::vector>& H2, - std::array* erl) { - erl->fill(0.f); - for (auto& H2_j : H2) { - for (size_t k = 0; k < kFftLengthBy2; k += 4) { - const __m128 H2_j_k = _mm_loadu_ps(&H2_j[k]); - __m128 erl_k = _mm_loadu_ps(&(*erl)[k]); - erl_k = _mm_add_ps(erl_k, H2_j_k); - _mm_storeu_ps(&(*erl)[k], erl_k); - } - (*erl)[kFftLengthBy2] += H2_j[kFftLengthBy2]; - } -} -#endif // Adapts the filter partitions as H(t+1)=H(t)+G(t)*conj(X(t)). void AdaptPartitions(const RenderBuffer& render_buffer, @@ -442,9 +393,7 @@ AdaptiveFirFilter::AdaptiveFirFilter(size_t max_size_partitions, current_size_partitions_(initial_size_partitions), target_size_partitions_(initial_size_partitions), old_target_size_partitions_(initial_size_partitions), - H_(max_size_partitions_), - H2_(max_size_partitions_, std::array()), - h_(GetTimeDomainLength(max_size_partitions_), 0.f) { + H_(max_size_partitions_) { RTC_DCHECK(data_dumper_); RTC_DCHECK_GE(max_size_partitions, initial_size_partitions); @@ -454,41 +403,23 @@ AdaptiveFirFilter::AdaptiveFirFilter(size_t max_size_partitions, for (auto& H_j : H_) { H_j.Clear(); } - for (auto& H2_k : H2_) { - H2_k.fill(0.f); - } - erl_.fill(0.f); SetSizePartitions(current_size_partitions_, true); } AdaptiveFirFilter::~AdaptiveFirFilter() = default; void AdaptiveFirFilter::HandleEchoPathChange() { - size_t current_h_size = h_.size(); - h_.resize(GetTimeDomainLength(max_size_partitions_)); - std::fill(h_.begin() + current_h_size, h_.end(), 0.f); - h_.resize(current_h_size); - size_t current_size_partitions = H_.size(); H_.resize(max_size_partitions_); - H2_.resize(max_size_partitions_); for (size_t k = current_size_partitions; k < max_size_partitions_; ++k) { H_[k].Clear(); - H2_[k].fill(0.f); } H_.resize(current_size_partitions); - H2_.resize(current_size_partitions); - - erl_.fill(0.f); } void AdaptiveFirFilter::SetSizePartitions(size_t size, bool immediate_effect) { RTC_DCHECK_EQ(max_size_partitions_, H_.capacity()); - RTC_DCHECK_EQ(max_size_partitions_, H2_.capacity()); - RTC_DCHECK_EQ(GetTimeDomainLength(max_size_partitions_), h_.capacity()); - RTC_DCHECK_EQ(H_.size(), H2_.size()); - RTC_DCHECK_EQ(h_.size(), GetTimeDomainLength(H_.size())); RTC_DCHECK_LE(size, max_size_partitions_); target_size_partitions_ = std::min(max_size_partitions_, size); @@ -503,18 +434,7 @@ void AdaptiveFirFilter::SetSizePartitions(size_t size, bool immediate_effect) { } void AdaptiveFirFilter::ResetFilterBuffersToCurrentSize() { - if (current_size_partitions_ < H_.size()) { - for (size_t k = current_size_partitions_; k < H_.size(); ++k) { - H_[k].Clear(); - H2_[k].fill(0.f); - } - std::fill(h_.begin() + GetTimeDomainLength(current_size_partitions_), - h_.end(), 0.f); - } - H_.resize(current_size_partitions_); - H2_.resize(current_size_partitions_); - h_.resize(GetTimeDomainLength(current_size_partitions_)); RTC_DCHECK_LT(0, current_size_partitions_); partition_to_constrain_ = std::min(partition_to_constrain_, current_size_partitions_ - 1); @@ -564,6 +484,52 @@ void AdaptiveFirFilter::Filter(const RenderBuffer& render_buffer, void AdaptiveFirFilter::Adapt(const RenderBuffer& render_buffer, const FftData& G) { + // Adapt the filter and update the filter size. + AdaptAndUpdateSize(render_buffer, G); + + // Constrain the filter partitions in a cyclic manner. + Constrain(); +} + +void AdaptiveFirFilter::Adapt(const RenderBuffer& render_buffer, + const FftData& G, + std::vector* impulse_response) { + // Adapt the filter and update the filter size. + AdaptAndUpdateSize(render_buffer, G); + + // Constrain the filter partitions in a cyclic manner. + ConstrainAndUpdateImpulseResponse(impulse_response); +} + +void AdaptiveFirFilter::ComputeFrequencyResponse( + std::vector>* H2) const { + RTC_DCHECK_EQ(max_size_partitions_, H2->capacity()); + + if (H2->size() > H_.size()) { + for (size_t k = H_.size(); k < H2->size(); ++k) { + (*H2)[k].fill(0.f); + } + } + H2->resize(H_.size()); + + switch (optimization_) { +#if defined(WEBRTC_ARCH_X86_FAMILY) + case Aec3Optimization::kSse2: + aec3::UpdateFrequencyResponse_SSE2(H_, H2); + break; +#endif +#if defined(WEBRTC_HAS_NEON) + case Aec3Optimization::kNeon: + aec3::UpdateFrequencyResponse_NEON(H_, H2); + break; +#endif + default: + aec3::UpdateFrequencyResponse(H_, H2); + } +} + +void AdaptiveFirFilter::AdaptAndUpdateSize(const RenderBuffer& render_buffer, + const FftData& G) { // Update the filter size if needed. UpdateSize(); @@ -582,28 +548,34 @@ void AdaptiveFirFilter::Adapt(const RenderBuffer& render_buffer, default: aec3::AdaptPartitions(render_buffer, G, H_); } +} - // Constrain the filter partitions in a cyclic manner. - Constrain(); +// Constrains the partition of the frequency domain filter to be limited in +// time via setting the relevant time-domain coefficients to zero and updates +// the corresponding values in an externally stored impulse response estimate. +void AdaptiveFirFilter::ConstrainAndUpdateImpulseResponse( + std::vector* impulse_response) { + RTC_DCHECK_EQ(GetTimeDomainLength(max_size_partitions_), + impulse_response->capacity()); - // Update the frequency response and echo return loss for the filter. - switch (optimization_) { -#if defined(WEBRTC_ARCH_X86_FAMILY) - case Aec3Optimization::kSse2: - aec3::UpdateFrequencyResponse_SSE2(H_, &H2_); - aec3::UpdateErlEstimator_SSE2(H2_, &erl_); - break; -#endif -#if defined(WEBRTC_HAS_NEON) - case Aec3Optimization::kNeon: - aec3::UpdateFrequencyResponse_NEON(H_, &H2_); - aec3::UpdateErlEstimator_NEON(H2_, &erl_); - break; -#endif - default: - aec3::UpdateFrequencyResponse(H_, &H2_); - aec3::UpdateErlEstimator(H2_, &erl_); - } + impulse_response->resize(GetTimeDomainLength(current_size_partitions_)); + std::array h; + fft_.Ifft(H_[partition_to_constrain_], &h); + + static constexpr float kScale = 1.0f / kFftLengthBy2; + std::for_each(h.begin(), h.begin() + kFftLengthBy2, + [](float& a) { a *= kScale; }); + std::fill(h.begin() + kFftLengthBy2, h.end(), 0.f); + + std::copy( + h.begin(), h.begin() + kFftLengthBy2, + impulse_response->begin() + partition_to_constrain_ * kFftLengthBy2); + + fft_.Fft(&h, &H_[partition_to_constrain_]); + + partition_to_constrain_ = partition_to_constrain_ < (H_.size() - 1) + ? partition_to_constrain_ + 1 + : 0; } // Constrains the a partiton of the frequency domain filter to be limited in @@ -617,9 +589,6 @@ void AdaptiveFirFilter::Constrain() { [](float& a) { a *= kScale; }); std::fill(h.begin() + kFftLengthBy2, h.end(), 0.f); - std::copy(h.begin(), h.begin() + kFftLengthBy2, - h_.begin() + partition_to_constrain_ * kFftLengthBy2); - fft_.Fft(&h, &H_[partition_to_constrain_]); partition_to_constrain_ = partition_to_constrain_ < (H_.size() - 1) @@ -636,9 +605,6 @@ void AdaptiveFirFilter::ScaleFilter(float factor) { im *= factor; } } - for (auto& h : h_) { - h *= factor; - } } // Set the filter coefficients. diff --git a/modules/audio_processing/aec3/adaptive_fir_filter.h b/modules/audio_processing/aec3/adaptive_fir_filter.h index 12716bbb5a..aec83aabd4 100644 --- a/modules/audio_processing/aec3/adaptive_fir_filter.h +++ b/modules/audio_processing/aec3/adaptive_fir_filter.h @@ -22,7 +22,6 @@ #include "modules/audio_processing/aec3/fft_data.h" #include "modules/audio_processing/aec3/render_buffer.h" #include "modules/audio_processing/logging/apm_data_dumper.h" -#include "rtc_base/constructor_magic.h" #include "rtc_base/system/arch.h" namespace webrtc { @@ -42,22 +41,6 @@ void UpdateFrequencyResponse_SSE2( std::vector>* H2); #endif -// Computes and stores the echo return loss estimate of the filter, which is the -// sum of the partition frequency responses. -void UpdateErlEstimator( - const std::vector>& H2, - std::array* erl); -#if defined(WEBRTC_HAS_NEON) -void UpdateErlEstimator_NEON( - const std::vector>& H2, - std::array* erl); -#endif -#if defined(WEBRTC_ARCH_X86_FAMILY) -void UpdateErlEstimator_SSE2( - const std::vector>& H2, - std::array* erl); -#endif - // Adapts the filter partitions. void AdaptPartitions(const RenderBuffer& render_buffer, const FftData& G, @@ -103,9 +86,18 @@ class AdaptiveFirFilter { ~AdaptiveFirFilter(); + AdaptiveFirFilter(const AdaptiveFirFilter&) = delete; + AdaptiveFirFilter& operator=(const AdaptiveFirFilter&) = delete; + // Produces the output of the filter. void Filter(const RenderBuffer& render_buffer, FftData* S) const; + // Adapts the filter and updates an externally stored impulse response + // estimate. + void Adapt(const RenderBuffer& render_buffer, + const FftData& G, + std::vector* impulse_response); + // Adapts the filter. void Adapt(const RenderBuffer& render_buffer, const FftData& G); @@ -119,20 +111,14 @@ class AdaptiveFirFilter { // Sets the filter size. void SetSizePartitions(size_t size, bool immediate_effect); - // Returns the filter based echo return loss. - const std::array& Erl() const { return erl_; } + // Computes the frequency responses for the filter partitions. + void ComputeFrequencyResponse( + std::vector>* H2) const; - // Returns the frequency responses for the filter partitions. - const std::vector>& - FilterFrequencyResponse() const { - return H2_; - } + // Returns the maximum number of partitions for the filter. + size_t max_filter_size_partitions() const { return max_size_partitions_; } - // Returns the estimate of the impulse response. - const std::vector& FilterImpulseResponse() const { return h_; } - - void DumpFilter(const char* name_frequency_domain, - const char* name_time_domain) { + void DumpFilter(const char* name_frequency_domain) { size_t current_size = H_.size(); H_.resize(max_size_partitions_); for (auto& H : H_) { @@ -140,11 +126,6 @@ class AdaptiveFirFilter { data_dumper_->DumpRaw(name_frequency_domain, H.im); } H_.resize(current_size); - - current_size = h_.size(); - h_.resize(GetTimeDomainLength(max_size_partitions_)); - data_dumper_->DumpRaw(name_time_domain, h_); - h_.resize(current_size); } // Scale the filter impulse response and spectrum by a factor. @@ -157,8 +138,14 @@ class AdaptiveFirFilter { const std::vector& GetFilter() const { return H_; } private: + // Adapts the filter and updates the filter size. + void AdaptAndUpdateSize(const RenderBuffer& render_buffer, const FftData& G); + // Constrain the filter partitions in a cyclic manner. void Constrain(); + // Constrains the filter in a cyclic manner and updates the corresponding + // values in the supplied impulse response. + void ConstrainAndUpdateImpulseResponse(std::vector* impulse_response); // Resets the filter buffers to use the current size. void ResetFilterBuffersToCurrentSize(); @@ -177,12 +164,7 @@ class AdaptiveFirFilter { size_t old_target_size_partitions_; int size_change_counter_ = 0; std::vector H_; - std::vector> H2_; - std::vector h_; - std::array erl_; size_t partition_to_constrain_ = 0; - - RTC_DISALLOW_IMPLICIT_CONSTRUCTORS(AdaptiveFirFilter); }; } // namespace webrtc diff --git a/modules/audio_processing/aec3/adaptive_fir_filter_erl.cc b/modules/audio_processing/aec3/adaptive_fir_filter_erl.cc new file mode 100644 index 0000000000..80378eb3cf --- /dev/null +++ b/modules/audio_processing/aec3/adaptive_fir_filter_erl.cc @@ -0,0 +1,100 @@ +/* + * Copyright (c) 2019 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/aec3/adaptive_fir_filter_erl.h" + +#include +#include + +#if defined(WEBRTC_HAS_NEON) +#include +#endif +#if defined(WEBRTC_ARCH_X86_FAMILY) +#include +#endif + +namespace webrtc { + +namespace aec3 { + +// Computes and stores the echo return loss estimate of the filter, which is the +// sum of the partition frequency responses. +void ErlComputer(const std::vector>& H2, + rtc::ArrayView erl) { + std::fill(erl.begin(), erl.end(), 0.f); + for (auto& H2_j : H2) { + std::transform(H2_j.begin(), H2_j.end(), erl.begin(), erl.begin(), + std::plus()); + } +} + +#if defined(WEBRTC_HAS_NEON) +// Computes and stores the echo return loss estimate of the filter, which is the +// sum of the partition frequency responses. +void ErlComputer_NEON( + const std::vector>& H2, + rtc::ArrayView erl) { + std::fill(erl.begin(), erl.end(), 0.f); + for (auto& H2_j : H2) { + for (size_t k = 0; k < kFftLengthBy2; k += 4) { + const float32x4_t H2_j_k = vld1q_f32(&H2_j[k]); + float32x4_t erl_k = vld1q_f32(&erl[k]); + erl_k = vaddq_f32(erl_k, H2_j_k); + vst1q_f32(&erl[k], erl_k); + } + erl[kFftLengthBy2] += H2_j[kFftLengthBy2]; + } +} +#endif + +#if defined(WEBRTC_ARCH_X86_FAMILY) +// Computes and stores the echo return loss estimate of the filter, which is the +// sum of the partition frequency responses. +void ErlComputer_SSE2( + const std::vector>& H2, + rtc::ArrayView erl) { + std::fill(erl.begin(), erl.end(), 0.f); + for (auto& H2_j : H2) { + for (size_t k = 0; k < kFftLengthBy2; k += 4) { + const __m128 H2_j_k = _mm_loadu_ps(&H2_j[k]); + __m128 erl_k = _mm_loadu_ps(&erl[k]); + erl_k = _mm_add_ps(erl_k, H2_j_k); + _mm_storeu_ps(&erl[k], erl_k); + } + erl[kFftLengthBy2] += H2_j[kFftLengthBy2]; + } +} +#endif + +} // namespace aec3 + +void ComputeErl(const Aec3Optimization& optimization, + const std::vector>& H2, + rtc::ArrayView erl) { + RTC_DCHECK_EQ(kFftLengthBy2Plus1, erl.size()); + // Update the frequency response and echo return loss for the filter. + switch (optimization) { +#if defined(WEBRTC_ARCH_X86_FAMILY) + case Aec3Optimization::kSse2: + aec3::ErlComputer_SSE2(H2, erl); + break; +#endif +#if defined(WEBRTC_HAS_NEON) + case Aec3Optimization::kNeon: + + aec3::ErlComputer_NEON(H2, erl); + break; +#endif + default: + aec3::ErlComputer(H2, erl); + } +} + +} // namespace webrtc diff --git a/modules/audio_processing/aec3/adaptive_fir_filter_erl.h b/modules/audio_processing/aec3/adaptive_fir_filter_erl.h new file mode 100644 index 0000000000..108d9f8e44 --- /dev/null +++ b/modules/audio_processing/aec3/adaptive_fir_filter_erl.h @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2019 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef MODULES_AUDIO_PROCESSING_AEC3_ADAPTIVE_FIR_FILTER_ERL_H_ +#define MODULES_AUDIO_PROCESSING_AEC3_ADAPTIVE_FIR_FILTER_ERL_H_ + +#include + +#include +#include + +#include "api/array_view.h" +#include "modules/audio_processing/aec3/aec3_common.h" +#include "rtc_base/system/arch.h" + +namespace webrtc { +namespace aec3 { + +// Computes and stores the echo return loss estimate of the filter, which is the +// sum of the partition frequency responses. +void ErlComputer(const std::vector>& H2, + rtc::ArrayView erl); +#if defined(WEBRTC_HAS_NEON) +void ErlComputer_NEON( + const std::vector>& H2, + rtc::ArrayView erl); +#endif +#if defined(WEBRTC_ARCH_X86_FAMILY) +void ErlComputer_SSE2( + const std::vector>& H2, + rtc::ArrayView erl); +#endif + +} // namespace aec3 + +// Computes the echo return loss based on a frequency response. +void ComputeErl(const Aec3Optimization& optimization, + const std::vector>& H2, + rtc::ArrayView erl); + +} // namespace webrtc + +#endif // MODULES_AUDIO_PROCESSING_AEC3_ADAPTIVE_FIR_FILTER_ERL_H_ diff --git a/modules/audio_processing/aec3/adaptive_fir_filter_erl_unittest.cc b/modules/audio_processing/aec3/adaptive_fir_filter_erl_unittest.cc new file mode 100644 index 0000000000..069fc9fa5b --- /dev/null +++ b/modules/audio_processing/aec3/adaptive_fir_filter_erl_unittest.cc @@ -0,0 +1,81 @@ +/* + * Copyright (c) 2019 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/aec3/adaptive_fir_filter_erl.h" + +#include +#include + +#include "rtc_base/system/arch.h" +#if defined(WEBRTC_ARCH_X86_FAMILY) +#include +#endif + +#include "system_wrappers/include/cpu_features_wrapper.h" +#include "test/gtest.h" + +namespace webrtc { +namespace aec3 { + +#if defined(WEBRTC_HAS_NEON) +// Verifies that the optimized method for echo return loss computation is +// bitexact to the reference counterpart. +TEST(AdaptiveFirFilter, UpdateErlNeonOptimization) { + const size_t kNumPartitions = 12; + std::vector> H2(kNumPartitions); + std::array erl; + std::array erl_NEON; + + for (size_t j = 0; j < H2.size(); ++j) { + for (size_t k = 0; k < H2[j].size(); ++k) { + H2[j][k] = k + j / 3.f; + } + } + + ErlComputer(H2, erl); + ErlComputer_NEON(H2, erl_NEON); + + for (size_t j = 0; j < erl.size(); ++j) { + EXPECT_FLOAT_EQ(erl[j], erl_NEON[j]); + } +} + +#endif + +#if defined(WEBRTC_ARCH_X86_FAMILY) +// Verifies that the optimized method for echo return loss computation is +// bitexact to the reference counterpart. +TEST(AdaptiveFirFilter, UpdateErlSse2Optimization) { + bool use_sse2 = (WebRtc_GetCPUInfo(kSSE2) != 0); + if (use_sse2) { + const size_t kNumPartitions = 12; + std::vector> H2(kNumPartitions); + std::array erl; + std::array erl_SSE2; + + for (size_t j = 0; j < H2.size(); ++j) { + for (size_t k = 0; k < H2[j].size(); ++k) { + H2[j][k] = k + j / 3.f; + } + } + + ErlComputer(H2, erl); + ErlComputer_SSE2(H2, erl_SSE2); + + for (size_t j = 0; j < erl.size(); ++j) { + EXPECT_FLOAT_EQ(erl[j], erl_SSE2[j]); + } + } +} + +#endif + +} // namespace aec3 +} // namespace webrtc diff --git a/modules/audio_processing/aec3/adaptive_fir_filter_unittest.cc b/modules/audio_processing/aec3/adaptive_fir_filter_unittest.cc index e7c9c85eca..9318c21ce9 100644 --- a/modules/audio_processing/aec3/adaptive_fir_filter_unittest.cc +++ b/modules/audio_processing/aec3/adaptive_fir_filter_unittest.cc @@ -22,6 +22,7 @@ #include #endif +#include "modules/audio_processing/aec3/adaptive_fir_filter_erl.h" #include "modules/audio_processing/aec3/aec3_fft.h" #include "modules/audio_processing/aec3/aec_state.h" #include "modules/audio_processing/aec3/render_delay_buffer.h" @@ -145,28 +146,6 @@ TEST(AdaptiveFirFilter, UpdateFrequencyResponseNeonOptimization) { } } -// Verifies that the optimized method for echo return loss computation is -// bitexact to the reference counterpart. -TEST(AdaptiveFirFilter, UpdateErlNeonOptimization) { - const size_t kNumPartitions = 12; - std::vector> H2(kNumPartitions); - std::array erl; - std::array erl_NEON; - - for (size_t j = 0; j < H2.size(); ++j) { - for (size_t k = 0; k < H2[j].size(); ++k) { - H2[j][k] = k + j / 3.f; - } - } - - UpdateErlEstimator(H2, &erl); - UpdateErlEstimator_NEON(H2, &erl_NEON); - - for (size_t j = 0; j < erl.size(); ++j) { - EXPECT_FLOAT_EQ(erl[j], erl_NEON[j]); - } -} - #endif #if defined(WEBRTC_ARCH_X86_FAMILY) @@ -266,31 +245,6 @@ TEST(AdaptiveFirFilter, UpdateFrequencyResponseSse2Optimization) { } } -// Verifies that the optimized method for echo return loss computation is -// bitexact to the reference counterpart. -TEST(AdaptiveFirFilter, UpdateErlSse2Optimization) { - bool use_sse2 = (WebRtc_GetCPUInfo(kSSE2) != 0); - if (use_sse2) { - const size_t kNumPartitions = 12; - std::vector> H2(kNumPartitions); - std::array erl; - std::array erl_SSE2; - - for (size_t j = 0; j < H2.size(); ++j) { - for (size_t k = 0; k < H2[j].size(); ++k) { - H2[j][k] = k + j / 3.f; - } - } - - UpdateErlEstimator(H2, &erl); - UpdateErlEstimator_SSE2(H2, &erl_SSE2); - - for (size_t j = 0; j < erl.size(); ++j) { - EXPECT_FLOAT_EQ(erl[j], erl_SSE2[j]); - } - } -} - #endif #if RTC_DCHECK_IS_ON && GTEST_HAS_DEATH_TEST && !defined(WEBRTC_ANDROID) @@ -316,9 +270,18 @@ TEST(AdaptiveFirFilter, NullFilterOutput) { // are turned on. TEST(AdaptiveFirFilter, FilterStatisticsAccess) { ApmDataDumper data_dumper(42); - AdaptiveFirFilter filter(9, 9, 250, 1, 1, DetectOptimization(), &data_dumper); - filter.Erl(); - filter.FilterFrequencyResponse(); + Aec3Optimization optimization = DetectOptimization(); + AdaptiveFirFilter filter(9, 9, 250, 1, 1, optimization, &data_dumper); + std::vector> H2( + filter.max_filter_size_partitions(), + std::array()); + for (auto& H2_k : H2) { + H2_k.fill(0.f); + } + + std::array erl; + ComputeErl(optimization, H2, erl); + filter.ComputeFrequencyResponse(&H2); } // Verifies that the filter size if correctly repported. @@ -345,6 +308,11 @@ TEST(AdaptiveFirFilter, FilterAndAdapt) { config.filter.main.length_blocks, config.filter.config_change_duration_blocks, 1, 1, DetectOptimization(), &data_dumper); + std::vector> H2( + filter.max_filter_size_partitions(), + std::array()); + std::vector h(GetTimeDomainLength(filter.max_filter_size_partitions()), + 0.f); Aec3Fft fft; config.delay.default_delay = 1; std::unique_ptr render_delay_buffer( @@ -424,13 +392,13 @@ TEST(AdaptiveFirFilter, FilterAndAdapt) { render_buffer->SpectralSum(filter.SizePartitions(), &render_power); gain.Compute(render_power, render_signal_analyzer, E, filter.SizePartitions(), false, &G); - filter.Adapt(*render_buffer, G); + filter.Adapt(*render_buffer, G, &h); aec_state.HandleEchoPathChange(EchoPathVariability( false, EchoPathVariability::DelayAdjustment::kNone, false)); - aec_state.Update(delay_estimate, filter.FilterFrequencyResponse(), - filter.FilterImpulseResponse(), *render_buffer, E2_main, - Y2, output, y); + filter.ComputeFrequencyResponse(&H2); + aec_state.Update(delay_estimate, H2, h, *render_buffer, E2_main, Y2, + output, y); } // Verify that the filter is able to perform well. EXPECT_LT(1000 * std::inner_product(e.begin(), e.end(), e.begin(), 0.f), diff --git a/modules/audio_processing/aec3/main_filter_update_gain.cc b/modules/audio_processing/aec3/main_filter_update_gain.cc index 11a97e2781..c2cfd2c447 100644 --- a/modules/audio_processing/aec3/main_filter_update_gain.cc +++ b/modules/audio_processing/aec3/main_filter_update_gain.cc @@ -70,7 +70,8 @@ void MainFilterUpdateGain::Compute( const std::array& render_power, const RenderSignalAnalyzer& render_signal_analyzer, const SubtractorOutput& subtractor_output, - const AdaptiveFirFilter& filter, + rtc::ArrayView erl, + size_t size_partitions, bool saturated_capture_signal, FftData* gain_fft) { RTC_DCHECK(gain_fft); @@ -79,9 +80,8 @@ void MainFilterUpdateGain::Compute( const auto& E2_main = subtractor_output.E2_main; const auto& E2_shadow = subtractor_output.E2_shadow; FftData* G = gain_fft; - const size_t size_partitions = filter.SizePartitions(); auto X2 = render_power; - const auto& erl = filter.Erl(); + ++call_counter_; UpdateCurrentConfig(); diff --git a/modules/audio_processing/aec3/main_filter_update_gain.h b/modules/audio_processing/aec3/main_filter_update_gain.h index dca0ff8713..1955d2a402 100644 --- a/modules/audio_processing/aec3/main_filter_update_gain.h +++ b/modules/audio_processing/aec3/main_filter_update_gain.h @@ -16,9 +16,9 @@ #include #include +#include "api/array_view.h" #include "api/audio/echo_canceller3_config.h" #include "modules/audio_processing/aec3/aec3_common.h" -#include "rtc_base/constructor_magic.h" namespace webrtc { @@ -32,11 +32,14 @@ struct SubtractorOutput; // Provides functionality for computing the adaptive gain for the main filter. class MainFilterUpdateGain { public: - explicit MainFilterUpdateGain( + MainFilterUpdateGain( const EchoCanceller3Config::Filter::MainConfiguration& config, size_t config_change_duration_blocks); ~MainFilterUpdateGain(); + MainFilterUpdateGain(const MainFilterUpdateGain&) = delete; + MainFilterUpdateGain& operator=(const MainFilterUpdateGain&) = delete; + // Takes action in the case of a known echo path change. void HandleEchoPathChange(const EchoPathVariability& echo_path_variability); @@ -44,7 +47,8 @@ class MainFilterUpdateGain { void Compute(const std::array& render_power, const RenderSignalAnalyzer& render_signal_analyzer, const SubtractorOutput& subtractor_output, - const AdaptiveFirFilter& filter, + rtc::ArrayView erl, + size_t size_partitions, bool saturated_capture_signal, FftData* gain_fft); @@ -76,8 +80,6 @@ class MainFilterUpdateGain { // Updates the current config towards the target config. void UpdateCurrentConfig(); - - RTC_DISALLOW_COPY_AND_ASSIGN(MainFilterUpdateGain); }; } // namespace webrtc diff --git a/modules/audio_processing/aec3/main_filter_update_gain_unittest.cc b/modules/audio_processing/aec3/main_filter_update_gain_unittest.cc index 29d8ea901e..e78f1cdb61 100644 --- a/modules/audio_processing/aec3/main_filter_update_gain_unittest.cc +++ b/modules/audio_processing/aec3/main_filter_update_gain_unittest.cc @@ -15,6 +15,7 @@ #include #include "modules/audio_processing/aec3/adaptive_fir_filter.h" +#include "modules/audio_processing/aec3/adaptive_fir_filter_erl.h" #include "modules/audio_processing/aec3/aec_state.h" #include "modules/audio_processing/aec3/render_delay_buffer.h" #include "modules/audio_processing/aec3/render_signal_analyzer.h" @@ -42,6 +43,7 @@ void RunFilterUpdateTest(int num_blocks_to_process, std::array* y_last_block, FftData* G_last_block) { ApmDataDumper data_dumper(42); + Aec3Optimization optimization = DetectOptimization(); constexpr size_t kNumChannels = 1; constexpr int kSampleRateHz = 48000; constexpr size_t kNumBands = NumBandsForRate(kSampleRateHz); @@ -52,11 +54,20 @@ void RunFilterUpdateTest(int num_blocks_to_process, AdaptiveFirFilter main_filter(config.filter.main.length_blocks, config.filter.main.length_blocks, config.filter.config_change_duration_blocks, 1, - 1, DetectOptimization(), &data_dumper); + 1, optimization, &data_dumper); AdaptiveFirFilter shadow_filter(config.filter.shadow.length_blocks, config.filter.shadow.length_blocks, config.filter.config_change_duration_blocks, - 1, 1, DetectOptimization(), &data_dumper); + 1, 1, optimization, &data_dumper); + std::vector> H2( + main_filter.max_filter_size_partitions(), + std::array()); + for (auto& H2_k : H2) { + H2_k.fill(0.f); + } + std::vector h( + GetTimeDomainLength(main_filter.max_filter_size_partitions()), 0.f); + Aec3Fft fft; std::array x_old; x_old.fill(0.f); @@ -168,15 +179,18 @@ void RunFilterUpdateTest(int num_blocks_to_process, // Adapt the main filter render_delay_buffer->GetRenderBuffer()->SpectralSum( main_filter.SizePartitions(), &render_power); - main_gain.Compute(render_power, render_signal_analyzer, output, main_filter, - saturation, &G); - main_filter.Adapt(*render_delay_buffer->GetRenderBuffer(), G); + + std::array erl; + ComputeErl(optimization, H2, erl); + main_gain.Compute(render_power, render_signal_analyzer, output, erl, + main_filter.SizePartitions(), saturation, &G); + main_filter.Adapt(*render_delay_buffer->GetRenderBuffer(), G, &h); // Update the delay. aec_state.HandleEchoPathChange(EchoPathVariability( false, EchoPathVariability::DelayAdjustment::kNone, false)); - aec_state.Update(delay_estimate, main_filter.FilterFrequencyResponse(), - main_filter.FilterImpulseResponse(), + main_filter.ComputeFrequencyResponse(&H2); + aec_state.Update(delay_estimate, H2, h, *render_delay_buffer->GetRenderBuffer(), E2_main, Y2, output, y); } @@ -208,18 +222,17 @@ std::string ProduceDebugText(size_t delay, int filter_length_blocks) { TEST(MainFilterUpdateGain, NullDataOutputGain) { ApmDataDumper data_dumper(42); EchoCanceller3Config config; - AdaptiveFirFilter filter(config.filter.main.length_blocks, - config.filter.main.length_blocks, - config.filter.config_change_duration_blocks, 1, 1, - DetectOptimization(), &data_dumper); - RenderSignalAnalyzer analyzer(EchoCanceller3Config{}); + RenderSignalAnalyzer analyzer(config); SubtractorOutput output; MainFilterUpdateGain gain(config.filter.main, config.filter.config_change_duration_blocks); std::array render_power; render_power.fill(0.f); - EXPECT_DEATH( - gain.Compute(render_power, analyzer, output, filter, false, nullptr), ""); + std::array erl; + erl.fill(0.f); + EXPECT_DEATH(gain.Compute(render_power, analyzer, output, erl, + config.filter.main.length_blocks, false, nullptr), + ""); } #endif diff --git a/modules/audio_processing/aec3/subtractor.cc b/modules/audio_processing/aec3/subtractor.cc index efb79d42dc..4d86358781 100644 --- a/modules/audio_processing/aec3/subtractor.cc +++ b/modules/audio_processing/aec3/subtractor.cc @@ -14,6 +14,7 @@ #include #include "api/array_view.h" +#include "modules/audio_processing/aec3/adaptive_fir_filter_erl.h" #include "modules/audio_processing/aec3/fft_data.h" #include "modules/audio_processing/logging/apm_data_dumper.h" #include "rtc_base/checks.h" @@ -81,8 +82,16 @@ Subtractor::Subtractor(const EchoCanceller3Config& config, G_main_(config_.filter.main_initial, config_.filter.config_change_duration_blocks), G_shadow_(config_.filter.shadow_initial, - config.filter.config_change_duration_blocks) { + config.filter.config_change_duration_blocks), + main_frequency_response_(main_filter_.max_filter_size_partitions(), + std::array()), + main_impulse_response_( + GetTimeDomainLength(main_filter_.max_filter_size_partitions()), + 0.f) { RTC_DCHECK(data_dumper_); + for (auto& H2_k : main_frequency_response_) { + H2_k.fill(0.f); + } } Subtractor::~Subtractor() = default; @@ -150,6 +159,9 @@ void Subtractor::Process(const RenderBuffer& render_buffer, if (filter_misadjustment_estimator_.IsAdjustmentNeeded()) { float scale = filter_misadjustment_estimator_.GetMisadjustment(); main_filter_.ScaleFilter(scale); + for (auto& h_k : main_impulse_response_) { + h_k *= scale; + } ScaleFilterOutput(y, scale, e_main, output->s_main); filter_misadjustment_estimator_.Reset(); main_filter_adjusted = true; @@ -184,13 +196,18 @@ void Subtractor::Process(const RenderBuffer& render_buffer, // Update the main filter. if (!main_filter_adjusted) { - G_main_.Compute(X2_main, render_signal_analyzer, *output, main_filter_, - aec_state.SaturatedCapture(), &G); + std::array erl; + ComputeErl(optimization_, main_frequency_response_, erl); + G_main_.Compute(X2_main, render_signal_analyzer, *output, erl, + main_filter_.SizePartitions(), aec_state.SaturatedCapture(), + &G); } else { G.re.fill(0.f); G.im.fill(0.f); } - main_filter_.Adapt(render_buffer, G); + main_filter_.Adapt(render_buffer, G, &main_impulse_response_); + main_filter_.ComputeFrequencyResponse(&main_frequency_response_); + data_dumper_->DumpRaw("aec3_subtractor_G_main", G.re); data_dumper_->DumpRaw("aec3_subtractor_G_main", G.im); diff --git a/modules/audio_processing/aec3/subtractor.h b/modules/audio_processing/aec3/subtractor.h index a23eaaf707..7c3c5e0930 100644 --- a/modules/audio_processing/aec3/subtractor.h +++ b/modules/audio_processing/aec3/subtractor.h @@ -31,7 +31,6 @@ #include "modules/audio_processing/aec3/subtractor_output.h" #include "modules/audio_processing/logging/apm_data_dumper.h" #include "rtc_base/checks.h" -#include "rtc_base/constructor_magic.h" namespace webrtc { @@ -44,6 +43,8 @@ class Subtractor { ApmDataDumper* data_dumper, Aec3Optimization optimization); ~Subtractor(); + Subtractor(const Subtractor&) = delete; + Subtractor& operator=(const Subtractor&) = delete; // Performs the echo subtraction. void Process(const RenderBuffer& render_buffer, @@ -60,18 +61,22 @@ class Subtractor { // Returns the block-wise frequency response for the main adaptive filter. const std::vector>& FilterFrequencyResponse() const { - return main_filter_.FilterFrequencyResponse(); + return main_frequency_response_; } // Returns the estimate of the impulse response for the main adaptive filter. const std::vector& FilterImpulseResponse() const { - return main_filter_.FilterImpulseResponse(); + return main_impulse_response_; } void DumpFilters() { - main_filter_.DumpFilter("aec3_subtractor_H_main", "aec3_subtractor_h_main"); - shadow_filter_.DumpFilter("aec3_subtractor_H_shadow", - "aec3_subtractor_h_shadow"); + size_t current_size = main_impulse_response_.size(); + main_impulse_response_.resize(main_impulse_response_.capacity()); + data_dumper_->DumpRaw("aec3_subtractor_h_main", main_impulse_response_); + main_impulse_response_.resize(current_size); + + main_filter_.DumpFilter("aec3_subtractor_H_main"); + shadow_filter_.DumpFilter("aec3_subtractor_H_shadow"); } private: @@ -117,7 +122,8 @@ class Subtractor { ShadowFilterUpdateGain G_shadow_; FilterMisadjustmentEstimator filter_misadjustment_estimator_; size_t poor_shadow_filter_counter_ = 0; - RTC_DISALLOW_IMPLICIT_CONSTRUCTORS(Subtractor); + std::vector> main_frequency_response_; + std::vector main_impulse_response_; }; } // namespace webrtc