AEC3: Reducing the complexity and heap usage of the adaptive filter

This CL reduces the complexity and heap usage of the adaptive filter
in AEC3 by avoiding to compute these for the shadow
filter. In particular it
-Moves to compute the ERL, frequency response and impulse response
 on an on-demand basis.
-Stores the ERL, frequency response and impulse response outside
 of the adaptive filter.

All the changes have been tested for bitexactness on a sizeable
amount of recordings.

Bug: webrtc:10913
Change-Id: If83c236a6e3f2e489be129b9ebf6143a72f521d1
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/151138
Commit-Queue: Per Åhgren <peah@webrtc.org>
Reviewed-by: Sam Zackrisson <saza@webrtc.org>
Cr-Commit-Position: refs/heads/master@{#29081}
This commit is contained in:
Per Åhgren
2019-09-05 15:55:58 +02:00
committed by Commit Bot
parent f294d2629f
commit d4e6904d40
12 changed files with 421 additions and 233 deletions

View File

@ -14,6 +14,8 @@ rtc_static_library("aec3") {
sources = [
"adaptive_fir_filter.cc",
"adaptive_fir_filter.h",
"adaptive_fir_filter_erl.cc",
"adaptive_fir_filter_erl.h",
"aec3_common.cc",
"aec3_common.h",
"aec3_fft.cc",
@ -185,6 +187,7 @@ if (rtc_include_tests) {
if (rtc_enable_protobuf) {
sources += [
"adaptive_fir_filter_erl_unittest.cc",
"adaptive_fir_filter_unittest.cc",
"aec3_fft_unittest.cc",
"aec_state_unittest.cc",

View File

@ -82,55 +82,6 @@ void UpdateFrequencyResponse_SSE2(
}
#endif
// Computes and stores the echo return loss estimate of the filter, which is the
// sum of the partition frequency responses.
void UpdateErlEstimator(
const std::vector<std::array<float, kFftLengthBy2Plus1>>& H2,
std::array<float, kFftLengthBy2Plus1>* erl) {
erl->fill(0.f);
for (auto& H2_j : H2) {
std::transform(H2_j.begin(), H2_j.end(), erl->begin(), erl->begin(),
std::plus<float>());
}
}
#if defined(WEBRTC_HAS_NEON)
// Computes and stores the echo return loss estimate of the filter, which is the
// sum of the partition frequency responses.
void UpdateErlEstimator_NEON(
const std::vector<std::array<float, kFftLengthBy2Plus1>>& H2,
std::array<float, kFftLengthBy2Plus1>* erl) {
erl->fill(0.f);
for (auto& H2_j : H2) {
for (size_t k = 0; k < kFftLengthBy2; k += 4) {
const float32x4_t H2_j_k = vld1q_f32(&H2_j[k]);
float32x4_t erl_k = vld1q_f32(&(*erl)[k]);
erl_k = vaddq_f32(erl_k, H2_j_k);
vst1q_f32(&(*erl)[k], erl_k);
}
(*erl)[kFftLengthBy2] += H2_j[kFftLengthBy2];
}
}
#endif
#if defined(WEBRTC_ARCH_X86_FAMILY)
// Computes and stores the echo return loss estimate of the filter, which is the
// sum of the partition frequency responses.
void UpdateErlEstimator_SSE2(
const std::vector<std::array<float, kFftLengthBy2Plus1>>& H2,
std::array<float, kFftLengthBy2Plus1>* erl) {
erl->fill(0.f);
for (auto& H2_j : H2) {
for (size_t k = 0; k < kFftLengthBy2; k += 4) {
const __m128 H2_j_k = _mm_loadu_ps(&H2_j[k]);
__m128 erl_k = _mm_loadu_ps(&(*erl)[k]);
erl_k = _mm_add_ps(erl_k, H2_j_k);
_mm_storeu_ps(&(*erl)[k], erl_k);
}
(*erl)[kFftLengthBy2] += H2_j[kFftLengthBy2];
}
}
#endif
// Adapts the filter partitions as H(t+1)=H(t)+G(t)*conj(X(t)).
void AdaptPartitions(const RenderBuffer& render_buffer,
@ -442,9 +393,7 @@ AdaptiveFirFilter::AdaptiveFirFilter(size_t max_size_partitions,
current_size_partitions_(initial_size_partitions),
target_size_partitions_(initial_size_partitions),
old_target_size_partitions_(initial_size_partitions),
H_(max_size_partitions_),
H2_(max_size_partitions_, std::array<float, kFftLengthBy2Plus1>()),
h_(GetTimeDomainLength(max_size_partitions_), 0.f) {
H_(max_size_partitions_) {
RTC_DCHECK(data_dumper_);
RTC_DCHECK_GE(max_size_partitions, initial_size_partitions);
@ -454,41 +403,23 @@ AdaptiveFirFilter::AdaptiveFirFilter(size_t max_size_partitions,
for (auto& H_j : H_) {
H_j.Clear();
}
for (auto& H2_k : H2_) {
H2_k.fill(0.f);
}
erl_.fill(0.f);
SetSizePartitions(current_size_partitions_, true);
}
AdaptiveFirFilter::~AdaptiveFirFilter() = default;
void AdaptiveFirFilter::HandleEchoPathChange() {
size_t current_h_size = h_.size();
h_.resize(GetTimeDomainLength(max_size_partitions_));
std::fill(h_.begin() + current_h_size, h_.end(), 0.f);
h_.resize(current_h_size);
size_t current_size_partitions = H_.size();
H_.resize(max_size_partitions_);
H2_.resize(max_size_partitions_);
for (size_t k = current_size_partitions; k < max_size_partitions_; ++k) {
H_[k].Clear();
H2_[k].fill(0.f);
}
H_.resize(current_size_partitions);
H2_.resize(current_size_partitions);
erl_.fill(0.f);
}
void AdaptiveFirFilter::SetSizePartitions(size_t size, bool immediate_effect) {
RTC_DCHECK_EQ(max_size_partitions_, H_.capacity());
RTC_DCHECK_EQ(max_size_partitions_, H2_.capacity());
RTC_DCHECK_EQ(GetTimeDomainLength(max_size_partitions_), h_.capacity());
RTC_DCHECK_EQ(H_.size(), H2_.size());
RTC_DCHECK_EQ(h_.size(), GetTimeDomainLength(H_.size()));
RTC_DCHECK_LE(size, max_size_partitions_);
target_size_partitions_ = std::min(max_size_partitions_, size);
@ -503,18 +434,7 @@ void AdaptiveFirFilter::SetSizePartitions(size_t size, bool immediate_effect) {
}
void AdaptiveFirFilter::ResetFilterBuffersToCurrentSize() {
if (current_size_partitions_ < H_.size()) {
for (size_t k = current_size_partitions_; k < H_.size(); ++k) {
H_[k].Clear();
H2_[k].fill(0.f);
}
std::fill(h_.begin() + GetTimeDomainLength(current_size_partitions_),
h_.end(), 0.f);
}
H_.resize(current_size_partitions_);
H2_.resize(current_size_partitions_);
h_.resize(GetTimeDomainLength(current_size_partitions_));
RTC_DCHECK_LT(0, current_size_partitions_);
partition_to_constrain_ =
std::min(partition_to_constrain_, current_size_partitions_ - 1);
@ -564,6 +484,52 @@ void AdaptiveFirFilter::Filter(const RenderBuffer& render_buffer,
void AdaptiveFirFilter::Adapt(const RenderBuffer& render_buffer,
const FftData& G) {
// Adapt the filter and update the filter size.
AdaptAndUpdateSize(render_buffer, G);
// Constrain the filter partitions in a cyclic manner.
Constrain();
}
void AdaptiveFirFilter::Adapt(const RenderBuffer& render_buffer,
const FftData& G,
std::vector<float>* impulse_response) {
// Adapt the filter and update the filter size.
AdaptAndUpdateSize(render_buffer, G);
// Constrain the filter partitions in a cyclic manner.
ConstrainAndUpdateImpulseResponse(impulse_response);
}
void AdaptiveFirFilter::ComputeFrequencyResponse(
std::vector<std::array<float, kFftLengthBy2Plus1>>* H2) const {
RTC_DCHECK_EQ(max_size_partitions_, H2->capacity());
if (H2->size() > H_.size()) {
for (size_t k = H_.size(); k < H2->size(); ++k) {
(*H2)[k].fill(0.f);
}
}
H2->resize(H_.size());
switch (optimization_) {
#if defined(WEBRTC_ARCH_X86_FAMILY)
case Aec3Optimization::kSse2:
aec3::UpdateFrequencyResponse_SSE2(H_, H2);
break;
#endif
#if defined(WEBRTC_HAS_NEON)
case Aec3Optimization::kNeon:
aec3::UpdateFrequencyResponse_NEON(H_, H2);
break;
#endif
default:
aec3::UpdateFrequencyResponse(H_, H2);
}
}
void AdaptiveFirFilter::AdaptAndUpdateSize(const RenderBuffer& render_buffer,
const FftData& G) {
// Update the filter size if needed.
UpdateSize();
@ -582,28 +548,34 @@ void AdaptiveFirFilter::Adapt(const RenderBuffer& render_buffer,
default:
aec3::AdaptPartitions(render_buffer, G, H_);
}
}
// Constrain the filter partitions in a cyclic manner.
Constrain();
// Constrains the partition of the frequency domain filter to be limited in
// time via setting the relevant time-domain coefficients to zero and updates
// the corresponding values in an externally stored impulse response estimate.
void AdaptiveFirFilter::ConstrainAndUpdateImpulseResponse(
std::vector<float>* impulse_response) {
RTC_DCHECK_EQ(GetTimeDomainLength(max_size_partitions_),
impulse_response->capacity());
// Update the frequency response and echo return loss for the filter.
switch (optimization_) {
#if defined(WEBRTC_ARCH_X86_FAMILY)
case Aec3Optimization::kSse2:
aec3::UpdateFrequencyResponse_SSE2(H_, &H2_);
aec3::UpdateErlEstimator_SSE2(H2_, &erl_);
break;
#endif
#if defined(WEBRTC_HAS_NEON)
case Aec3Optimization::kNeon:
aec3::UpdateFrequencyResponse_NEON(H_, &H2_);
aec3::UpdateErlEstimator_NEON(H2_, &erl_);
break;
#endif
default:
aec3::UpdateFrequencyResponse(H_, &H2_);
aec3::UpdateErlEstimator(H2_, &erl_);
}
impulse_response->resize(GetTimeDomainLength(current_size_partitions_));
std::array<float, kFftLength> h;
fft_.Ifft(H_[partition_to_constrain_], &h);
static constexpr float kScale = 1.0f / kFftLengthBy2;
std::for_each(h.begin(), h.begin() + kFftLengthBy2,
[](float& a) { a *= kScale; });
std::fill(h.begin() + kFftLengthBy2, h.end(), 0.f);
std::copy(
h.begin(), h.begin() + kFftLengthBy2,
impulse_response->begin() + partition_to_constrain_ * kFftLengthBy2);
fft_.Fft(&h, &H_[partition_to_constrain_]);
partition_to_constrain_ = partition_to_constrain_ < (H_.size() - 1)
? partition_to_constrain_ + 1
: 0;
}
// Constrains the a partiton of the frequency domain filter to be limited in
@ -617,9 +589,6 @@ void AdaptiveFirFilter::Constrain() {
[](float& a) { a *= kScale; });
std::fill(h.begin() + kFftLengthBy2, h.end(), 0.f);
std::copy(h.begin(), h.begin() + kFftLengthBy2,
h_.begin() + partition_to_constrain_ * kFftLengthBy2);
fft_.Fft(&h, &H_[partition_to_constrain_]);
partition_to_constrain_ = partition_to_constrain_ < (H_.size() - 1)
@ -636,9 +605,6 @@ void AdaptiveFirFilter::ScaleFilter(float factor) {
im *= factor;
}
}
for (auto& h : h_) {
h *= factor;
}
}
// Set the filter coefficients.

View File

@ -22,7 +22,6 @@
#include "modules/audio_processing/aec3/fft_data.h"
#include "modules/audio_processing/aec3/render_buffer.h"
#include "modules/audio_processing/logging/apm_data_dumper.h"
#include "rtc_base/constructor_magic.h"
#include "rtc_base/system/arch.h"
namespace webrtc {
@ -42,22 +41,6 @@ void UpdateFrequencyResponse_SSE2(
std::vector<std::array<float, kFftLengthBy2Plus1>>* H2);
#endif
// Computes and stores the echo return loss estimate of the filter, which is the
// sum of the partition frequency responses.
void UpdateErlEstimator(
const std::vector<std::array<float, kFftLengthBy2Plus1>>& H2,
std::array<float, kFftLengthBy2Plus1>* erl);
#if defined(WEBRTC_HAS_NEON)
void UpdateErlEstimator_NEON(
const std::vector<std::array<float, kFftLengthBy2Plus1>>& H2,
std::array<float, kFftLengthBy2Plus1>* erl);
#endif
#if defined(WEBRTC_ARCH_X86_FAMILY)
void UpdateErlEstimator_SSE2(
const std::vector<std::array<float, kFftLengthBy2Plus1>>& H2,
std::array<float, kFftLengthBy2Plus1>* erl);
#endif
// Adapts the filter partitions.
void AdaptPartitions(const RenderBuffer& render_buffer,
const FftData& G,
@ -103,9 +86,18 @@ class AdaptiveFirFilter {
~AdaptiveFirFilter();
AdaptiveFirFilter(const AdaptiveFirFilter&) = delete;
AdaptiveFirFilter& operator=(const AdaptiveFirFilter&) = delete;
// Produces the output of the filter.
void Filter(const RenderBuffer& render_buffer, FftData* S) const;
// Adapts the filter and updates an externally stored impulse response
// estimate.
void Adapt(const RenderBuffer& render_buffer,
const FftData& G,
std::vector<float>* impulse_response);
// Adapts the filter.
void Adapt(const RenderBuffer& render_buffer, const FftData& G);
@ -119,20 +111,14 @@ class AdaptiveFirFilter {
// Sets the filter size.
void SetSizePartitions(size_t size, bool immediate_effect);
// Returns the filter based echo return loss.
const std::array<float, kFftLengthBy2Plus1>& Erl() const { return erl_; }
// Computes the frequency responses for the filter partitions.
void ComputeFrequencyResponse(
std::vector<std::array<float, kFftLengthBy2Plus1>>* H2) const;
// Returns the frequency responses for the filter partitions.
const std::vector<std::array<float, kFftLengthBy2Plus1>>&
FilterFrequencyResponse() const {
return H2_;
}
// Returns the maximum number of partitions for the filter.
size_t max_filter_size_partitions() const { return max_size_partitions_; }
// Returns the estimate of the impulse response.
const std::vector<float>& FilterImpulseResponse() const { return h_; }
void DumpFilter(const char* name_frequency_domain,
const char* name_time_domain) {
void DumpFilter(const char* name_frequency_domain) {
size_t current_size = H_.size();
H_.resize(max_size_partitions_);
for (auto& H : H_) {
@ -140,11 +126,6 @@ class AdaptiveFirFilter {
data_dumper_->DumpRaw(name_frequency_domain, H.im);
}
H_.resize(current_size);
current_size = h_.size();
h_.resize(GetTimeDomainLength(max_size_partitions_));
data_dumper_->DumpRaw(name_time_domain, h_);
h_.resize(current_size);
}
// Scale the filter impulse response and spectrum by a factor.
@ -157,8 +138,14 @@ class AdaptiveFirFilter {
const std::vector<FftData>& GetFilter() const { return H_; }
private:
// Adapts the filter and updates the filter size.
void AdaptAndUpdateSize(const RenderBuffer& render_buffer, const FftData& G);
// Constrain the filter partitions in a cyclic manner.
void Constrain();
// Constrains the filter in a cyclic manner and updates the corresponding
// values in the supplied impulse response.
void ConstrainAndUpdateImpulseResponse(std::vector<float>* impulse_response);
// Resets the filter buffers to use the current size.
void ResetFilterBuffersToCurrentSize();
@ -177,12 +164,7 @@ class AdaptiveFirFilter {
size_t old_target_size_partitions_;
int size_change_counter_ = 0;
std::vector<FftData> H_;
std::vector<std::array<float, kFftLengthBy2Plus1>> H2_;
std::vector<float> h_;
std::array<float, kFftLengthBy2Plus1> erl_;
size_t partition_to_constrain_ = 0;
RTC_DISALLOW_IMPLICIT_CONSTRUCTORS(AdaptiveFirFilter);
};
} // namespace webrtc

View File

@ -0,0 +1,100 @@
/*
* Copyright (c) 2019 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#include "modules/audio_processing/aec3/adaptive_fir_filter_erl.h"
#include <algorithm>
#include <functional>
#if defined(WEBRTC_HAS_NEON)
#include <arm_neon.h>
#endif
#if defined(WEBRTC_ARCH_X86_FAMILY)
#include <emmintrin.h>
#endif
namespace webrtc {
namespace aec3 {
// Computes and stores the echo return loss estimate of the filter, which is the
// sum of the partition frequency responses.
void ErlComputer(const std::vector<std::array<float, kFftLengthBy2Plus1>>& H2,
rtc::ArrayView<float> erl) {
std::fill(erl.begin(), erl.end(), 0.f);
for (auto& H2_j : H2) {
std::transform(H2_j.begin(), H2_j.end(), erl.begin(), erl.begin(),
std::plus<float>());
}
}
#if defined(WEBRTC_HAS_NEON)
// Computes and stores the echo return loss estimate of the filter, which is the
// sum of the partition frequency responses.
void ErlComputer_NEON(
const std::vector<std::array<float, kFftLengthBy2Plus1>>& H2,
rtc::ArrayView<float> erl) {
std::fill(erl.begin(), erl.end(), 0.f);
for (auto& H2_j : H2) {
for (size_t k = 0; k < kFftLengthBy2; k += 4) {
const float32x4_t H2_j_k = vld1q_f32(&H2_j[k]);
float32x4_t erl_k = vld1q_f32(&erl[k]);
erl_k = vaddq_f32(erl_k, H2_j_k);
vst1q_f32(&erl[k], erl_k);
}
erl[kFftLengthBy2] += H2_j[kFftLengthBy2];
}
}
#endif
#if defined(WEBRTC_ARCH_X86_FAMILY)
// Computes and stores the echo return loss estimate of the filter, which is the
// sum of the partition frequency responses.
void ErlComputer_SSE2(
const std::vector<std::array<float, kFftLengthBy2Plus1>>& H2,
rtc::ArrayView<float> erl) {
std::fill(erl.begin(), erl.end(), 0.f);
for (auto& H2_j : H2) {
for (size_t k = 0; k < kFftLengthBy2; k += 4) {
const __m128 H2_j_k = _mm_loadu_ps(&H2_j[k]);
__m128 erl_k = _mm_loadu_ps(&erl[k]);
erl_k = _mm_add_ps(erl_k, H2_j_k);
_mm_storeu_ps(&erl[k], erl_k);
}
erl[kFftLengthBy2] += H2_j[kFftLengthBy2];
}
}
#endif
} // namespace aec3
void ComputeErl(const Aec3Optimization& optimization,
const std::vector<std::array<float, kFftLengthBy2Plus1>>& H2,
rtc::ArrayView<float> erl) {
RTC_DCHECK_EQ(kFftLengthBy2Plus1, erl.size());
// Update the frequency response and echo return loss for the filter.
switch (optimization) {
#if defined(WEBRTC_ARCH_X86_FAMILY)
case Aec3Optimization::kSse2:
aec3::ErlComputer_SSE2(H2, erl);
break;
#endif
#if defined(WEBRTC_HAS_NEON)
case Aec3Optimization::kNeon:
aec3::ErlComputer_NEON(H2, erl);
break;
#endif
default:
aec3::ErlComputer(H2, erl);
}
}
} // namespace webrtc

View File

@ -0,0 +1,50 @@
/*
* Copyright (c) 2019 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#ifndef MODULES_AUDIO_PROCESSING_AEC3_ADAPTIVE_FIR_FILTER_ERL_H_
#define MODULES_AUDIO_PROCESSING_AEC3_ADAPTIVE_FIR_FILTER_ERL_H_
#include <stddef.h>
#include <array>
#include <vector>
#include "api/array_view.h"
#include "modules/audio_processing/aec3/aec3_common.h"
#include "rtc_base/system/arch.h"
namespace webrtc {
namespace aec3 {
// Computes and stores the echo return loss estimate of the filter, which is the
// sum of the partition frequency responses.
void ErlComputer(const std::vector<std::array<float, kFftLengthBy2Plus1>>& H2,
rtc::ArrayView<float> erl);
#if defined(WEBRTC_HAS_NEON)
void ErlComputer_NEON(
const std::vector<std::array<float, kFftLengthBy2Plus1>>& H2,
rtc::ArrayView<float> erl);
#endif
#if defined(WEBRTC_ARCH_X86_FAMILY)
void ErlComputer_SSE2(
const std::vector<std::array<float, kFftLengthBy2Plus1>>& H2,
rtc::ArrayView<float> erl);
#endif
} // namespace aec3
// Computes the echo return loss based on a frequency response.
void ComputeErl(const Aec3Optimization& optimization,
const std::vector<std::array<float, kFftLengthBy2Plus1>>& H2,
rtc::ArrayView<float> erl);
} // namespace webrtc
#endif // MODULES_AUDIO_PROCESSING_AEC3_ADAPTIVE_FIR_FILTER_ERL_H_

View File

@ -0,0 +1,81 @@
/*
* Copyright (c) 2019 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#include "modules/audio_processing/aec3/adaptive_fir_filter_erl.h"
#include <array>
#include <vector>
#include "rtc_base/system/arch.h"
#if defined(WEBRTC_ARCH_X86_FAMILY)
#include <emmintrin.h>
#endif
#include "system_wrappers/include/cpu_features_wrapper.h"
#include "test/gtest.h"
namespace webrtc {
namespace aec3 {
#if defined(WEBRTC_HAS_NEON)
// Verifies that the optimized method for echo return loss computation is
// bitexact to the reference counterpart.
TEST(AdaptiveFirFilter, UpdateErlNeonOptimization) {
const size_t kNumPartitions = 12;
std::vector<std::array<float, kFftLengthBy2Plus1>> H2(kNumPartitions);
std::array<float, kFftLengthBy2Plus1> erl;
std::array<float, kFftLengthBy2Plus1> erl_NEON;
for (size_t j = 0; j < H2.size(); ++j) {
for (size_t k = 0; k < H2[j].size(); ++k) {
H2[j][k] = k + j / 3.f;
}
}
ErlComputer(H2, erl);
ErlComputer_NEON(H2, erl_NEON);
for (size_t j = 0; j < erl.size(); ++j) {
EXPECT_FLOAT_EQ(erl[j], erl_NEON[j]);
}
}
#endif
#if defined(WEBRTC_ARCH_X86_FAMILY)
// Verifies that the optimized method for echo return loss computation is
// bitexact to the reference counterpart.
TEST(AdaptiveFirFilter, UpdateErlSse2Optimization) {
bool use_sse2 = (WebRtc_GetCPUInfo(kSSE2) != 0);
if (use_sse2) {
const size_t kNumPartitions = 12;
std::vector<std::array<float, kFftLengthBy2Plus1>> H2(kNumPartitions);
std::array<float, kFftLengthBy2Plus1> erl;
std::array<float, kFftLengthBy2Plus1> erl_SSE2;
for (size_t j = 0; j < H2.size(); ++j) {
for (size_t k = 0; k < H2[j].size(); ++k) {
H2[j][k] = k + j / 3.f;
}
}
ErlComputer(H2, erl);
ErlComputer_SSE2(H2, erl_SSE2);
for (size_t j = 0; j < erl.size(); ++j) {
EXPECT_FLOAT_EQ(erl[j], erl_SSE2[j]);
}
}
}
#endif
} // namespace aec3
} // namespace webrtc

View File

@ -22,6 +22,7 @@
#include <emmintrin.h>
#endif
#include "modules/audio_processing/aec3/adaptive_fir_filter_erl.h"
#include "modules/audio_processing/aec3/aec3_fft.h"
#include "modules/audio_processing/aec3/aec_state.h"
#include "modules/audio_processing/aec3/render_delay_buffer.h"
@ -145,28 +146,6 @@ TEST(AdaptiveFirFilter, UpdateFrequencyResponseNeonOptimization) {
}
}
// Verifies that the optimized method for echo return loss computation is
// bitexact to the reference counterpart.
TEST(AdaptiveFirFilter, UpdateErlNeonOptimization) {
const size_t kNumPartitions = 12;
std::vector<std::array<float, kFftLengthBy2Plus1>> H2(kNumPartitions);
std::array<float, kFftLengthBy2Plus1> erl;
std::array<float, kFftLengthBy2Plus1> erl_NEON;
for (size_t j = 0; j < H2.size(); ++j) {
for (size_t k = 0; k < H2[j].size(); ++k) {
H2[j][k] = k + j / 3.f;
}
}
UpdateErlEstimator(H2, &erl);
UpdateErlEstimator_NEON(H2, &erl_NEON);
for (size_t j = 0; j < erl.size(); ++j) {
EXPECT_FLOAT_EQ(erl[j], erl_NEON[j]);
}
}
#endif
#if defined(WEBRTC_ARCH_X86_FAMILY)
@ -266,31 +245,6 @@ TEST(AdaptiveFirFilter, UpdateFrequencyResponseSse2Optimization) {
}
}
// Verifies that the optimized method for echo return loss computation is
// bitexact to the reference counterpart.
TEST(AdaptiveFirFilter, UpdateErlSse2Optimization) {
bool use_sse2 = (WebRtc_GetCPUInfo(kSSE2) != 0);
if (use_sse2) {
const size_t kNumPartitions = 12;
std::vector<std::array<float, kFftLengthBy2Plus1>> H2(kNumPartitions);
std::array<float, kFftLengthBy2Plus1> erl;
std::array<float, kFftLengthBy2Plus1> erl_SSE2;
for (size_t j = 0; j < H2.size(); ++j) {
for (size_t k = 0; k < H2[j].size(); ++k) {
H2[j][k] = k + j / 3.f;
}
}
UpdateErlEstimator(H2, &erl);
UpdateErlEstimator_SSE2(H2, &erl_SSE2);
for (size_t j = 0; j < erl.size(); ++j) {
EXPECT_FLOAT_EQ(erl[j], erl_SSE2[j]);
}
}
}
#endif
#if RTC_DCHECK_IS_ON && GTEST_HAS_DEATH_TEST && !defined(WEBRTC_ANDROID)
@ -316,9 +270,18 @@ TEST(AdaptiveFirFilter, NullFilterOutput) {
// are turned on.
TEST(AdaptiveFirFilter, FilterStatisticsAccess) {
ApmDataDumper data_dumper(42);
AdaptiveFirFilter filter(9, 9, 250, 1, 1, DetectOptimization(), &data_dumper);
filter.Erl();
filter.FilterFrequencyResponse();
Aec3Optimization optimization = DetectOptimization();
AdaptiveFirFilter filter(9, 9, 250, 1, 1, optimization, &data_dumper);
std::vector<std::array<float, kFftLengthBy2Plus1>> H2(
filter.max_filter_size_partitions(),
std::array<float, kFftLengthBy2Plus1>());
for (auto& H2_k : H2) {
H2_k.fill(0.f);
}
std::array<float, kFftLengthBy2Plus1> erl;
ComputeErl(optimization, H2, erl);
filter.ComputeFrequencyResponse(&H2);
}
// Verifies that the filter size if correctly repported.
@ -345,6 +308,11 @@ TEST(AdaptiveFirFilter, FilterAndAdapt) {
config.filter.main.length_blocks,
config.filter.config_change_duration_blocks, 1, 1,
DetectOptimization(), &data_dumper);
std::vector<std::array<float, kFftLengthBy2Plus1>> H2(
filter.max_filter_size_partitions(),
std::array<float, kFftLengthBy2Plus1>());
std::vector<float> h(GetTimeDomainLength(filter.max_filter_size_partitions()),
0.f);
Aec3Fft fft;
config.delay.default_delay = 1;
std::unique_ptr<RenderDelayBuffer> render_delay_buffer(
@ -424,13 +392,13 @@ TEST(AdaptiveFirFilter, FilterAndAdapt) {
render_buffer->SpectralSum(filter.SizePartitions(), &render_power);
gain.Compute(render_power, render_signal_analyzer, E,
filter.SizePartitions(), false, &G);
filter.Adapt(*render_buffer, G);
filter.Adapt(*render_buffer, G, &h);
aec_state.HandleEchoPathChange(EchoPathVariability(
false, EchoPathVariability::DelayAdjustment::kNone, false));
aec_state.Update(delay_estimate, filter.FilterFrequencyResponse(),
filter.FilterImpulseResponse(), *render_buffer, E2_main,
Y2, output, y);
filter.ComputeFrequencyResponse(&H2);
aec_state.Update(delay_estimate, H2, h, *render_buffer, E2_main, Y2,
output, y);
}
// Verify that the filter is able to perform well.
EXPECT_LT(1000 * std::inner_product(e.begin(), e.end(), e.begin(), 0.f),

View File

@ -70,7 +70,8 @@ void MainFilterUpdateGain::Compute(
const std::array<float, kFftLengthBy2Plus1>& render_power,
const RenderSignalAnalyzer& render_signal_analyzer,
const SubtractorOutput& subtractor_output,
const AdaptiveFirFilter& filter,
rtc::ArrayView<const float> erl,
size_t size_partitions,
bool saturated_capture_signal,
FftData* gain_fft) {
RTC_DCHECK(gain_fft);
@ -79,9 +80,8 @@ void MainFilterUpdateGain::Compute(
const auto& E2_main = subtractor_output.E2_main;
const auto& E2_shadow = subtractor_output.E2_shadow;
FftData* G = gain_fft;
const size_t size_partitions = filter.SizePartitions();
auto X2 = render_power;
const auto& erl = filter.Erl();
++call_counter_;
UpdateCurrentConfig();

View File

@ -16,9 +16,9 @@
#include <array>
#include <memory>
#include "api/array_view.h"
#include "api/audio/echo_canceller3_config.h"
#include "modules/audio_processing/aec3/aec3_common.h"
#include "rtc_base/constructor_magic.h"
namespace webrtc {
@ -32,11 +32,14 @@ struct SubtractorOutput;
// Provides functionality for computing the adaptive gain for the main filter.
class MainFilterUpdateGain {
public:
explicit MainFilterUpdateGain(
MainFilterUpdateGain(
const EchoCanceller3Config::Filter::MainConfiguration& config,
size_t config_change_duration_blocks);
~MainFilterUpdateGain();
MainFilterUpdateGain(const MainFilterUpdateGain&) = delete;
MainFilterUpdateGain& operator=(const MainFilterUpdateGain&) = delete;
// Takes action in the case of a known echo path change.
void HandleEchoPathChange(const EchoPathVariability& echo_path_variability);
@ -44,7 +47,8 @@ class MainFilterUpdateGain {
void Compute(const std::array<float, kFftLengthBy2Plus1>& render_power,
const RenderSignalAnalyzer& render_signal_analyzer,
const SubtractorOutput& subtractor_output,
const AdaptiveFirFilter& filter,
rtc::ArrayView<const float> erl,
size_t size_partitions,
bool saturated_capture_signal,
FftData* gain_fft);
@ -76,8 +80,6 @@ class MainFilterUpdateGain {
// Updates the current config towards the target config.
void UpdateCurrentConfig();
RTC_DISALLOW_COPY_AND_ASSIGN(MainFilterUpdateGain);
};
} // namespace webrtc

View File

@ -15,6 +15,7 @@
#include <string>
#include "modules/audio_processing/aec3/adaptive_fir_filter.h"
#include "modules/audio_processing/aec3/adaptive_fir_filter_erl.h"
#include "modules/audio_processing/aec3/aec_state.h"
#include "modules/audio_processing/aec3/render_delay_buffer.h"
#include "modules/audio_processing/aec3/render_signal_analyzer.h"
@ -42,6 +43,7 @@ void RunFilterUpdateTest(int num_blocks_to_process,
std::array<float, kBlockSize>* y_last_block,
FftData* G_last_block) {
ApmDataDumper data_dumper(42);
Aec3Optimization optimization = DetectOptimization();
constexpr size_t kNumChannels = 1;
constexpr int kSampleRateHz = 48000;
constexpr size_t kNumBands = NumBandsForRate(kSampleRateHz);
@ -52,11 +54,20 @@ void RunFilterUpdateTest(int num_blocks_to_process,
AdaptiveFirFilter main_filter(config.filter.main.length_blocks,
config.filter.main.length_blocks,
config.filter.config_change_duration_blocks, 1,
1, DetectOptimization(), &data_dumper);
1, optimization, &data_dumper);
AdaptiveFirFilter shadow_filter(config.filter.shadow.length_blocks,
config.filter.shadow.length_blocks,
config.filter.config_change_duration_blocks,
1, 1, DetectOptimization(), &data_dumper);
1, 1, optimization, &data_dumper);
std::vector<std::array<float, kFftLengthBy2Plus1>> H2(
main_filter.max_filter_size_partitions(),
std::array<float, kFftLengthBy2Plus1>());
for (auto& H2_k : H2) {
H2_k.fill(0.f);
}
std::vector<float> h(
GetTimeDomainLength(main_filter.max_filter_size_partitions()), 0.f);
Aec3Fft fft;
std::array<float, kBlockSize> x_old;
x_old.fill(0.f);
@ -168,15 +179,18 @@ void RunFilterUpdateTest(int num_blocks_to_process,
// Adapt the main filter
render_delay_buffer->GetRenderBuffer()->SpectralSum(
main_filter.SizePartitions(), &render_power);
main_gain.Compute(render_power, render_signal_analyzer, output, main_filter,
saturation, &G);
main_filter.Adapt(*render_delay_buffer->GetRenderBuffer(), G);
std::array<float, kFftLengthBy2Plus1> erl;
ComputeErl(optimization, H2, erl);
main_gain.Compute(render_power, render_signal_analyzer, output, erl,
main_filter.SizePartitions(), saturation, &G);
main_filter.Adapt(*render_delay_buffer->GetRenderBuffer(), G, &h);
// Update the delay.
aec_state.HandleEchoPathChange(EchoPathVariability(
false, EchoPathVariability::DelayAdjustment::kNone, false));
aec_state.Update(delay_estimate, main_filter.FilterFrequencyResponse(),
main_filter.FilterImpulseResponse(),
main_filter.ComputeFrequencyResponse(&H2);
aec_state.Update(delay_estimate, H2, h,
*render_delay_buffer->GetRenderBuffer(), E2_main, Y2,
output, y);
}
@ -208,18 +222,17 @@ std::string ProduceDebugText(size_t delay, int filter_length_blocks) {
TEST(MainFilterUpdateGain, NullDataOutputGain) {
ApmDataDumper data_dumper(42);
EchoCanceller3Config config;
AdaptiveFirFilter filter(config.filter.main.length_blocks,
config.filter.main.length_blocks,
config.filter.config_change_duration_blocks, 1, 1,
DetectOptimization(), &data_dumper);
RenderSignalAnalyzer analyzer(EchoCanceller3Config{});
RenderSignalAnalyzer analyzer(config);
SubtractorOutput output;
MainFilterUpdateGain gain(config.filter.main,
config.filter.config_change_duration_blocks);
std::array<float, kFftLengthBy2Plus1> render_power;
render_power.fill(0.f);
EXPECT_DEATH(
gain.Compute(render_power, analyzer, output, filter, false, nullptr), "");
std::array<float, kFftLengthBy2Plus1> erl;
erl.fill(0.f);
EXPECT_DEATH(gain.Compute(render_power, analyzer, output, erl,
config.filter.main.length_blocks, false, nullptr),
"");
}
#endif

View File

@ -14,6 +14,7 @@
#include <utility>
#include "api/array_view.h"
#include "modules/audio_processing/aec3/adaptive_fir_filter_erl.h"
#include "modules/audio_processing/aec3/fft_data.h"
#include "modules/audio_processing/logging/apm_data_dumper.h"
#include "rtc_base/checks.h"
@ -81,8 +82,16 @@ Subtractor::Subtractor(const EchoCanceller3Config& config,
G_main_(config_.filter.main_initial,
config_.filter.config_change_duration_blocks),
G_shadow_(config_.filter.shadow_initial,
config.filter.config_change_duration_blocks) {
config.filter.config_change_duration_blocks),
main_frequency_response_(main_filter_.max_filter_size_partitions(),
std::array<float, kFftLengthBy2Plus1>()),
main_impulse_response_(
GetTimeDomainLength(main_filter_.max_filter_size_partitions()),
0.f) {
RTC_DCHECK(data_dumper_);
for (auto& H2_k : main_frequency_response_) {
H2_k.fill(0.f);
}
}
Subtractor::~Subtractor() = default;
@ -150,6 +159,9 @@ void Subtractor::Process(const RenderBuffer& render_buffer,
if (filter_misadjustment_estimator_.IsAdjustmentNeeded()) {
float scale = filter_misadjustment_estimator_.GetMisadjustment();
main_filter_.ScaleFilter(scale);
for (auto& h_k : main_impulse_response_) {
h_k *= scale;
}
ScaleFilterOutput(y, scale, e_main, output->s_main);
filter_misadjustment_estimator_.Reset();
main_filter_adjusted = true;
@ -184,13 +196,18 @@ void Subtractor::Process(const RenderBuffer& render_buffer,
// Update the main filter.
if (!main_filter_adjusted) {
G_main_.Compute(X2_main, render_signal_analyzer, *output, main_filter_,
aec_state.SaturatedCapture(), &G);
std::array<float, kFftLengthBy2Plus1> erl;
ComputeErl(optimization_, main_frequency_response_, erl);
G_main_.Compute(X2_main, render_signal_analyzer, *output, erl,
main_filter_.SizePartitions(), aec_state.SaturatedCapture(),
&G);
} else {
G.re.fill(0.f);
G.im.fill(0.f);
}
main_filter_.Adapt(render_buffer, G);
main_filter_.Adapt(render_buffer, G, &main_impulse_response_);
main_filter_.ComputeFrequencyResponse(&main_frequency_response_);
data_dumper_->DumpRaw("aec3_subtractor_G_main", G.re);
data_dumper_->DumpRaw("aec3_subtractor_G_main", G.im);

View File

@ -31,7 +31,6 @@
#include "modules/audio_processing/aec3/subtractor_output.h"
#include "modules/audio_processing/logging/apm_data_dumper.h"
#include "rtc_base/checks.h"
#include "rtc_base/constructor_magic.h"
namespace webrtc {
@ -44,6 +43,8 @@ class Subtractor {
ApmDataDumper* data_dumper,
Aec3Optimization optimization);
~Subtractor();
Subtractor(const Subtractor&) = delete;
Subtractor& operator=(const Subtractor&) = delete;
// Performs the echo subtraction.
void Process(const RenderBuffer& render_buffer,
@ -60,18 +61,22 @@ class Subtractor {
// Returns the block-wise frequency response for the main adaptive filter.
const std::vector<std::array<float, kFftLengthBy2Plus1>>&
FilterFrequencyResponse() const {
return main_filter_.FilterFrequencyResponse();
return main_frequency_response_;
}
// Returns the estimate of the impulse response for the main adaptive filter.
const std::vector<float>& FilterImpulseResponse() const {
return main_filter_.FilterImpulseResponse();
return main_impulse_response_;
}
void DumpFilters() {
main_filter_.DumpFilter("aec3_subtractor_H_main", "aec3_subtractor_h_main");
shadow_filter_.DumpFilter("aec3_subtractor_H_shadow",
"aec3_subtractor_h_shadow");
size_t current_size = main_impulse_response_.size();
main_impulse_response_.resize(main_impulse_response_.capacity());
data_dumper_->DumpRaw("aec3_subtractor_h_main", main_impulse_response_);
main_impulse_response_.resize(current_size);
main_filter_.DumpFilter("aec3_subtractor_H_main");
shadow_filter_.DumpFilter("aec3_subtractor_H_shadow");
}
private:
@ -117,7 +122,8 @@ class Subtractor {
ShadowFilterUpdateGain G_shadow_;
FilterMisadjustmentEstimator filter_misadjustment_estimator_;
size_t poor_shadow_filter_counter_ = 0;
RTC_DISALLOW_IMPLICIT_CONSTRUCTORS(Subtractor);
std::vector<std::array<float, kFftLengthBy2Plus1>> main_frequency_response_;
std::vector<float> main_impulse_response_;
};
} // namespace webrtc