Optimize MatchedFilter.
Changing to an index for-loop (instead of using std::max_element & std::distance) tracking even & odd elements separately allows the compiler to produce code with less pipeline stall. Bug: None Change-Id: Iaa3e820a3a3b61e2eb276f0dac9106c848db1891 Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/240061 Reviewed-by: Per Åhgren <peah@webrtc.org> Commit-Queue: Christian Schuldt <cschuldt@google.com> Cr-Commit-Position: refs/heads/main@{#35729}
This commit is contained in:
@ -308,6 +308,41 @@ void MatchedFilterCore(size_t x_start_index,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
size_t MaxSquarePeakIndex(rtc::ArrayView<const float> h) {
|
||||||
|
if (h.size() < 2) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
float max_element1 = h[0] * h[0];
|
||||||
|
float max_element2 = h[1] * h[1];
|
||||||
|
size_t lag_estimate1 = 0;
|
||||||
|
size_t lag_estimate2 = 1;
|
||||||
|
const size_t last_index = h.size() - 1;
|
||||||
|
// Keeping track of even & odd max elements separately typically allows the
|
||||||
|
// compiler to produce more efficient code.
|
||||||
|
for (size_t k = 2; k < last_index; k += 2) {
|
||||||
|
float element1 = h[k] * h[k];
|
||||||
|
float element2 = h[k + 1] * h[k + 1];
|
||||||
|
if (element1 > max_element1) {
|
||||||
|
max_element1 = element1;
|
||||||
|
lag_estimate1 = k;
|
||||||
|
}
|
||||||
|
if (element2 > max_element2) {
|
||||||
|
max_element2 = element2;
|
||||||
|
lag_estimate2 = k + 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (max_element2 > max_element1) {
|
||||||
|
max_element1 = max_element2;
|
||||||
|
lag_estimate1 = lag_estimate2;
|
||||||
|
}
|
||||||
|
// In case of odd h size, we have not yet checked the last element.
|
||||||
|
float last_element = h[last_index] * h[last_index];
|
||||||
|
if (last_element > max_element1) {
|
||||||
|
return last_index;
|
||||||
|
}
|
||||||
|
return lag_estimate1;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace aec3
|
} // namespace aec3
|
||||||
|
|
||||||
MatchedFilter::MatchedFilter(ApmDataDumper* data_dumper,
|
MatchedFilter::MatchedFilter(ApmDataDumper* data_dumper,
|
||||||
@ -400,17 +435,15 @@ void MatchedFilter::Update(const DownsampledRenderBuffer& render_buffer,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Compute anchor for the matched filter error.
|
// Compute anchor for the matched filter error.
|
||||||
const float error_sum_anchor =
|
float error_sum_anchor = 0.0f;
|
||||||
std::inner_product(y.begin(), y.end(), y.begin(), 0.f);
|
for (size_t k = 0; k < y.size(); ++k) {
|
||||||
|
error_sum_anchor += y[k] * y[k];
|
||||||
|
}
|
||||||
|
|
||||||
// Estimate the lag in the matched filter as the distance to the portion in
|
// Estimate the lag in the matched filter as the distance to the portion in
|
||||||
// the filter that contributes the most to the matched filter output. This
|
// the filter that contributes the most to the matched filter output. This
|
||||||
// is detected as the peak of the matched filter.
|
// is detected as the peak of the matched filter.
|
||||||
const size_t lag_estimate = std::distance(
|
const size_t lag_estimate = aec3::MaxSquarePeakIndex(filters_[n]);
|
||||||
filters_[n].begin(),
|
|
||||||
std::max_element(
|
|
||||||
filters_[n].begin(), filters_[n].end(),
|
|
||||||
[](float a, float b) -> bool { return a * a < b * b; }));
|
|
||||||
|
|
||||||
// Update the lag estimates for the matched filter.
|
// Update the lag estimates for the matched filter.
|
||||||
lag_estimates_[n] = LagEstimate(
|
lag_estimates_[n] = LagEstimate(
|
||||||
|
@ -74,6 +74,9 @@ void MatchedFilterCore(size_t x_start_index,
|
|||||||
bool* filters_updated,
|
bool* filters_updated,
|
||||||
float* error_sum);
|
float* error_sum);
|
||||||
|
|
||||||
|
// Find largest peak of squared values in array.
|
||||||
|
size_t MaxSquarePeakIndex(rtc::ArrayView<const float> h);
|
||||||
|
|
||||||
} // namespace aec3
|
} // namespace aec3
|
||||||
|
|
||||||
// Produces recursively updated cross-correlation estimates for several signal
|
// Produces recursively updated cross-correlation estimates for several signal
|
||||||
|
@ -176,6 +176,28 @@ TEST(MatchedFilter, TestAvx2Optimizations) {
|
|||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
// Verifies that the (optimized) function MaxSquarePeakIndex() produces output
|
||||||
|
// equal to the corresponding std-functions.
|
||||||
|
TEST(MatchedFilter, MaxSquarePeakIndex) {
|
||||||
|
Random random_generator(42U);
|
||||||
|
constexpr int kMaxLength = 128;
|
||||||
|
constexpr int kNumIterationsPerLength = 256;
|
||||||
|
for (int length = 1; length < kMaxLength; ++length) {
|
||||||
|
std::vector<float> y(length);
|
||||||
|
for (int i = 0; i < kNumIterationsPerLength; ++i) {
|
||||||
|
RandomizeSampleVector(&random_generator, y);
|
||||||
|
|
||||||
|
size_t lag_from_function = MaxSquarePeakIndex(y);
|
||||||
|
size_t lag_from_std = std::distance(
|
||||||
|
y.begin(),
|
||||||
|
std::max_element(y.begin(), y.end(), [](float a, float b) -> bool {
|
||||||
|
return a * a < b * b;
|
||||||
|
}));
|
||||||
|
EXPECT_EQ(lag_from_function, lag_from_std);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Verifies that the matched filter produces proper lag estimates for
|
// Verifies that the matched filter produces proper lag estimates for
|
||||||
// artificially
|
// artificially
|
||||||
// delayed signals.
|
// delayed signals.
|
||||||
|
Reference in New Issue
Block a user