delay estimator: Look for early reverberation
Look for first echo (and not only the strongest one) on the same matched filter. This change is bit exact with previous version when `pre_echo` is false. Author: Jesús de Vicente Peña <devicentepena@webrtc.org> Bug: webrtc:14205 Change-Id: I6782eaa1d690b0df78d00f6d425a85c951b2ca9d Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/266321 Reviewed-by: Gustaf Ullberg <gustaf@webrtc.org> Commit-Queue: Lionel Koenig <lionelk@webrtc.org> Cr-Commit-Position: refs/heads/main@{#37360}
This commit is contained in:

committed by
WebRTC LUCI CQ

parent
7534ebd2bf
commit
8783c678a5
@ -59,6 +59,7 @@ struct RTC_EXPORT EchoCanceller3Config {
|
||||
};
|
||||
AlignmentMixing render_alignment_mixing = {false, true, 10000.f, true};
|
||||
AlignmentMixing capture_alignment_mixing = {false, true, 10000.f, false};
|
||||
bool detect_pre_echo = false;
|
||||
} delay;
|
||||
|
||||
struct Filter {
|
||||
|
@ -220,6 +220,7 @@ void Aec3ConfigFromJsonString(absl::string_view json_string,
|
||||
&cfg.delay.render_alignment_mixing);
|
||||
ReadParam(section, "capture_alignment_mixing",
|
||||
&cfg.delay.capture_alignment_mixing);
|
||||
ReadParam(section, "detect_pre_echo", &cfg.delay.detect_pre_echo);
|
||||
}
|
||||
|
||||
if (rtc::GetValueFromJsonObject(aec3_root, "filter", §ion)) {
|
||||
@ -505,7 +506,9 @@ std::string Aec3ConfigToJsonString(const EchoCanceller3Config& config) {
|
||||
<< (config.delay.capture_alignment_mixing.prefer_first_two_channels
|
||||
? "true"
|
||||
: "false");
|
||||
ost << "}";
|
||||
ost << "},";
|
||||
ost << "\"detect_pre_echo\": "
|
||||
<< (config.delay.detect_pre_echo ? "true" : "false");
|
||||
ost << "},";
|
||||
|
||||
ost << "\"filter\": {";
|
||||
|
@ -226,6 +226,7 @@ rtc_source_set("matched_filter") {
|
||||
"../../../api:array_view",
|
||||
"../../../rtc_base/system:arch",
|
||||
]
|
||||
absl_deps = [ "//third_party/abseil-cpp/absl/types:optional" ]
|
||||
}
|
||||
|
||||
rtc_source_set("vector_math") {
|
||||
|
@ -377,6 +377,14 @@ EchoCanceller3Config AdjustConfig(const EchoCanceller3Config& config) {
|
||||
false;
|
||||
}
|
||||
|
||||
if (field_trial::IsEnabled("WebRTC-Aec3DelayEstimatorDetectPreEcho")) {
|
||||
adjusted_cfg.delay.detect_pre_echo = true;
|
||||
}
|
||||
|
||||
if (field_trial::IsDisabled("WebRTC-Aec3DelayEstimatorDetectPreEcho")) {
|
||||
adjusted_cfg.delay.detect_pre_echo = false;
|
||||
}
|
||||
|
||||
if (field_trial::IsEnabled("WebRTC-Aec3SensitiveDominantNearendActivation")) {
|
||||
adjusted_cfg.suppressor.dominant_nearend_detection.enr_threshold = 0.5f;
|
||||
} else if (field_trial::IsEnabled(
|
||||
|
@ -43,10 +43,11 @@ EchoPathDelayEstimator::EchoPathDelayEstimator(
|
||||
: config.render_levels.poor_excitation_render_limit,
|
||||
config.delay.delay_estimate_smoothing,
|
||||
config.delay.delay_estimate_smoothing_delay_found,
|
||||
config.delay.delay_candidate_detection_threshold),
|
||||
config.delay.delay_candidate_detection_threshold,
|
||||
config.delay.detect_pre_echo),
|
||||
matched_filter_lag_aggregator_(data_dumper_,
|
||||
matched_filter_.GetMaxFilterLag(),
|
||||
config.delay.delay_selection_thresholds) {
|
||||
config.delay) {
|
||||
RTC_DCHECK(data_dumper);
|
||||
RTC_DCHECK(down_sampling_factor_ > 0);
|
||||
}
|
||||
@ -75,13 +76,14 @@ absl::optional<DelayEstimate> EchoPathDelayEstimator::EstimateDelay(
|
||||
|
||||
absl::optional<DelayEstimate> aggregated_matched_filter_lag =
|
||||
matched_filter_lag_aggregator_.Aggregate(
|
||||
matched_filter_.GetLagEstimates());
|
||||
matched_filter_.GetBestLagEstimate());
|
||||
|
||||
// Run clockdrift detection.
|
||||
if (aggregated_matched_filter_lag &&
|
||||
(*aggregated_matched_filter_lag).quality ==
|
||||
DelayEstimate::Quality::kRefined)
|
||||
clockdrift_detector_.Update((*aggregated_matched_filter_lag).delay);
|
||||
clockdrift_detector_.Update(
|
||||
matched_filter_lag_aggregator_.GetDelayAtHighestPeak());
|
||||
|
||||
// TODO(peah): Move this logging outside of this class once EchoCanceller3
|
||||
// development is done.
|
||||
|
@ -78,6 +78,7 @@ TEST(EchoPathDelayEstimator, DelayEstimation) {
|
||||
constexpr size_t kDownSamplingFactors[] = {2, 4, 8};
|
||||
for (auto down_sampling_factor : kDownSamplingFactors) {
|
||||
EchoCanceller3Config config;
|
||||
config.delay.delay_headroom_samples = 0;
|
||||
config.delay.down_sampling_factor = down_sampling_factor;
|
||||
config.delay.num_filters = 10;
|
||||
for (size_t delay_samples : {30, 64, 150, 200, 800, 4000}) {
|
||||
@ -111,12 +112,13 @@ TEST(EchoPathDelayEstimator, DelayEstimation) {
|
||||
}
|
||||
|
||||
if (estimated_delay_samples) {
|
||||
// Allow estimated delay to be off by one sample in the down-sampled
|
||||
// domain.
|
||||
// Allow estimated delay to be off by a block as internally the delay is
|
||||
// quantized with an error up to a block.
|
||||
size_t delay_ds = delay_samples / down_sampling_factor;
|
||||
size_t estimated_delay_ds =
|
||||
estimated_delay_samples->delay / down_sampling_factor;
|
||||
EXPECT_NEAR(delay_ds, estimated_delay_ds, 1);
|
||||
EXPECT_NEAR(delay_ds, estimated_delay_ds,
|
||||
kBlockSize / down_sampling_factor);
|
||||
} else {
|
||||
ADD_FAILURE();
|
||||
}
|
||||
|
@ -24,16 +24,147 @@
|
||||
#include <iterator>
|
||||
#include <numeric>
|
||||
|
||||
#include "absl/types/optional.h"
|
||||
#include "api/array_view.h"
|
||||
#include "modules/audio_processing/aec3/downsampled_render_buffer.h"
|
||||
#include "modules/audio_processing/logging/apm_data_dumper.h"
|
||||
#include "rtc_base/checks.h"
|
||||
#include "rtc_base/logging.h"
|
||||
|
||||
namespace {
|
||||
|
||||
// Subsample rate used for computing the accumulated error.
|
||||
// The implementation of some core functions depends on this constant being
|
||||
// equal to 4.
|
||||
constexpr int kAccumulatedErrorSubSampleRate = 4;
|
||||
|
||||
void UpdateAccumulatedError(
|
||||
const rtc::ArrayView<const float> instantaneous_accumulated_error,
|
||||
const rtc::ArrayView<float> accumulated_error,
|
||||
float one_over_error_sum_anchor) {
|
||||
for (size_t k = 0; k < instantaneous_accumulated_error.size(); ++k) {
|
||||
float error_norm =
|
||||
instantaneous_accumulated_error[k] * one_over_error_sum_anchor;
|
||||
if (error_norm < accumulated_error[k]) {
|
||||
accumulated_error[k] = error_norm;
|
||||
} else {
|
||||
accumulated_error[k] += 0.01f * (error_norm - accumulated_error[k]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
size_t ComputePreEchoLag(const rtc::ArrayView<float> accumulated_error,
|
||||
size_t lag,
|
||||
size_t alignment_shift_winner) {
|
||||
size_t pre_echo_lag_estimate = lag - alignment_shift_winner;
|
||||
size_t maximum_pre_echo_lag =
|
||||
std::min(pre_echo_lag_estimate / kAccumulatedErrorSubSampleRate,
|
||||
accumulated_error.size());
|
||||
for (size_t k = 1; k < maximum_pre_echo_lag; ++k) {
|
||||
if (accumulated_error[k] < 0.5f * accumulated_error[k - 1] &&
|
||||
accumulated_error[k] < 0.5f) {
|
||||
pre_echo_lag_estimate = (k + 1) * kAccumulatedErrorSubSampleRate - 1;
|
||||
break;
|
||||
}
|
||||
}
|
||||
return pre_echo_lag_estimate + alignment_shift_winner;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
namespace webrtc {
|
||||
namespace aec3 {
|
||||
|
||||
#if defined(WEBRTC_HAS_NEON)
|
||||
|
||||
inline float SumAllElements(float32x4_t elements) {
|
||||
float32x2_t sum = vpadd_f32(vget_low_f32(elements), vget_high_f32(elements));
|
||||
sum = vpadd_f32(sum, sum);
|
||||
return vget_lane_f32(sum, 0);
|
||||
}
|
||||
|
||||
void MatchedFilterCoreWithAccumulatedError_NEON(
|
||||
size_t x_start_index,
|
||||
float x2_sum_threshold,
|
||||
float smoothing,
|
||||
rtc::ArrayView<const float> x,
|
||||
rtc::ArrayView<const float> y,
|
||||
rtc::ArrayView<float> h,
|
||||
bool* filters_updated,
|
||||
float* error_sum,
|
||||
rtc::ArrayView<float> accumulated_error,
|
||||
rtc::ArrayView<float> scratch_memory) {
|
||||
const int h_size = static_cast<int>(h.size());
|
||||
const int x_size = static_cast<int>(x.size());
|
||||
RTC_DCHECK_EQ(0, h_size % 4);
|
||||
std::fill(accumulated_error.begin(), accumulated_error.end(), 0.0f);
|
||||
// Process for all samples in the sub-block.
|
||||
for (size_t i = 0; i < y.size(); ++i) {
|
||||
// Apply the matched filter as filter * x, and compute x * x.
|
||||
RTC_DCHECK_GT(x_size, x_start_index);
|
||||
// Compute loop chunk sizes until, and after, the wraparound of the circular
|
||||
// buffer for x.
|
||||
const int chunk1 =
|
||||
std::min(h_size, static_cast<int>(x_size - x_start_index));
|
||||
if (chunk1 != h_size) {
|
||||
const int chunk2 = h_size - chunk1;
|
||||
std::copy(x.begin() + x_start_index, x.end(), scratch_memory.begin());
|
||||
std::copy(x.begin(), x.begin() + chunk2, scratch_memory.begin() + chunk1);
|
||||
}
|
||||
const float* x_p =
|
||||
chunk1 != h_size ? scratch_memory.data() : &x[x_start_index];
|
||||
const float* h_p = &h[0];
|
||||
float* accumulated_error_p = &accumulated_error[0];
|
||||
// Initialize values for the accumulation.
|
||||
float32x4_t x2_sum_128 = vdupq_n_f32(0);
|
||||
float x2_sum = 0.f;
|
||||
float s = 0;
|
||||
// Perform 128 bit vector operations.
|
||||
const int limit_by_4 = h_size >> 2;
|
||||
for (int k = limit_by_4; k > 0;
|
||||
--k, h_p += 4, x_p += 4, accumulated_error_p++) {
|
||||
// Load the data into 128 bit vectors.
|
||||
const float32x4_t x_k = vld1q_f32(x_p);
|
||||
const float32x4_t h_k = vld1q_f32(h_p);
|
||||
// Compute and accumulate x * x.
|
||||
x2_sum_128 = vmlaq_f32(x2_sum_128, x_k, x_k);
|
||||
// Compute x * h
|
||||
float32x4_t hk_xk_128 = vmulq_f32(h_k, x_k);
|
||||
s += SumAllElements(hk_xk_128);
|
||||
const float e = s - y[i];
|
||||
accumulated_error_p[0] += e * e;
|
||||
}
|
||||
// Combine the accumulated vector and scalar values.
|
||||
x2_sum += SumAllElements(x2_sum_128);
|
||||
// Compute the matched filter error.
|
||||
float e = y[i] - s;
|
||||
const bool saturation = y[i] >= 32000.f || y[i] <= -32000.f;
|
||||
(*error_sum) += e * e;
|
||||
// Update the matched filter estimate in an NLMS manner.
|
||||
if (x2_sum > x2_sum_threshold && !saturation) {
|
||||
RTC_DCHECK_LT(0.f, x2_sum);
|
||||
const float alpha = smoothing * e / x2_sum;
|
||||
const float32x4_t alpha_128 = vmovq_n_f32(alpha);
|
||||
// filter = filter + smoothing * (y - filter * x) * x / x * x.
|
||||
float* h_p = &h[0];
|
||||
x_p = chunk1 != h_size ? scratch_memory.data() : &x[x_start_index];
|
||||
// Perform 128 bit vector operations.
|
||||
const int limit_by_4 = h_size >> 2;
|
||||
for (int k = limit_by_4; k > 0; --k, h_p += 4, x_p += 4) {
|
||||
// Load the data into 128 bit vectors.
|
||||
float32x4_t h_k = vld1q_f32(h_p);
|
||||
const float32x4_t x_k = vld1q_f32(x_p);
|
||||
// Compute h = h + alpha * x.
|
||||
h_k = vmlaq_f32(h_k, alpha_128, x_k);
|
||||
// Store the result.
|
||||
vst1q_f32(h_p, h_k);
|
||||
}
|
||||
*filters_updated = true;
|
||||
}
|
||||
x_start_index = x_start_index > 0 ? x_start_index - 1 : x_size - 1;
|
||||
}
|
||||
}
|
||||
|
||||
void MatchedFilterCore_NEON(size_t x_start_index,
|
||||
float x2_sum_threshold,
|
||||
float smoothing,
|
||||
@ -41,11 +172,20 @@ void MatchedFilterCore_NEON(size_t x_start_index,
|
||||
rtc::ArrayView<const float> y,
|
||||
rtc::ArrayView<float> h,
|
||||
bool* filters_updated,
|
||||
float* error_sum) {
|
||||
float* error_sum,
|
||||
bool compute_accumulated_error,
|
||||
rtc::ArrayView<float> accumulated_error,
|
||||
rtc::ArrayView<float> scratch_memory) {
|
||||
const int h_size = static_cast<int>(h.size());
|
||||
const int x_size = static_cast<int>(x.size());
|
||||
RTC_DCHECK_EQ(0, h_size % 4);
|
||||
|
||||
if (compute_accumulated_error) {
|
||||
return MatchedFilterCoreWithAccumulatedError_NEON(
|
||||
x_start_index, x2_sum_threshold, smoothing, x, y, h, filters_updated,
|
||||
error_sum, accumulated_error, scratch_memory);
|
||||
}
|
||||
|
||||
// Process for all samples in the sub-block.
|
||||
for (size_t i = 0; i < y.size(); ++i) {
|
||||
// Apply the matched filter as filter * x, and compute x * x.
|
||||
@ -90,10 +230,8 @@ void MatchedFilterCore_NEON(size_t x_start_index,
|
||||
}
|
||||
|
||||
// Combine the accumulated vector and scalar values.
|
||||
float* v = reinterpret_cast<float*>(&x2_sum_128);
|
||||
x2_sum += v[0] + v[1] + v[2] + v[3];
|
||||
v = reinterpret_cast<float*>(&s_128);
|
||||
s += v[0] + v[1] + v[2] + v[3];
|
||||
s += SumAllElements(s_128);
|
||||
x2_sum += SumAllElements(x2_sum_128);
|
||||
|
||||
// Compute the matched filter error.
|
||||
float e = y[i] - s;
|
||||
@ -144,6 +282,103 @@ void MatchedFilterCore_NEON(size_t x_start_index,
|
||||
|
||||
#if defined(WEBRTC_ARCH_X86_FAMILY)
|
||||
|
||||
void MatchedFilterCore_AccumulatedError_SSE2(
|
||||
size_t x_start_index,
|
||||
float x2_sum_threshold,
|
||||
float smoothing,
|
||||
rtc::ArrayView<const float> x,
|
||||
rtc::ArrayView<const float> y,
|
||||
rtc::ArrayView<float> h,
|
||||
bool* filters_updated,
|
||||
float* error_sum,
|
||||
rtc::ArrayView<float> accumulated_error,
|
||||
rtc::ArrayView<float> scratch_memory) {
|
||||
const int h_size = static_cast<int>(h.size());
|
||||
const int x_size = static_cast<int>(x.size());
|
||||
RTC_DCHECK_EQ(0, h_size % 8);
|
||||
std::fill(accumulated_error.begin(), accumulated_error.end(), 0.0f);
|
||||
// Process for all samples in the sub-block.
|
||||
for (size_t i = 0; i < y.size(); ++i) {
|
||||
// Apply the matched filter as filter * x, and compute x * x.
|
||||
RTC_DCHECK_GT(x_size, x_start_index);
|
||||
const int chunk1 =
|
||||
std::min(h_size, static_cast<int>(x_size - x_start_index));
|
||||
if (chunk1 != h_size) {
|
||||
const int chunk2 = h_size - chunk1;
|
||||
std::copy(x.begin() + x_start_index, x.end(), scratch_memory.begin());
|
||||
std::copy(x.begin(), x.begin() + chunk2, scratch_memory.begin() + chunk1);
|
||||
}
|
||||
const float* x_p =
|
||||
chunk1 != h_size ? scratch_memory.data() : &x[x_start_index];
|
||||
const float* h_p = &h[0];
|
||||
float* a_p = &accumulated_error[0];
|
||||
__m128 s_inst_128;
|
||||
__m128 s_inst_128_4;
|
||||
__m128 x2_sum_128 = _mm_set1_ps(0);
|
||||
__m128 x2_sum_128_4 = _mm_set1_ps(0);
|
||||
__m128 e_128;
|
||||
float* const s_p = reinterpret_cast<float*>(&s_inst_128);
|
||||
float* const s_4_p = reinterpret_cast<float*>(&s_inst_128_4);
|
||||
float* const e_p = reinterpret_cast<float*>(&e_128);
|
||||
float x2_sum = 0.0f;
|
||||
float s_acum = 0;
|
||||
// Perform 128 bit vector operations.
|
||||
const int limit_by_8 = h_size >> 3;
|
||||
for (int k = limit_by_8; k > 0; --k, h_p += 8, x_p += 8, a_p += 2) {
|
||||
// Load the data into 128 bit vectors.
|
||||
const __m128 x_k = _mm_loadu_ps(x_p);
|
||||
const __m128 h_k = _mm_loadu_ps(h_p);
|
||||
const __m128 x_k_4 = _mm_loadu_ps(x_p + 4);
|
||||
const __m128 h_k_4 = _mm_loadu_ps(h_p + 4);
|
||||
const __m128 xx = _mm_mul_ps(x_k, x_k);
|
||||
const __m128 xx_4 = _mm_mul_ps(x_k_4, x_k_4);
|
||||
// Compute and accumulate x * x and h * x.
|
||||
x2_sum_128 = _mm_add_ps(x2_sum_128, xx);
|
||||
x2_sum_128_4 = _mm_add_ps(x2_sum_128_4, xx_4);
|
||||
s_inst_128 = _mm_mul_ps(h_k, x_k);
|
||||
s_inst_128_4 = _mm_mul_ps(h_k_4, x_k_4);
|
||||
s_acum += s_p[0] + s_p[1] + s_p[2] + s_p[3];
|
||||
e_p[0] = s_acum - y[i];
|
||||
s_acum += s_4_p[0] + s_4_p[1] + s_4_p[2] + s_4_p[3];
|
||||
e_p[1] = s_acum - y[i];
|
||||
a_p[0] += e_p[0] * e_p[0];
|
||||
a_p[1] += e_p[1] * e_p[1];
|
||||
}
|
||||
// Combine the accumulated vector and scalar values.
|
||||
x2_sum_128 = _mm_add_ps(x2_sum_128, x2_sum_128_4);
|
||||
float* v = reinterpret_cast<float*>(&x2_sum_128);
|
||||
x2_sum += v[0] + v[1] + v[2] + v[3];
|
||||
// Compute the matched filter error.
|
||||
float e = y[i] - s_acum;
|
||||
const bool saturation = y[i] >= 32000.f || y[i] <= -32000.f;
|
||||
(*error_sum) += e * e;
|
||||
// Update the matched filter estimate in an NLMS manner.
|
||||
if (x2_sum > x2_sum_threshold && !saturation) {
|
||||
RTC_DCHECK_LT(0.f, x2_sum);
|
||||
const float alpha = smoothing * e / x2_sum;
|
||||
const __m128 alpha_128 = _mm_set1_ps(alpha);
|
||||
// filter = filter + smoothing * (y - filter * x) * x / x * x.
|
||||
float* h_p = &h[0];
|
||||
const float* x_p =
|
||||
chunk1 != h_size ? scratch_memory.data() : &x[x_start_index];
|
||||
// Perform 128 bit vector operations.
|
||||
const int limit_by_4 = h_size >> 2;
|
||||
for (int k = limit_by_4; k > 0; --k, h_p += 4, x_p += 4) {
|
||||
// Load the data into 128 bit vectors.
|
||||
__m128 h_k = _mm_loadu_ps(h_p);
|
||||
const __m128 x_k = _mm_loadu_ps(x_p);
|
||||
// Compute h = h + alpha * x.
|
||||
const __m128 alpha_x = _mm_mul_ps(alpha_128, x_k);
|
||||
h_k = _mm_add_ps(h_k, alpha_x);
|
||||
// Store the result.
|
||||
_mm_storeu_ps(h_p, h_k);
|
||||
}
|
||||
*filters_updated = true;
|
||||
}
|
||||
x_start_index = x_start_index > 0 ? x_start_index - 1 : x_size - 1;
|
||||
}
|
||||
}
|
||||
|
||||
void MatchedFilterCore_SSE2(size_t x_start_index,
|
||||
float x2_sum_threshold,
|
||||
float smoothing,
|
||||
@ -151,19 +386,24 @@ void MatchedFilterCore_SSE2(size_t x_start_index,
|
||||
rtc::ArrayView<const float> y,
|
||||
rtc::ArrayView<float> h,
|
||||
bool* filters_updated,
|
||||
float* error_sum) {
|
||||
float* error_sum,
|
||||
bool compute_accumulated_error,
|
||||
rtc::ArrayView<float> accumulated_error,
|
||||
rtc::ArrayView<float> scratch_memory) {
|
||||
if (compute_accumulated_error) {
|
||||
return MatchedFilterCore_AccumulatedError_SSE2(
|
||||
x_start_index, x2_sum_threshold, smoothing, x, y, h, filters_updated,
|
||||
error_sum, accumulated_error, scratch_memory);
|
||||
}
|
||||
const int h_size = static_cast<int>(h.size());
|
||||
const int x_size = static_cast<int>(x.size());
|
||||
RTC_DCHECK_EQ(0, h_size % 4);
|
||||
|
||||
// Process for all samples in the sub-block.
|
||||
for (size_t i = 0; i < y.size(); ++i) {
|
||||
// Apply the matched filter as filter * x, and compute x * x.
|
||||
|
||||
RTC_DCHECK_GT(x_size, x_start_index);
|
||||
const float* x_p = &x[x_start_index];
|
||||
const float* h_p = &h[0];
|
||||
|
||||
// Initialize values for the accumulation.
|
||||
__m128 s_128 = _mm_set1_ps(0);
|
||||
__m128 s_128_4 = _mm_set1_ps(0);
|
||||
@ -171,12 +411,10 @@ void MatchedFilterCore_SSE2(size_t x_start_index,
|
||||
__m128 x2_sum_128_4 = _mm_set1_ps(0);
|
||||
float x2_sum = 0.f;
|
||||
float s = 0;
|
||||
|
||||
// Compute loop chunk sizes until, and after, the wraparound of the circular
|
||||
// buffer for x.
|
||||
const int chunk1 =
|
||||
std::min(h_size, static_cast<int>(x_size - x_start_index));
|
||||
|
||||
// Perform the loop in two chunks.
|
||||
const int chunk2 = h_size - chunk1;
|
||||
for (int limit : {chunk1, chunk2}) {
|
||||
@ -198,17 +436,14 @@ void MatchedFilterCore_SSE2(size_t x_start_index,
|
||||
s_128 = _mm_add_ps(s_128, hx);
|
||||
s_128_4 = _mm_add_ps(s_128_4, hx_4);
|
||||
}
|
||||
|
||||
// Perform non-vector operations for any remaining items.
|
||||
for (int k = limit - limit_by_8 * 8; k > 0; --k, ++h_p, ++x_p) {
|
||||
const float x_k = *x_p;
|
||||
x2_sum += x_k * x_k;
|
||||
s += *h_p * x_k;
|
||||
}
|
||||
|
||||
x_p = &x[0];
|
||||
}
|
||||
|
||||
// Combine the accumulated vector and scalar values.
|
||||
x2_sum_128 = _mm_add_ps(x2_sum_128, x2_sum_128_4);
|
||||
float* v = reinterpret_cast<float*>(&x2_sum_128);
|
||||
@ -216,22 +451,18 @@ void MatchedFilterCore_SSE2(size_t x_start_index,
|
||||
s_128 = _mm_add_ps(s_128, s_128_4);
|
||||
v = reinterpret_cast<float*>(&s_128);
|
||||
s += v[0] + v[1] + v[2] + v[3];
|
||||
|
||||
// Compute the matched filter error.
|
||||
float e = y[i] - s;
|
||||
const bool saturation = y[i] >= 32000.f || y[i] <= -32000.f;
|
||||
(*error_sum) += e * e;
|
||||
|
||||
// Update the matched filter estimate in an NLMS manner.
|
||||
if (x2_sum > x2_sum_threshold && !saturation) {
|
||||
RTC_DCHECK_LT(0.f, x2_sum);
|
||||
const float alpha = smoothing * e / x2_sum;
|
||||
const __m128 alpha_128 = _mm_set1_ps(alpha);
|
||||
|
||||
// filter = filter + smoothing * (y - filter * x) * x / x * x.
|
||||
float* h_p = &h[0];
|
||||
x_p = &x[x_start_index];
|
||||
|
||||
// Perform the loop in two chunks.
|
||||
for (int limit : {chunk1, chunk2}) {
|
||||
// Perform 128 bit vector operations.
|
||||
@ -244,22 +475,17 @@ void MatchedFilterCore_SSE2(size_t x_start_index,
|
||||
// Compute h = h + alpha * x.
|
||||
const __m128 alpha_x = _mm_mul_ps(alpha_128, x_k);
|
||||
h_k = _mm_add_ps(h_k, alpha_x);
|
||||
|
||||
// Store the result.
|
||||
_mm_storeu_ps(h_p, h_k);
|
||||
}
|
||||
|
||||
// Perform non-vector operations for any remaining items.
|
||||
for (int k = limit - limit_by_4 * 4; k > 0; --k, ++h_p, ++x_p) {
|
||||
*h_p += alpha * *x_p;
|
||||
}
|
||||
|
||||
x_p = &x[0];
|
||||
}
|
||||
|
||||
*filters_updated = true;
|
||||
}
|
||||
|
||||
x_start_index = x_start_index > 0 ? x_start_index - 1 : x_size - 1;
|
||||
}
|
||||
}
|
||||
@ -272,17 +498,35 @@ void MatchedFilterCore(size_t x_start_index,
|
||||
rtc::ArrayView<const float> y,
|
||||
rtc::ArrayView<float> h,
|
||||
bool* filters_updated,
|
||||
float* error_sum) {
|
||||
float* error_sum,
|
||||
bool compute_accumulated_error,
|
||||
rtc::ArrayView<float> accumulated_error) {
|
||||
if (compute_accumulated_error) {
|
||||
std::fill(accumulated_error.begin(), accumulated_error.end(), 0.0f);
|
||||
}
|
||||
|
||||
// Process for all samples in the sub-block.
|
||||
for (size_t i = 0; i < y.size(); ++i) {
|
||||
// Apply the matched filter as filter * x, and compute x * x.
|
||||
float x2_sum = 0.f;
|
||||
float s = 0;
|
||||
size_t x_index = x_start_index;
|
||||
for (size_t k = 0; k < h.size(); ++k) {
|
||||
x2_sum += x[x_index] * x[x_index];
|
||||
s += h[k] * x[x_index];
|
||||
x_index = x_index < (x.size() - 1) ? x_index + 1 : 0;
|
||||
if (compute_accumulated_error) {
|
||||
for (size_t k = 0; k < h.size(); ++k) {
|
||||
x2_sum += x[x_index] * x[x_index];
|
||||
s += h[k] * x[x_index];
|
||||
x_index = x_index < (x.size() - 1) ? x_index + 1 : 0;
|
||||
if ((k + 1 & 0b11) == 0) {
|
||||
int idx = k >> 2;
|
||||
accumulated_error[idx] += (y[i] - s) * (y[i] - s);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (size_t k = 0; k < h.size(); ++k) {
|
||||
x2_sum += x[x_index] * x[x_index];
|
||||
s += h[k] * x[x_index];
|
||||
x_index = x_index < (x.size() - 1) ? x_index + 1 : 0;
|
||||
}
|
||||
}
|
||||
|
||||
// Compute the matched filter error.
|
||||
@ -354,7 +598,8 @@ MatchedFilter::MatchedFilter(ApmDataDumper* data_dumper,
|
||||
float excitation_limit,
|
||||
float smoothing_fast,
|
||||
float smoothing_slow,
|
||||
float matching_filter_threshold)
|
||||
float matching_filter_threshold,
|
||||
bool detect_pre_echo)
|
||||
: data_dumper_(data_dumper),
|
||||
optimization_(optimization),
|
||||
sub_block_size_(sub_block_size),
|
||||
@ -362,16 +607,31 @@ MatchedFilter::MatchedFilter(ApmDataDumper* data_dumper,
|
||||
filters_(
|
||||
num_matched_filters,
|
||||
std::vector<float>(window_size_sub_blocks * sub_block_size_, 0.f)),
|
||||
lag_estimates_(num_matched_filters),
|
||||
filters_offsets_(num_matched_filters, 0),
|
||||
excitation_limit_(excitation_limit),
|
||||
smoothing_fast_(smoothing_fast),
|
||||
smoothing_slow_(smoothing_slow),
|
||||
matching_filter_threshold_(matching_filter_threshold) {
|
||||
matching_filter_threshold_(matching_filter_threshold),
|
||||
detect_pre_echo_(detect_pre_echo) {
|
||||
RTC_DCHECK(data_dumper);
|
||||
RTC_DCHECK_LT(0, window_size_sub_blocks);
|
||||
RTC_DCHECK((kBlockSize % sub_block_size) == 0);
|
||||
RTC_DCHECK((sub_block_size % 4) == 0);
|
||||
static_assert(kAccumulatedErrorSubSampleRate == 4);
|
||||
if (detect_pre_echo_) {
|
||||
accumulated_error_ = std::vector<std::vector<float>>(
|
||||
num_matched_filters,
|
||||
std::vector<float>(window_size_sub_blocks * sub_block_size_ /
|
||||
kAccumulatedErrorSubSampleRate,
|
||||
1.0f));
|
||||
|
||||
instantaneous_accumulated_error_ =
|
||||
std::vector<float>(window_size_sub_blocks * sub_block_size_ /
|
||||
kAccumulatedErrorSubSampleRate,
|
||||
0.0f);
|
||||
scratch_memory_ =
|
||||
std::vector<float>(window_size_sub_blocks * sub_block_size_);
|
||||
}
|
||||
}
|
||||
|
||||
MatchedFilter::~MatchedFilter() = default;
|
||||
@ -381,9 +641,12 @@ void MatchedFilter::Reset() {
|
||||
std::fill(f.begin(), f.end(), 0.f);
|
||||
}
|
||||
|
||||
for (auto& l : lag_estimates_) {
|
||||
l = MatchedFilter::LagEstimate();
|
||||
for (auto& e : accumulated_error_) {
|
||||
std::fill(e.begin(), e.end(), 1.0f);
|
||||
}
|
||||
|
||||
winner_lag_ = absl::nullopt;
|
||||
reported_lag_estimate_ = absl::nullopt;
|
||||
}
|
||||
|
||||
void MatchedFilter::Update(const DownsampledRenderBuffer& render_buffer,
|
||||
@ -398,11 +661,25 @@ void MatchedFilter::Update(const DownsampledRenderBuffer& render_buffer,
|
||||
const float x2_sum_threshold =
|
||||
filters_[0].size() * excitation_limit_ * excitation_limit_;
|
||||
|
||||
// Compute anchor for the matched filter error.
|
||||
float error_sum_anchor = 0.0f;
|
||||
for (size_t k = 0; k < y.size(); ++k) {
|
||||
error_sum_anchor += y[k] * y[k];
|
||||
}
|
||||
|
||||
// Apply all matched filters.
|
||||
float winner_error_sum = error_sum_anchor;
|
||||
winner_lag_ = absl::nullopt;
|
||||
reported_lag_estimate_ = absl::nullopt;
|
||||
size_t alignment_shift = 0;
|
||||
for (size_t n = 0; n < filters_.size(); ++n) {
|
||||
absl::optional<size_t> previous_lag_estimate;
|
||||
const int num_filters = static_cast<int>(filters_.size());
|
||||
int winner_index = -1;
|
||||
for (int n = 0; n < num_filters; ++n) {
|
||||
float error_sum = 0.f;
|
||||
bool filters_updated = false;
|
||||
const bool compute_pre_echo =
|
||||
detect_pre_echo_ && n == last_detected_best_lag_filter_;
|
||||
|
||||
size_t x_start_index =
|
||||
(render_buffer.read + alignment_shift + sub_block_size_ - 1) %
|
||||
@ -411,85 +688,79 @@ void MatchedFilter::Update(const DownsampledRenderBuffer& render_buffer,
|
||||
switch (optimization_) {
|
||||
#if defined(WEBRTC_ARCH_X86_FAMILY)
|
||||
case Aec3Optimization::kSse2:
|
||||
aec3::MatchedFilterCore_SSE2(x_start_index, x2_sum_threshold, smoothing,
|
||||
render_buffer.buffer, y, filters_[n],
|
||||
&filters_updated, &error_sum);
|
||||
aec3::MatchedFilterCore_SSE2(
|
||||
x_start_index, x2_sum_threshold, smoothing, render_buffer.buffer, y,
|
||||
filters_[n], &filters_updated, &error_sum, compute_pre_echo,
|
||||
instantaneous_accumulated_error_, scratch_memory_);
|
||||
break;
|
||||
case Aec3Optimization::kAvx2:
|
||||
aec3::MatchedFilterCore_AVX2(x_start_index, x2_sum_threshold, smoothing,
|
||||
render_buffer.buffer, y, filters_[n],
|
||||
&filters_updated, &error_sum);
|
||||
aec3::MatchedFilterCore_AVX2(
|
||||
x_start_index, x2_sum_threshold, smoothing, render_buffer.buffer, y,
|
||||
filters_[n], &filters_updated, &error_sum, compute_pre_echo,
|
||||
instantaneous_accumulated_error_, scratch_memory_);
|
||||
break;
|
||||
#endif
|
||||
#if defined(WEBRTC_HAS_NEON)
|
||||
case Aec3Optimization::kNeon:
|
||||
aec3::MatchedFilterCore_NEON(x_start_index, x2_sum_threshold, smoothing,
|
||||
render_buffer.buffer, y, filters_[n],
|
||||
&filters_updated, &error_sum);
|
||||
aec3::MatchedFilterCore_NEON(
|
||||
x_start_index, x2_sum_threshold, smoothing, render_buffer.buffer, y,
|
||||
filters_[n], &filters_updated, &error_sum, compute_pre_echo,
|
||||
instantaneous_accumulated_error_, scratch_memory_);
|
||||
break;
|
||||
#endif
|
||||
default:
|
||||
aec3::MatchedFilterCore(x_start_index, x2_sum_threshold, smoothing,
|
||||
render_buffer.buffer, y, filters_[n],
|
||||
&filters_updated, &error_sum);
|
||||
}
|
||||
|
||||
// Compute anchor for the matched filter error.
|
||||
float error_sum_anchor = 0.0f;
|
||||
for (size_t k = 0; k < y.size(); ++k) {
|
||||
error_sum_anchor += y[k] * y[k];
|
||||
&filters_updated, &error_sum, compute_pre_echo,
|
||||
instantaneous_accumulated_error_);
|
||||
}
|
||||
|
||||
// Estimate the lag in the matched filter as the distance to the portion in
|
||||
// the filter that contributes the most to the matched filter output. This
|
||||
// is detected as the peak of the matched filter.
|
||||
const size_t lag_estimate = aec3::MaxSquarePeakIndex(filters_[n]);
|
||||
const bool reliable =
|
||||
lag_estimate > 2 && lag_estimate < (filters_[n].size() - 10) &&
|
||||
error_sum < matching_filter_threshold_ * error_sum_anchor;
|
||||
|
||||
// Update the lag estimates for the matched filter.
|
||||
lag_estimates_[n] = LagEstimate(
|
||||
error_sum_anchor - error_sum,
|
||||
(lag_estimate > 2 && lag_estimate < (filters_[n].size() - 10) &&
|
||||
error_sum < matching_filter_threshold_ * error_sum_anchor),
|
||||
lag_estimate + alignment_shift, filters_updated);
|
||||
|
||||
RTC_DCHECK_GE(10, filters_.size());
|
||||
switch (n) {
|
||||
case 0:
|
||||
data_dumper_->DumpRaw("aec3_correlator_0_h", filters_[0]);
|
||||
break;
|
||||
case 1:
|
||||
data_dumper_->DumpRaw("aec3_correlator_1_h", filters_[1]);
|
||||
break;
|
||||
case 2:
|
||||
data_dumper_->DumpRaw("aec3_correlator_2_h", filters_[2]);
|
||||
break;
|
||||
case 3:
|
||||
data_dumper_->DumpRaw("aec3_correlator_3_h", filters_[3]);
|
||||
break;
|
||||
case 4:
|
||||
data_dumper_->DumpRaw("aec3_correlator_4_h", filters_[4]);
|
||||
break;
|
||||
case 5:
|
||||
data_dumper_->DumpRaw("aec3_correlator_5_h", filters_[5]);
|
||||
break;
|
||||
case 6:
|
||||
data_dumper_->DumpRaw("aec3_correlator_6_h", filters_[6]);
|
||||
break;
|
||||
case 7:
|
||||
data_dumper_->DumpRaw("aec3_correlator_7_h", filters_[7]);
|
||||
break;
|
||||
case 8:
|
||||
data_dumper_->DumpRaw("aec3_correlator_8_h", filters_[8]);
|
||||
break;
|
||||
case 9:
|
||||
data_dumper_->DumpRaw("aec3_correlator_9_h", filters_[9]);
|
||||
break;
|
||||
default:
|
||||
RTC_DCHECK_NOTREACHED();
|
||||
// Find the best estimate
|
||||
const size_t lag = lag_estimate + alignment_shift;
|
||||
if (filters_updated && reliable && error_sum < winner_error_sum) {
|
||||
winner_error_sum = error_sum;
|
||||
winner_index = n;
|
||||
// In case that 2 matched filters return the same winner candidate
|
||||
// (overlap region), the one with the smaller index is chosen in order
|
||||
// to search for pre-echoes.
|
||||
if (previous_lag_estimate && previous_lag_estimate == lag) {
|
||||
winner_lag_ = previous_lag_estimate;
|
||||
winner_index = n - 1;
|
||||
} else {
|
||||
winner_lag_ = lag;
|
||||
}
|
||||
}
|
||||
|
||||
previous_lag_estimate = lag;
|
||||
alignment_shift += filter_intra_lag_shift_;
|
||||
}
|
||||
|
||||
if (winner_index != -1) {
|
||||
RTC_DCHECK(winner_lag_.has_value());
|
||||
reported_lag_estimate_ =
|
||||
LagEstimate(winner_lag_.value(), /*pre_echo_lag=*/winner_lag_.value());
|
||||
if (detect_pre_echo_ && last_detected_best_lag_filter_ == winner_index) {
|
||||
if (error_sum_anchor > 30.0f * 30.0f * y.size()) {
|
||||
UpdateAccumulatedError(instantaneous_accumulated_error_,
|
||||
accumulated_error_[winner_index],
|
||||
1.0f / error_sum_anchor);
|
||||
}
|
||||
reported_lag_estimate_->pre_echo_lag = ComputePreEchoLag(
|
||||
accumulated_error_[winner_index], winner_lag_.value(),
|
||||
winner_index * filter_intra_lag_shift_ /*alignment_shift_winner*/);
|
||||
}
|
||||
last_detected_best_lag_filter_ = winner_index;
|
||||
}
|
||||
if (ApmDataDumper::IsAvailable()) {
|
||||
Dump();
|
||||
}
|
||||
}
|
||||
|
||||
void MatchedFilter::LogFilterProperties(int sample_rate_hz,
|
||||
@ -510,4 +781,27 @@ void MatchedFilter::LogFilterProperties(int sample_rate_hz,
|
||||
}
|
||||
}
|
||||
|
||||
void MatchedFilter::Dump() {
|
||||
for (size_t n = 0; n < filters_.size(); ++n) {
|
||||
const size_t lag_estimate = aec3::MaxSquarePeakIndex(filters_[n]);
|
||||
std::string dumper_filter = "aec3_correlator_" + std::to_string(n) + "_h";
|
||||
data_dumper_->DumpRaw(dumper_filter.c_str(), filters_[n]);
|
||||
std::string dumper_lag = "aec3_correlator_lag_" + std::to_string(n);
|
||||
data_dumper_->DumpRaw(dumper_lag.c_str(),
|
||||
lag_estimate + n * filter_intra_lag_shift_);
|
||||
if (detect_pre_echo_) {
|
||||
std::string dumper_error =
|
||||
"aec3_correlator_error_" + std::to_string(n) + "_h";
|
||||
data_dumper_->DumpRaw(dumper_error.c_str(), accumulated_error_[n]);
|
||||
|
||||
size_t pre_echo_lag = ComputePreEchoLag(
|
||||
accumulated_error_[n], lag_estimate + n * filter_intra_lag_shift_,
|
||||
n * filter_intra_lag_shift_);
|
||||
std::string dumper_pre_lag =
|
||||
"aec3_correlator_pre_echo_lag_" + std::to_string(n);
|
||||
data_dumper_->DumpRaw(dumper_pre_lag.c_str(), pre_echo_lag);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace webrtc
|
||||
|
@ -15,6 +15,7 @@
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "absl/types/optional.h"
|
||||
#include "api/array_view.h"
|
||||
#include "modules/audio_processing/aec3/aec3_common.h"
|
||||
#include "rtc_base/system/arch.h"
|
||||
@ -36,7 +37,10 @@ void MatchedFilterCore_NEON(size_t x_start_index,
|
||||
rtc::ArrayView<const float> y,
|
||||
rtc::ArrayView<float> h,
|
||||
bool* filters_updated,
|
||||
float* error_sum);
|
||||
float* error_sum,
|
||||
bool compute_accumulation_error,
|
||||
rtc::ArrayView<float> accumulated_error,
|
||||
rtc::ArrayView<float> scratch_memory);
|
||||
|
||||
#endif
|
||||
|
||||
@ -50,7 +54,10 @@ void MatchedFilterCore_SSE2(size_t x_start_index,
|
||||
rtc::ArrayView<const float> y,
|
||||
rtc::ArrayView<float> h,
|
||||
bool* filters_updated,
|
||||
float* error_sum);
|
||||
float* error_sum,
|
||||
bool compute_accumulated_error,
|
||||
rtc::ArrayView<float> accumulated_error,
|
||||
rtc::ArrayView<float> scratch_memory);
|
||||
|
||||
// Filter core for the matched filter that is optimized for AVX2.
|
||||
void MatchedFilterCore_AVX2(size_t x_start_index,
|
||||
@ -60,7 +67,10 @@ void MatchedFilterCore_AVX2(size_t x_start_index,
|
||||
rtc::ArrayView<const float> y,
|
||||
rtc::ArrayView<float> h,
|
||||
bool* filters_updated,
|
||||
float* error_sum);
|
||||
float* error_sum,
|
||||
bool compute_accumulated_error,
|
||||
rtc::ArrayView<float> accumulated_error,
|
||||
rtc::ArrayView<float> scratch_memory);
|
||||
|
||||
#endif
|
||||
|
||||
@ -72,7 +82,9 @@ void MatchedFilterCore(size_t x_start_index,
|
||||
rtc::ArrayView<const float> y,
|
||||
rtc::ArrayView<float> h,
|
||||
bool* filters_updated,
|
||||
float* error_sum);
|
||||
float* error_sum,
|
||||
bool compute_accumulation_error,
|
||||
rtc::ArrayView<float> accumulated_error);
|
||||
|
||||
// Find largest peak of squared values in array.
|
||||
size_t MaxSquarePeakIndex(rtc::ArrayView<const float> h);
|
||||
@ -87,13 +99,10 @@ class MatchedFilter {
|
||||
// shift.
|
||||
struct LagEstimate {
|
||||
LagEstimate() = default;
|
||||
LagEstimate(float accuracy, bool reliable, size_t lag, bool updated)
|
||||
: accuracy(accuracy), reliable(reliable), lag(lag), updated(updated) {}
|
||||
|
||||
float accuracy = 0.f;
|
||||
bool reliable = false;
|
||||
LagEstimate(size_t lag, size_t pre_echo_lag)
|
||||
: lag(lag), pre_echo_lag(pre_echo_lag) {}
|
||||
size_t lag = 0;
|
||||
bool updated = false;
|
||||
size_t pre_echo_lag = 0;
|
||||
};
|
||||
|
||||
MatchedFilter(ApmDataDumper* data_dumper,
|
||||
@ -105,7 +114,8 @@ class MatchedFilter {
|
||||
float excitation_limit,
|
||||
float smoothing_fast,
|
||||
float smoothing_slow,
|
||||
float matching_filter_threshold);
|
||||
float matching_filter_threshold,
|
||||
bool detect_pre_echo);
|
||||
|
||||
MatchedFilter() = delete;
|
||||
MatchedFilter(const MatchedFilter&) = delete;
|
||||
@ -122,8 +132,8 @@ class MatchedFilter {
|
||||
void Reset();
|
||||
|
||||
// Returns the current lag estimates.
|
||||
rtc::ArrayView<const MatchedFilter::LagEstimate> GetLagEstimates() const {
|
||||
return lag_estimates_;
|
||||
absl::optional<const MatchedFilter::LagEstimate> GetBestLagEstimate() const {
|
||||
return reported_lag_estimate_;
|
||||
}
|
||||
|
||||
// Returns the maximum filter lag.
|
||||
@ -137,17 +147,25 @@ class MatchedFilter {
|
||||
size_t downsampling_factor) const;
|
||||
|
||||
private:
|
||||
void Dump();
|
||||
|
||||
ApmDataDumper* const data_dumper_;
|
||||
const Aec3Optimization optimization_;
|
||||
const size_t sub_block_size_;
|
||||
const size_t filter_intra_lag_shift_;
|
||||
std::vector<std::vector<float>> filters_;
|
||||
std::vector<LagEstimate> lag_estimates_;
|
||||
std::vector<std::vector<float>> accumulated_error_;
|
||||
std::vector<float> instantaneous_accumulated_error_;
|
||||
std::vector<float> scratch_memory_;
|
||||
absl::optional<MatchedFilter::LagEstimate> reported_lag_estimate_;
|
||||
absl::optional<size_t> winner_lag_;
|
||||
int last_detected_best_lag_filter_ = -1;
|
||||
std::vector<size_t> filters_offsets_;
|
||||
const float excitation_limit_;
|
||||
const float smoothing_fast_;
|
||||
const float smoothing_slow_;
|
||||
const float matching_filter_threshold_;
|
||||
const bool detect_pre_echo_;
|
||||
};
|
||||
|
||||
} // namespace webrtc
|
||||
|
@ -8,15 +8,134 @@
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#include "modules/audio_processing/aec3/matched_filter.h"
|
||||
|
||||
#include <immintrin.h>
|
||||
|
||||
#include "modules/audio_processing/aec3/matched_filter.h"
|
||||
#include "rtc_base/checks.h"
|
||||
|
||||
namespace webrtc {
|
||||
namespace aec3 {
|
||||
|
||||
// Let ha denote the horizontal of a, and hb the horizontal sum of b
|
||||
// returns [ha, hb, ha, hb]
|
||||
inline __m128 hsum_ab(__m256 a, __m256 b) {
|
||||
__m256 s_256 = _mm256_hadd_ps(a, b);
|
||||
const __m256i mask = _mm256_set_epi32(7, 6, 3, 2, 5, 4, 1, 0);
|
||||
s_256 = _mm256_permutevar8x32_ps(s_256, mask);
|
||||
__m128 s = _mm_hadd_ps(_mm256_extractf128_ps(s_256, 0),
|
||||
_mm256_extractf128_ps(s_256, 1));
|
||||
s = _mm_hadd_ps(s, s);
|
||||
return s;
|
||||
}
|
||||
|
||||
void MatchedFilterCore_AccumulatedError_AVX2(
|
||||
size_t x_start_index,
|
||||
float x2_sum_threshold,
|
||||
float smoothing,
|
||||
rtc::ArrayView<const float> x,
|
||||
rtc::ArrayView<const float> y,
|
||||
rtc::ArrayView<float> h,
|
||||
bool* filters_updated,
|
||||
float* error_sum,
|
||||
rtc::ArrayView<float> accumulated_error,
|
||||
rtc::ArrayView<float> scratch_memory) {
|
||||
const int h_size = static_cast<int>(h.size());
|
||||
const int x_size = static_cast<int>(x.size());
|
||||
RTC_DCHECK_EQ(0, h_size % 16);
|
||||
std::fill(accumulated_error.begin(), accumulated_error.end(), 0.0f);
|
||||
|
||||
// Process for all samples in the sub-block.
|
||||
for (size_t i = 0; i < y.size(); ++i) {
|
||||
// Apply the matched filter as filter * x, and compute x * x.
|
||||
RTC_DCHECK_GT(x_size, x_start_index);
|
||||
const int chunk1 =
|
||||
std::min(h_size, static_cast<int>(x_size - x_start_index));
|
||||
if (chunk1 != h_size) {
|
||||
const int chunk2 = h_size - chunk1;
|
||||
std::copy(x.begin() + x_start_index, x.end(), scratch_memory.begin());
|
||||
std::copy(x.begin(), x.begin() + chunk2, scratch_memory.begin() + chunk1);
|
||||
}
|
||||
const float* x_p =
|
||||
chunk1 != h_size ? scratch_memory.data() : &x[x_start_index];
|
||||
const float* h_p = &h[0];
|
||||
float* a_p = &accumulated_error[0];
|
||||
__m256 s_inst_hadd_256;
|
||||
__m256 s_inst_256;
|
||||
__m256 s_inst_256_8;
|
||||
__m256 x2_sum_256 = _mm256_set1_ps(0);
|
||||
__m256 x2_sum_256_8 = _mm256_set1_ps(0);
|
||||
__m128 e_128;
|
||||
float x2_sum = 0.0f;
|
||||
float s_acum = 0;
|
||||
const int limit_by_16 = h_size >> 4;
|
||||
for (int k = limit_by_16; k > 0; --k, h_p += 16, x_p += 16, a_p += 4) {
|
||||
// Load the data into 256 bit vectors.
|
||||
__m256 x_k = _mm256_loadu_ps(x_p);
|
||||
__m256 h_k = _mm256_loadu_ps(h_p);
|
||||
__m256 x_k_8 = _mm256_loadu_ps(x_p + 8);
|
||||
__m256 h_k_8 = _mm256_loadu_ps(h_p + 8);
|
||||
// Compute and accumulate x * x and h * x.
|
||||
x2_sum_256 = _mm256_fmadd_ps(x_k, x_k, x2_sum_256);
|
||||
x2_sum_256_8 = _mm256_fmadd_ps(x_k_8, x_k_8, x2_sum_256_8);
|
||||
s_inst_256 = _mm256_mul_ps(h_k, x_k);
|
||||
s_inst_256_8 = _mm256_mul_ps(h_k_8, x_k_8);
|
||||
s_inst_hadd_256 = _mm256_hadd_ps(s_inst_256, s_inst_256_8);
|
||||
s_inst_hadd_256 = _mm256_hadd_ps(s_inst_hadd_256, s_inst_hadd_256);
|
||||
s_acum += s_inst_hadd_256[0];
|
||||
e_128[0] = s_acum - y[i];
|
||||
s_acum += s_inst_hadd_256[4];
|
||||
e_128[1] = s_acum - y[i];
|
||||
s_acum += s_inst_hadd_256[1];
|
||||
e_128[2] = s_acum - y[i];
|
||||
s_acum += s_inst_hadd_256[5];
|
||||
e_128[3] = s_acum - y[i];
|
||||
|
||||
__m128 accumulated_error = _mm_load_ps(a_p);
|
||||
accumulated_error = _mm_fmadd_ps(e_128, e_128, accumulated_error);
|
||||
_mm_storeu_ps(a_p, accumulated_error);
|
||||
}
|
||||
// Sum components together.
|
||||
x2_sum_256 = _mm256_add_ps(x2_sum_256, x2_sum_256_8);
|
||||
__m128 x2_sum_128 = _mm_add_ps(_mm256_extractf128_ps(x2_sum_256, 0),
|
||||
_mm256_extractf128_ps(x2_sum_256, 1));
|
||||
// Combine the accumulated vector and scalar values.
|
||||
float* v = reinterpret_cast<float*>(&x2_sum_128);
|
||||
x2_sum += v[0] + v[1] + v[2] + v[3];
|
||||
|
||||
// Compute the matched filter error.
|
||||
float e = y[i] - s_acum;
|
||||
const bool saturation = y[i] >= 32000.f || y[i] <= -32000.f;
|
||||
(*error_sum) += e * e;
|
||||
|
||||
// Update the matched filter estimate in an NLMS manner.
|
||||
if (x2_sum > x2_sum_threshold && !saturation) {
|
||||
RTC_DCHECK_LT(0.f, x2_sum);
|
||||
const float alpha = smoothing * e / x2_sum;
|
||||
const __m256 alpha_256 = _mm256_set1_ps(alpha);
|
||||
|
||||
// filter = filter + smoothing * (y - filter * x) * x / x * x.
|
||||
float* h_p = &h[0];
|
||||
const float* x_p =
|
||||
chunk1 != h_size ? scratch_memory.data() : &x[x_start_index];
|
||||
// Perform 256 bit vector operations.
|
||||
const int limit_by_8 = h_size >> 3;
|
||||
for (int k = limit_by_8; k > 0; --k, h_p += 8, x_p += 8) {
|
||||
// Load the data into 256 bit vectors.
|
||||
__m256 h_k = _mm256_loadu_ps(h_p);
|
||||
__m256 x_k = _mm256_loadu_ps(x_p);
|
||||
// Compute h = h + alpha * x.
|
||||
h_k = _mm256_fmadd_ps(x_k, alpha_256, h_k);
|
||||
|
||||
// Store the result.
|
||||
_mm256_storeu_ps(h_p, h_k);
|
||||
}
|
||||
*filters_updated = true;
|
||||
}
|
||||
|
||||
x_start_index = x_start_index > 0 ? x_start_index - 1 : x_size - 1;
|
||||
}
|
||||
}
|
||||
|
||||
void MatchedFilterCore_AVX2(size_t x_start_index,
|
||||
float x2_sum_threshold,
|
||||
float smoothing,
|
||||
@ -24,7 +143,15 @@ void MatchedFilterCore_AVX2(size_t x_start_index,
|
||||
rtc::ArrayView<const float> y,
|
||||
rtc::ArrayView<float> h,
|
||||
bool* filters_updated,
|
||||
float* error_sum) {
|
||||
float* error_sum,
|
||||
bool compute_accumulated_error,
|
||||
rtc::ArrayView<float> accumulated_error,
|
||||
rtc::ArrayView<float> scratch_memory) {
|
||||
if (compute_accumulated_error) {
|
||||
return MatchedFilterCore_AccumulatedError_AVX2(
|
||||
x_start_index, x2_sum_threshold, smoothing, x, y, h, filters_updated,
|
||||
error_sum, accumulated_error, scratch_memory);
|
||||
}
|
||||
const int h_size = static_cast<int>(h.size());
|
||||
const int x_size = static_cast<int>(x.size());
|
||||
RTC_DCHECK_EQ(0, h_size % 8);
|
||||
@ -81,15 +208,9 @@ void MatchedFilterCore_AVX2(size_t x_start_index,
|
||||
// Sum components together.
|
||||
x2_sum_256 = _mm256_add_ps(x2_sum_256, x2_sum_256_8);
|
||||
s_256 = _mm256_add_ps(s_256, s_256_8);
|
||||
__m128 x2_sum_128 = _mm_add_ps(_mm256_extractf128_ps(x2_sum_256, 0),
|
||||
_mm256_extractf128_ps(x2_sum_256, 1));
|
||||
__m128 s_128 = _mm_add_ps(_mm256_extractf128_ps(s_256, 0),
|
||||
_mm256_extractf128_ps(s_256, 1));
|
||||
// Combine the accumulated vector and scalar values.
|
||||
float* v = reinterpret_cast<float*>(&x2_sum_128);
|
||||
x2_sum += v[0] + v[1] + v[2] + v[3];
|
||||
v = reinterpret_cast<float*>(&s_128);
|
||||
s += v[0] + v[1] + v[2] + v[3];
|
||||
__m128 sum = hsum_ab(x2_sum_256, s_256);
|
||||
x2_sum += sum[0];
|
||||
s += sum[1];
|
||||
|
||||
// Compute the matched filter error.
|
||||
float e = y[i] - s;
|
||||
|
@ -14,84 +14,148 @@
|
||||
|
||||
#include "modules/audio_processing/logging/apm_data_dumper.h"
|
||||
#include "rtc_base/checks.h"
|
||||
#include "rtc_base/numerics/safe_minmax.h"
|
||||
|
||||
namespace webrtc {
|
||||
namespace {
|
||||
int GetDownSamplingBlockSizeLog2(int down_sampling_factor) {
|
||||
int down_sampling_factor_log2 = 0;
|
||||
down_sampling_factor >>= 1;
|
||||
while (down_sampling_factor > 0) {
|
||||
down_sampling_factor_log2++;
|
||||
down_sampling_factor >>= 1;
|
||||
}
|
||||
return static_cast<int>(kBlockSizeLog2) > down_sampling_factor_log2
|
||||
? static_cast<int>(kBlockSizeLog2) - down_sampling_factor_log2
|
||||
: 0;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
MatchedFilterLagAggregator::MatchedFilterLagAggregator(
|
||||
ApmDataDumper* data_dumper,
|
||||
size_t max_filter_lag,
|
||||
const EchoCanceller3Config::Delay::DelaySelectionThresholds& thresholds)
|
||||
const EchoCanceller3Config::Delay& delay_config)
|
||||
: data_dumper_(data_dumper),
|
||||
histogram_(max_filter_lag + 1, 0),
|
||||
thresholds_(thresholds) {
|
||||
thresholds_(delay_config.delay_selection_thresholds),
|
||||
headroom_(static_cast<int>(delay_config.delay_headroom_samples /
|
||||
delay_config.down_sampling_factor)),
|
||||
highest_peak_aggregator_(max_filter_lag) {
|
||||
if (delay_config.detect_pre_echo) {
|
||||
pre_echo_lag_aggregator_ = std::make_unique<PreEchoLagAggregator>(
|
||||
max_filter_lag, delay_config.down_sampling_factor);
|
||||
}
|
||||
RTC_DCHECK(data_dumper);
|
||||
RTC_DCHECK_LE(thresholds_.initial, thresholds_.converged);
|
||||
histogram_data_.fill(0);
|
||||
}
|
||||
|
||||
MatchedFilterLagAggregator::~MatchedFilterLagAggregator() = default;
|
||||
|
||||
void MatchedFilterLagAggregator::Reset(bool hard_reset) {
|
||||
std::fill(histogram_.begin(), histogram_.end(), 0);
|
||||
histogram_data_.fill(0);
|
||||
histogram_data_index_ = 0;
|
||||
highest_peak_aggregator_.Reset();
|
||||
if (pre_echo_lag_aggregator_ != nullptr) {
|
||||
pre_echo_lag_aggregator_->Reset();
|
||||
}
|
||||
if (hard_reset) {
|
||||
significant_candidate_found_ = false;
|
||||
}
|
||||
}
|
||||
|
||||
absl::optional<DelayEstimate> MatchedFilterLagAggregator::Aggregate(
|
||||
rtc::ArrayView<const MatchedFilter::LagEstimate> lag_estimates) {
|
||||
// Choose the strongest lag estimate as the best one.
|
||||
float best_accuracy = 0.f;
|
||||
int best_lag_estimate_index = -1;
|
||||
for (size_t k = 0; k < lag_estimates.size(); ++k) {
|
||||
if (lag_estimates[k].updated && lag_estimates[k].reliable) {
|
||||
if (lag_estimates[k].accuracy > best_accuracy) {
|
||||
best_accuracy = lag_estimates[k].accuracy;
|
||||
best_lag_estimate_index = static_cast<int>(k);
|
||||
}
|
||||
}
|
||||
const absl::optional<const MatchedFilter::LagEstimate>& lag_estimate) {
|
||||
if (lag_estimate && pre_echo_lag_aggregator_) {
|
||||
pre_echo_lag_aggregator_->Dump(data_dumper_);
|
||||
pre_echo_lag_aggregator_->Aggregate(
|
||||
std::max(0, static_cast<int>(lag_estimate->pre_echo_lag) - headroom_));
|
||||
}
|
||||
|
||||
// TODO(peah): Remove this logging once all development is done.
|
||||
data_dumper_->DumpRaw("aec3_echo_path_delay_estimator_best_index",
|
||||
best_lag_estimate_index);
|
||||
data_dumper_->DumpRaw("aec3_echo_path_delay_estimator_histogram", histogram_);
|
||||
|
||||
if (best_lag_estimate_index != -1) {
|
||||
RTC_DCHECK_GT(histogram_.size(), histogram_data_[histogram_data_index_]);
|
||||
RTC_DCHECK_LE(0, histogram_data_[histogram_data_index_]);
|
||||
--histogram_[histogram_data_[histogram_data_index_]];
|
||||
|
||||
histogram_data_[histogram_data_index_] =
|
||||
lag_estimates[best_lag_estimate_index].lag;
|
||||
|
||||
RTC_DCHECK_GT(histogram_.size(), histogram_data_[histogram_data_index_]);
|
||||
RTC_DCHECK_LE(0, histogram_data_[histogram_data_index_]);
|
||||
++histogram_[histogram_data_[histogram_data_index_]];
|
||||
|
||||
histogram_data_index_ =
|
||||
(histogram_data_index_ + 1) % histogram_data_.size();
|
||||
|
||||
const int candidate =
|
||||
std::distance(histogram_.begin(),
|
||||
std::max_element(histogram_.begin(), histogram_.end()));
|
||||
|
||||
significant_candidate_found_ =
|
||||
significant_candidate_found_ ||
|
||||
histogram_[candidate] > thresholds_.converged;
|
||||
if (histogram_[candidate] > thresholds_.converged ||
|
||||
(histogram_[candidate] > thresholds_.initial &&
|
||||
if (lag_estimate) {
|
||||
highest_peak_aggregator_.Aggregate(
|
||||
std::max(0, static_cast<int>(lag_estimate->lag) - headroom_));
|
||||
rtc::ArrayView<const int> histogram = highest_peak_aggregator_.histogram();
|
||||
int candidate = highest_peak_aggregator_.candidate();
|
||||
significant_candidate_found_ = significant_candidate_found_ ||
|
||||
histogram[candidate] > thresholds_.converged;
|
||||
if (histogram[candidate] > thresholds_.converged ||
|
||||
(histogram[candidate] > thresholds_.initial &&
|
||||
!significant_candidate_found_)) {
|
||||
DelayEstimate::Quality quality = significant_candidate_found_
|
||||
? DelayEstimate::Quality::kRefined
|
||||
: DelayEstimate::Quality::kCoarse;
|
||||
return DelayEstimate(quality, candidate);
|
||||
int reported_delay = pre_echo_lag_aggregator_ != nullptr
|
||||
? pre_echo_lag_aggregator_->pre_echo_candidate()
|
||||
: candidate;
|
||||
return DelayEstimate(quality, reported_delay);
|
||||
}
|
||||
}
|
||||
|
||||
return absl::nullopt;
|
||||
}
|
||||
|
||||
MatchedFilterLagAggregator::HighestPeakAggregator::HighestPeakAggregator(
|
||||
size_t max_filter_lag)
|
||||
: histogram_(max_filter_lag + 1, 0) {
|
||||
histogram_data_.fill(0);
|
||||
}
|
||||
|
||||
void MatchedFilterLagAggregator::HighestPeakAggregator::Reset() {
|
||||
std::fill(histogram_.begin(), histogram_.end(), 0);
|
||||
histogram_data_.fill(0);
|
||||
histogram_data_index_ = 0;
|
||||
}
|
||||
|
||||
void MatchedFilterLagAggregator::HighestPeakAggregator::Aggregate(int lag) {
|
||||
RTC_DCHECK_GT(histogram_.size(), histogram_data_[histogram_data_index_]);
|
||||
RTC_DCHECK_LE(0, histogram_data_[histogram_data_index_]);
|
||||
--histogram_[histogram_data_[histogram_data_index_]];
|
||||
histogram_data_[histogram_data_index_] = lag;
|
||||
RTC_DCHECK_GT(histogram_.size(), histogram_data_[histogram_data_index_]);
|
||||
RTC_DCHECK_LE(0, histogram_data_[histogram_data_index_]);
|
||||
++histogram_[histogram_data_[histogram_data_index_]];
|
||||
histogram_data_index_ = (histogram_data_index_ + 1) % histogram_data_.size();
|
||||
candidate_ =
|
||||
std::distance(histogram_.begin(),
|
||||
std::max_element(histogram_.begin(), histogram_.end()));
|
||||
}
|
||||
|
||||
MatchedFilterLagAggregator::PreEchoLagAggregator::PreEchoLagAggregator(
|
||||
size_t max_filter_lag,
|
||||
size_t down_sampling_factor)
|
||||
: block_size_log2_(GetDownSamplingBlockSizeLog2(down_sampling_factor)),
|
||||
histogram_(
|
||||
((max_filter_lag + 1) * down_sampling_factor) >> kBlockSizeLog2,
|
||||
0) {
|
||||
Reset();
|
||||
}
|
||||
|
||||
void MatchedFilterLagAggregator::PreEchoLagAggregator::Reset() {
|
||||
std::fill(histogram_.begin(), histogram_.end(), 0);
|
||||
histogram_data_.fill(0);
|
||||
histogram_data_index_ = 0;
|
||||
pre_echo_candidate_ = 0;
|
||||
}
|
||||
|
||||
void MatchedFilterLagAggregator::PreEchoLagAggregator::Aggregate(
|
||||
int pre_echo_lag) {
|
||||
int pre_echo_block_size = pre_echo_lag >> block_size_log2_;
|
||||
RTC_DCHECK(pre_echo_block_size >= 0 &&
|
||||
pre_echo_block_size < static_cast<int>(histogram_.size()));
|
||||
pre_echo_block_size =
|
||||
rtc::SafeClamp(pre_echo_block_size, 0, histogram_.size() - 1);
|
||||
if (histogram_[histogram_data_[histogram_data_index_]] > 0) {
|
||||
--histogram_[histogram_data_[histogram_data_index_]];
|
||||
}
|
||||
histogram_data_[histogram_data_index_] = pre_echo_block_size;
|
||||
++histogram_[histogram_data_[histogram_data_index_]];
|
||||
histogram_data_index_ = (histogram_data_index_ + 1) % histogram_data_.size();
|
||||
int pre_echo_candidate_block_size =
|
||||
std::distance(histogram_.begin(),
|
||||
std::max_element(histogram_.begin(), histogram_.end()));
|
||||
pre_echo_candidate_ = (pre_echo_candidate_block_size << block_size_log2_);
|
||||
}
|
||||
|
||||
void MatchedFilterLagAggregator::PreEchoLagAggregator::Dump(
|
||||
ApmDataDumper* const data_dumper) {
|
||||
data_dumper->DumpRaw("aec3_pre_echo_delay_candidate", pre_echo_candidate_);
|
||||
}
|
||||
|
||||
} // namespace webrtc
|
||||
|
@ -26,10 +26,9 @@ class ApmDataDumper;
|
||||
// reliable combined lag estimate.
|
||||
class MatchedFilterLagAggregator {
|
||||
public:
|
||||
MatchedFilterLagAggregator(
|
||||
ApmDataDumper* data_dumper,
|
||||
size_t max_filter_lag,
|
||||
const EchoCanceller3Config::Delay::DelaySelectionThresholds& thresholds);
|
||||
MatchedFilterLagAggregator(ApmDataDumper* data_dumper,
|
||||
size_t max_filter_lag,
|
||||
const EchoCanceller3Config::Delay& delay_config);
|
||||
|
||||
MatchedFilterLagAggregator() = delete;
|
||||
MatchedFilterLagAggregator(const MatchedFilterLagAggregator&) = delete;
|
||||
@ -43,18 +42,55 @@ class MatchedFilterLagAggregator {
|
||||
|
||||
// Aggregates the provided lag estimates.
|
||||
absl::optional<DelayEstimate> Aggregate(
|
||||
rtc::ArrayView<const MatchedFilter::LagEstimate> lag_estimates);
|
||||
const absl::optional<const MatchedFilter::LagEstimate>& lag_estimate);
|
||||
|
||||
// Returns whether a reliable delay estimate has been found.
|
||||
bool ReliableDelayFound() const { return significant_candidate_found_; }
|
||||
|
||||
// Returns the delay candidate that is computed by looking at the highest peak
|
||||
// on the matched filters.
|
||||
int GetDelayAtHighestPeak() const {
|
||||
return highest_peak_aggregator_.candidate();
|
||||
}
|
||||
|
||||
private:
|
||||
class PreEchoLagAggregator {
|
||||
public:
|
||||
PreEchoLagAggregator(size_t max_filter_lag, size_t down_sampling_factor);
|
||||
void Reset();
|
||||
void Aggregate(int pre_echo_lag);
|
||||
int pre_echo_candidate() const { return pre_echo_candidate_; }
|
||||
void Dump(ApmDataDumper* const data_dumper);
|
||||
|
||||
private:
|
||||
const int block_size_log2_;
|
||||
std::array<int, 250> histogram_data_;
|
||||
std::vector<int> histogram_;
|
||||
int histogram_data_index_ = 0;
|
||||
int pre_echo_candidate_ = 0;
|
||||
};
|
||||
|
||||
class HighestPeakAggregator {
|
||||
public:
|
||||
explicit HighestPeakAggregator(size_t max_filter_lag);
|
||||
void Reset();
|
||||
void Aggregate(int lag);
|
||||
int candidate() const { return candidate_; }
|
||||
rtc::ArrayView<const int> histogram() const { return histogram_; }
|
||||
|
||||
private:
|
||||
std::vector<int> histogram_;
|
||||
std::array<int, 250> histogram_data_;
|
||||
int histogram_data_index_ = 0;
|
||||
int candidate_ = -1;
|
||||
};
|
||||
|
||||
ApmDataDumper* const data_dumper_;
|
||||
std::vector<int> histogram_;
|
||||
std::array<int, 250> histogram_data_;
|
||||
int histogram_data_index_ = 0;
|
||||
bool significant_candidate_found_ = false;
|
||||
const EchoCanceller3Config::Delay::DelaySelectionThresholds thresholds_;
|
||||
const int headroom_;
|
||||
HighestPeakAggregator highest_peak_aggregator_;
|
||||
std::unique_ptr<PreEchoLagAggregator> pre_echo_lag_aggregator_;
|
||||
};
|
||||
} // namespace webrtc
|
||||
|
||||
|
@ -27,69 +27,31 @@ constexpr size_t kNumLagsBeforeDetection = 26;
|
||||
|
||||
} // namespace
|
||||
|
||||
// Verifies that the most accurate lag estimate is chosen.
|
||||
TEST(MatchedFilterLagAggregator, MostAccurateLagChosen) {
|
||||
constexpr size_t kLag1 = 5;
|
||||
constexpr size_t kLag2 = 10;
|
||||
ApmDataDumper data_dumper(0);
|
||||
EchoCanceller3Config config;
|
||||
std::vector<MatchedFilter::LagEstimate> lag_estimates(2);
|
||||
MatchedFilterLagAggregator aggregator(
|
||||
&data_dumper, std::max(kLag1, kLag2),
|
||||
config.delay.delay_selection_thresholds);
|
||||
lag_estimates[0] = MatchedFilter::LagEstimate(1.f, true, kLag1, true);
|
||||
lag_estimates[1] = MatchedFilter::LagEstimate(0.5f, true, kLag2, true);
|
||||
|
||||
for (size_t k = 0; k < kNumLagsBeforeDetection; ++k) {
|
||||
aggregator.Aggregate(lag_estimates);
|
||||
}
|
||||
|
||||
absl::optional<DelayEstimate> aggregated_lag =
|
||||
aggregator.Aggregate(lag_estimates);
|
||||
EXPECT_TRUE(aggregated_lag);
|
||||
EXPECT_EQ(kLag1, aggregated_lag->delay);
|
||||
|
||||
lag_estimates[0] = MatchedFilter::LagEstimate(0.5f, true, kLag1, true);
|
||||
lag_estimates[1] = MatchedFilter::LagEstimate(1.f, true, kLag2, true);
|
||||
|
||||
for (size_t k = 0; k < kNumLagsBeforeDetection; ++k) {
|
||||
aggregated_lag = aggregator.Aggregate(lag_estimates);
|
||||
EXPECT_TRUE(aggregated_lag);
|
||||
EXPECT_EQ(kLag1, aggregated_lag->delay);
|
||||
}
|
||||
|
||||
aggregated_lag = aggregator.Aggregate(lag_estimates);
|
||||
aggregated_lag = aggregator.Aggregate(lag_estimates);
|
||||
EXPECT_TRUE(aggregated_lag);
|
||||
EXPECT_EQ(kLag2, aggregated_lag->delay);
|
||||
}
|
||||
|
||||
// Verifies that varying lag estimates causes lag estimates to not be deemed
|
||||
// reliable.
|
||||
TEST(MatchedFilterLagAggregator,
|
||||
LagEstimateInvarianceRequiredForAggregatedLag) {
|
||||
ApmDataDumper data_dumper(0);
|
||||
EchoCanceller3Config config;
|
||||
std::vector<MatchedFilter::LagEstimate> lag_estimates(1);
|
||||
MatchedFilterLagAggregator aggregator(
|
||||
&data_dumper, 100, config.delay.delay_selection_thresholds);
|
||||
MatchedFilterLagAggregator aggregator(&data_dumper, /*max_filter_lag=*/100,
|
||||
config.delay);
|
||||
|
||||
absl::optional<DelayEstimate> aggregated_lag;
|
||||
for (size_t k = 0; k < kNumLagsBeforeDetection; ++k) {
|
||||
lag_estimates[0] = MatchedFilter::LagEstimate(1.f, true, 10, true);
|
||||
aggregated_lag = aggregator.Aggregate(lag_estimates);
|
||||
aggregated_lag = aggregator.Aggregate(
|
||||
MatchedFilter::LagEstimate(/*lag=*/10, /*pre_echo_lag=*/10));
|
||||
}
|
||||
EXPECT_TRUE(aggregated_lag);
|
||||
|
||||
for (size_t k = 0; k < kNumLagsBeforeDetection * 100; ++k) {
|
||||
lag_estimates[0] = MatchedFilter::LagEstimate(1.f, true, k % 100, true);
|
||||
aggregated_lag = aggregator.Aggregate(lag_estimates);
|
||||
aggregated_lag = aggregator.Aggregate(
|
||||
MatchedFilter::LagEstimate(/*lag=*/k % 100, /*pre_echo_lag=*/k % 100));
|
||||
}
|
||||
EXPECT_FALSE(aggregated_lag);
|
||||
|
||||
for (size_t k = 0; k < kNumLagsBeforeDetection * 100; ++k) {
|
||||
lag_estimates[0] = MatchedFilter::LagEstimate(1.f, true, k % 100, true);
|
||||
aggregated_lag = aggregator.Aggregate(lag_estimates);
|
||||
aggregated_lag = aggregator.Aggregate(
|
||||
MatchedFilter::LagEstimate(/*lag=*/k % 100, /*pre_echo_lag=*/k % 100));
|
||||
EXPECT_FALSE(aggregated_lag);
|
||||
}
|
||||
}
|
||||
@ -101,13 +63,11 @@ TEST(MatchedFilterLagAggregator,
|
||||
constexpr size_t kLag = 5;
|
||||
ApmDataDumper data_dumper(0);
|
||||
EchoCanceller3Config config;
|
||||
std::vector<MatchedFilter::LagEstimate> lag_estimates(1);
|
||||
MatchedFilterLagAggregator aggregator(
|
||||
&data_dumper, kLag, config.delay.delay_selection_thresholds);
|
||||
MatchedFilterLagAggregator aggregator(&data_dumper, /*max_filter_lag=*/kLag,
|
||||
config.delay);
|
||||
for (size_t k = 0; k < kNumLagsBeforeDetection * 10; ++k) {
|
||||
lag_estimates[0] = MatchedFilter::LagEstimate(1.f, true, kLag, false);
|
||||
absl::optional<DelayEstimate> aggregated_lag =
|
||||
aggregator.Aggregate(lag_estimates);
|
||||
absl::optional<DelayEstimate> aggregated_lag = aggregator.Aggregate(
|
||||
MatchedFilter::LagEstimate(/*lag=*/kLag, /*pre_echo_lag=*/kLag));
|
||||
EXPECT_FALSE(aggregated_lag);
|
||||
EXPECT_EQ(kLag, aggregated_lag->delay);
|
||||
}
|
||||
@ -122,20 +82,19 @@ TEST(MatchedFilterLagAggregator, DISABLED_PersistentAggregatedLag) {
|
||||
ApmDataDumper data_dumper(0);
|
||||
EchoCanceller3Config config;
|
||||
std::vector<MatchedFilter::LagEstimate> lag_estimates(1);
|
||||
MatchedFilterLagAggregator aggregator(
|
||||
&data_dumper, std::max(kLag1, kLag2),
|
||||
config.delay.delay_selection_thresholds);
|
||||
MatchedFilterLagAggregator aggregator(&data_dumper, std::max(kLag1, kLag2),
|
||||
config.delay);
|
||||
absl::optional<DelayEstimate> aggregated_lag;
|
||||
for (size_t k = 0; k < kNumLagsBeforeDetection; ++k) {
|
||||
lag_estimates[0] = MatchedFilter::LagEstimate(1.f, true, kLag1, true);
|
||||
aggregated_lag = aggregator.Aggregate(lag_estimates);
|
||||
aggregated_lag = aggregator.Aggregate(
|
||||
MatchedFilter::LagEstimate(/*lag=*/kLag1, /*pre_echo_lag=*/kLag1));
|
||||
}
|
||||
EXPECT_TRUE(aggregated_lag);
|
||||
EXPECT_EQ(kLag1, aggregated_lag->delay);
|
||||
|
||||
for (size_t k = 0; k < kNumLagsBeforeDetection * 40; ++k) {
|
||||
lag_estimates[0] = MatchedFilter::LagEstimate(1.f, false, kLag2, true);
|
||||
aggregated_lag = aggregator.Aggregate(lag_estimates);
|
||||
aggregated_lag = aggregator.Aggregate(
|
||||
MatchedFilter::LagEstimate(/*lag=*/kLag2, /*pre_echo_lag=*/kLag2));
|
||||
EXPECT_TRUE(aggregated_lag);
|
||||
EXPECT_EQ(kLag1, aggregated_lag->delay);
|
||||
}
|
||||
@ -146,9 +105,7 @@ TEST(MatchedFilterLagAggregator, DISABLED_PersistentAggregatedLag) {
|
||||
// Verifies the check for non-null data dumper.
|
||||
TEST(MatchedFilterLagAggregatorDeathTest, NullDataDumper) {
|
||||
EchoCanceller3Config config;
|
||||
EXPECT_DEATH(MatchedFilterLagAggregator(
|
||||
nullptr, 10, config.delay.delay_selection_thresholds),
|
||||
"");
|
||||
EXPECT_DEATH(MatchedFilterLagAggregator(nullptr, 10, config.delay), "");
|
||||
}
|
||||
|
||||
#endif
|
||||
|
@ -47,12 +47,15 @@ constexpr size_t kAlignmentShiftSubBlocks = kWindowSizeSubBlocks * 3 / 4;
|
||||
|
||||
} // namespace
|
||||
|
||||
class MatchedFilterTest : public ::testing::TestWithParam<bool> {};
|
||||
|
||||
#if defined(WEBRTC_HAS_NEON)
|
||||
// Verifies that the optimized methods for NEON are similar to their reference
|
||||
// counterparts.
|
||||
TEST(MatchedFilter, TestNeonOptimizations) {
|
||||
TEST_P(MatchedFilterTest, TestNeonOptimizations) {
|
||||
Random random_generator(42U);
|
||||
constexpr float kSmoothing = 0.7f;
|
||||
const bool kComputeAccumulatederror = GetParam();
|
||||
for (auto down_sampling_factor : kDownSamplingFactors) {
|
||||
const size_t sub_block_size = kBlockSize / down_sampling_factor;
|
||||
|
||||
@ -61,6 +64,10 @@ TEST(MatchedFilter, TestNeonOptimizations) {
|
||||
std::vector<float> y(sub_block_size);
|
||||
std::vector<float> h_NEON(512);
|
||||
std::vector<float> h(512);
|
||||
std::vector<float> accumulated_error(512);
|
||||
std::vector<float> accumulated_error_NEON(512);
|
||||
std::vector<float> scratch_memory(512);
|
||||
|
||||
int x_index = 0;
|
||||
for (int k = 0; k < 1000; ++k) {
|
||||
RandomizeSampleVector(&random_generator, y);
|
||||
@ -71,10 +78,13 @@ TEST(MatchedFilter, TestNeonOptimizations) {
|
||||
float error_sum_NEON = 0.f;
|
||||
|
||||
MatchedFilterCore_NEON(x_index, h.size() * 150.f * 150.f, kSmoothing, x,
|
||||
y, h_NEON, &filters_updated_NEON, &error_sum_NEON);
|
||||
y, h_NEON, &filters_updated_NEON, &error_sum_NEON,
|
||||
kComputeAccumulatederror, accumulated_error_NEON,
|
||||
scratch_memory);
|
||||
|
||||
MatchedFilterCore(x_index, h.size() * 150.f * 150.f, kSmoothing, x, y, h,
|
||||
&filters_updated, &error_sum);
|
||||
&filters_updated, &error_sum, kComputeAccumulatederror,
|
||||
accumulated_error);
|
||||
|
||||
EXPECT_EQ(filters_updated, filters_updated_NEON);
|
||||
EXPECT_NEAR(error_sum, error_sum_NEON, error_sum / 100000.f);
|
||||
@ -83,6 +93,17 @@ TEST(MatchedFilter, TestNeonOptimizations) {
|
||||
EXPECT_NEAR(h[j], h_NEON[j], 0.00001f);
|
||||
}
|
||||
|
||||
if (kComputeAccumulatederror) {
|
||||
for (size_t j = 0; j < accumulated_error.size(); ++j) {
|
||||
float difference =
|
||||
std::abs(accumulated_error[j] - accumulated_error_NEON[j]);
|
||||
float relative_difference = accumulated_error[j] > 0
|
||||
? difference / accumulated_error[j]
|
||||
: difference;
|
||||
EXPECT_NEAR(relative_difference, 0.0f, 0.02f);
|
||||
}
|
||||
}
|
||||
|
||||
x_index = (x_index + sub_block_size) % x.size();
|
||||
}
|
||||
}
|
||||
@ -92,7 +113,8 @@ TEST(MatchedFilter, TestNeonOptimizations) {
|
||||
#if defined(WEBRTC_ARCH_X86_FAMILY)
|
||||
// Verifies that the optimized methods for SSE2 are bitexact to their reference
|
||||
// counterparts.
|
||||
TEST(MatchedFilter, TestSse2Optimizations) {
|
||||
TEST_P(MatchedFilterTest, TestSse2Optimizations) {
|
||||
const bool kComputeAccumulatederror = GetParam();
|
||||
bool use_sse2 = (GetCPUInfo(kSSE2) != 0);
|
||||
if (use_sse2) {
|
||||
Random random_generator(42U);
|
||||
@ -104,6 +126,9 @@ TEST(MatchedFilter, TestSse2Optimizations) {
|
||||
std::vector<float> y(sub_block_size);
|
||||
std::vector<float> h_SSE2(512);
|
||||
std::vector<float> h(512);
|
||||
std::vector<float> accumulated_error(512 / 4);
|
||||
std::vector<float> accumulated_error_SSE2(512 / 4);
|
||||
std::vector<float> scratch_memory(512);
|
||||
int x_index = 0;
|
||||
for (int k = 0; k < 1000; ++k) {
|
||||
RandomizeSampleVector(&random_generator, y);
|
||||
@ -115,10 +140,12 @@ TEST(MatchedFilter, TestSse2Optimizations) {
|
||||
|
||||
MatchedFilterCore_SSE2(x_index, h.size() * 150.f * 150.f, kSmoothing, x,
|
||||
y, h_SSE2, &filters_updated_SSE2,
|
||||
&error_sum_SSE2);
|
||||
&error_sum_SSE2, kComputeAccumulatederror,
|
||||
accumulated_error_SSE2, scratch_memory);
|
||||
|
||||
MatchedFilterCore(x_index, h.size() * 150.f * 150.f, kSmoothing, x, y,
|
||||
h, &filters_updated, &error_sum);
|
||||
h, &filters_updated, &error_sum,
|
||||
kComputeAccumulatederror, accumulated_error);
|
||||
|
||||
EXPECT_EQ(filters_updated, filters_updated_SSE2);
|
||||
EXPECT_NEAR(error_sum, error_sum_SSE2, error_sum / 100000.f);
|
||||
@ -127,14 +154,24 @@ TEST(MatchedFilter, TestSse2Optimizations) {
|
||||
EXPECT_NEAR(h[j], h_SSE2[j], 0.00001f);
|
||||
}
|
||||
|
||||
for (size_t j = 0; j < accumulated_error.size(); ++j) {
|
||||
float difference =
|
||||
std::abs(accumulated_error[j] - accumulated_error_SSE2[j]);
|
||||
float relative_difference = accumulated_error[j] > 0
|
||||
? difference / accumulated_error[j]
|
||||
: difference;
|
||||
EXPECT_NEAR(relative_difference, 0.0f, 0.00001f);
|
||||
}
|
||||
|
||||
x_index = (x_index + sub_block_size) % x.size();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(MatchedFilter, TestAvx2Optimizations) {
|
||||
TEST_P(MatchedFilterTest, TestAvx2Optimizations) {
|
||||
bool use_avx2 = (GetCPUInfo(kAVX2) != 0);
|
||||
const bool kComputeAccumulatederror = GetParam();
|
||||
if (use_avx2) {
|
||||
Random random_generator(42U);
|
||||
constexpr float kSmoothing = 0.7f;
|
||||
@ -145,29 +182,36 @@ TEST(MatchedFilter, TestAvx2Optimizations) {
|
||||
std::vector<float> y(sub_block_size);
|
||||
std::vector<float> h_AVX2(512);
|
||||
std::vector<float> h(512);
|
||||
std::vector<float> accumulated_error(512 / 4);
|
||||
std::vector<float> accumulated_error_AVX2(512 / 4);
|
||||
std::vector<float> scratch_memory(512);
|
||||
int x_index = 0;
|
||||
for (int k = 0; k < 1000; ++k) {
|
||||
RandomizeSampleVector(&random_generator, y);
|
||||
|
||||
bool filters_updated = false;
|
||||
float error_sum = 0.f;
|
||||
bool filters_updated_AVX2 = false;
|
||||
float error_sum_AVX2 = 0.f;
|
||||
|
||||
MatchedFilterCore_AVX2(x_index, h.size() * 150.f * 150.f, kSmoothing, x,
|
||||
y, h_AVX2, &filters_updated_AVX2,
|
||||
&error_sum_AVX2);
|
||||
|
||||
&error_sum_AVX2, kComputeAccumulatederror,
|
||||
accumulated_error_AVX2, scratch_memory);
|
||||
MatchedFilterCore(x_index, h.size() * 150.f * 150.f, kSmoothing, x, y,
|
||||
h, &filters_updated, &error_sum);
|
||||
|
||||
h, &filters_updated, &error_sum,
|
||||
kComputeAccumulatederror, accumulated_error);
|
||||
EXPECT_EQ(filters_updated, filters_updated_AVX2);
|
||||
EXPECT_NEAR(error_sum, error_sum_AVX2, error_sum / 100000.f);
|
||||
|
||||
for (size_t j = 0; j < h.size(); ++j) {
|
||||
EXPECT_NEAR(h[j], h_AVX2[j], 0.00001f);
|
||||
}
|
||||
|
||||
for (size_t j = 0; j < accumulated_error.size(); j += 4) {
|
||||
float difference =
|
||||
std::abs(accumulated_error[j] - accumulated_error_AVX2[j]);
|
||||
float relative_difference = accumulated_error[j] > 0
|
||||
? difference / accumulated_error[j]
|
||||
: difference;
|
||||
EXPECT_NEAR(relative_difference, 0.0f, 0.00001f);
|
||||
}
|
||||
x_index = (x_index + sub_block_size) % x.size();
|
||||
}
|
||||
}
|
||||
@ -199,9 +243,9 @@ TEST(MatchedFilter, MaxSquarePeakIndex) {
|
||||
}
|
||||
|
||||
// Verifies that the matched filter produces proper lag estimates for
|
||||
// artificially
|
||||
// delayed signals.
|
||||
TEST(MatchedFilter, LagEstimation) {
|
||||
// artificially delayed signals.
|
||||
TEST_P(MatchedFilterTest, LagEstimation) {
|
||||
const bool kDetectPreEcho = GetParam();
|
||||
Random random_generator(42U);
|
||||
constexpr size_t kNumChannels = 1;
|
||||
constexpr int kSampleRateHz = 48000;
|
||||
@ -222,12 +266,12 @@ TEST(MatchedFilter, LagEstimation) {
|
||||
Decimator capture_decimator(down_sampling_factor);
|
||||
DelayBuffer<float> signal_delay_buffer(down_sampling_factor *
|
||||
delay_samples);
|
||||
MatchedFilter filter(&data_dumper, DetectOptimization(), sub_block_size,
|
||||
kWindowSizeSubBlocks, kNumMatchedFilters,
|
||||
kAlignmentShiftSubBlocks, 150,
|
||||
config.delay.delay_estimate_smoothing,
|
||||
config.delay.delay_estimate_smoothing_delay_found,
|
||||
config.delay.delay_candidate_detection_threshold);
|
||||
MatchedFilter filter(
|
||||
&data_dumper, DetectOptimization(), sub_block_size,
|
||||
kWindowSizeSubBlocks, kNumMatchedFilters, kAlignmentShiftSubBlocks,
|
||||
150, config.delay.delay_estimate_smoothing,
|
||||
config.delay.delay_estimate_smoothing_delay_found,
|
||||
config.delay.delay_candidate_detection_threshold, kDetectPreEcho);
|
||||
|
||||
std::unique_ptr<RenderDelayBuffer> render_delay_buffer(
|
||||
RenderDelayBuffer::Create(config, kSampleRateHz, kNumChannels));
|
||||
@ -254,62 +298,97 @@ TEST(MatchedFilter, LagEstimation) {
|
||||
downsampled_capture_data.data(), sub_block_size);
|
||||
capture_decimator.Decimate(capture[0], downsampled_capture);
|
||||
filter.Update(render_delay_buffer->GetDownsampledRenderBuffer(),
|
||||
downsampled_capture, false);
|
||||
downsampled_capture, /*use_slow_smoothing=*/false);
|
||||
}
|
||||
|
||||
// Obtain the lag estimates.
|
||||
auto lag_estimates = filter.GetLagEstimates();
|
||||
|
||||
// Find which lag estimate should be the most accurate.
|
||||
absl::optional<size_t> expected_most_accurate_lag_estimate;
|
||||
size_t alignment_shift_sub_blocks = 0;
|
||||
for (size_t k = 0; k < config.delay.num_filters; ++k) {
|
||||
if ((alignment_shift_sub_blocks + 3 * kWindowSizeSubBlocks / 4) *
|
||||
sub_block_size >
|
||||
delay_samples) {
|
||||
expected_most_accurate_lag_estimate = k > 0 ? k - 1 : 0;
|
||||
break;
|
||||
}
|
||||
alignment_shift_sub_blocks += kAlignmentShiftSubBlocks;
|
||||
}
|
||||
ASSERT_TRUE(expected_most_accurate_lag_estimate);
|
||||
|
||||
// Verify that the expected most accurate lag estimate is the most
|
||||
// accurate estimate.
|
||||
for (size_t k = 0; k < kNumMatchedFilters; ++k) {
|
||||
if (k != *expected_most_accurate_lag_estimate &&
|
||||
k != (*expected_most_accurate_lag_estimate + 1)) {
|
||||
EXPECT_TRUE(
|
||||
lag_estimates[*expected_most_accurate_lag_estimate].accuracy >
|
||||
lag_estimates[k].accuracy ||
|
||||
!lag_estimates[k].reliable ||
|
||||
!lag_estimates[*expected_most_accurate_lag_estimate].reliable);
|
||||
}
|
||||
}
|
||||
|
||||
// Verify that all lag estimates are updated as expected for signals
|
||||
// containing strong noise.
|
||||
for (auto& le : lag_estimates) {
|
||||
EXPECT_TRUE(le.updated);
|
||||
}
|
||||
|
||||
// Verify that the expected most accurate lag estimate is reliable.
|
||||
EXPECT_TRUE(
|
||||
lag_estimates[*expected_most_accurate_lag_estimate].reliable ||
|
||||
lag_estimates[std::min(*expected_most_accurate_lag_estimate + 1,
|
||||
lag_estimates.size() - 1)]
|
||||
.reliable);
|
||||
auto lag_estimate = filter.GetBestLagEstimate();
|
||||
EXPECT_TRUE(lag_estimate.has_value());
|
||||
|
||||
// Verify that the expected most accurate lag estimate is correct.
|
||||
if (lag_estimates[*expected_most_accurate_lag_estimate].reliable) {
|
||||
EXPECT_TRUE(delay_samples ==
|
||||
lag_estimates[*expected_most_accurate_lag_estimate].lag);
|
||||
if (lag_estimate.has_value()) {
|
||||
EXPECT_EQ(delay_samples, lag_estimate->lag);
|
||||
EXPECT_EQ(delay_samples, lag_estimate->pre_echo_lag);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Test the pre echo estimation.
|
||||
TEST_P(MatchedFilterTest, PreEchoEstimation) {
|
||||
const bool kDetectPreEcho = GetParam();
|
||||
Random random_generator(42U);
|
||||
constexpr size_t kNumChannels = 1;
|
||||
constexpr int kSampleRateHz = 48000;
|
||||
constexpr size_t kNumBands = NumBandsForRate(kSampleRateHz);
|
||||
|
||||
for (auto down_sampling_factor : kDownSamplingFactors) {
|
||||
const size_t sub_block_size = kBlockSize / down_sampling_factor;
|
||||
|
||||
Block render(kNumBands, kNumChannels);
|
||||
std::vector<std::vector<float>> capture(
|
||||
1, std::vector<float>(kBlockSize, 0.f));
|
||||
std::vector<float> capture_with_pre_echo(kBlockSize, 0.f);
|
||||
ApmDataDumper data_dumper(0);
|
||||
// data_dumper.SetActivated(true);
|
||||
size_t pre_echo_delay_samples = 20e-3 * 16000 / down_sampling_factor;
|
||||
size_t echo_delay_samples = 50e-3 * 16000 / down_sampling_factor;
|
||||
EchoCanceller3Config config;
|
||||
config.delay.down_sampling_factor = down_sampling_factor;
|
||||
config.delay.num_filters = kNumMatchedFilters;
|
||||
Decimator capture_decimator(down_sampling_factor);
|
||||
DelayBuffer<float> signal_echo_delay_buffer(down_sampling_factor *
|
||||
echo_delay_samples);
|
||||
DelayBuffer<float> signal_pre_echo_delay_buffer(down_sampling_factor *
|
||||
pre_echo_delay_samples);
|
||||
MatchedFilter filter(
|
||||
&data_dumper, DetectOptimization(), sub_block_size,
|
||||
kWindowSizeSubBlocks, kNumMatchedFilters, kAlignmentShiftSubBlocks, 150,
|
||||
config.delay.delay_estimate_smoothing,
|
||||
config.delay.delay_estimate_smoothing_delay_found,
|
||||
config.delay.delay_candidate_detection_threshold, kDetectPreEcho);
|
||||
std::unique_ptr<RenderDelayBuffer> render_delay_buffer(
|
||||
RenderDelayBuffer::Create(config, kSampleRateHz, kNumChannels));
|
||||
// Analyze the correlation between render and capture.
|
||||
for (size_t k = 0; k < (600 + echo_delay_samples / sub_block_size); ++k) {
|
||||
for (size_t band = 0; band < kNumBands; ++band) {
|
||||
for (size_t channel = 0; channel < kNumChannels; ++channel) {
|
||||
RandomizeSampleVector(&random_generator, render.View(band, channel));
|
||||
}
|
||||
}
|
||||
signal_echo_delay_buffer.Delay(render.View(0, 0), capture[0]);
|
||||
signal_pre_echo_delay_buffer.Delay(render.View(0, 0),
|
||||
capture_with_pre_echo);
|
||||
for (size_t k = 0; k < capture[0].size(); ++k) {
|
||||
constexpr float gain_pre_echo = 0.8f;
|
||||
capture[0][k] += gain_pre_echo * capture_with_pre_echo[k];
|
||||
}
|
||||
render_delay_buffer->Insert(render);
|
||||
if (k == 0) {
|
||||
render_delay_buffer->Reset();
|
||||
}
|
||||
render_delay_buffer->PrepareCaptureProcessing();
|
||||
std::array<float, kBlockSize> downsampled_capture_data;
|
||||
rtc::ArrayView<float> downsampled_capture(downsampled_capture_data.data(),
|
||||
sub_block_size);
|
||||
capture_decimator.Decimate(capture[0], downsampled_capture);
|
||||
filter.Update(render_delay_buffer->GetDownsampledRenderBuffer(),
|
||||
downsampled_capture, /*use_slow_smoothing=*/false);
|
||||
}
|
||||
// Obtain the lag estimates.
|
||||
auto lag_estimate = filter.GetBestLagEstimate();
|
||||
EXPECT_TRUE(lag_estimate.has_value());
|
||||
// Verify that the expected most accurate lag estimate is correct.
|
||||
if (lag_estimate.has_value()) {
|
||||
EXPECT_EQ(echo_delay_samples, lag_estimate->lag);
|
||||
if (kDetectPreEcho) {
|
||||
// The pre echo delay is estimated in a subsampled domain and a larger
|
||||
// error is allowed.
|
||||
EXPECT_NEAR(pre_echo_delay_samples, lag_estimate->pre_echo_lag, 4);
|
||||
} else {
|
||||
EXPECT_TRUE(
|
||||
delay_samples ==
|
||||
lag_estimates[std::min(*expected_most_accurate_lag_estimate + 1,
|
||||
lag_estimates.size() - 1)]
|
||||
.lag);
|
||||
// The pre echo delay fallback to the highest mached filter peak when
|
||||
// its detection is disabled.
|
||||
EXPECT_EQ(echo_delay_samples, lag_estimate->pre_echo_lag);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -317,7 +396,8 @@ TEST(MatchedFilter, LagEstimation) {
|
||||
|
||||
// Verifies that the matched filter does not produce reliable and accurate
|
||||
// estimates for uncorrelated render and capture signals.
|
||||
TEST(MatchedFilter, LagNotReliableForUncorrelatedRenderAndCapture) {
|
||||
TEST_P(MatchedFilterTest, LagNotReliableForUncorrelatedRenderAndCapture) {
|
||||
const bool kDetectPreEcho = GetParam();
|
||||
constexpr size_t kNumChannels = 1;
|
||||
constexpr int kSampleRateHz = 48000;
|
||||
constexpr size_t kNumBands = NumBandsForRate(kSampleRateHz);
|
||||
@ -335,12 +415,12 @@ TEST(MatchedFilter, LagNotReliableForUncorrelatedRenderAndCapture) {
|
||||
ApmDataDumper data_dumper(0);
|
||||
std::unique_ptr<RenderDelayBuffer> render_delay_buffer(
|
||||
RenderDelayBuffer::Create(config, kSampleRateHz, kNumChannels));
|
||||
MatchedFilter filter(&data_dumper, DetectOptimization(), sub_block_size,
|
||||
kWindowSizeSubBlocks, kNumMatchedFilters,
|
||||
kAlignmentShiftSubBlocks, 150,
|
||||
config.delay.delay_estimate_smoothing,
|
||||
config.delay.delay_estimate_smoothing_delay_found,
|
||||
config.delay.delay_candidate_detection_threshold);
|
||||
MatchedFilter filter(
|
||||
&data_dumper, DetectOptimization(), sub_block_size,
|
||||
kWindowSizeSubBlocks, kNumMatchedFilters, kAlignmentShiftSubBlocks, 150,
|
||||
config.delay.delay_estimate_smoothing,
|
||||
config.delay.delay_estimate_smoothing_delay_found,
|
||||
config.delay.delay_candidate_detection_threshold, kDetectPreEcho);
|
||||
|
||||
// Analyze the correlation between render and capture.
|
||||
for (size_t k = 0; k < 100; ++k) {
|
||||
@ -352,20 +432,17 @@ TEST(MatchedFilter, LagNotReliableForUncorrelatedRenderAndCapture) {
|
||||
false);
|
||||
}
|
||||
|
||||
// Obtain the lag estimates.
|
||||
auto lag_estimates = filter.GetLagEstimates();
|
||||
EXPECT_EQ(kNumMatchedFilters, lag_estimates.size());
|
||||
|
||||
// Verify that no lag estimates are reliable.
|
||||
for (auto& le : lag_estimates) {
|
||||
EXPECT_FALSE(le.reliable);
|
||||
}
|
||||
// Obtain the best lag estimate and Verify that no lag estimates are
|
||||
// reliable.
|
||||
auto best_lag_estimates = filter.GetBestLagEstimate();
|
||||
EXPECT_FALSE(best_lag_estimates.has_value());
|
||||
}
|
||||
}
|
||||
|
||||
// Verifies that the matched filter does not produce updated lag estimates for
|
||||
// render signals of low level.
|
||||
TEST(MatchedFilter, LagNotUpdatedForLowLevelRender) {
|
||||
TEST_P(MatchedFilterTest, LagNotUpdatedForLowLevelRender) {
|
||||
const bool kDetectPreEcho = GetParam();
|
||||
Random random_generator(42U);
|
||||
constexpr size_t kNumChannels = 1;
|
||||
constexpr int kSampleRateHz = 48000;
|
||||
@ -374,19 +451,17 @@ TEST(MatchedFilter, LagNotUpdatedForLowLevelRender) {
|
||||
for (auto down_sampling_factor : kDownSamplingFactors) {
|
||||
const size_t sub_block_size = kBlockSize / down_sampling_factor;
|
||||
|
||||
std::vector<std::vector<std::vector<float>>> render(
|
||||
kNumBands, std::vector<std::vector<float>>(
|
||||
kNumChannels, std::vector<float>(kBlockSize, 0.f)));
|
||||
Block render(kNumBands, kNumChannels);
|
||||
std::vector<std::vector<float>> capture(
|
||||
1, std::vector<float>(kBlockSize, 0.f));
|
||||
ApmDataDumper data_dumper(0);
|
||||
EchoCanceller3Config config;
|
||||
MatchedFilter filter(&data_dumper, DetectOptimization(), sub_block_size,
|
||||
kWindowSizeSubBlocks, kNumMatchedFilters,
|
||||
kAlignmentShiftSubBlocks, 150,
|
||||
config.delay.delay_estimate_smoothing,
|
||||
config.delay.delay_estimate_smoothing_delay_found,
|
||||
config.delay.delay_candidate_detection_threshold);
|
||||
MatchedFilter filter(
|
||||
&data_dumper, DetectOptimization(), sub_block_size,
|
||||
kWindowSizeSubBlocks, kNumMatchedFilters, kAlignmentShiftSubBlocks, 150,
|
||||
config.delay.delay_estimate_smoothing,
|
||||
config.delay.delay_estimate_smoothing_delay_found,
|
||||
config.delay.delay_candidate_detection_threshold, kDetectPreEcho);
|
||||
std::unique_ptr<RenderDelayBuffer> render_delay_buffer(
|
||||
RenderDelayBuffer::Create(EchoCanceller3Config(), kSampleRateHz,
|
||||
kNumChannels));
|
||||
@ -394,11 +469,11 @@ TEST(MatchedFilter, LagNotUpdatedForLowLevelRender) {
|
||||
|
||||
// Analyze the correlation between render and capture.
|
||||
for (size_t k = 0; k < 100; ++k) {
|
||||
RandomizeSampleVector(&random_generator, render[0][0]);
|
||||
for (auto& render_k : render[0][0]) {
|
||||
RandomizeSampleVector(&random_generator, render.View(0, 0));
|
||||
for (auto& render_k : render.View(0, 0)) {
|
||||
render_k *= 149.f / 32767.f;
|
||||
}
|
||||
std::copy(render[0][0].begin(), render[0][0].end(), capture[0].begin());
|
||||
std::copy(render.begin(0, 0), render.end(0, 0), capture[0].begin());
|
||||
std::array<float, kBlockSize> downsampled_capture_data;
|
||||
rtc::ArrayView<float> downsampled_capture(downsampled_capture_data.data(),
|
||||
sub_block_size);
|
||||
@ -407,86 +482,76 @@ TEST(MatchedFilter, LagNotUpdatedForLowLevelRender) {
|
||||
downsampled_capture, false);
|
||||
}
|
||||
|
||||
// Obtain the lag estimates.
|
||||
auto lag_estimates = filter.GetLagEstimates();
|
||||
EXPECT_EQ(kNumMatchedFilters, lag_estimates.size());
|
||||
|
||||
// Verify that no lag estimates are updated and that no lag estimates are
|
||||
// reliable.
|
||||
for (auto& le : lag_estimates) {
|
||||
EXPECT_FALSE(le.updated);
|
||||
EXPECT_FALSE(le.reliable);
|
||||
}
|
||||
// Verify that no lag estimate has been produced.
|
||||
auto lag_estimate = filter.GetBestLagEstimate();
|
||||
EXPECT_FALSE(lag_estimate.has_value());
|
||||
}
|
||||
}
|
||||
|
||||
// Verifies that the correct number of lag estimates are produced for a certain
|
||||
// number of alignment shifts.
|
||||
TEST(MatchedFilter, NumberOfLagEstimates) {
|
||||
ApmDataDumper data_dumper(0);
|
||||
EchoCanceller3Config config;
|
||||
for (auto down_sampling_factor : kDownSamplingFactors) {
|
||||
const size_t sub_block_size = kBlockSize / down_sampling_factor;
|
||||
for (size_t num_matched_filters = 0; num_matched_filters < 10;
|
||||
++num_matched_filters) {
|
||||
MatchedFilter filter(&data_dumper, DetectOptimization(), sub_block_size,
|
||||
32, num_matched_filters, 1, 150,
|
||||
config.delay.delay_estimate_smoothing,
|
||||
config.delay.delay_estimate_smoothing_delay_found,
|
||||
config.delay.delay_candidate_detection_threshold);
|
||||
EXPECT_EQ(num_matched_filters, filter.GetLagEstimates().size());
|
||||
}
|
||||
}
|
||||
}
|
||||
INSTANTIATE_TEST_SUITE_P(_, MatchedFilterTest, testing::Values(true, false));
|
||||
|
||||
#if RTC_DCHECK_IS_ON && GTEST_HAS_DEATH_TEST && !defined(WEBRTC_ANDROID)
|
||||
|
||||
class MatchedFilterDeathTest : public ::testing::TestWithParam<bool> {};
|
||||
|
||||
// Verifies the check for non-zero windows size.
|
||||
TEST(MatchedFilterDeathTest, ZeroWindowSize) {
|
||||
TEST_P(MatchedFilterDeathTest, ZeroWindowSize) {
|
||||
const bool kDetectPreEcho = GetParam();
|
||||
ApmDataDumper data_dumper(0);
|
||||
EchoCanceller3Config config;
|
||||
EXPECT_DEATH(MatchedFilter(&data_dumper, DetectOptimization(), 16, 0, 1, 1,
|
||||
150, config.delay.delay_estimate_smoothing,
|
||||
config.delay.delay_estimate_smoothing_delay_found,
|
||||
config.delay.delay_candidate_detection_threshold),
|
||||
config.delay.delay_candidate_detection_threshold,
|
||||
kDetectPreEcho),
|
||||
"");
|
||||
}
|
||||
|
||||
// Verifies the check for non-null data dumper.
|
||||
TEST(MatchedFilterDeathTest, NullDataDumper) {
|
||||
TEST_P(MatchedFilterDeathTest, NullDataDumper) {
|
||||
const bool kDetectPreEcho = GetParam();
|
||||
EchoCanceller3Config config;
|
||||
EXPECT_DEATH(MatchedFilter(nullptr, DetectOptimization(), 16, 1, 1, 1, 150,
|
||||
config.delay.delay_estimate_smoothing,
|
||||
config.delay.delay_estimate_smoothing_delay_found,
|
||||
config.delay.delay_candidate_detection_threshold),
|
||||
config.delay.delay_candidate_detection_threshold,
|
||||
kDetectPreEcho),
|
||||
"");
|
||||
}
|
||||
|
||||
// Verifies the check for that the sub block size is a multiple of 4.
|
||||
// TODO(peah): Activate the unittest once the required code has been landed.
|
||||
TEST(MatchedFilterDeathTest, DISABLED_BlockSizeMultipleOf4) {
|
||||
TEST_P(MatchedFilterDeathTest, DISABLED_BlockSizeMultipleOf4) {
|
||||
const bool kDetectPreEcho = GetParam();
|
||||
ApmDataDumper data_dumper(0);
|
||||
EchoCanceller3Config config;
|
||||
EXPECT_DEATH(MatchedFilter(&data_dumper, DetectOptimization(), 15, 1, 1, 1,
|
||||
150, config.delay.delay_estimate_smoothing,
|
||||
config.delay.delay_estimate_smoothing_delay_found,
|
||||
config.delay.delay_candidate_detection_threshold),
|
||||
config.delay.delay_candidate_detection_threshold,
|
||||
kDetectPreEcho),
|
||||
"");
|
||||
}
|
||||
|
||||
// Verifies the check for that there is an integer number of sub blocks that add
|
||||
// up to a block size.
|
||||
// TODO(peah): Activate the unittest once the required code has been landed.
|
||||
TEST(MatchedFilterDeathTest, DISABLED_SubBlockSizeAddsUpToBlockSize) {
|
||||
TEST_P(MatchedFilterDeathTest, DISABLED_SubBlockSizeAddsUpToBlockSize) {
|
||||
const bool kDetectPreEcho = GetParam();
|
||||
ApmDataDumper data_dumper(0);
|
||||
EchoCanceller3Config config;
|
||||
EXPECT_DEATH(MatchedFilter(&data_dumper, DetectOptimization(), 12, 1, 1, 1,
|
||||
150, config.delay.delay_estimate_smoothing,
|
||||
config.delay.delay_estimate_smoothing_delay_found,
|
||||
config.delay.delay_candidate_detection_threshold),
|
||||
config.delay.delay_candidate_detection_threshold,
|
||||
kDetectPreEcho),
|
||||
"");
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(_,
|
||||
MatchedFilterDeathTest,
|
||||
testing::Values(true, false));
|
||||
|
||||
#endif
|
||||
|
||||
} // namespace aec3
|
||||
|
@ -53,7 +53,6 @@ class RenderDelayControllerImpl final : public RenderDelayController {
|
||||
static std::atomic<int> instance_count_;
|
||||
std::unique_ptr<ApmDataDumper> data_dumper_;
|
||||
const int hysteresis_limit_blocks_;
|
||||
const int delay_headroom_samples_;
|
||||
absl::optional<DelayEstimate> delay_;
|
||||
EchoPathDelayEstimator delay_estimator_;
|
||||
RenderDelayControllerMetrics metrics_;
|
||||
@ -66,15 +65,9 @@ class RenderDelayControllerImpl final : public RenderDelayController {
|
||||
DelayEstimate ComputeBufferDelay(
|
||||
const absl::optional<DelayEstimate>& current_delay,
|
||||
int hysteresis_limit_blocks,
|
||||
int delay_headroom_samples,
|
||||
DelayEstimate estimated_delay) {
|
||||
// Subtract delay headroom.
|
||||
const int delay_with_headroom_samples = std::max(
|
||||
static_cast<int>(estimated_delay.delay) - delay_headroom_samples, 0);
|
||||
|
||||
// Compute the buffer delay increase required to achieve the desired latency.
|
||||
size_t new_delay_blocks = delay_with_headroom_samples >> kBlockSizeLog2;
|
||||
|
||||
size_t new_delay_blocks = estimated_delay.delay >> kBlockSizeLog2;
|
||||
// Add hysteresis.
|
||||
if (current_delay) {
|
||||
size_t current_delay_blocks = current_delay->delay;
|
||||
@ -83,7 +76,6 @@ DelayEstimate ComputeBufferDelay(
|
||||
new_delay_blocks = current_delay_blocks;
|
||||
}
|
||||
}
|
||||
|
||||
DelayEstimate new_delay = estimated_delay;
|
||||
new_delay.delay = new_delay_blocks;
|
||||
return new_delay;
|
||||
@ -98,7 +90,6 @@ RenderDelayControllerImpl::RenderDelayControllerImpl(
|
||||
: data_dumper_(new ApmDataDumper(instance_count_.fetch_add(1) + 1)),
|
||||
hysteresis_limit_blocks_(
|
||||
static_cast<int>(config.delay.hysteresis_limit_blocks)),
|
||||
delay_headroom_samples_(config.delay.delay_headroom_samples),
|
||||
delay_estimator_(data_dumper_.get(), config, num_capture_channels),
|
||||
last_delay_estimate_quality_(DelayEstimate::Quality::kCoarse) {
|
||||
RTC_DCHECK(ValidFullBandRate(sample_rate_hz));
|
||||
@ -158,9 +149,8 @@ absl::optional<DelayEstimate> RenderDelayControllerImpl::GetDelay(
|
||||
const bool use_hysteresis =
|
||||
last_delay_estimate_quality_ == DelayEstimate::Quality::kRefined &&
|
||||
delay_samples_->quality == DelayEstimate::Quality::kRefined;
|
||||
delay_ = ComputeBufferDelay(delay_,
|
||||
use_hysteresis ? hysteresis_limit_blocks_ : 0,
|
||||
delay_headroom_samples_, *delay_samples_);
|
||||
delay_ = ComputeBufferDelay(
|
||||
delay_, use_hysteresis ? hysteresis_limit_blocks_ : 0, *delay_samples_);
|
||||
last_delay_estimate_quality_ = delay_samples_->quality;
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user