delay estimator: Look for early reverberation

Look for first echo (and not only the strongest one) on the same matched
filter.

This change is bit exact with previous version when `pre_echo` is false.

Author: Jesús de Vicente Peña <devicentepena@webrtc.org>

Bug: webrtc:14205
Change-Id: I6782eaa1d690b0df78d00f6d425a85c951b2ca9d
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/266321
Reviewed-by: Gustaf Ullberg <gustaf@webrtc.org>
Commit-Queue: Lionel Koenig <lionelk@webrtc.org>
Cr-Commit-Position: refs/heads/main@{#37360}
This commit is contained in:
Lionel Koenig
2022-06-28 15:37:13 +02:00
committed by WebRTC LUCI CQ
parent 7534ebd2bf
commit 8783c678a5
14 changed files with 957 additions and 395 deletions

View File

@ -59,6 +59,7 @@ struct RTC_EXPORT EchoCanceller3Config {
};
AlignmentMixing render_alignment_mixing = {false, true, 10000.f, true};
AlignmentMixing capture_alignment_mixing = {false, true, 10000.f, false};
bool detect_pre_echo = false;
} delay;
struct Filter {

View File

@ -220,6 +220,7 @@ void Aec3ConfigFromJsonString(absl::string_view json_string,
&cfg.delay.render_alignment_mixing);
ReadParam(section, "capture_alignment_mixing",
&cfg.delay.capture_alignment_mixing);
ReadParam(section, "detect_pre_echo", &cfg.delay.detect_pre_echo);
}
if (rtc::GetValueFromJsonObject(aec3_root, "filter", &section)) {
@ -505,7 +506,9 @@ std::string Aec3ConfigToJsonString(const EchoCanceller3Config& config) {
<< (config.delay.capture_alignment_mixing.prefer_first_two_channels
? "true"
: "false");
ost << "}";
ost << "},";
ost << "\"detect_pre_echo\": "
<< (config.delay.detect_pre_echo ? "true" : "false");
ost << "},";
ost << "\"filter\": {";

View File

@ -226,6 +226,7 @@ rtc_source_set("matched_filter") {
"../../../api:array_view",
"../../../rtc_base/system:arch",
]
absl_deps = [ "//third_party/abseil-cpp/absl/types:optional" ]
}
rtc_source_set("vector_math") {

View File

@ -377,6 +377,14 @@ EchoCanceller3Config AdjustConfig(const EchoCanceller3Config& config) {
false;
}
if (field_trial::IsEnabled("WebRTC-Aec3DelayEstimatorDetectPreEcho")) {
adjusted_cfg.delay.detect_pre_echo = true;
}
if (field_trial::IsDisabled("WebRTC-Aec3DelayEstimatorDetectPreEcho")) {
adjusted_cfg.delay.detect_pre_echo = false;
}
if (field_trial::IsEnabled("WebRTC-Aec3SensitiveDominantNearendActivation")) {
adjusted_cfg.suppressor.dominant_nearend_detection.enr_threshold = 0.5f;
} else if (field_trial::IsEnabled(

View File

@ -43,10 +43,11 @@ EchoPathDelayEstimator::EchoPathDelayEstimator(
: config.render_levels.poor_excitation_render_limit,
config.delay.delay_estimate_smoothing,
config.delay.delay_estimate_smoothing_delay_found,
config.delay.delay_candidate_detection_threshold),
config.delay.delay_candidate_detection_threshold,
config.delay.detect_pre_echo),
matched_filter_lag_aggregator_(data_dumper_,
matched_filter_.GetMaxFilterLag(),
config.delay.delay_selection_thresholds) {
config.delay) {
RTC_DCHECK(data_dumper);
RTC_DCHECK(down_sampling_factor_ > 0);
}
@ -75,13 +76,14 @@ absl::optional<DelayEstimate> EchoPathDelayEstimator::EstimateDelay(
absl::optional<DelayEstimate> aggregated_matched_filter_lag =
matched_filter_lag_aggregator_.Aggregate(
matched_filter_.GetLagEstimates());
matched_filter_.GetBestLagEstimate());
// Run clockdrift detection.
if (aggregated_matched_filter_lag &&
(*aggregated_matched_filter_lag).quality ==
DelayEstimate::Quality::kRefined)
clockdrift_detector_.Update((*aggregated_matched_filter_lag).delay);
clockdrift_detector_.Update(
matched_filter_lag_aggregator_.GetDelayAtHighestPeak());
// TODO(peah): Move this logging outside of this class once EchoCanceller3
// development is done.

View File

@ -78,6 +78,7 @@ TEST(EchoPathDelayEstimator, DelayEstimation) {
constexpr size_t kDownSamplingFactors[] = {2, 4, 8};
for (auto down_sampling_factor : kDownSamplingFactors) {
EchoCanceller3Config config;
config.delay.delay_headroom_samples = 0;
config.delay.down_sampling_factor = down_sampling_factor;
config.delay.num_filters = 10;
for (size_t delay_samples : {30, 64, 150, 200, 800, 4000}) {
@ -111,12 +112,13 @@ TEST(EchoPathDelayEstimator, DelayEstimation) {
}
if (estimated_delay_samples) {
// Allow estimated delay to be off by one sample in the down-sampled
// domain.
// Allow estimated delay to be off by a block as internally the delay is
// quantized with an error up to a block.
size_t delay_ds = delay_samples / down_sampling_factor;
size_t estimated_delay_ds =
estimated_delay_samples->delay / down_sampling_factor;
EXPECT_NEAR(delay_ds, estimated_delay_ds, 1);
EXPECT_NEAR(delay_ds, estimated_delay_ds,
kBlockSize / down_sampling_factor);
} else {
ADD_FAILURE();
}

View File

@ -24,16 +24,147 @@
#include <iterator>
#include <numeric>
#include "absl/types/optional.h"
#include "api/array_view.h"
#include "modules/audio_processing/aec3/downsampled_render_buffer.h"
#include "modules/audio_processing/logging/apm_data_dumper.h"
#include "rtc_base/checks.h"
#include "rtc_base/logging.h"
namespace {
// Subsample rate used for computing the accumulated error.
// The implementation of some core functions depends on this constant being
// equal to 4.
constexpr int kAccumulatedErrorSubSampleRate = 4;
void UpdateAccumulatedError(
const rtc::ArrayView<const float> instantaneous_accumulated_error,
const rtc::ArrayView<float> accumulated_error,
float one_over_error_sum_anchor) {
for (size_t k = 0; k < instantaneous_accumulated_error.size(); ++k) {
float error_norm =
instantaneous_accumulated_error[k] * one_over_error_sum_anchor;
if (error_norm < accumulated_error[k]) {
accumulated_error[k] = error_norm;
} else {
accumulated_error[k] += 0.01f * (error_norm - accumulated_error[k]);
}
}
}
size_t ComputePreEchoLag(const rtc::ArrayView<float> accumulated_error,
size_t lag,
size_t alignment_shift_winner) {
size_t pre_echo_lag_estimate = lag - alignment_shift_winner;
size_t maximum_pre_echo_lag =
std::min(pre_echo_lag_estimate / kAccumulatedErrorSubSampleRate,
accumulated_error.size());
for (size_t k = 1; k < maximum_pre_echo_lag; ++k) {
if (accumulated_error[k] < 0.5f * accumulated_error[k - 1] &&
accumulated_error[k] < 0.5f) {
pre_echo_lag_estimate = (k + 1) * kAccumulatedErrorSubSampleRate - 1;
break;
}
}
return pre_echo_lag_estimate + alignment_shift_winner;
}
} // namespace
namespace webrtc {
namespace aec3 {
#if defined(WEBRTC_HAS_NEON)
inline float SumAllElements(float32x4_t elements) {
float32x2_t sum = vpadd_f32(vget_low_f32(elements), vget_high_f32(elements));
sum = vpadd_f32(sum, sum);
return vget_lane_f32(sum, 0);
}
void MatchedFilterCoreWithAccumulatedError_NEON(
size_t x_start_index,
float x2_sum_threshold,
float smoothing,
rtc::ArrayView<const float> x,
rtc::ArrayView<const float> y,
rtc::ArrayView<float> h,
bool* filters_updated,
float* error_sum,
rtc::ArrayView<float> accumulated_error,
rtc::ArrayView<float> scratch_memory) {
const int h_size = static_cast<int>(h.size());
const int x_size = static_cast<int>(x.size());
RTC_DCHECK_EQ(0, h_size % 4);
std::fill(accumulated_error.begin(), accumulated_error.end(), 0.0f);
// Process for all samples in the sub-block.
for (size_t i = 0; i < y.size(); ++i) {
// Apply the matched filter as filter * x, and compute x * x.
RTC_DCHECK_GT(x_size, x_start_index);
// Compute loop chunk sizes until, and after, the wraparound of the circular
// buffer for x.
const int chunk1 =
std::min(h_size, static_cast<int>(x_size - x_start_index));
if (chunk1 != h_size) {
const int chunk2 = h_size - chunk1;
std::copy(x.begin() + x_start_index, x.end(), scratch_memory.begin());
std::copy(x.begin(), x.begin() + chunk2, scratch_memory.begin() + chunk1);
}
const float* x_p =
chunk1 != h_size ? scratch_memory.data() : &x[x_start_index];
const float* h_p = &h[0];
float* accumulated_error_p = &accumulated_error[0];
// Initialize values for the accumulation.
float32x4_t x2_sum_128 = vdupq_n_f32(0);
float x2_sum = 0.f;
float s = 0;
// Perform 128 bit vector operations.
const int limit_by_4 = h_size >> 2;
for (int k = limit_by_4; k > 0;
--k, h_p += 4, x_p += 4, accumulated_error_p++) {
// Load the data into 128 bit vectors.
const float32x4_t x_k = vld1q_f32(x_p);
const float32x4_t h_k = vld1q_f32(h_p);
// Compute and accumulate x * x.
x2_sum_128 = vmlaq_f32(x2_sum_128, x_k, x_k);
// Compute x * h
float32x4_t hk_xk_128 = vmulq_f32(h_k, x_k);
s += SumAllElements(hk_xk_128);
const float e = s - y[i];
accumulated_error_p[0] += e * e;
}
// Combine the accumulated vector and scalar values.
x2_sum += SumAllElements(x2_sum_128);
// Compute the matched filter error.
float e = y[i] - s;
const bool saturation = y[i] >= 32000.f || y[i] <= -32000.f;
(*error_sum) += e * e;
// Update the matched filter estimate in an NLMS manner.
if (x2_sum > x2_sum_threshold && !saturation) {
RTC_DCHECK_LT(0.f, x2_sum);
const float alpha = smoothing * e / x2_sum;
const float32x4_t alpha_128 = vmovq_n_f32(alpha);
// filter = filter + smoothing * (y - filter * x) * x / x * x.
float* h_p = &h[0];
x_p = chunk1 != h_size ? scratch_memory.data() : &x[x_start_index];
// Perform 128 bit vector operations.
const int limit_by_4 = h_size >> 2;
for (int k = limit_by_4; k > 0; --k, h_p += 4, x_p += 4) {
// Load the data into 128 bit vectors.
float32x4_t h_k = vld1q_f32(h_p);
const float32x4_t x_k = vld1q_f32(x_p);
// Compute h = h + alpha * x.
h_k = vmlaq_f32(h_k, alpha_128, x_k);
// Store the result.
vst1q_f32(h_p, h_k);
}
*filters_updated = true;
}
x_start_index = x_start_index > 0 ? x_start_index - 1 : x_size - 1;
}
}
void MatchedFilterCore_NEON(size_t x_start_index,
float x2_sum_threshold,
float smoothing,
@ -41,11 +172,20 @@ void MatchedFilterCore_NEON(size_t x_start_index,
rtc::ArrayView<const float> y,
rtc::ArrayView<float> h,
bool* filters_updated,
float* error_sum) {
float* error_sum,
bool compute_accumulated_error,
rtc::ArrayView<float> accumulated_error,
rtc::ArrayView<float> scratch_memory) {
const int h_size = static_cast<int>(h.size());
const int x_size = static_cast<int>(x.size());
RTC_DCHECK_EQ(0, h_size % 4);
if (compute_accumulated_error) {
return MatchedFilterCoreWithAccumulatedError_NEON(
x_start_index, x2_sum_threshold, smoothing, x, y, h, filters_updated,
error_sum, accumulated_error, scratch_memory);
}
// Process for all samples in the sub-block.
for (size_t i = 0; i < y.size(); ++i) {
// Apply the matched filter as filter * x, and compute x * x.
@ -90,10 +230,8 @@ void MatchedFilterCore_NEON(size_t x_start_index,
}
// Combine the accumulated vector and scalar values.
float* v = reinterpret_cast<float*>(&x2_sum_128);
x2_sum += v[0] + v[1] + v[2] + v[3];
v = reinterpret_cast<float*>(&s_128);
s += v[0] + v[1] + v[2] + v[3];
s += SumAllElements(s_128);
x2_sum += SumAllElements(x2_sum_128);
// Compute the matched filter error.
float e = y[i] - s;
@ -144,6 +282,103 @@ void MatchedFilterCore_NEON(size_t x_start_index,
#if defined(WEBRTC_ARCH_X86_FAMILY)
void MatchedFilterCore_AccumulatedError_SSE2(
size_t x_start_index,
float x2_sum_threshold,
float smoothing,
rtc::ArrayView<const float> x,
rtc::ArrayView<const float> y,
rtc::ArrayView<float> h,
bool* filters_updated,
float* error_sum,
rtc::ArrayView<float> accumulated_error,
rtc::ArrayView<float> scratch_memory) {
const int h_size = static_cast<int>(h.size());
const int x_size = static_cast<int>(x.size());
RTC_DCHECK_EQ(0, h_size % 8);
std::fill(accumulated_error.begin(), accumulated_error.end(), 0.0f);
// Process for all samples in the sub-block.
for (size_t i = 0; i < y.size(); ++i) {
// Apply the matched filter as filter * x, and compute x * x.
RTC_DCHECK_GT(x_size, x_start_index);
const int chunk1 =
std::min(h_size, static_cast<int>(x_size - x_start_index));
if (chunk1 != h_size) {
const int chunk2 = h_size - chunk1;
std::copy(x.begin() + x_start_index, x.end(), scratch_memory.begin());
std::copy(x.begin(), x.begin() + chunk2, scratch_memory.begin() + chunk1);
}
const float* x_p =
chunk1 != h_size ? scratch_memory.data() : &x[x_start_index];
const float* h_p = &h[0];
float* a_p = &accumulated_error[0];
__m128 s_inst_128;
__m128 s_inst_128_4;
__m128 x2_sum_128 = _mm_set1_ps(0);
__m128 x2_sum_128_4 = _mm_set1_ps(0);
__m128 e_128;
float* const s_p = reinterpret_cast<float*>(&s_inst_128);
float* const s_4_p = reinterpret_cast<float*>(&s_inst_128_4);
float* const e_p = reinterpret_cast<float*>(&e_128);
float x2_sum = 0.0f;
float s_acum = 0;
// Perform 128 bit vector operations.
const int limit_by_8 = h_size >> 3;
for (int k = limit_by_8; k > 0; --k, h_p += 8, x_p += 8, a_p += 2) {
// Load the data into 128 bit vectors.
const __m128 x_k = _mm_loadu_ps(x_p);
const __m128 h_k = _mm_loadu_ps(h_p);
const __m128 x_k_4 = _mm_loadu_ps(x_p + 4);
const __m128 h_k_4 = _mm_loadu_ps(h_p + 4);
const __m128 xx = _mm_mul_ps(x_k, x_k);
const __m128 xx_4 = _mm_mul_ps(x_k_4, x_k_4);
// Compute and accumulate x * x and h * x.
x2_sum_128 = _mm_add_ps(x2_sum_128, xx);
x2_sum_128_4 = _mm_add_ps(x2_sum_128_4, xx_4);
s_inst_128 = _mm_mul_ps(h_k, x_k);
s_inst_128_4 = _mm_mul_ps(h_k_4, x_k_4);
s_acum += s_p[0] + s_p[1] + s_p[2] + s_p[3];
e_p[0] = s_acum - y[i];
s_acum += s_4_p[0] + s_4_p[1] + s_4_p[2] + s_4_p[3];
e_p[1] = s_acum - y[i];
a_p[0] += e_p[0] * e_p[0];
a_p[1] += e_p[1] * e_p[1];
}
// Combine the accumulated vector and scalar values.
x2_sum_128 = _mm_add_ps(x2_sum_128, x2_sum_128_4);
float* v = reinterpret_cast<float*>(&x2_sum_128);
x2_sum += v[0] + v[1] + v[2] + v[3];
// Compute the matched filter error.
float e = y[i] - s_acum;
const bool saturation = y[i] >= 32000.f || y[i] <= -32000.f;
(*error_sum) += e * e;
// Update the matched filter estimate in an NLMS manner.
if (x2_sum > x2_sum_threshold && !saturation) {
RTC_DCHECK_LT(0.f, x2_sum);
const float alpha = smoothing * e / x2_sum;
const __m128 alpha_128 = _mm_set1_ps(alpha);
// filter = filter + smoothing * (y - filter * x) * x / x * x.
float* h_p = &h[0];
const float* x_p =
chunk1 != h_size ? scratch_memory.data() : &x[x_start_index];
// Perform 128 bit vector operations.
const int limit_by_4 = h_size >> 2;
for (int k = limit_by_4; k > 0; --k, h_p += 4, x_p += 4) {
// Load the data into 128 bit vectors.
__m128 h_k = _mm_loadu_ps(h_p);
const __m128 x_k = _mm_loadu_ps(x_p);
// Compute h = h + alpha * x.
const __m128 alpha_x = _mm_mul_ps(alpha_128, x_k);
h_k = _mm_add_ps(h_k, alpha_x);
// Store the result.
_mm_storeu_ps(h_p, h_k);
}
*filters_updated = true;
}
x_start_index = x_start_index > 0 ? x_start_index - 1 : x_size - 1;
}
}
void MatchedFilterCore_SSE2(size_t x_start_index,
float x2_sum_threshold,
float smoothing,
@ -151,19 +386,24 @@ void MatchedFilterCore_SSE2(size_t x_start_index,
rtc::ArrayView<const float> y,
rtc::ArrayView<float> h,
bool* filters_updated,
float* error_sum) {
float* error_sum,
bool compute_accumulated_error,
rtc::ArrayView<float> accumulated_error,
rtc::ArrayView<float> scratch_memory) {
if (compute_accumulated_error) {
return MatchedFilterCore_AccumulatedError_SSE2(
x_start_index, x2_sum_threshold, smoothing, x, y, h, filters_updated,
error_sum, accumulated_error, scratch_memory);
}
const int h_size = static_cast<int>(h.size());
const int x_size = static_cast<int>(x.size());
RTC_DCHECK_EQ(0, h_size % 4);
// Process for all samples in the sub-block.
for (size_t i = 0; i < y.size(); ++i) {
// Apply the matched filter as filter * x, and compute x * x.
RTC_DCHECK_GT(x_size, x_start_index);
const float* x_p = &x[x_start_index];
const float* h_p = &h[0];
// Initialize values for the accumulation.
__m128 s_128 = _mm_set1_ps(0);
__m128 s_128_4 = _mm_set1_ps(0);
@ -171,12 +411,10 @@ void MatchedFilterCore_SSE2(size_t x_start_index,
__m128 x2_sum_128_4 = _mm_set1_ps(0);
float x2_sum = 0.f;
float s = 0;
// Compute loop chunk sizes until, and after, the wraparound of the circular
// buffer for x.
const int chunk1 =
std::min(h_size, static_cast<int>(x_size - x_start_index));
// Perform the loop in two chunks.
const int chunk2 = h_size - chunk1;
for (int limit : {chunk1, chunk2}) {
@ -198,17 +436,14 @@ void MatchedFilterCore_SSE2(size_t x_start_index,
s_128 = _mm_add_ps(s_128, hx);
s_128_4 = _mm_add_ps(s_128_4, hx_4);
}
// Perform non-vector operations for any remaining items.
for (int k = limit - limit_by_8 * 8; k > 0; --k, ++h_p, ++x_p) {
const float x_k = *x_p;
x2_sum += x_k * x_k;
s += *h_p * x_k;
}
x_p = &x[0];
}
// Combine the accumulated vector and scalar values.
x2_sum_128 = _mm_add_ps(x2_sum_128, x2_sum_128_4);
float* v = reinterpret_cast<float*>(&x2_sum_128);
@ -216,22 +451,18 @@ void MatchedFilterCore_SSE2(size_t x_start_index,
s_128 = _mm_add_ps(s_128, s_128_4);
v = reinterpret_cast<float*>(&s_128);
s += v[0] + v[1] + v[2] + v[3];
// Compute the matched filter error.
float e = y[i] - s;
const bool saturation = y[i] >= 32000.f || y[i] <= -32000.f;
(*error_sum) += e * e;
// Update the matched filter estimate in an NLMS manner.
if (x2_sum > x2_sum_threshold && !saturation) {
RTC_DCHECK_LT(0.f, x2_sum);
const float alpha = smoothing * e / x2_sum;
const __m128 alpha_128 = _mm_set1_ps(alpha);
// filter = filter + smoothing * (y - filter * x) * x / x * x.
float* h_p = &h[0];
x_p = &x[x_start_index];
// Perform the loop in two chunks.
for (int limit : {chunk1, chunk2}) {
// Perform 128 bit vector operations.
@ -244,22 +475,17 @@ void MatchedFilterCore_SSE2(size_t x_start_index,
// Compute h = h + alpha * x.
const __m128 alpha_x = _mm_mul_ps(alpha_128, x_k);
h_k = _mm_add_ps(h_k, alpha_x);
// Store the result.
_mm_storeu_ps(h_p, h_k);
}
// Perform non-vector operations for any remaining items.
for (int k = limit - limit_by_4 * 4; k > 0; --k, ++h_p, ++x_p) {
*h_p += alpha * *x_p;
}
x_p = &x[0];
}
*filters_updated = true;
}
x_start_index = x_start_index > 0 ? x_start_index - 1 : x_size - 1;
}
}
@ -272,17 +498,35 @@ void MatchedFilterCore(size_t x_start_index,
rtc::ArrayView<const float> y,
rtc::ArrayView<float> h,
bool* filters_updated,
float* error_sum) {
float* error_sum,
bool compute_accumulated_error,
rtc::ArrayView<float> accumulated_error) {
if (compute_accumulated_error) {
std::fill(accumulated_error.begin(), accumulated_error.end(), 0.0f);
}
// Process for all samples in the sub-block.
for (size_t i = 0; i < y.size(); ++i) {
// Apply the matched filter as filter * x, and compute x * x.
float x2_sum = 0.f;
float s = 0;
size_t x_index = x_start_index;
if (compute_accumulated_error) {
for (size_t k = 0; k < h.size(); ++k) {
x2_sum += x[x_index] * x[x_index];
s += h[k] * x[x_index];
x_index = x_index < (x.size() - 1) ? x_index + 1 : 0;
if ((k + 1 & 0b11) == 0) {
int idx = k >> 2;
accumulated_error[idx] += (y[i] - s) * (y[i] - s);
}
}
} else {
for (size_t k = 0; k < h.size(); ++k) {
x2_sum += x[x_index] * x[x_index];
s += h[k] * x[x_index];
x_index = x_index < (x.size() - 1) ? x_index + 1 : 0;
}
}
// Compute the matched filter error.
@ -354,7 +598,8 @@ MatchedFilter::MatchedFilter(ApmDataDumper* data_dumper,
float excitation_limit,
float smoothing_fast,
float smoothing_slow,
float matching_filter_threshold)
float matching_filter_threshold,
bool detect_pre_echo)
: data_dumper_(data_dumper),
optimization_(optimization),
sub_block_size_(sub_block_size),
@ -362,16 +607,31 @@ MatchedFilter::MatchedFilter(ApmDataDumper* data_dumper,
filters_(
num_matched_filters,
std::vector<float>(window_size_sub_blocks * sub_block_size_, 0.f)),
lag_estimates_(num_matched_filters),
filters_offsets_(num_matched_filters, 0),
excitation_limit_(excitation_limit),
smoothing_fast_(smoothing_fast),
smoothing_slow_(smoothing_slow),
matching_filter_threshold_(matching_filter_threshold) {
matching_filter_threshold_(matching_filter_threshold),
detect_pre_echo_(detect_pre_echo) {
RTC_DCHECK(data_dumper);
RTC_DCHECK_LT(0, window_size_sub_blocks);
RTC_DCHECK((kBlockSize % sub_block_size) == 0);
RTC_DCHECK((sub_block_size % 4) == 0);
static_assert(kAccumulatedErrorSubSampleRate == 4);
if (detect_pre_echo_) {
accumulated_error_ = std::vector<std::vector<float>>(
num_matched_filters,
std::vector<float>(window_size_sub_blocks * sub_block_size_ /
kAccumulatedErrorSubSampleRate,
1.0f));
instantaneous_accumulated_error_ =
std::vector<float>(window_size_sub_blocks * sub_block_size_ /
kAccumulatedErrorSubSampleRate,
0.0f);
scratch_memory_ =
std::vector<float>(window_size_sub_blocks * sub_block_size_);
}
}
MatchedFilter::~MatchedFilter() = default;
@ -381,9 +641,12 @@ void MatchedFilter::Reset() {
std::fill(f.begin(), f.end(), 0.f);
}
for (auto& l : lag_estimates_) {
l = MatchedFilter::LagEstimate();
for (auto& e : accumulated_error_) {
std::fill(e.begin(), e.end(), 1.0f);
}
winner_lag_ = absl::nullopt;
reported_lag_estimate_ = absl::nullopt;
}
void MatchedFilter::Update(const DownsampledRenderBuffer& render_buffer,
@ -398,11 +661,25 @@ void MatchedFilter::Update(const DownsampledRenderBuffer& render_buffer,
const float x2_sum_threshold =
filters_[0].size() * excitation_limit_ * excitation_limit_;
// Compute anchor for the matched filter error.
float error_sum_anchor = 0.0f;
for (size_t k = 0; k < y.size(); ++k) {
error_sum_anchor += y[k] * y[k];
}
// Apply all matched filters.
float winner_error_sum = error_sum_anchor;
winner_lag_ = absl::nullopt;
reported_lag_estimate_ = absl::nullopt;
size_t alignment_shift = 0;
for (size_t n = 0; n < filters_.size(); ++n) {
absl::optional<size_t> previous_lag_estimate;
const int num_filters = static_cast<int>(filters_.size());
int winner_index = -1;
for (int n = 0; n < num_filters; ++n) {
float error_sum = 0.f;
bool filters_updated = false;
const bool compute_pre_echo =
detect_pre_echo_ && n == last_detected_best_lag_filter_;
size_t x_start_index =
(render_buffer.read + alignment_shift + sub_block_size_ - 1) %
@ -411,84 +688,78 @@ void MatchedFilter::Update(const DownsampledRenderBuffer& render_buffer,
switch (optimization_) {
#if defined(WEBRTC_ARCH_X86_FAMILY)
case Aec3Optimization::kSse2:
aec3::MatchedFilterCore_SSE2(x_start_index, x2_sum_threshold, smoothing,
render_buffer.buffer, y, filters_[n],
&filters_updated, &error_sum);
aec3::MatchedFilterCore_SSE2(
x_start_index, x2_sum_threshold, smoothing, render_buffer.buffer, y,
filters_[n], &filters_updated, &error_sum, compute_pre_echo,
instantaneous_accumulated_error_, scratch_memory_);
break;
case Aec3Optimization::kAvx2:
aec3::MatchedFilterCore_AVX2(x_start_index, x2_sum_threshold, smoothing,
render_buffer.buffer, y, filters_[n],
&filters_updated, &error_sum);
aec3::MatchedFilterCore_AVX2(
x_start_index, x2_sum_threshold, smoothing, render_buffer.buffer, y,
filters_[n], &filters_updated, &error_sum, compute_pre_echo,
instantaneous_accumulated_error_, scratch_memory_);
break;
#endif
#if defined(WEBRTC_HAS_NEON)
case Aec3Optimization::kNeon:
aec3::MatchedFilterCore_NEON(x_start_index, x2_sum_threshold, smoothing,
render_buffer.buffer, y, filters_[n],
&filters_updated, &error_sum);
aec3::MatchedFilterCore_NEON(
x_start_index, x2_sum_threshold, smoothing, render_buffer.buffer, y,
filters_[n], &filters_updated, &error_sum, compute_pre_echo,
instantaneous_accumulated_error_, scratch_memory_);
break;
#endif
default:
aec3::MatchedFilterCore(x_start_index, x2_sum_threshold, smoothing,
render_buffer.buffer, y, filters_[n],
&filters_updated, &error_sum);
}
// Compute anchor for the matched filter error.
float error_sum_anchor = 0.0f;
for (size_t k = 0; k < y.size(); ++k) {
error_sum_anchor += y[k] * y[k];
&filters_updated, &error_sum, compute_pre_echo,
instantaneous_accumulated_error_);
}
// Estimate the lag in the matched filter as the distance to the portion in
// the filter that contributes the most to the matched filter output. This
// is detected as the peak of the matched filter.
const size_t lag_estimate = aec3::MaxSquarePeakIndex(filters_[n]);
const bool reliable =
lag_estimate > 2 && lag_estimate < (filters_[n].size() - 10) &&
error_sum < matching_filter_threshold_ * error_sum_anchor;
// Update the lag estimates for the matched filter.
lag_estimates_[n] = LagEstimate(
error_sum_anchor - error_sum,
(lag_estimate > 2 && lag_estimate < (filters_[n].size() - 10) &&
error_sum < matching_filter_threshold_ * error_sum_anchor),
lag_estimate + alignment_shift, filters_updated);
RTC_DCHECK_GE(10, filters_.size());
switch (n) {
case 0:
data_dumper_->DumpRaw("aec3_correlator_0_h", filters_[0]);
break;
case 1:
data_dumper_->DumpRaw("aec3_correlator_1_h", filters_[1]);
break;
case 2:
data_dumper_->DumpRaw("aec3_correlator_2_h", filters_[2]);
break;
case 3:
data_dumper_->DumpRaw("aec3_correlator_3_h", filters_[3]);
break;
case 4:
data_dumper_->DumpRaw("aec3_correlator_4_h", filters_[4]);
break;
case 5:
data_dumper_->DumpRaw("aec3_correlator_5_h", filters_[5]);
break;
case 6:
data_dumper_->DumpRaw("aec3_correlator_6_h", filters_[6]);
break;
case 7:
data_dumper_->DumpRaw("aec3_correlator_7_h", filters_[7]);
break;
case 8:
data_dumper_->DumpRaw("aec3_correlator_8_h", filters_[8]);
break;
case 9:
data_dumper_->DumpRaw("aec3_correlator_9_h", filters_[9]);
break;
default:
RTC_DCHECK_NOTREACHED();
// Find the best estimate
const size_t lag = lag_estimate + alignment_shift;
if (filters_updated && reliable && error_sum < winner_error_sum) {
winner_error_sum = error_sum;
winner_index = n;
// In case that 2 matched filters return the same winner candidate
// (overlap region), the one with the smaller index is chosen in order
// to search for pre-echoes.
if (previous_lag_estimate && previous_lag_estimate == lag) {
winner_lag_ = previous_lag_estimate;
winner_index = n - 1;
} else {
winner_lag_ = lag;
}
}
previous_lag_estimate = lag;
alignment_shift += filter_intra_lag_shift_;
}
alignment_shift += filter_intra_lag_shift_;
if (winner_index != -1) {
RTC_DCHECK(winner_lag_.has_value());
reported_lag_estimate_ =
LagEstimate(winner_lag_.value(), /*pre_echo_lag=*/winner_lag_.value());
if (detect_pre_echo_ && last_detected_best_lag_filter_ == winner_index) {
if (error_sum_anchor > 30.0f * 30.0f * y.size()) {
UpdateAccumulatedError(instantaneous_accumulated_error_,
accumulated_error_[winner_index],
1.0f / error_sum_anchor);
}
reported_lag_estimate_->pre_echo_lag = ComputePreEchoLag(
accumulated_error_[winner_index], winner_lag_.value(),
winner_index * filter_intra_lag_shift_ /*alignment_shift_winner*/);
}
last_detected_best_lag_filter_ = winner_index;
}
if (ApmDataDumper::IsAvailable()) {
Dump();
}
}
@ -510,4 +781,27 @@ void MatchedFilter::LogFilterProperties(int sample_rate_hz,
}
}
void MatchedFilter::Dump() {
for (size_t n = 0; n < filters_.size(); ++n) {
const size_t lag_estimate = aec3::MaxSquarePeakIndex(filters_[n]);
std::string dumper_filter = "aec3_correlator_" + std::to_string(n) + "_h";
data_dumper_->DumpRaw(dumper_filter.c_str(), filters_[n]);
std::string dumper_lag = "aec3_correlator_lag_" + std::to_string(n);
data_dumper_->DumpRaw(dumper_lag.c_str(),
lag_estimate + n * filter_intra_lag_shift_);
if (detect_pre_echo_) {
std::string dumper_error =
"aec3_correlator_error_" + std::to_string(n) + "_h";
data_dumper_->DumpRaw(dumper_error.c_str(), accumulated_error_[n]);
size_t pre_echo_lag = ComputePreEchoLag(
accumulated_error_[n], lag_estimate + n * filter_intra_lag_shift_,
n * filter_intra_lag_shift_);
std::string dumper_pre_lag =
"aec3_correlator_pre_echo_lag_" + std::to_string(n);
data_dumper_->DumpRaw(dumper_pre_lag.c_str(), pre_echo_lag);
}
}
}
} // namespace webrtc

View File

@ -15,6 +15,7 @@
#include <vector>
#include "absl/types/optional.h"
#include "api/array_view.h"
#include "modules/audio_processing/aec3/aec3_common.h"
#include "rtc_base/system/arch.h"
@ -36,7 +37,10 @@ void MatchedFilterCore_NEON(size_t x_start_index,
rtc::ArrayView<const float> y,
rtc::ArrayView<float> h,
bool* filters_updated,
float* error_sum);
float* error_sum,
bool compute_accumulation_error,
rtc::ArrayView<float> accumulated_error,
rtc::ArrayView<float> scratch_memory);
#endif
@ -50,7 +54,10 @@ void MatchedFilterCore_SSE2(size_t x_start_index,
rtc::ArrayView<const float> y,
rtc::ArrayView<float> h,
bool* filters_updated,
float* error_sum);
float* error_sum,
bool compute_accumulated_error,
rtc::ArrayView<float> accumulated_error,
rtc::ArrayView<float> scratch_memory);
// Filter core for the matched filter that is optimized for AVX2.
void MatchedFilterCore_AVX2(size_t x_start_index,
@ -60,7 +67,10 @@ void MatchedFilterCore_AVX2(size_t x_start_index,
rtc::ArrayView<const float> y,
rtc::ArrayView<float> h,
bool* filters_updated,
float* error_sum);
float* error_sum,
bool compute_accumulated_error,
rtc::ArrayView<float> accumulated_error,
rtc::ArrayView<float> scratch_memory);
#endif
@ -72,7 +82,9 @@ void MatchedFilterCore(size_t x_start_index,
rtc::ArrayView<const float> y,
rtc::ArrayView<float> h,
bool* filters_updated,
float* error_sum);
float* error_sum,
bool compute_accumulation_error,
rtc::ArrayView<float> accumulated_error);
// Find largest peak of squared values in array.
size_t MaxSquarePeakIndex(rtc::ArrayView<const float> h);
@ -87,13 +99,10 @@ class MatchedFilter {
// shift.
struct LagEstimate {
LagEstimate() = default;
LagEstimate(float accuracy, bool reliable, size_t lag, bool updated)
: accuracy(accuracy), reliable(reliable), lag(lag), updated(updated) {}
float accuracy = 0.f;
bool reliable = false;
LagEstimate(size_t lag, size_t pre_echo_lag)
: lag(lag), pre_echo_lag(pre_echo_lag) {}
size_t lag = 0;
bool updated = false;
size_t pre_echo_lag = 0;
};
MatchedFilter(ApmDataDumper* data_dumper,
@ -105,7 +114,8 @@ class MatchedFilter {
float excitation_limit,
float smoothing_fast,
float smoothing_slow,
float matching_filter_threshold);
float matching_filter_threshold,
bool detect_pre_echo);
MatchedFilter() = delete;
MatchedFilter(const MatchedFilter&) = delete;
@ -122,8 +132,8 @@ class MatchedFilter {
void Reset();
// Returns the current lag estimates.
rtc::ArrayView<const MatchedFilter::LagEstimate> GetLagEstimates() const {
return lag_estimates_;
absl::optional<const MatchedFilter::LagEstimate> GetBestLagEstimate() const {
return reported_lag_estimate_;
}
// Returns the maximum filter lag.
@ -137,17 +147,25 @@ class MatchedFilter {
size_t downsampling_factor) const;
private:
void Dump();
ApmDataDumper* const data_dumper_;
const Aec3Optimization optimization_;
const size_t sub_block_size_;
const size_t filter_intra_lag_shift_;
std::vector<std::vector<float>> filters_;
std::vector<LagEstimate> lag_estimates_;
std::vector<std::vector<float>> accumulated_error_;
std::vector<float> instantaneous_accumulated_error_;
std::vector<float> scratch_memory_;
absl::optional<MatchedFilter::LagEstimate> reported_lag_estimate_;
absl::optional<size_t> winner_lag_;
int last_detected_best_lag_filter_ = -1;
std::vector<size_t> filters_offsets_;
const float excitation_limit_;
const float smoothing_fast_;
const float smoothing_slow_;
const float matching_filter_threshold_;
const bool detect_pre_echo_;
};
} // namespace webrtc

View File

@ -8,15 +8,134 @@
* be found in the AUTHORS file in the root of the source tree.
*/
#include "modules/audio_processing/aec3/matched_filter.h"
#include <immintrin.h>
#include "modules/audio_processing/aec3/matched_filter.h"
#include "rtc_base/checks.h"
namespace webrtc {
namespace aec3 {
// Let ha denote the horizontal of a, and hb the horizontal sum of b
// returns [ha, hb, ha, hb]
inline __m128 hsum_ab(__m256 a, __m256 b) {
__m256 s_256 = _mm256_hadd_ps(a, b);
const __m256i mask = _mm256_set_epi32(7, 6, 3, 2, 5, 4, 1, 0);
s_256 = _mm256_permutevar8x32_ps(s_256, mask);
__m128 s = _mm_hadd_ps(_mm256_extractf128_ps(s_256, 0),
_mm256_extractf128_ps(s_256, 1));
s = _mm_hadd_ps(s, s);
return s;
}
void MatchedFilterCore_AccumulatedError_AVX2(
size_t x_start_index,
float x2_sum_threshold,
float smoothing,
rtc::ArrayView<const float> x,
rtc::ArrayView<const float> y,
rtc::ArrayView<float> h,
bool* filters_updated,
float* error_sum,
rtc::ArrayView<float> accumulated_error,
rtc::ArrayView<float> scratch_memory) {
const int h_size = static_cast<int>(h.size());
const int x_size = static_cast<int>(x.size());
RTC_DCHECK_EQ(0, h_size % 16);
std::fill(accumulated_error.begin(), accumulated_error.end(), 0.0f);
// Process for all samples in the sub-block.
for (size_t i = 0; i < y.size(); ++i) {
// Apply the matched filter as filter * x, and compute x * x.
RTC_DCHECK_GT(x_size, x_start_index);
const int chunk1 =
std::min(h_size, static_cast<int>(x_size - x_start_index));
if (chunk1 != h_size) {
const int chunk2 = h_size - chunk1;
std::copy(x.begin() + x_start_index, x.end(), scratch_memory.begin());
std::copy(x.begin(), x.begin() + chunk2, scratch_memory.begin() + chunk1);
}
const float* x_p =
chunk1 != h_size ? scratch_memory.data() : &x[x_start_index];
const float* h_p = &h[0];
float* a_p = &accumulated_error[0];
__m256 s_inst_hadd_256;
__m256 s_inst_256;
__m256 s_inst_256_8;
__m256 x2_sum_256 = _mm256_set1_ps(0);
__m256 x2_sum_256_8 = _mm256_set1_ps(0);
__m128 e_128;
float x2_sum = 0.0f;
float s_acum = 0;
const int limit_by_16 = h_size >> 4;
for (int k = limit_by_16; k > 0; --k, h_p += 16, x_p += 16, a_p += 4) {
// Load the data into 256 bit vectors.
__m256 x_k = _mm256_loadu_ps(x_p);
__m256 h_k = _mm256_loadu_ps(h_p);
__m256 x_k_8 = _mm256_loadu_ps(x_p + 8);
__m256 h_k_8 = _mm256_loadu_ps(h_p + 8);
// Compute and accumulate x * x and h * x.
x2_sum_256 = _mm256_fmadd_ps(x_k, x_k, x2_sum_256);
x2_sum_256_8 = _mm256_fmadd_ps(x_k_8, x_k_8, x2_sum_256_8);
s_inst_256 = _mm256_mul_ps(h_k, x_k);
s_inst_256_8 = _mm256_mul_ps(h_k_8, x_k_8);
s_inst_hadd_256 = _mm256_hadd_ps(s_inst_256, s_inst_256_8);
s_inst_hadd_256 = _mm256_hadd_ps(s_inst_hadd_256, s_inst_hadd_256);
s_acum += s_inst_hadd_256[0];
e_128[0] = s_acum - y[i];
s_acum += s_inst_hadd_256[4];
e_128[1] = s_acum - y[i];
s_acum += s_inst_hadd_256[1];
e_128[2] = s_acum - y[i];
s_acum += s_inst_hadd_256[5];
e_128[3] = s_acum - y[i];
__m128 accumulated_error = _mm_load_ps(a_p);
accumulated_error = _mm_fmadd_ps(e_128, e_128, accumulated_error);
_mm_storeu_ps(a_p, accumulated_error);
}
// Sum components together.
x2_sum_256 = _mm256_add_ps(x2_sum_256, x2_sum_256_8);
__m128 x2_sum_128 = _mm_add_ps(_mm256_extractf128_ps(x2_sum_256, 0),
_mm256_extractf128_ps(x2_sum_256, 1));
// Combine the accumulated vector and scalar values.
float* v = reinterpret_cast<float*>(&x2_sum_128);
x2_sum += v[0] + v[1] + v[2] + v[3];
// Compute the matched filter error.
float e = y[i] - s_acum;
const bool saturation = y[i] >= 32000.f || y[i] <= -32000.f;
(*error_sum) += e * e;
// Update the matched filter estimate in an NLMS manner.
if (x2_sum > x2_sum_threshold && !saturation) {
RTC_DCHECK_LT(0.f, x2_sum);
const float alpha = smoothing * e / x2_sum;
const __m256 alpha_256 = _mm256_set1_ps(alpha);
// filter = filter + smoothing * (y - filter * x) * x / x * x.
float* h_p = &h[0];
const float* x_p =
chunk1 != h_size ? scratch_memory.data() : &x[x_start_index];
// Perform 256 bit vector operations.
const int limit_by_8 = h_size >> 3;
for (int k = limit_by_8; k > 0; --k, h_p += 8, x_p += 8) {
// Load the data into 256 bit vectors.
__m256 h_k = _mm256_loadu_ps(h_p);
__m256 x_k = _mm256_loadu_ps(x_p);
// Compute h = h + alpha * x.
h_k = _mm256_fmadd_ps(x_k, alpha_256, h_k);
// Store the result.
_mm256_storeu_ps(h_p, h_k);
}
*filters_updated = true;
}
x_start_index = x_start_index > 0 ? x_start_index - 1 : x_size - 1;
}
}
void MatchedFilterCore_AVX2(size_t x_start_index,
float x2_sum_threshold,
float smoothing,
@ -24,7 +143,15 @@ void MatchedFilterCore_AVX2(size_t x_start_index,
rtc::ArrayView<const float> y,
rtc::ArrayView<float> h,
bool* filters_updated,
float* error_sum) {
float* error_sum,
bool compute_accumulated_error,
rtc::ArrayView<float> accumulated_error,
rtc::ArrayView<float> scratch_memory) {
if (compute_accumulated_error) {
return MatchedFilterCore_AccumulatedError_AVX2(
x_start_index, x2_sum_threshold, smoothing, x, y, h, filters_updated,
error_sum, accumulated_error, scratch_memory);
}
const int h_size = static_cast<int>(h.size());
const int x_size = static_cast<int>(x.size());
RTC_DCHECK_EQ(0, h_size % 8);
@ -81,15 +208,9 @@ void MatchedFilterCore_AVX2(size_t x_start_index,
// Sum components together.
x2_sum_256 = _mm256_add_ps(x2_sum_256, x2_sum_256_8);
s_256 = _mm256_add_ps(s_256, s_256_8);
__m128 x2_sum_128 = _mm_add_ps(_mm256_extractf128_ps(x2_sum_256, 0),
_mm256_extractf128_ps(x2_sum_256, 1));
__m128 s_128 = _mm_add_ps(_mm256_extractf128_ps(s_256, 0),
_mm256_extractf128_ps(s_256, 1));
// Combine the accumulated vector and scalar values.
float* v = reinterpret_cast<float*>(&x2_sum_128);
x2_sum += v[0] + v[1] + v[2] + v[3];
v = reinterpret_cast<float*>(&s_128);
s += v[0] + v[1] + v[2] + v[3];
__m128 sum = hsum_ab(x2_sum_256, s_256);
x2_sum += sum[0];
s += sum[1];
// Compute the matched filter error.
float e = y[i] - s;

View File

@ -14,84 +14,148 @@
#include "modules/audio_processing/logging/apm_data_dumper.h"
#include "rtc_base/checks.h"
#include "rtc_base/numerics/safe_minmax.h"
namespace webrtc {
namespace {
int GetDownSamplingBlockSizeLog2(int down_sampling_factor) {
int down_sampling_factor_log2 = 0;
down_sampling_factor >>= 1;
while (down_sampling_factor > 0) {
down_sampling_factor_log2++;
down_sampling_factor >>= 1;
}
return static_cast<int>(kBlockSizeLog2) > down_sampling_factor_log2
? static_cast<int>(kBlockSizeLog2) - down_sampling_factor_log2
: 0;
}
} // namespace
MatchedFilterLagAggregator::MatchedFilterLagAggregator(
ApmDataDumper* data_dumper,
size_t max_filter_lag,
const EchoCanceller3Config::Delay::DelaySelectionThresholds& thresholds)
const EchoCanceller3Config::Delay& delay_config)
: data_dumper_(data_dumper),
histogram_(max_filter_lag + 1, 0),
thresholds_(thresholds) {
thresholds_(delay_config.delay_selection_thresholds),
headroom_(static_cast<int>(delay_config.delay_headroom_samples /
delay_config.down_sampling_factor)),
highest_peak_aggregator_(max_filter_lag) {
if (delay_config.detect_pre_echo) {
pre_echo_lag_aggregator_ = std::make_unique<PreEchoLagAggregator>(
max_filter_lag, delay_config.down_sampling_factor);
}
RTC_DCHECK(data_dumper);
RTC_DCHECK_LE(thresholds_.initial, thresholds_.converged);
histogram_data_.fill(0);
}
MatchedFilterLagAggregator::~MatchedFilterLagAggregator() = default;
void MatchedFilterLagAggregator::Reset(bool hard_reset) {
std::fill(histogram_.begin(), histogram_.end(), 0);
histogram_data_.fill(0);
histogram_data_index_ = 0;
highest_peak_aggregator_.Reset();
if (pre_echo_lag_aggregator_ != nullptr) {
pre_echo_lag_aggregator_->Reset();
}
if (hard_reset) {
significant_candidate_found_ = false;
}
}
absl::optional<DelayEstimate> MatchedFilterLagAggregator::Aggregate(
rtc::ArrayView<const MatchedFilter::LagEstimate> lag_estimates) {
// Choose the strongest lag estimate as the best one.
float best_accuracy = 0.f;
int best_lag_estimate_index = -1;
for (size_t k = 0; k < lag_estimates.size(); ++k) {
if (lag_estimates[k].updated && lag_estimates[k].reliable) {
if (lag_estimates[k].accuracy > best_accuracy) {
best_accuracy = lag_estimates[k].accuracy;
best_lag_estimate_index = static_cast<int>(k);
}
}
const absl::optional<const MatchedFilter::LagEstimate>& lag_estimate) {
if (lag_estimate && pre_echo_lag_aggregator_) {
pre_echo_lag_aggregator_->Dump(data_dumper_);
pre_echo_lag_aggregator_->Aggregate(
std::max(0, static_cast<int>(lag_estimate->pre_echo_lag) - headroom_));
}
// TODO(peah): Remove this logging once all development is done.
data_dumper_->DumpRaw("aec3_echo_path_delay_estimator_best_index",
best_lag_estimate_index);
data_dumper_->DumpRaw("aec3_echo_path_delay_estimator_histogram", histogram_);
if (best_lag_estimate_index != -1) {
RTC_DCHECK_GT(histogram_.size(), histogram_data_[histogram_data_index_]);
RTC_DCHECK_LE(0, histogram_data_[histogram_data_index_]);
--histogram_[histogram_data_[histogram_data_index_]];
histogram_data_[histogram_data_index_] =
lag_estimates[best_lag_estimate_index].lag;
RTC_DCHECK_GT(histogram_.size(), histogram_data_[histogram_data_index_]);
RTC_DCHECK_LE(0, histogram_data_[histogram_data_index_]);
++histogram_[histogram_data_[histogram_data_index_]];
histogram_data_index_ =
(histogram_data_index_ + 1) % histogram_data_.size();
const int candidate =
std::distance(histogram_.begin(),
std::max_element(histogram_.begin(), histogram_.end()));
significant_candidate_found_ =
significant_candidate_found_ ||
histogram_[candidate] > thresholds_.converged;
if (histogram_[candidate] > thresholds_.converged ||
(histogram_[candidate] > thresholds_.initial &&
if (lag_estimate) {
highest_peak_aggregator_.Aggregate(
std::max(0, static_cast<int>(lag_estimate->lag) - headroom_));
rtc::ArrayView<const int> histogram = highest_peak_aggregator_.histogram();
int candidate = highest_peak_aggregator_.candidate();
significant_candidate_found_ = significant_candidate_found_ ||
histogram[candidate] > thresholds_.converged;
if (histogram[candidate] > thresholds_.converged ||
(histogram[candidate] > thresholds_.initial &&
!significant_candidate_found_)) {
DelayEstimate::Quality quality = significant_candidate_found_
? DelayEstimate::Quality::kRefined
: DelayEstimate::Quality::kCoarse;
return DelayEstimate(quality, candidate);
int reported_delay = pre_echo_lag_aggregator_ != nullptr
? pre_echo_lag_aggregator_->pre_echo_candidate()
: candidate;
return DelayEstimate(quality, reported_delay);
}
}
return absl::nullopt;
}
MatchedFilterLagAggregator::HighestPeakAggregator::HighestPeakAggregator(
size_t max_filter_lag)
: histogram_(max_filter_lag + 1, 0) {
histogram_data_.fill(0);
}
void MatchedFilterLagAggregator::HighestPeakAggregator::Reset() {
std::fill(histogram_.begin(), histogram_.end(), 0);
histogram_data_.fill(0);
histogram_data_index_ = 0;
}
void MatchedFilterLagAggregator::HighestPeakAggregator::Aggregate(int lag) {
RTC_DCHECK_GT(histogram_.size(), histogram_data_[histogram_data_index_]);
RTC_DCHECK_LE(0, histogram_data_[histogram_data_index_]);
--histogram_[histogram_data_[histogram_data_index_]];
histogram_data_[histogram_data_index_] = lag;
RTC_DCHECK_GT(histogram_.size(), histogram_data_[histogram_data_index_]);
RTC_DCHECK_LE(0, histogram_data_[histogram_data_index_]);
++histogram_[histogram_data_[histogram_data_index_]];
histogram_data_index_ = (histogram_data_index_ + 1) % histogram_data_.size();
candidate_ =
std::distance(histogram_.begin(),
std::max_element(histogram_.begin(), histogram_.end()));
}
MatchedFilterLagAggregator::PreEchoLagAggregator::PreEchoLagAggregator(
size_t max_filter_lag,
size_t down_sampling_factor)
: block_size_log2_(GetDownSamplingBlockSizeLog2(down_sampling_factor)),
histogram_(
((max_filter_lag + 1) * down_sampling_factor) >> kBlockSizeLog2,
0) {
Reset();
}
void MatchedFilterLagAggregator::PreEchoLagAggregator::Reset() {
std::fill(histogram_.begin(), histogram_.end(), 0);
histogram_data_.fill(0);
histogram_data_index_ = 0;
pre_echo_candidate_ = 0;
}
void MatchedFilterLagAggregator::PreEchoLagAggregator::Aggregate(
int pre_echo_lag) {
int pre_echo_block_size = pre_echo_lag >> block_size_log2_;
RTC_DCHECK(pre_echo_block_size >= 0 &&
pre_echo_block_size < static_cast<int>(histogram_.size()));
pre_echo_block_size =
rtc::SafeClamp(pre_echo_block_size, 0, histogram_.size() - 1);
if (histogram_[histogram_data_[histogram_data_index_]] > 0) {
--histogram_[histogram_data_[histogram_data_index_]];
}
histogram_data_[histogram_data_index_] = pre_echo_block_size;
++histogram_[histogram_data_[histogram_data_index_]];
histogram_data_index_ = (histogram_data_index_ + 1) % histogram_data_.size();
int pre_echo_candidate_block_size =
std::distance(histogram_.begin(),
std::max_element(histogram_.begin(), histogram_.end()));
pre_echo_candidate_ = (pre_echo_candidate_block_size << block_size_log2_);
}
void MatchedFilterLagAggregator::PreEchoLagAggregator::Dump(
ApmDataDumper* const data_dumper) {
data_dumper->DumpRaw("aec3_pre_echo_delay_candidate", pre_echo_candidate_);
}
} // namespace webrtc

View File

@ -26,10 +26,9 @@ class ApmDataDumper;
// reliable combined lag estimate.
class MatchedFilterLagAggregator {
public:
MatchedFilterLagAggregator(
ApmDataDumper* data_dumper,
MatchedFilterLagAggregator(ApmDataDumper* data_dumper,
size_t max_filter_lag,
const EchoCanceller3Config::Delay::DelaySelectionThresholds& thresholds);
const EchoCanceller3Config::Delay& delay_config);
MatchedFilterLagAggregator() = delete;
MatchedFilterLagAggregator(const MatchedFilterLagAggregator&) = delete;
@ -43,18 +42,55 @@ class MatchedFilterLagAggregator {
// Aggregates the provided lag estimates.
absl::optional<DelayEstimate> Aggregate(
rtc::ArrayView<const MatchedFilter::LagEstimate> lag_estimates);
const absl::optional<const MatchedFilter::LagEstimate>& lag_estimate);
// Returns whether a reliable delay estimate has been found.
bool ReliableDelayFound() const { return significant_candidate_found_; }
// Returns the delay candidate that is computed by looking at the highest peak
// on the matched filters.
int GetDelayAtHighestPeak() const {
return highest_peak_aggregator_.candidate();
}
private:
class PreEchoLagAggregator {
public:
PreEchoLagAggregator(size_t max_filter_lag, size_t down_sampling_factor);
void Reset();
void Aggregate(int pre_echo_lag);
int pre_echo_candidate() const { return pre_echo_candidate_; }
void Dump(ApmDataDumper* const data_dumper);
private:
const int block_size_log2_;
std::array<int, 250> histogram_data_;
std::vector<int> histogram_;
int histogram_data_index_ = 0;
int pre_echo_candidate_ = 0;
};
class HighestPeakAggregator {
public:
explicit HighestPeakAggregator(size_t max_filter_lag);
void Reset();
void Aggregate(int lag);
int candidate() const { return candidate_; }
rtc::ArrayView<const int> histogram() const { return histogram_; }
private:
ApmDataDumper* const data_dumper_;
std::vector<int> histogram_;
std::array<int, 250> histogram_data_;
int histogram_data_index_ = 0;
int candidate_ = -1;
};
ApmDataDumper* const data_dumper_;
bool significant_candidate_found_ = false;
const EchoCanceller3Config::Delay::DelaySelectionThresholds thresholds_;
const int headroom_;
HighestPeakAggregator highest_peak_aggregator_;
std::unique_ptr<PreEchoLagAggregator> pre_echo_lag_aggregator_;
};
} // namespace webrtc

View File

@ -27,69 +27,31 @@ constexpr size_t kNumLagsBeforeDetection = 26;
} // namespace
// Verifies that the most accurate lag estimate is chosen.
TEST(MatchedFilterLagAggregator, MostAccurateLagChosen) {
constexpr size_t kLag1 = 5;
constexpr size_t kLag2 = 10;
ApmDataDumper data_dumper(0);
EchoCanceller3Config config;
std::vector<MatchedFilter::LagEstimate> lag_estimates(2);
MatchedFilterLagAggregator aggregator(
&data_dumper, std::max(kLag1, kLag2),
config.delay.delay_selection_thresholds);
lag_estimates[0] = MatchedFilter::LagEstimate(1.f, true, kLag1, true);
lag_estimates[1] = MatchedFilter::LagEstimate(0.5f, true, kLag2, true);
for (size_t k = 0; k < kNumLagsBeforeDetection; ++k) {
aggregator.Aggregate(lag_estimates);
}
absl::optional<DelayEstimate> aggregated_lag =
aggregator.Aggregate(lag_estimates);
EXPECT_TRUE(aggregated_lag);
EXPECT_EQ(kLag1, aggregated_lag->delay);
lag_estimates[0] = MatchedFilter::LagEstimate(0.5f, true, kLag1, true);
lag_estimates[1] = MatchedFilter::LagEstimate(1.f, true, kLag2, true);
for (size_t k = 0; k < kNumLagsBeforeDetection; ++k) {
aggregated_lag = aggregator.Aggregate(lag_estimates);
EXPECT_TRUE(aggregated_lag);
EXPECT_EQ(kLag1, aggregated_lag->delay);
}
aggregated_lag = aggregator.Aggregate(lag_estimates);
aggregated_lag = aggregator.Aggregate(lag_estimates);
EXPECT_TRUE(aggregated_lag);
EXPECT_EQ(kLag2, aggregated_lag->delay);
}
// Verifies that varying lag estimates causes lag estimates to not be deemed
// reliable.
TEST(MatchedFilterLagAggregator,
LagEstimateInvarianceRequiredForAggregatedLag) {
ApmDataDumper data_dumper(0);
EchoCanceller3Config config;
std::vector<MatchedFilter::LagEstimate> lag_estimates(1);
MatchedFilterLagAggregator aggregator(
&data_dumper, 100, config.delay.delay_selection_thresholds);
MatchedFilterLagAggregator aggregator(&data_dumper, /*max_filter_lag=*/100,
config.delay);
absl::optional<DelayEstimate> aggregated_lag;
for (size_t k = 0; k < kNumLagsBeforeDetection; ++k) {
lag_estimates[0] = MatchedFilter::LagEstimate(1.f, true, 10, true);
aggregated_lag = aggregator.Aggregate(lag_estimates);
aggregated_lag = aggregator.Aggregate(
MatchedFilter::LagEstimate(/*lag=*/10, /*pre_echo_lag=*/10));
}
EXPECT_TRUE(aggregated_lag);
for (size_t k = 0; k < kNumLagsBeforeDetection * 100; ++k) {
lag_estimates[0] = MatchedFilter::LagEstimate(1.f, true, k % 100, true);
aggregated_lag = aggregator.Aggregate(lag_estimates);
aggregated_lag = aggregator.Aggregate(
MatchedFilter::LagEstimate(/*lag=*/k % 100, /*pre_echo_lag=*/k % 100));
}
EXPECT_FALSE(aggregated_lag);
for (size_t k = 0; k < kNumLagsBeforeDetection * 100; ++k) {
lag_estimates[0] = MatchedFilter::LagEstimate(1.f, true, k % 100, true);
aggregated_lag = aggregator.Aggregate(lag_estimates);
aggregated_lag = aggregator.Aggregate(
MatchedFilter::LagEstimate(/*lag=*/k % 100, /*pre_echo_lag=*/k % 100));
EXPECT_FALSE(aggregated_lag);
}
}
@ -101,13 +63,11 @@ TEST(MatchedFilterLagAggregator,
constexpr size_t kLag = 5;
ApmDataDumper data_dumper(0);
EchoCanceller3Config config;
std::vector<MatchedFilter::LagEstimate> lag_estimates(1);
MatchedFilterLagAggregator aggregator(
&data_dumper, kLag, config.delay.delay_selection_thresholds);
MatchedFilterLagAggregator aggregator(&data_dumper, /*max_filter_lag=*/kLag,
config.delay);
for (size_t k = 0; k < kNumLagsBeforeDetection * 10; ++k) {
lag_estimates[0] = MatchedFilter::LagEstimate(1.f, true, kLag, false);
absl::optional<DelayEstimate> aggregated_lag =
aggregator.Aggregate(lag_estimates);
absl::optional<DelayEstimate> aggregated_lag = aggregator.Aggregate(
MatchedFilter::LagEstimate(/*lag=*/kLag, /*pre_echo_lag=*/kLag));
EXPECT_FALSE(aggregated_lag);
EXPECT_EQ(kLag, aggregated_lag->delay);
}
@ -122,20 +82,19 @@ TEST(MatchedFilterLagAggregator, DISABLED_PersistentAggregatedLag) {
ApmDataDumper data_dumper(0);
EchoCanceller3Config config;
std::vector<MatchedFilter::LagEstimate> lag_estimates(1);
MatchedFilterLagAggregator aggregator(
&data_dumper, std::max(kLag1, kLag2),
config.delay.delay_selection_thresholds);
MatchedFilterLagAggregator aggregator(&data_dumper, std::max(kLag1, kLag2),
config.delay);
absl::optional<DelayEstimate> aggregated_lag;
for (size_t k = 0; k < kNumLagsBeforeDetection; ++k) {
lag_estimates[0] = MatchedFilter::LagEstimate(1.f, true, kLag1, true);
aggregated_lag = aggregator.Aggregate(lag_estimates);
aggregated_lag = aggregator.Aggregate(
MatchedFilter::LagEstimate(/*lag=*/kLag1, /*pre_echo_lag=*/kLag1));
}
EXPECT_TRUE(aggregated_lag);
EXPECT_EQ(kLag1, aggregated_lag->delay);
for (size_t k = 0; k < kNumLagsBeforeDetection * 40; ++k) {
lag_estimates[0] = MatchedFilter::LagEstimate(1.f, false, kLag2, true);
aggregated_lag = aggregator.Aggregate(lag_estimates);
aggregated_lag = aggregator.Aggregate(
MatchedFilter::LagEstimate(/*lag=*/kLag2, /*pre_echo_lag=*/kLag2));
EXPECT_TRUE(aggregated_lag);
EXPECT_EQ(kLag1, aggregated_lag->delay);
}
@ -146,9 +105,7 @@ TEST(MatchedFilterLagAggregator, DISABLED_PersistentAggregatedLag) {
// Verifies the check for non-null data dumper.
TEST(MatchedFilterLagAggregatorDeathTest, NullDataDumper) {
EchoCanceller3Config config;
EXPECT_DEATH(MatchedFilterLagAggregator(
nullptr, 10, config.delay.delay_selection_thresholds),
"");
EXPECT_DEATH(MatchedFilterLagAggregator(nullptr, 10, config.delay), "");
}
#endif

View File

@ -47,12 +47,15 @@ constexpr size_t kAlignmentShiftSubBlocks = kWindowSizeSubBlocks * 3 / 4;
} // namespace
class MatchedFilterTest : public ::testing::TestWithParam<bool> {};
#if defined(WEBRTC_HAS_NEON)
// Verifies that the optimized methods for NEON are similar to their reference
// counterparts.
TEST(MatchedFilter, TestNeonOptimizations) {
TEST_P(MatchedFilterTest, TestNeonOptimizations) {
Random random_generator(42U);
constexpr float kSmoothing = 0.7f;
const bool kComputeAccumulatederror = GetParam();
for (auto down_sampling_factor : kDownSamplingFactors) {
const size_t sub_block_size = kBlockSize / down_sampling_factor;
@ -61,6 +64,10 @@ TEST(MatchedFilter, TestNeonOptimizations) {
std::vector<float> y(sub_block_size);
std::vector<float> h_NEON(512);
std::vector<float> h(512);
std::vector<float> accumulated_error(512);
std::vector<float> accumulated_error_NEON(512);
std::vector<float> scratch_memory(512);
int x_index = 0;
for (int k = 0; k < 1000; ++k) {
RandomizeSampleVector(&random_generator, y);
@ -71,10 +78,13 @@ TEST(MatchedFilter, TestNeonOptimizations) {
float error_sum_NEON = 0.f;
MatchedFilterCore_NEON(x_index, h.size() * 150.f * 150.f, kSmoothing, x,
y, h_NEON, &filters_updated_NEON, &error_sum_NEON);
y, h_NEON, &filters_updated_NEON, &error_sum_NEON,
kComputeAccumulatederror, accumulated_error_NEON,
scratch_memory);
MatchedFilterCore(x_index, h.size() * 150.f * 150.f, kSmoothing, x, y, h,
&filters_updated, &error_sum);
&filters_updated, &error_sum, kComputeAccumulatederror,
accumulated_error);
EXPECT_EQ(filters_updated, filters_updated_NEON);
EXPECT_NEAR(error_sum, error_sum_NEON, error_sum / 100000.f);
@ -83,6 +93,17 @@ TEST(MatchedFilter, TestNeonOptimizations) {
EXPECT_NEAR(h[j], h_NEON[j], 0.00001f);
}
if (kComputeAccumulatederror) {
for (size_t j = 0; j < accumulated_error.size(); ++j) {
float difference =
std::abs(accumulated_error[j] - accumulated_error_NEON[j]);
float relative_difference = accumulated_error[j] > 0
? difference / accumulated_error[j]
: difference;
EXPECT_NEAR(relative_difference, 0.0f, 0.02f);
}
}
x_index = (x_index + sub_block_size) % x.size();
}
}
@ -92,7 +113,8 @@ TEST(MatchedFilter, TestNeonOptimizations) {
#if defined(WEBRTC_ARCH_X86_FAMILY)
// Verifies that the optimized methods for SSE2 are bitexact to their reference
// counterparts.
TEST(MatchedFilter, TestSse2Optimizations) {
TEST_P(MatchedFilterTest, TestSse2Optimizations) {
const bool kComputeAccumulatederror = GetParam();
bool use_sse2 = (GetCPUInfo(kSSE2) != 0);
if (use_sse2) {
Random random_generator(42U);
@ -104,6 +126,9 @@ TEST(MatchedFilter, TestSse2Optimizations) {
std::vector<float> y(sub_block_size);
std::vector<float> h_SSE2(512);
std::vector<float> h(512);
std::vector<float> accumulated_error(512 / 4);
std::vector<float> accumulated_error_SSE2(512 / 4);
std::vector<float> scratch_memory(512);
int x_index = 0;
for (int k = 0; k < 1000; ++k) {
RandomizeSampleVector(&random_generator, y);
@ -115,10 +140,12 @@ TEST(MatchedFilter, TestSse2Optimizations) {
MatchedFilterCore_SSE2(x_index, h.size() * 150.f * 150.f, kSmoothing, x,
y, h_SSE2, &filters_updated_SSE2,
&error_sum_SSE2);
&error_sum_SSE2, kComputeAccumulatederror,
accumulated_error_SSE2, scratch_memory);
MatchedFilterCore(x_index, h.size() * 150.f * 150.f, kSmoothing, x, y,
h, &filters_updated, &error_sum);
h, &filters_updated, &error_sum,
kComputeAccumulatederror, accumulated_error);
EXPECT_EQ(filters_updated, filters_updated_SSE2);
EXPECT_NEAR(error_sum, error_sum_SSE2, error_sum / 100000.f);
@ -127,14 +154,24 @@ TEST(MatchedFilter, TestSse2Optimizations) {
EXPECT_NEAR(h[j], h_SSE2[j], 0.00001f);
}
for (size_t j = 0; j < accumulated_error.size(); ++j) {
float difference =
std::abs(accumulated_error[j] - accumulated_error_SSE2[j]);
float relative_difference = accumulated_error[j] > 0
? difference / accumulated_error[j]
: difference;
EXPECT_NEAR(relative_difference, 0.0f, 0.00001f);
}
x_index = (x_index + sub_block_size) % x.size();
}
}
}
}
TEST(MatchedFilter, TestAvx2Optimizations) {
TEST_P(MatchedFilterTest, TestAvx2Optimizations) {
bool use_avx2 = (GetCPUInfo(kAVX2) != 0);
const bool kComputeAccumulatederror = GetParam();
if (use_avx2) {
Random random_generator(42U);
constexpr float kSmoothing = 0.7f;
@ -145,29 +182,36 @@ TEST(MatchedFilter, TestAvx2Optimizations) {
std::vector<float> y(sub_block_size);
std::vector<float> h_AVX2(512);
std::vector<float> h(512);
std::vector<float> accumulated_error(512 / 4);
std::vector<float> accumulated_error_AVX2(512 / 4);
std::vector<float> scratch_memory(512);
int x_index = 0;
for (int k = 0; k < 1000; ++k) {
RandomizeSampleVector(&random_generator, y);
bool filters_updated = false;
float error_sum = 0.f;
bool filters_updated_AVX2 = false;
float error_sum_AVX2 = 0.f;
MatchedFilterCore_AVX2(x_index, h.size() * 150.f * 150.f, kSmoothing, x,
y, h_AVX2, &filters_updated_AVX2,
&error_sum_AVX2);
&error_sum_AVX2, kComputeAccumulatederror,
accumulated_error_AVX2, scratch_memory);
MatchedFilterCore(x_index, h.size() * 150.f * 150.f, kSmoothing, x, y,
h, &filters_updated, &error_sum);
h, &filters_updated, &error_sum,
kComputeAccumulatederror, accumulated_error);
EXPECT_EQ(filters_updated, filters_updated_AVX2);
EXPECT_NEAR(error_sum, error_sum_AVX2, error_sum / 100000.f);
for (size_t j = 0; j < h.size(); ++j) {
EXPECT_NEAR(h[j], h_AVX2[j], 0.00001f);
}
for (size_t j = 0; j < accumulated_error.size(); j += 4) {
float difference =
std::abs(accumulated_error[j] - accumulated_error_AVX2[j]);
float relative_difference = accumulated_error[j] > 0
? difference / accumulated_error[j]
: difference;
EXPECT_NEAR(relative_difference, 0.0f, 0.00001f);
}
x_index = (x_index + sub_block_size) % x.size();
}
}
@ -199,9 +243,9 @@ TEST(MatchedFilter, MaxSquarePeakIndex) {
}
// Verifies that the matched filter produces proper lag estimates for
// artificially
// delayed signals.
TEST(MatchedFilter, LagEstimation) {
// artificially delayed signals.
TEST_P(MatchedFilterTest, LagEstimation) {
const bool kDetectPreEcho = GetParam();
Random random_generator(42U);
constexpr size_t kNumChannels = 1;
constexpr int kSampleRateHz = 48000;
@ -222,12 +266,12 @@ TEST(MatchedFilter, LagEstimation) {
Decimator capture_decimator(down_sampling_factor);
DelayBuffer<float> signal_delay_buffer(down_sampling_factor *
delay_samples);
MatchedFilter filter(&data_dumper, DetectOptimization(), sub_block_size,
kWindowSizeSubBlocks, kNumMatchedFilters,
kAlignmentShiftSubBlocks, 150,
config.delay.delay_estimate_smoothing,
MatchedFilter filter(
&data_dumper, DetectOptimization(), sub_block_size,
kWindowSizeSubBlocks, kNumMatchedFilters, kAlignmentShiftSubBlocks,
150, config.delay.delay_estimate_smoothing,
config.delay.delay_estimate_smoothing_delay_found,
config.delay.delay_candidate_detection_threshold);
config.delay.delay_candidate_detection_threshold, kDetectPreEcho);
std::unique_ptr<RenderDelayBuffer> render_delay_buffer(
RenderDelayBuffer::Create(config, kSampleRateHz, kNumChannels));
@ -254,62 +298,97 @@ TEST(MatchedFilter, LagEstimation) {
downsampled_capture_data.data(), sub_block_size);
capture_decimator.Decimate(capture[0], downsampled_capture);
filter.Update(render_delay_buffer->GetDownsampledRenderBuffer(),
downsampled_capture, false);
downsampled_capture, /*use_slow_smoothing=*/false);
}
// Obtain the lag estimates.
auto lag_estimates = filter.GetLagEstimates();
// Find which lag estimate should be the most accurate.
absl::optional<size_t> expected_most_accurate_lag_estimate;
size_t alignment_shift_sub_blocks = 0;
for (size_t k = 0; k < config.delay.num_filters; ++k) {
if ((alignment_shift_sub_blocks + 3 * kWindowSizeSubBlocks / 4) *
sub_block_size >
delay_samples) {
expected_most_accurate_lag_estimate = k > 0 ? k - 1 : 0;
break;
}
alignment_shift_sub_blocks += kAlignmentShiftSubBlocks;
}
ASSERT_TRUE(expected_most_accurate_lag_estimate);
// Verify that the expected most accurate lag estimate is the most
// accurate estimate.
for (size_t k = 0; k < kNumMatchedFilters; ++k) {
if (k != *expected_most_accurate_lag_estimate &&
k != (*expected_most_accurate_lag_estimate + 1)) {
EXPECT_TRUE(
lag_estimates[*expected_most_accurate_lag_estimate].accuracy >
lag_estimates[k].accuracy ||
!lag_estimates[k].reliable ||
!lag_estimates[*expected_most_accurate_lag_estimate].reliable);
}
}
// Verify that all lag estimates are updated as expected for signals
// containing strong noise.
for (auto& le : lag_estimates) {
EXPECT_TRUE(le.updated);
}
// Verify that the expected most accurate lag estimate is reliable.
EXPECT_TRUE(
lag_estimates[*expected_most_accurate_lag_estimate].reliable ||
lag_estimates[std::min(*expected_most_accurate_lag_estimate + 1,
lag_estimates.size() - 1)]
.reliable);
auto lag_estimate = filter.GetBestLagEstimate();
EXPECT_TRUE(lag_estimate.has_value());
// Verify that the expected most accurate lag estimate is correct.
if (lag_estimates[*expected_most_accurate_lag_estimate].reliable) {
EXPECT_TRUE(delay_samples ==
lag_estimates[*expected_most_accurate_lag_estimate].lag);
if (lag_estimate.has_value()) {
EXPECT_EQ(delay_samples, lag_estimate->lag);
EXPECT_EQ(delay_samples, lag_estimate->pre_echo_lag);
}
}
}
}
// Test the pre echo estimation.
TEST_P(MatchedFilterTest, PreEchoEstimation) {
const bool kDetectPreEcho = GetParam();
Random random_generator(42U);
constexpr size_t kNumChannels = 1;
constexpr int kSampleRateHz = 48000;
constexpr size_t kNumBands = NumBandsForRate(kSampleRateHz);
for (auto down_sampling_factor : kDownSamplingFactors) {
const size_t sub_block_size = kBlockSize / down_sampling_factor;
Block render(kNumBands, kNumChannels);
std::vector<std::vector<float>> capture(
1, std::vector<float>(kBlockSize, 0.f));
std::vector<float> capture_with_pre_echo(kBlockSize, 0.f);
ApmDataDumper data_dumper(0);
// data_dumper.SetActivated(true);
size_t pre_echo_delay_samples = 20e-3 * 16000 / down_sampling_factor;
size_t echo_delay_samples = 50e-3 * 16000 / down_sampling_factor;
EchoCanceller3Config config;
config.delay.down_sampling_factor = down_sampling_factor;
config.delay.num_filters = kNumMatchedFilters;
Decimator capture_decimator(down_sampling_factor);
DelayBuffer<float> signal_echo_delay_buffer(down_sampling_factor *
echo_delay_samples);
DelayBuffer<float> signal_pre_echo_delay_buffer(down_sampling_factor *
pre_echo_delay_samples);
MatchedFilter filter(
&data_dumper, DetectOptimization(), sub_block_size,
kWindowSizeSubBlocks, kNumMatchedFilters, kAlignmentShiftSubBlocks, 150,
config.delay.delay_estimate_smoothing,
config.delay.delay_estimate_smoothing_delay_found,
config.delay.delay_candidate_detection_threshold, kDetectPreEcho);
std::unique_ptr<RenderDelayBuffer> render_delay_buffer(
RenderDelayBuffer::Create(config, kSampleRateHz, kNumChannels));
// Analyze the correlation between render and capture.
for (size_t k = 0; k < (600 + echo_delay_samples / sub_block_size); ++k) {
for (size_t band = 0; band < kNumBands; ++band) {
for (size_t channel = 0; channel < kNumChannels; ++channel) {
RandomizeSampleVector(&random_generator, render.View(band, channel));
}
}
signal_echo_delay_buffer.Delay(render.View(0, 0), capture[0]);
signal_pre_echo_delay_buffer.Delay(render.View(0, 0),
capture_with_pre_echo);
for (size_t k = 0; k < capture[0].size(); ++k) {
constexpr float gain_pre_echo = 0.8f;
capture[0][k] += gain_pre_echo * capture_with_pre_echo[k];
}
render_delay_buffer->Insert(render);
if (k == 0) {
render_delay_buffer->Reset();
}
render_delay_buffer->PrepareCaptureProcessing();
std::array<float, kBlockSize> downsampled_capture_data;
rtc::ArrayView<float> downsampled_capture(downsampled_capture_data.data(),
sub_block_size);
capture_decimator.Decimate(capture[0], downsampled_capture);
filter.Update(render_delay_buffer->GetDownsampledRenderBuffer(),
downsampled_capture, /*use_slow_smoothing=*/false);
}
// Obtain the lag estimates.
auto lag_estimate = filter.GetBestLagEstimate();
EXPECT_TRUE(lag_estimate.has_value());
// Verify that the expected most accurate lag estimate is correct.
if (lag_estimate.has_value()) {
EXPECT_EQ(echo_delay_samples, lag_estimate->lag);
if (kDetectPreEcho) {
// The pre echo delay is estimated in a subsampled domain and a larger
// error is allowed.
EXPECT_NEAR(pre_echo_delay_samples, lag_estimate->pre_echo_lag, 4);
} else {
EXPECT_TRUE(
delay_samples ==
lag_estimates[std::min(*expected_most_accurate_lag_estimate + 1,
lag_estimates.size() - 1)]
.lag);
// The pre echo delay fallback to the highest mached filter peak when
// its detection is disabled.
EXPECT_EQ(echo_delay_samples, lag_estimate->pre_echo_lag);
}
}
}
@ -317,7 +396,8 @@ TEST(MatchedFilter, LagEstimation) {
// Verifies that the matched filter does not produce reliable and accurate
// estimates for uncorrelated render and capture signals.
TEST(MatchedFilter, LagNotReliableForUncorrelatedRenderAndCapture) {
TEST_P(MatchedFilterTest, LagNotReliableForUncorrelatedRenderAndCapture) {
const bool kDetectPreEcho = GetParam();
constexpr size_t kNumChannels = 1;
constexpr int kSampleRateHz = 48000;
constexpr size_t kNumBands = NumBandsForRate(kSampleRateHz);
@ -335,12 +415,12 @@ TEST(MatchedFilter, LagNotReliableForUncorrelatedRenderAndCapture) {
ApmDataDumper data_dumper(0);
std::unique_ptr<RenderDelayBuffer> render_delay_buffer(
RenderDelayBuffer::Create(config, kSampleRateHz, kNumChannels));
MatchedFilter filter(&data_dumper, DetectOptimization(), sub_block_size,
kWindowSizeSubBlocks, kNumMatchedFilters,
kAlignmentShiftSubBlocks, 150,
MatchedFilter filter(
&data_dumper, DetectOptimization(), sub_block_size,
kWindowSizeSubBlocks, kNumMatchedFilters, kAlignmentShiftSubBlocks, 150,
config.delay.delay_estimate_smoothing,
config.delay.delay_estimate_smoothing_delay_found,
config.delay.delay_candidate_detection_threshold);
config.delay.delay_candidate_detection_threshold, kDetectPreEcho);
// Analyze the correlation between render and capture.
for (size_t k = 0; k < 100; ++k) {
@ -352,20 +432,17 @@ TEST(MatchedFilter, LagNotReliableForUncorrelatedRenderAndCapture) {
false);
}
// Obtain the lag estimates.
auto lag_estimates = filter.GetLagEstimates();
EXPECT_EQ(kNumMatchedFilters, lag_estimates.size());
// Verify that no lag estimates are reliable.
for (auto& le : lag_estimates) {
EXPECT_FALSE(le.reliable);
}
// Obtain the best lag estimate and Verify that no lag estimates are
// reliable.
auto best_lag_estimates = filter.GetBestLagEstimate();
EXPECT_FALSE(best_lag_estimates.has_value());
}
}
// Verifies that the matched filter does not produce updated lag estimates for
// render signals of low level.
TEST(MatchedFilter, LagNotUpdatedForLowLevelRender) {
TEST_P(MatchedFilterTest, LagNotUpdatedForLowLevelRender) {
const bool kDetectPreEcho = GetParam();
Random random_generator(42U);
constexpr size_t kNumChannels = 1;
constexpr int kSampleRateHz = 48000;
@ -374,19 +451,17 @@ TEST(MatchedFilter, LagNotUpdatedForLowLevelRender) {
for (auto down_sampling_factor : kDownSamplingFactors) {
const size_t sub_block_size = kBlockSize / down_sampling_factor;
std::vector<std::vector<std::vector<float>>> render(
kNumBands, std::vector<std::vector<float>>(
kNumChannels, std::vector<float>(kBlockSize, 0.f)));
Block render(kNumBands, kNumChannels);
std::vector<std::vector<float>> capture(
1, std::vector<float>(kBlockSize, 0.f));
ApmDataDumper data_dumper(0);
EchoCanceller3Config config;
MatchedFilter filter(&data_dumper, DetectOptimization(), sub_block_size,
kWindowSizeSubBlocks, kNumMatchedFilters,
kAlignmentShiftSubBlocks, 150,
MatchedFilter filter(
&data_dumper, DetectOptimization(), sub_block_size,
kWindowSizeSubBlocks, kNumMatchedFilters, kAlignmentShiftSubBlocks, 150,
config.delay.delay_estimate_smoothing,
config.delay.delay_estimate_smoothing_delay_found,
config.delay.delay_candidate_detection_threshold);
config.delay.delay_candidate_detection_threshold, kDetectPreEcho);
std::unique_ptr<RenderDelayBuffer> render_delay_buffer(
RenderDelayBuffer::Create(EchoCanceller3Config(), kSampleRateHz,
kNumChannels));
@ -394,11 +469,11 @@ TEST(MatchedFilter, LagNotUpdatedForLowLevelRender) {
// Analyze the correlation between render and capture.
for (size_t k = 0; k < 100; ++k) {
RandomizeSampleVector(&random_generator, render[0][0]);
for (auto& render_k : render[0][0]) {
RandomizeSampleVector(&random_generator, render.View(0, 0));
for (auto& render_k : render.View(0, 0)) {
render_k *= 149.f / 32767.f;
}
std::copy(render[0][0].begin(), render[0][0].end(), capture[0].begin());
std::copy(render.begin(0, 0), render.end(0, 0), capture[0].begin());
std::array<float, kBlockSize> downsampled_capture_data;
rtc::ArrayView<float> downsampled_capture(downsampled_capture_data.data(),
sub_block_size);
@ -407,86 +482,76 @@ TEST(MatchedFilter, LagNotUpdatedForLowLevelRender) {
downsampled_capture, false);
}
// Obtain the lag estimates.
auto lag_estimates = filter.GetLagEstimates();
EXPECT_EQ(kNumMatchedFilters, lag_estimates.size());
// Verify that no lag estimates are updated and that no lag estimates are
// reliable.
for (auto& le : lag_estimates) {
EXPECT_FALSE(le.updated);
EXPECT_FALSE(le.reliable);
}
// Verify that no lag estimate has been produced.
auto lag_estimate = filter.GetBestLagEstimate();
EXPECT_FALSE(lag_estimate.has_value());
}
}
// Verifies that the correct number of lag estimates are produced for a certain
// number of alignment shifts.
TEST(MatchedFilter, NumberOfLagEstimates) {
ApmDataDumper data_dumper(0);
EchoCanceller3Config config;
for (auto down_sampling_factor : kDownSamplingFactors) {
const size_t sub_block_size = kBlockSize / down_sampling_factor;
for (size_t num_matched_filters = 0; num_matched_filters < 10;
++num_matched_filters) {
MatchedFilter filter(&data_dumper, DetectOptimization(), sub_block_size,
32, num_matched_filters, 1, 150,
config.delay.delay_estimate_smoothing,
config.delay.delay_estimate_smoothing_delay_found,
config.delay.delay_candidate_detection_threshold);
EXPECT_EQ(num_matched_filters, filter.GetLagEstimates().size());
}
}
}
INSTANTIATE_TEST_SUITE_P(_, MatchedFilterTest, testing::Values(true, false));
#if RTC_DCHECK_IS_ON && GTEST_HAS_DEATH_TEST && !defined(WEBRTC_ANDROID)
class MatchedFilterDeathTest : public ::testing::TestWithParam<bool> {};
// Verifies the check for non-zero windows size.
TEST(MatchedFilterDeathTest, ZeroWindowSize) {
TEST_P(MatchedFilterDeathTest, ZeroWindowSize) {
const bool kDetectPreEcho = GetParam();
ApmDataDumper data_dumper(0);
EchoCanceller3Config config;
EXPECT_DEATH(MatchedFilter(&data_dumper, DetectOptimization(), 16, 0, 1, 1,
150, config.delay.delay_estimate_smoothing,
config.delay.delay_estimate_smoothing_delay_found,
config.delay.delay_candidate_detection_threshold),
config.delay.delay_candidate_detection_threshold,
kDetectPreEcho),
"");
}
// Verifies the check for non-null data dumper.
TEST(MatchedFilterDeathTest, NullDataDumper) {
TEST_P(MatchedFilterDeathTest, NullDataDumper) {
const bool kDetectPreEcho = GetParam();
EchoCanceller3Config config;
EXPECT_DEATH(MatchedFilter(nullptr, DetectOptimization(), 16, 1, 1, 1, 150,
config.delay.delay_estimate_smoothing,
config.delay.delay_estimate_smoothing_delay_found,
config.delay.delay_candidate_detection_threshold),
config.delay.delay_candidate_detection_threshold,
kDetectPreEcho),
"");
}
// Verifies the check for that the sub block size is a multiple of 4.
// TODO(peah): Activate the unittest once the required code has been landed.
TEST(MatchedFilterDeathTest, DISABLED_BlockSizeMultipleOf4) {
TEST_P(MatchedFilterDeathTest, DISABLED_BlockSizeMultipleOf4) {
const bool kDetectPreEcho = GetParam();
ApmDataDumper data_dumper(0);
EchoCanceller3Config config;
EXPECT_DEATH(MatchedFilter(&data_dumper, DetectOptimization(), 15, 1, 1, 1,
150, config.delay.delay_estimate_smoothing,
config.delay.delay_estimate_smoothing_delay_found,
config.delay.delay_candidate_detection_threshold),
config.delay.delay_candidate_detection_threshold,
kDetectPreEcho),
"");
}
// Verifies the check for that there is an integer number of sub blocks that add
// up to a block size.
// TODO(peah): Activate the unittest once the required code has been landed.
TEST(MatchedFilterDeathTest, DISABLED_SubBlockSizeAddsUpToBlockSize) {
TEST_P(MatchedFilterDeathTest, DISABLED_SubBlockSizeAddsUpToBlockSize) {
const bool kDetectPreEcho = GetParam();
ApmDataDumper data_dumper(0);
EchoCanceller3Config config;
EXPECT_DEATH(MatchedFilter(&data_dumper, DetectOptimization(), 12, 1, 1, 1,
150, config.delay.delay_estimate_smoothing,
config.delay.delay_estimate_smoothing_delay_found,
config.delay.delay_candidate_detection_threshold),
config.delay.delay_candidate_detection_threshold,
kDetectPreEcho),
"");
}
INSTANTIATE_TEST_SUITE_P(_,
MatchedFilterDeathTest,
testing::Values(true, false));
#endif
} // namespace aec3

View File

@ -53,7 +53,6 @@ class RenderDelayControllerImpl final : public RenderDelayController {
static std::atomic<int> instance_count_;
std::unique_ptr<ApmDataDumper> data_dumper_;
const int hysteresis_limit_blocks_;
const int delay_headroom_samples_;
absl::optional<DelayEstimate> delay_;
EchoPathDelayEstimator delay_estimator_;
RenderDelayControllerMetrics metrics_;
@ -66,15 +65,9 @@ class RenderDelayControllerImpl final : public RenderDelayController {
DelayEstimate ComputeBufferDelay(
const absl::optional<DelayEstimate>& current_delay,
int hysteresis_limit_blocks,
int delay_headroom_samples,
DelayEstimate estimated_delay) {
// Subtract delay headroom.
const int delay_with_headroom_samples = std::max(
static_cast<int>(estimated_delay.delay) - delay_headroom_samples, 0);
// Compute the buffer delay increase required to achieve the desired latency.
size_t new_delay_blocks = delay_with_headroom_samples >> kBlockSizeLog2;
size_t new_delay_blocks = estimated_delay.delay >> kBlockSizeLog2;
// Add hysteresis.
if (current_delay) {
size_t current_delay_blocks = current_delay->delay;
@ -83,7 +76,6 @@ DelayEstimate ComputeBufferDelay(
new_delay_blocks = current_delay_blocks;
}
}
DelayEstimate new_delay = estimated_delay;
new_delay.delay = new_delay_blocks;
return new_delay;
@ -98,7 +90,6 @@ RenderDelayControllerImpl::RenderDelayControllerImpl(
: data_dumper_(new ApmDataDumper(instance_count_.fetch_add(1) + 1)),
hysteresis_limit_blocks_(
static_cast<int>(config.delay.hysteresis_limit_blocks)),
delay_headroom_samples_(config.delay.delay_headroom_samples),
delay_estimator_(data_dumper_.get(), config, num_capture_channels),
last_delay_estimate_quality_(DelayEstimate::Quality::kCoarse) {
RTC_DCHECK(ValidFullBandRate(sample_rate_hz));
@ -158,9 +149,8 @@ absl::optional<DelayEstimate> RenderDelayControllerImpl::GetDelay(
const bool use_hysteresis =
last_delay_estimate_quality_ == DelayEstimate::Quality::kRefined &&
delay_samples_->quality == DelayEstimate::Quality::kRefined;
delay_ = ComputeBufferDelay(delay_,
use_hysteresis ? hysteresis_limit_blocks_ : 0,
delay_headroom_samples_, *delay_samples_);
delay_ = ComputeBufferDelay(
delay_, use_hysteresis ? hysteresis_limit_blocks_ : 0, *delay_samples_);
last_delay_estimate_quality_ = delay_samples_->quality;
}