Optimize SSE2- & AVX2 parts of the matched filter further.
Manually unrolling the multiply-and-accumulate loop of the matched filter allows interleaving of instruction, which gives a significant saving. Bug: None Change-Id: Ie7a7d92bd453d81e9dd61812781a7b6d62e1f1f4 Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/240321 Reviewed-by: Ivo Creusen <ivoc@webrtc.org> Commit-Queue: Christian Schuldt <cschuldt@google.com> Reviewed-by: Per Åhgren <peah@webrtc.org> Cr-Commit-Position: refs/heads/main@{#35566}
This commit is contained in:
@ -166,7 +166,9 @@ void MatchedFilterCore_SSE2(size_t x_start_index,
|
||||
|
||||
// Initialize values for the accumulation.
|
||||
__m128 s_128 = _mm_set1_ps(0);
|
||||
__m128 s_128_4 = _mm_set1_ps(0);
|
||||
__m128 x2_sum_128 = _mm_set1_ps(0);
|
||||
__m128 x2_sum_128_4 = _mm_set1_ps(0);
|
||||
float x2_sum = 0.f;
|
||||
float s = 0;
|
||||
|
||||
@ -179,20 +181,26 @@ void MatchedFilterCore_SSE2(size_t x_start_index,
|
||||
const int chunk2 = h_size - chunk1;
|
||||
for (int limit : {chunk1, chunk2}) {
|
||||
// Perform 128 bit vector operations.
|
||||
const int limit_by_4 = limit >> 2;
|
||||
for (int k = limit_by_4; k > 0; --k, h_p += 4, x_p += 4) {
|
||||
const int limit_by_8 = limit >> 3;
|
||||
for (int k = limit_by_8; k > 0; --k, h_p += 8, x_p += 8) {
|
||||
// Load the data into 128 bit vectors.
|
||||
const __m128 x_k = _mm_loadu_ps(x_p);
|
||||
const __m128 h_k = _mm_loadu_ps(h_p);
|
||||
const __m128 x_k_4 = _mm_loadu_ps(x_p + 4);
|
||||
const __m128 h_k_4 = _mm_loadu_ps(h_p + 4);
|
||||
const __m128 xx = _mm_mul_ps(x_k, x_k);
|
||||
const __m128 xx_4 = _mm_mul_ps(x_k_4, x_k_4);
|
||||
// Compute and accumulate x * x and h * x.
|
||||
x2_sum_128 = _mm_add_ps(x2_sum_128, xx);
|
||||
x2_sum_128_4 = _mm_add_ps(x2_sum_128_4, xx_4);
|
||||
const __m128 hx = _mm_mul_ps(h_k, x_k);
|
||||
const __m128 hx_4 = _mm_mul_ps(h_k_4, x_k_4);
|
||||
s_128 = _mm_add_ps(s_128, hx);
|
||||
s_128_4 = _mm_add_ps(s_128_4, hx_4);
|
||||
}
|
||||
|
||||
// Perform non-vector operations for any remaining items.
|
||||
for (int k = limit - limit_by_4 * 4; k > 0; --k, ++h_p, ++x_p) {
|
||||
for (int k = limit - limit_by_8 * 8; k > 0; --k, ++h_p, ++x_p) {
|
||||
const float x_k = *x_p;
|
||||
x2_sum += x_k * x_k;
|
||||
s += *h_p * x_k;
|
||||
@ -202,8 +210,10 @@ void MatchedFilterCore_SSE2(size_t x_start_index,
|
||||
}
|
||||
|
||||
// Combine the accumulated vector and scalar values.
|
||||
x2_sum_128 = _mm_add_ps(x2_sum_128, x2_sum_128_4);
|
||||
float* v = reinterpret_cast<float*>(&x2_sum_128);
|
||||
x2_sum += v[0] + v[1] + v[2] + v[3];
|
||||
s_128 = _mm_add_ps(s_128, s_128_4);
|
||||
v = reinterpret_cast<float*>(&s_128);
|
||||
s += v[0] + v[1] + v[2] + v[3];
|
||||
|
||||
|
||||
@ -39,7 +39,9 @@ void MatchedFilterCore_AVX2(size_t x_start_index,
|
||||
|
||||
// Initialize values for the accumulation.
|
||||
__m256 s_256 = _mm256_set1_ps(0);
|
||||
__m256 s_256_8 = _mm256_set1_ps(0);
|
||||
__m256 x2_sum_256 = _mm256_set1_ps(0);
|
||||
__m256 x2_sum_256_8 = _mm256_set1_ps(0);
|
||||
float x2_sum = 0.f;
|
||||
float s = 0;
|
||||
|
||||
@ -52,18 +54,22 @@ void MatchedFilterCore_AVX2(size_t x_start_index,
|
||||
const int chunk2 = h_size - chunk1;
|
||||
for (int limit : {chunk1, chunk2}) {
|
||||
// Perform 256 bit vector operations.
|
||||
const int limit_by_8 = limit >> 3;
|
||||
for (int k = limit_by_8; k > 0; --k, h_p += 8, x_p += 8) {
|
||||
const int limit_by_16 = limit >> 4;
|
||||
for (int k = limit_by_16; k > 0; --k, h_p += 16, x_p += 16) {
|
||||
// Load the data into 256 bit vectors.
|
||||
__m256 x_k = _mm256_loadu_ps(x_p);
|
||||
__m256 h_k = _mm256_loadu_ps(h_p);
|
||||
__m256 x_k_8 = _mm256_loadu_ps(x_p + 8);
|
||||
__m256 h_k_8 = _mm256_loadu_ps(h_p + 8);
|
||||
// Compute and accumulate x * x and h * x.
|
||||
x2_sum_256 = _mm256_fmadd_ps(x_k, x_k, x2_sum_256);
|
||||
x2_sum_256_8 = _mm256_fmadd_ps(x_k_8, x_k_8, x2_sum_256_8);
|
||||
s_256 = _mm256_fmadd_ps(h_k, x_k, s_256);
|
||||
s_256_8 = _mm256_fmadd_ps(h_k_8, x_k_8, s_256_8);
|
||||
}
|
||||
|
||||
// Perform non-vector operations for any remaining items.
|
||||
for (int k = limit - limit_by_8 * 8; k > 0; --k, ++h_p, ++x_p) {
|
||||
for (int k = limit - limit_by_16 * 16; k > 0; --k, ++h_p, ++x_p) {
|
||||
const float x_k = *x_p;
|
||||
x2_sum += x_k * x_k;
|
||||
s += *h_p * x_k;
|
||||
@ -73,6 +79,8 @@ void MatchedFilterCore_AVX2(size_t x_start_index,
|
||||
}
|
||||
|
||||
// Sum components together.
|
||||
x2_sum_256 = _mm256_add_ps(x2_sum_256, x2_sum_256_8);
|
||||
s_256 = _mm256_add_ps(s_256, s_256_8);
|
||||
__m128 x2_sum_128 = _mm_add_ps(_mm256_extractf128_ps(x2_sum_256, 0),
|
||||
_mm256_extractf128_ps(x2_sum_256, 1));
|
||||
__m128 s_128 = _mm_add_ps(_mm256_extractf128_ps(s_256, 0),
|
||||
|
||||
Reference in New Issue
Block a user