Optimize SSE2- & AVX2 parts of the matched filter further.

Manually unrolling the multiply-and-accumulate loop of the matched filter allows interleaving of instruction, which gives a significant saving.

Bug: None
Change-Id: Ie7a7d92bd453d81e9dd61812781a7b6d62e1f1f4
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/240321
Reviewed-by: Ivo Creusen <ivoc@webrtc.org>
Commit-Queue: Christian Schuldt <cschuldt@google.com>
Reviewed-by: Per Åhgren <peah@webrtc.org>
Cr-Commit-Position: refs/heads/main@{#35566}
This commit is contained in:
cschuldt
2021-12-08 09:49:00 +01:00
committed by WebRTC LUCI CQ
parent 396e1baa46
commit ce702dbbe4
2 changed files with 24 additions and 6 deletions

View File

@ -166,7 +166,9 @@ void MatchedFilterCore_SSE2(size_t x_start_index,
// Initialize values for the accumulation.
__m128 s_128 = _mm_set1_ps(0);
__m128 s_128_4 = _mm_set1_ps(0);
__m128 x2_sum_128 = _mm_set1_ps(0);
__m128 x2_sum_128_4 = _mm_set1_ps(0);
float x2_sum = 0.f;
float s = 0;
@ -179,20 +181,26 @@ void MatchedFilterCore_SSE2(size_t x_start_index,
const int chunk2 = h_size - chunk1;
for (int limit : {chunk1, chunk2}) {
// Perform 128 bit vector operations.
const int limit_by_4 = limit >> 2;
for (int k = limit_by_4; k > 0; --k, h_p += 4, x_p += 4) {
const int limit_by_8 = limit >> 3;
for (int k = limit_by_8; k > 0; --k, h_p += 8, x_p += 8) {
// Load the data into 128 bit vectors.
const __m128 x_k = _mm_loadu_ps(x_p);
const __m128 h_k = _mm_loadu_ps(h_p);
const __m128 x_k_4 = _mm_loadu_ps(x_p + 4);
const __m128 h_k_4 = _mm_loadu_ps(h_p + 4);
const __m128 xx = _mm_mul_ps(x_k, x_k);
const __m128 xx_4 = _mm_mul_ps(x_k_4, x_k_4);
// Compute and accumulate x * x and h * x.
x2_sum_128 = _mm_add_ps(x2_sum_128, xx);
x2_sum_128_4 = _mm_add_ps(x2_sum_128_4, xx_4);
const __m128 hx = _mm_mul_ps(h_k, x_k);
const __m128 hx_4 = _mm_mul_ps(h_k_4, x_k_4);
s_128 = _mm_add_ps(s_128, hx);
s_128_4 = _mm_add_ps(s_128_4, hx_4);
}
// Perform non-vector operations for any remaining items.
for (int k = limit - limit_by_4 * 4; k > 0; --k, ++h_p, ++x_p) {
for (int k = limit - limit_by_8 * 8; k > 0; --k, ++h_p, ++x_p) {
const float x_k = *x_p;
x2_sum += x_k * x_k;
s += *h_p * x_k;
@ -202,8 +210,10 @@ void MatchedFilterCore_SSE2(size_t x_start_index,
}
// Combine the accumulated vector and scalar values.
x2_sum_128 = _mm_add_ps(x2_sum_128, x2_sum_128_4);
float* v = reinterpret_cast<float*>(&x2_sum_128);
x2_sum += v[0] + v[1] + v[2] + v[3];
s_128 = _mm_add_ps(s_128, s_128_4);
v = reinterpret_cast<float*>(&s_128);
s += v[0] + v[1] + v[2] + v[3];

View File

@ -39,7 +39,9 @@ void MatchedFilterCore_AVX2(size_t x_start_index,
// Initialize values for the accumulation.
__m256 s_256 = _mm256_set1_ps(0);
__m256 s_256_8 = _mm256_set1_ps(0);
__m256 x2_sum_256 = _mm256_set1_ps(0);
__m256 x2_sum_256_8 = _mm256_set1_ps(0);
float x2_sum = 0.f;
float s = 0;
@ -52,18 +54,22 @@ void MatchedFilterCore_AVX2(size_t x_start_index,
const int chunk2 = h_size - chunk1;
for (int limit : {chunk1, chunk2}) {
// Perform 256 bit vector operations.
const int limit_by_8 = limit >> 3;
for (int k = limit_by_8; k > 0; --k, h_p += 8, x_p += 8) {
const int limit_by_16 = limit >> 4;
for (int k = limit_by_16; k > 0; --k, h_p += 16, x_p += 16) {
// Load the data into 256 bit vectors.
__m256 x_k = _mm256_loadu_ps(x_p);
__m256 h_k = _mm256_loadu_ps(h_p);
__m256 x_k_8 = _mm256_loadu_ps(x_p + 8);
__m256 h_k_8 = _mm256_loadu_ps(h_p + 8);
// Compute and accumulate x * x and h * x.
x2_sum_256 = _mm256_fmadd_ps(x_k, x_k, x2_sum_256);
x2_sum_256_8 = _mm256_fmadd_ps(x_k_8, x_k_8, x2_sum_256_8);
s_256 = _mm256_fmadd_ps(h_k, x_k, s_256);
s_256_8 = _mm256_fmadd_ps(h_k_8, x_k_8, s_256_8);
}
// Perform non-vector operations for any remaining items.
for (int k = limit - limit_by_8 * 8; k > 0; --k, ++h_p, ++x_p) {
for (int k = limit - limit_by_16 * 16; k > 0; --k, ++h_p, ++x_p) {
const float x_k = *x_p;
x2_sum += x_k * x_k;
s += *h_p * x_k;
@ -73,6 +79,8 @@ void MatchedFilterCore_AVX2(size_t x_start_index,
}
// Sum components together.
x2_sum_256 = _mm256_add_ps(x2_sum_256, x2_sum_256_8);
s_256 = _mm256_add_ps(s_256, s_256_8);
__m128 x2_sum_128 = _mm_add_ps(_mm256_extractf128_ps(x2_sum_256, 0),
_mm256_extractf128_ps(x2_sum_256, 1));
__m128 s_128 = _mm_add_ps(_mm256_extractf128_ps(s_256, 0),