diff --git a/modules/audio_processing/aec3/adaptive_fir_filter.cc b/modules/audio_processing/aec3/adaptive_fir_filter.cc index 00fa884aeb..6a0f531663 100644 --- a/modules/audio_processing/aec3/adaptive_fir_filter.cc +++ b/modules/audio_processing/aec3/adaptive_fir_filter.cc @@ -19,6 +19,8 @@ #if defined(WEBRTC_ARCH_X86_FAMILY) #include #endif +#include + #include #include @@ -30,207 +32,255 @@ namespace webrtc { namespace aec3 { // Computes and stores the frequency response of the filter. -void UpdateFrequencyResponse( - rtc::ArrayView H, +void ComputeFrequencyResponse( + size_t num_partitions, + const std::vector>& H, std::vector>* H2) { - RTC_DCHECK_EQ(H.size(), H2->size()); - for (size_t k = 0; k < H.size(); ++k) { - std::transform(H[k].re.begin(), H[k].re.end(), H[k].im.begin(), - (*H2)[k].begin(), - [](float a, float b) { return a * a + b * b; }); + for (auto& H2_ch : *H2) { + H2_ch.fill(0.f); + } + + const size_t num_render_channels = H[0].size(); + RTC_DCHECK_EQ(H.size(), H2->capacity()); + for (size_t p = 0; p < num_partitions; ++p) { + RTC_DCHECK_EQ(kFftLengthBy2Plus1, (*H2)[p].size()); + for (size_t ch = 0; ch < num_render_channels; ++ch) { + for (size_t j = 0; j < kFftLengthBy2Plus1; ++j) { + float tmp = + H[p][ch].re[j] * H[p][ch].re[j] + H[p][ch].im[j] * H[p][ch].im[j]; + (*H2)[p][j] = std::max((*H2)[p][j], tmp); + } + } } } #if defined(WEBRTC_HAS_NEON) // Computes and stores the frequency response of the filter. -void UpdateFrequencyResponse_NEON( - rtc::ArrayView H, +void ComputeFrequencyResponse_Neon( + size_t num_partitions, + const std::vector>& H, std::vector>* H2) { - RTC_DCHECK_EQ(H.size(), H2->size()); - for (size_t k = 0; k < H.size(); ++k) { - for (size_t j = 0; j < kFftLengthBy2; j += 4) { - const float32x4_t re = vld1q_f32(&H[k].re[j]); - const float32x4_t im = vld1q_f32(&H[k].im[j]); - float32x4_t H2_k_j = vmulq_f32(re, re); - H2_k_j = vmlaq_f32(H2_k_j, im, im); - vst1q_f32(&(*H2)[k][j], H2_k_j); + for (auto& H2_ch : *H2) { + H2_ch.fill(0.f); + } + + const size_t num_render_channels = H[0].size(); + RTC_DCHECK_EQ(H.size(), H2->capacity()); + for (size_t p = 0; p < num_partitions; ++p) { + RTC_DCHECK_EQ(kFftLengthBy2Plus1, (*H2)[p].size()); + for (size_t ch = 0; ch < num_render_channels; ++ch) { + for (size_t j = 0; j < kFftLengthBy2; j += 4) { + const float32x4_t re = vld1q_f32(&H[p][ch].re[j]); + const float32x4_t im = vld1q_f32(&H[p][ch].im[j]); + float32x4_t H2_new = vmulq_f32(re, re); + H2_new = vmlaq_f32(H2_new, im, im); + float32x4_t H2_p_j = vld1q_f32(&(*H2)[p][j]); + H2_p_j = vmaxq_f32(H2_p_j, H2_new); + vst1q_f32(&(*H2)[p][j], H2_p_j); + } + float H2_new = H[p][ch].re[kFftLengthBy2] * H[p][ch].re[kFftLengthBy2] + + H[p][ch].im[kFftLengthBy2] * H[p][ch].im[kFftLengthBy2]; + (*H2)[p][kFftLengthBy2] = std::max((*H2)[p][kFftLengthBy2], H2_new); } - (*H2)[k][kFftLengthBy2] = H[k].re[kFftLengthBy2] * H[k].re[kFftLengthBy2] + - H[k].im[kFftLengthBy2] * H[k].im[kFftLengthBy2]; } } #endif #if defined(WEBRTC_ARCH_X86_FAMILY) // Computes and stores the frequency response of the filter. -void UpdateFrequencyResponse_SSE2( - rtc::ArrayView H, +void ComputeFrequencyResponse_Sse2( + size_t num_partitions, + const std::vector>& H, std::vector>* H2) { - RTC_DCHECK_EQ(H.size(), H2->size()); - for (size_t k = 0; k < H.size(); ++k) { - for (size_t j = 0; j < kFftLengthBy2; j += 4) { - const __m128 re = _mm_loadu_ps(&H[k].re[j]); - const __m128 re2 = _mm_mul_ps(re, re); - const __m128 im = _mm_loadu_ps(&H[k].im[j]); - const __m128 im2 = _mm_mul_ps(im, im); - const __m128 H2_k_j = _mm_add_ps(re2, im2); - _mm_storeu_ps(&(*H2)[k][j], H2_k_j); + for (auto& H2_ch : *H2) { + H2_ch.fill(0.f); + } + + const size_t num_render_channels = H[0].size(); + RTC_DCHECK_EQ(H.size(), H2->capacity()); + // constexpr __mmmask8 kMaxMask = static_cast<__mmmask8>(256u); + for (size_t p = 0; p < num_partitions; ++p) { + RTC_DCHECK_EQ(kFftLengthBy2Plus1, (*H2)[p].size()); + for (size_t ch = 0; ch < num_render_channels; ++ch) { + for (size_t j = 0; j < kFftLengthBy2; j += 4) { + const __m128 re = _mm_loadu_ps(&H[p][ch].re[j]); + const __m128 re2 = _mm_mul_ps(re, re); + const __m128 im = _mm_loadu_ps(&H[p][ch].im[j]); + const __m128 im2 = _mm_mul_ps(im, im); + const __m128 H2_new = _mm_add_ps(re2, im2); + __m128 H2_k_j = _mm_loadu_ps(&(*H2)[p][j]); + H2_k_j = _mm_max_ps(H2_k_j, H2_new); + _mm_storeu_ps(&(*H2)[p][j], H2_k_j); + } + float H2_new = H[p][ch].re[kFftLengthBy2] * H[p][ch].re[kFftLengthBy2] + + H[p][ch].im[kFftLengthBy2] * H[p][ch].im[kFftLengthBy2]; + (*H2)[p][kFftLengthBy2] = std::max((*H2)[p][kFftLengthBy2], H2_new); } - (*H2)[k][kFftLengthBy2] = H[k].re[kFftLengthBy2] * H[k].re[kFftLengthBy2] + - H[k].im[kFftLengthBy2] * H[k].im[kFftLengthBy2]; } } #endif - // Adapts the filter partitions as H(t+1)=H(t)+G(t)*conj(X(t)). void AdaptPartitions(const RenderBuffer& render_buffer, const FftData& G, - rtc::ArrayView H) { + size_t num_partitions, + std::vector>* H) { rtc::ArrayView> render_buffer_data = render_buffer.GetFftBuffer(); size_t index = render_buffer.Position(); - for (auto& H_j : H) { - const FftData& X = render_buffer_data[index][/*channel=*/0]; - for (size_t k = 0; k < kFftLengthBy2Plus1; ++k) { - H_j.re[k] += X.re[k] * G.re[k] + X.im[k] * G.im[k]; - H_j.im[k] += X.re[k] * G.im[k] - X.im[k] * G.re[k]; + const size_t num_render_channels = render_buffer_data[index].size(); + for (size_t p = 0; p < num_partitions; ++p) { + for (size_t ch = 0; ch < num_render_channels; ++ch) { + const FftData& X_p_ch = render_buffer_data[index][ch]; + FftData& H_p_ch = (*H)[p][ch]; + for (size_t k = 0; k < kFftLengthBy2Plus1; ++k) { + H_p_ch.re[k] += X_p_ch.re[k] * G.re[k] + X_p_ch.im[k] * G.im[k]; + H_p_ch.im[k] += X_p_ch.re[k] * G.im[k] - X_p_ch.im[k] * G.re[k]; + } } - index = index < (render_buffer_data.size() - 1) ? index + 1 : 0; } } #if defined(WEBRTC_HAS_NEON) -// Adapts the filter partitions. (NEON variant) -void AdaptPartitions_NEON(const RenderBuffer& render_buffer, +// Adapts the filter partitions. (Neon variant) +void AdaptPartitions_Neon(const RenderBuffer& render_buffer, const FftData& G, - rtc::ArrayView H) { + size_t num_partitions, + std::vector>* H) { rtc::ArrayView> render_buffer_data = render_buffer.GetFftBuffer(); - const int lim1 = - std::min(render_buffer_data.size() - render_buffer.Position(), H.size()); - const int lim2 = H.size(); - constexpr int kNumFourBinBands = kFftLengthBy2 / 4; - FftData* H_j = &H[0]; - const std::vector* X_channels = - &render_buffer_data[render_buffer.Position()]; - int limit = lim1; - int j = 0; - do { - for (; j < limit; ++j, ++H_j, ++X_channels) { - const FftData& X = (*X_channels)[/*channel=*/0]; - for (int k = 0, n = 0; n < kNumFourBinBands; ++n, k += 4) { - const float32x4_t G_re = vld1q_f32(&G.re[k]); - const float32x4_t G_im = vld1q_f32(&G.im[k]); - const float32x4_t X_re = vld1q_f32(&X.re[k]); - const float32x4_t X_im = vld1q_f32(&X.im[k]); - const float32x4_t H_re = vld1q_f32(&H_j->re[k]); - const float32x4_t H_im = vld1q_f32(&H_j->im[k]); - const float32x4_t a = vmulq_f32(X_re, G_re); - const float32x4_t e = vmlaq_f32(a, X_im, G_im); - const float32x4_t c = vmulq_f32(X_re, G_im); - const float32x4_t f = vmlsq_f32(c, X_im, G_re); - const float32x4_t g = vaddq_f32(H_re, e); - const float32x4_t h = vaddq_f32(H_im, f); + const size_t num_render_channels = render_buffer_data[0].size(); + const size_t lim1 = std::min( + render_buffer_data.size() - render_buffer.Position(), num_partitions); + const size_t lim2 = num_partitions; + constexpr size_t kNumFourBinBands = kFftLengthBy2 / 4; - vst1q_f32(&H_j->re[k], g); - vst1q_f32(&H_j->im[k], h); + size_t X_partition = render_buffer.Position(); + size_t limit = lim1; + size_t p = 0; + do { + for (; p < limit; ++p, ++X_partition) { + for (size_t ch = 0; ch < num_render_channels; ++ch) { + FftData& H_p_ch = (*H)[p][ch]; + const FftData& X = render_buffer_data[X_partition][ch]; + for (size_t k = 0, n = 0; n < kNumFourBinBands; ++n, k += 4) { + const float32x4_t G_re = vld1q_f32(&G.re[k]); + const float32x4_t G_im = vld1q_f32(&G.im[k]); + const float32x4_t X_re = vld1q_f32(&X.re[k]); + const float32x4_t X_im = vld1q_f32(&X.im[k]); + const float32x4_t H_re = vld1q_f32(&H_p_ch.re[k]); + const float32x4_t H_im = vld1q_f32(&H_p_ch.im[k]); + const float32x4_t a = vmulq_f32(X_re, G_re); + const float32x4_t e = vmlaq_f32(a, X_im, G_im); + const float32x4_t c = vmulq_f32(X_re, G_im); + const float32x4_t f = vmlsq_f32(c, X_im, G_re); + const float32x4_t g = vaddq_f32(H_re, e); + const float32x4_t h = vaddq_f32(H_im, f); + vst1q_f32(&H_p_ch.re[k], g); + vst1q_f32(&H_p_ch.im[k], h); + } } } - X_channels = &render_buffer_data[0]; + X_partition = 0; limit = lim2; - } while (j < lim2); + } while (p < lim2); - H_j = &H[0]; - X_channels = &render_buffer_data[render_buffer.Position()]; + X_partition = render_buffer.Position(); limit = lim1; - j = 0; + p = 0; do { - for (; j < limit; ++j, ++H_j, ++X_channels) { - const FftData& X = (*X_channels)[/*channel=*/0]; - H_j->re[kFftLengthBy2] += X.re[kFftLengthBy2] * G.re[kFftLengthBy2] + - X.im[kFftLengthBy2] * G.im[kFftLengthBy2]; - H_j->im[kFftLengthBy2] += X.re[kFftLengthBy2] * G.im[kFftLengthBy2] - - X.im[kFftLengthBy2] * G.re[kFftLengthBy2]; - } + for (; p < limit; ++p, ++X_partition) { + for (size_t ch = 0; ch < num_render_channels; ++ch) { + FftData& H_p_ch = (*H)[p][ch]; + const FftData& X = render_buffer_data[X_partition][ch]; - X_channels = &render_buffer_data[0]; + H_p_ch.re[kFftLengthBy2] += X.re[kFftLengthBy2] * G.re[kFftLengthBy2] + + X.im[kFftLengthBy2] * G.im[kFftLengthBy2]; + H_p_ch.im[kFftLengthBy2] += X.re[kFftLengthBy2] * G.im[kFftLengthBy2] - + X.im[kFftLengthBy2] * G.re[kFftLengthBy2]; + } + } + X_partition = 0; limit = lim2; - } while (j < lim2); + } while (p < lim2); } #endif #if defined(WEBRTC_ARCH_X86_FAMILY) // Adapts the filter partitions. (SSE2 variant) -void AdaptPartitions_SSE2(const RenderBuffer& render_buffer, +void AdaptPartitions_Sse2(const RenderBuffer& render_buffer, const FftData& G, - rtc::ArrayView H) { + size_t num_partitions, + std::vector>* H) { rtc::ArrayView> render_buffer_data = render_buffer.GetFftBuffer(); - const int lim1 = - std::min(render_buffer_data.size() - render_buffer.Position(), H.size()); - const int lim2 = H.size(); - constexpr int kNumFourBinBands = kFftLengthBy2 / 4; - FftData* H_j; - const std::vector* X_channels; - int limit; - int j; - for (int k = 0, n = 0; n < kNumFourBinBands; ++n, k += 4) { - const __m128 G_re = _mm_loadu_ps(&G.re[k]); - const __m128 G_im = _mm_loadu_ps(&G.im[k]); + const size_t num_render_channels = render_buffer_data[0].size(); + const size_t lim1 = std::min( + render_buffer_data.size() - render_buffer.Position(), num_partitions); + const size_t lim2 = num_partitions; + constexpr size_t kNumFourBinBands = kFftLengthBy2 / 4; - H_j = &H[0]; - X_channels = &render_buffer_data[render_buffer.Position()]; - limit = lim1; - j = 0; - do { - for (; j < limit; ++j, ++H_j, ++X_channels) { - const FftData& X = (*X_channels)[/*channel=*/0]; - const __m128 X_re = _mm_loadu_ps(&X.re[k]); - const __m128 X_im = _mm_loadu_ps(&X.im[k]); - const __m128 H_re = _mm_loadu_ps(&H_j->re[k]); - const __m128 H_im = _mm_loadu_ps(&H_j->im[k]); - const __m128 a = _mm_mul_ps(X_re, G_re); - const __m128 b = _mm_mul_ps(X_im, G_im); - const __m128 c = _mm_mul_ps(X_re, G_im); - const __m128 d = _mm_mul_ps(X_im, G_re); - const __m128 e = _mm_add_ps(a, b); - const __m128 f = _mm_sub_ps(c, d); - const __m128 g = _mm_add_ps(H_re, e); - const __m128 h = _mm_add_ps(H_im, f); - _mm_storeu_ps(&H_j->re[k], g); - _mm_storeu_ps(&H_j->im[k], h); - } - - X_channels = &render_buffer_data[0]; - limit = lim2; - } while (j < lim2); - } - - H_j = &H[0]; - X_channels = &render_buffer_data[render_buffer.Position()]; - limit = lim1; - j = 0; + size_t X_partition = render_buffer.Position(); + size_t limit = lim1; + size_t p = 0; do { - for (; j < limit; ++j, ++H_j, ++X_channels) { - const FftData& X = (*X_channels)[/*channel=*/0]; - H_j->re[kFftLengthBy2] += X.re[kFftLengthBy2] * G.re[kFftLengthBy2] + - X.im[kFftLengthBy2] * G.im[kFftLengthBy2]; - H_j->im[kFftLengthBy2] += X.re[kFftLengthBy2] * G.im[kFftLengthBy2] - - X.im[kFftLengthBy2] * G.re[kFftLengthBy2]; + for (; p < limit; ++p, ++X_partition) { + for (size_t ch = 0; ch < num_render_channels; ++ch) { + FftData& H_p_ch = (*H)[p][ch]; + const FftData& X = render_buffer_data[X_partition][ch]; + + for (size_t k = 0, n = 0; n < kNumFourBinBands; ++n, k += 4) { + const __m128 G_re = _mm_loadu_ps(&G.re[k]); + const __m128 G_im = _mm_loadu_ps(&G.im[k]); + const __m128 X_re = _mm_loadu_ps(&X.re[k]); + const __m128 X_im = _mm_loadu_ps(&X.im[k]); + const __m128 H_re = _mm_loadu_ps(&H_p_ch.re[k]); + const __m128 H_im = _mm_loadu_ps(&H_p_ch.im[k]); + const __m128 a = _mm_mul_ps(X_re, G_re); + const __m128 b = _mm_mul_ps(X_im, G_im); + const __m128 c = _mm_mul_ps(X_re, G_im); + const __m128 d = _mm_mul_ps(X_im, G_re); + const __m128 e = _mm_add_ps(a, b); + const __m128 f = _mm_sub_ps(c, d); + const __m128 g = _mm_add_ps(H_re, e); + const __m128 h = _mm_add_ps(H_im, f); + _mm_storeu_ps(&H_p_ch.re[k], g); + _mm_storeu_ps(&H_p_ch.im[k], h); + } + } + } + X_partition = 0; + limit = lim2; + } while (p < lim2); + + X_partition = render_buffer.Position(); + limit = lim1; + p = 0; + do { + for (; p < limit; ++p, ++X_partition) { + for (size_t ch = 0; ch < num_render_channels; ++ch) { + FftData& H_p_ch = (*H)[p][ch]; + const FftData& X = render_buffer_data[X_partition][ch]; + + H_p_ch.re[kFftLengthBy2] += X.re[kFftLengthBy2] * G.re[kFftLengthBy2] + + X.im[kFftLengthBy2] * G.im[kFftLengthBy2]; + H_p_ch.im[kFftLengthBy2] += X.re[kFftLengthBy2] * G.im[kFftLengthBy2] - + X.im[kFftLengthBy2] * G.re[kFftLengthBy2]; + } } - X_channels = &render_buffer_data[0]; + X_partition = 0; limit = lim2; - } while (j < lim2); + } while (p < lim2); } #endif // Produces the filter output. void ApplyFilter(const RenderBuffer& render_buffer, - rtc::ArrayView H, + size_t num_partitions, + const std::vector>& H, FftData* S) { S->re.fill(0.f); S->im.fill(0.f); @@ -238,184 +288,219 @@ void ApplyFilter(const RenderBuffer& render_buffer, rtc::ArrayView> render_buffer_data = render_buffer.GetFftBuffer(); size_t index = render_buffer.Position(); - for (auto& H_j : H) { - const FftData& X = render_buffer_data[index][0]; - for (size_t k = 0; k < kFftLengthBy2Plus1; ++k) { - S->re[k] += X.re[k] * H_j.re[k] - X.im[k] * H_j.im[k]; - S->im[k] += X.re[k] * H_j.im[k] + X.im[k] * H_j.re[k]; + const size_t num_render_channels = render_buffer_data[index].size(); + for (size_t p = 0; p < num_partitions; ++p) { + RTC_DCHECK_EQ(num_render_channels, H[p].size()); + for (size_t ch = 0; ch < num_render_channels; ++ch) { + const FftData& X_p_ch = render_buffer_data[index][ch]; + const FftData& H_p_ch = H[p][ch]; + for (size_t k = 0; k < kFftLengthBy2Plus1; ++k) { + S->re[k] += X_p_ch.re[k] * H_p_ch.re[k] - X_p_ch.im[k] * H_p_ch.im[k]; + S->im[k] += X_p_ch.re[k] * H_p_ch.im[k] + X_p_ch.im[k] * H_p_ch.re[k]; + } } index = index < (render_buffer_data.size() - 1) ? index + 1 : 0; } } #if defined(WEBRTC_HAS_NEON) -// Produces the filter output (NEON variant). -void ApplyFilter_NEON(const RenderBuffer& render_buffer, - rtc::ArrayView H, +// Produces the filter output (Neon variant). +void ApplyFilter_Neon(const RenderBuffer& render_buffer, + size_t num_partitions, + const std::vector>& H, FftData* S) { + // const RenderBuffer& render_buffer, + // rtc::ArrayView H, + // FftData* S) { RTC_DCHECK_GE(H.size(), H.size() - 1); S->Clear(); rtc::ArrayView> render_buffer_data = render_buffer.GetFftBuffer(); - const int lim1 = - std::min(render_buffer_data.size() - render_buffer.Position(), H.size()); - const int lim2 = H.size(); - constexpr int kNumFourBinBands = kFftLengthBy2 / 4; - const FftData* H_j = &H[0]; - const std::vector* X_channels = - &render_buffer_data[render_buffer.Position()]; + const size_t num_render_channels = render_buffer_data[0].size(); + const size_t lim1 = std::min( + render_buffer_data.size() - render_buffer.Position(), num_partitions); + const size_t lim2 = num_partitions; + constexpr size_t kNumFourBinBands = kFftLengthBy2 / 4; - int j = 0; - int limit = lim1; + size_t X_partition = render_buffer.Position(); + size_t p = 0; + size_t limit = lim1; do { - for (; j < limit; ++j, ++H_j, ++X_channels) { - const FftData& X = (*X_channels)[/*channel=*/0]; - for (int k = 0, n = 0; n < kNumFourBinBands; ++n, k += 4) { - const float32x4_t X_re = vld1q_f32(&X.re[k]); - const float32x4_t X_im = vld1q_f32(&X.im[k]); - const float32x4_t H_re = vld1q_f32(&H_j->re[k]); - const float32x4_t H_im = vld1q_f32(&H_j->im[k]); - const float32x4_t S_re = vld1q_f32(&S->re[k]); - const float32x4_t S_im = vld1q_f32(&S->im[k]); - const float32x4_t a = vmulq_f32(X_re, H_re); - const float32x4_t e = vmlsq_f32(a, X_im, H_im); - const float32x4_t c = vmulq_f32(X_re, H_im); - const float32x4_t f = vmlaq_f32(c, X_im, H_re); - const float32x4_t g = vaddq_f32(S_re, e); - const float32x4_t h = vaddq_f32(S_im, f); - vst1q_f32(&S->re[k], g); - vst1q_f32(&S->im[k], h); + for (; p < limit; ++p, ++X_partition) { + for (size_t ch = 0; ch < num_render_channels; ++ch) { + const FftData& H_p_ch = H[p][ch]; + const FftData& X = render_buffer_data[X_partition][ch]; + for (size_t k = 0, n = 0; n < kNumFourBinBands; ++n, k += 4) { + const float32x4_t X_re = vld1q_f32(&X.re[k]); + const float32x4_t X_im = vld1q_f32(&X.im[k]); + const float32x4_t H_re = vld1q_f32(&H_p_ch.re[k]); + const float32x4_t H_im = vld1q_f32(&H_p_ch.im[k]); + const float32x4_t S_re = vld1q_f32(&S->re[k]); + const float32x4_t S_im = vld1q_f32(&S->im[k]); + const float32x4_t a = vmulq_f32(X_re, H_re); + const float32x4_t e = vmlsq_f32(a, X_im, H_im); + const float32x4_t c = vmulq_f32(X_re, H_im); + const float32x4_t f = vmlaq_f32(c, X_im, H_re); + const float32x4_t g = vaddq_f32(S_re, e); + const float32x4_t h = vaddq_f32(S_im, f); + vst1q_f32(&S->re[k], g); + vst1q_f32(&S->im[k], h); + } } } limit = lim2; - X_channels = &render_buffer_data[0]; - } while (j < lim2); + X_partition = 0; + } while (p < lim2); - H_j = &H[0]; - X_channels = &render_buffer_data[render_buffer.Position()]; - j = 0; + X_partition = render_buffer.Position(); + p = 0; limit = lim1; do { - for (; j < limit; ++j, ++H_j, ++X_channels) { - const FftData& X = (*X_channels)[/*channel=*/0]; - S->re[kFftLengthBy2] += X.re[kFftLengthBy2] * H_j->re[kFftLengthBy2] - - X.im[kFftLengthBy2] * H_j->im[kFftLengthBy2]; - S->im[kFftLengthBy2] += X.re[kFftLengthBy2] * H_j->im[kFftLengthBy2] + - X.im[kFftLengthBy2] * H_j->re[kFftLengthBy2]; + for (; p < limit; ++p, ++X_partition) { + for (size_t ch = 0; ch < num_render_channels; ++ch) { + const FftData& H_p_ch = H[p][ch]; + const FftData& X = render_buffer_data[X_partition][ch]; + S->re[kFftLengthBy2] += X.re[kFftLengthBy2] * H_p_ch.re[kFftLengthBy2] - + X.im[kFftLengthBy2] * H_p_ch.im[kFftLengthBy2]; + S->im[kFftLengthBy2] += X.re[kFftLengthBy2] * H_p_ch.im[kFftLengthBy2] + + X.im[kFftLengthBy2] * H_p_ch.re[kFftLengthBy2]; + } } limit = lim2; - X_channels = &render_buffer_data[0]; - } while (j < lim2); + X_partition = 0; + } while (p < lim2); } #endif #if defined(WEBRTC_ARCH_X86_FAMILY) // Produces the filter output (SSE2 variant). -void ApplyFilter_SSE2(const RenderBuffer& render_buffer, - rtc::ArrayView H, +void ApplyFilter_Sse2(const RenderBuffer& render_buffer, + size_t num_partitions, + const std::vector>& H, FftData* S) { + // const RenderBuffer& render_buffer, + // rtc::ArrayView H, + // FftData* S) { RTC_DCHECK_GE(H.size(), H.size() - 1); S->re.fill(0.f); S->im.fill(0.f); rtc::ArrayView> render_buffer_data = render_buffer.GetFftBuffer(); - const int lim1 = - std::min(render_buffer_data.size() - render_buffer.Position(), H.size()); - const int lim2 = H.size(); - constexpr int kNumFourBinBands = kFftLengthBy2 / 4; - const FftData* H_j = &H[0]; - const std::vector* X_channels = - &render_buffer_data[render_buffer.Position()]; + const size_t num_render_channels = render_buffer_data[0].size(); + const size_t lim1 = std::min( + render_buffer_data.size() - render_buffer.Position(), num_partitions); + const size_t lim2 = num_partitions; + constexpr size_t kNumFourBinBands = kFftLengthBy2 / 4; - int j = 0; - int limit = lim1; + size_t X_partition = render_buffer.Position(); + size_t p = 0; + size_t limit = lim1; do { - for (; j < limit; ++j, ++H_j, ++X_channels) { - const FftData& X = (*X_channels)[/*channel=*/0]; - for (int k = 0, n = 0; n < kNumFourBinBands; ++n, k += 4) { - const __m128 X_re = _mm_loadu_ps(&X.re[k]); - const __m128 X_im = _mm_loadu_ps(&X.im[k]); - const __m128 H_re = _mm_loadu_ps(&H_j->re[k]); - const __m128 H_im = _mm_loadu_ps(&H_j->im[k]); - const __m128 S_re = _mm_loadu_ps(&S->re[k]); - const __m128 S_im = _mm_loadu_ps(&S->im[k]); - const __m128 a = _mm_mul_ps(X_re, H_re); - const __m128 b = _mm_mul_ps(X_im, H_im); - const __m128 c = _mm_mul_ps(X_re, H_im); - const __m128 d = _mm_mul_ps(X_im, H_re); - const __m128 e = _mm_sub_ps(a, b); - const __m128 f = _mm_add_ps(c, d); - const __m128 g = _mm_add_ps(S_re, e); - const __m128 h = _mm_add_ps(S_im, f); - _mm_storeu_ps(&S->re[k], g); - _mm_storeu_ps(&S->im[k], h); + for (; p < limit; ++p, ++X_partition) { + for (size_t ch = 0; ch < num_render_channels; ++ch) { + const FftData& H_p_ch = H[p][ch]; + const FftData& X = render_buffer_data[X_partition][ch]; + for (size_t k = 0, n = 0; n < kNumFourBinBands; ++n, k += 4) { + const __m128 X_re = _mm_loadu_ps(&X.re[k]); + const __m128 X_im = _mm_loadu_ps(&X.im[k]); + const __m128 H_re = _mm_loadu_ps(&H_p_ch.re[k]); + const __m128 H_im = _mm_loadu_ps(&H_p_ch.im[k]); + const __m128 S_re = _mm_loadu_ps(&S->re[k]); + const __m128 S_im = _mm_loadu_ps(&S->im[k]); + const __m128 a = _mm_mul_ps(X_re, H_re); + const __m128 b = _mm_mul_ps(X_im, H_im); + const __m128 c = _mm_mul_ps(X_re, H_im); + const __m128 d = _mm_mul_ps(X_im, H_re); + const __m128 e = _mm_sub_ps(a, b); + const __m128 f = _mm_add_ps(c, d); + const __m128 g = _mm_add_ps(S_re, e); + const __m128 h = _mm_add_ps(S_im, f); + _mm_storeu_ps(&S->re[k], g); + _mm_storeu_ps(&S->im[k], h); + } } } limit = lim2; - X_channels = &render_buffer_data[0]; - } while (j < lim2); + X_partition = 0; + } while (p < lim2); - H_j = &H[0]; - X_channels = &render_buffer_data[render_buffer.Position()]; - j = 0; + X_partition = render_buffer.Position(); + p = 0; limit = lim1; do { - for (; j < limit; ++j, ++H_j, ++X_channels) { - const FftData& X = (*X_channels)[/*channel=*/0]; - S->re[kFftLengthBy2] += X.re[kFftLengthBy2] * H_j->re[kFftLengthBy2] - - X.im[kFftLengthBy2] * H_j->im[kFftLengthBy2]; - S->im[kFftLengthBy2] += X.re[kFftLengthBy2] * H_j->im[kFftLengthBy2] + - X.im[kFftLengthBy2] * H_j->re[kFftLengthBy2]; + for (; p < limit; ++p, ++X_partition) { + for (size_t ch = 0; ch < num_render_channels; ++ch) { + const FftData& H_p_ch = H[p][ch]; + const FftData& X = render_buffer_data[X_partition][ch]; + S->re[kFftLengthBy2] += X.re[kFftLengthBy2] * H_p_ch.re[kFftLengthBy2] - + X.im[kFftLengthBy2] * H_p_ch.im[kFftLengthBy2]; + S->im[kFftLengthBy2] += X.re[kFftLengthBy2] * H_p_ch.im[kFftLengthBy2] + + X.im[kFftLengthBy2] * H_p_ch.re[kFftLengthBy2]; + } } limit = lim2; - X_channels = &render_buffer_data[0]; - } while (j < lim2); + X_partition = 0; + } while (p < lim2); } #endif } // namespace aec3 +namespace { + +// Ensures that the newly added filter partitions after a size increase are set +// to zero. +void ZeroFilter(size_t old_size, + size_t new_size, + std::vector>* H) { + RTC_DCHECK_GE(H->size(), old_size); + RTC_DCHECK_GE(H->size(), new_size); + + for (size_t p = old_size; p < new_size; ++p) { + RTC_DCHECK_EQ((*H)[p].size(), (*H)[0].size()); + for (size_t ch = 0; ch < (*H)[0].size(); ++ch) { + (*H)[p][ch].Clear(); + } + } +} + +} // namespace + AdaptiveFirFilter::AdaptiveFirFilter(size_t max_size_partitions, size_t initial_size_partitions, size_t size_change_duration_blocks, size_t num_render_channels, - size_t num_capture_channels, Aec3Optimization optimization, ApmDataDumper* data_dumper) : data_dumper_(data_dumper), fft_(), optimization_(optimization), + num_render_channels_(num_render_channels), max_size_partitions_(max_size_partitions), size_change_duration_blocks_( static_cast(size_change_duration_blocks)), current_size_partitions_(initial_size_partitions), target_size_partitions_(initial_size_partitions), old_target_size_partitions_(initial_size_partitions), - H_(max_size_partitions_) { + H_(max_size_partitions_, std::vector(num_render_channels_)) { RTC_DCHECK(data_dumper_); RTC_DCHECK_GE(max_size_partitions, initial_size_partitions); RTC_DCHECK_LT(0, size_change_duration_blocks_); one_by_size_change_duration_blocks_ = 1.f / size_change_duration_blocks_; - for (auto& H_j : H_) { - H_j.Clear(); - } + ZeroFilter(0, max_size_partitions_, &H_); + SetSizePartitions(current_size_partitions_, true); } AdaptiveFirFilter::~AdaptiveFirFilter() = default; void AdaptiveFirFilter::HandleEchoPathChange() { - size_t current_size_partitions = H_.size(); - H_.resize(max_size_partitions_); - - for (size_t k = current_size_partitions; k < max_size_partitions_; ++k) { - H_[k].Clear(); - } - H_.resize(current_size_partitions); + // TODO(peah): Check the value and purpose of the code below. + ZeroFilter(current_size_partitions_, max_size_partitions_, &H_); } void AdaptiveFirFilter::SetSizePartitions(size_t size, bool immediate_effect) { @@ -424,24 +509,22 @@ void AdaptiveFirFilter::SetSizePartitions(size_t size, bool immediate_effect) { target_size_partitions_ = std::min(max_size_partitions_, size); if (immediate_effect) { + size_t old_size_partitions_ = current_size_partitions_; current_size_partitions_ = old_target_size_partitions_ = target_size_partitions_; - ResetFilterBuffersToCurrentSize(); + ZeroFilter(old_size_partitions_, current_size_partitions_, &H_); + + partition_to_constrain_ = + std::min(partition_to_constrain_, current_size_partitions_ - 1); size_change_counter_ = 0; } else { size_change_counter_ = size_change_duration_blocks_; } } -void AdaptiveFirFilter::ResetFilterBuffersToCurrentSize() { - H_.resize(current_size_partitions_); - RTC_DCHECK_LT(0, current_size_partitions_); - partition_to_constrain_ = - std::min(partition_to_constrain_, current_size_partitions_ - 1); -} - void AdaptiveFirFilter::UpdateSize() { RTC_DCHECK_GE(size_change_duration_blocks_, size_change_counter_); + size_t old_size_partitions_ = current_size_partitions_; if (size_change_counter_ > 0) { --size_change_counter_; @@ -455,11 +538,13 @@ void AdaptiveFirFilter::UpdateSize() { current_size_partitions_ = average(old_target_size_partitions_, target_size_partitions_, change_factor); - ResetFilterBuffersToCurrentSize(); + partition_to_constrain_ = + std::min(partition_to_constrain_, current_size_partitions_ - 1); } else { current_size_partitions_ = old_target_size_partitions_ = target_size_partitions_; } + ZeroFilter(old_size_partitions_, current_size_partitions_, &H_); RTC_DCHECK_LE(0, size_change_counter_); } @@ -469,16 +554,16 @@ void AdaptiveFirFilter::Filter(const RenderBuffer& render_buffer, switch (optimization_) { #if defined(WEBRTC_ARCH_X86_FAMILY) case Aec3Optimization::kSse2: - aec3::ApplyFilter_SSE2(render_buffer, H_, S); + aec3::ApplyFilter_Sse2(render_buffer, current_size_partitions_, H_, S); break; #endif #if defined(WEBRTC_HAS_NEON) case Aec3Optimization::kNeon: - aec3::ApplyFilter_NEON(render_buffer, H_, S); + aec3::ApplyFilter_Neon(render_buffer, current_size_partitions_, H_, S); break; #endif default: - aec3::ApplyFilter(render_buffer, H_, S); + aec3::ApplyFilter(render_buffer, current_size_partitions_, H_, S); } } @@ -503,28 +588,23 @@ void AdaptiveFirFilter::Adapt(const RenderBuffer& render_buffer, void AdaptiveFirFilter::ComputeFrequencyResponse( std::vector>* H2) const { - RTC_DCHECK_EQ(max_size_partitions_, H2->capacity()); + RTC_DCHECK_GE(max_size_partitions_, H2->capacity()); - if (H2->size() > H_.size()) { - for (size_t k = H_.size(); k < H2->size(); ++k) { - (*H2)[k].fill(0.f); - } - } - H2->resize(H_.size()); + H2->resize(current_size_partitions_); switch (optimization_) { #if defined(WEBRTC_ARCH_X86_FAMILY) case Aec3Optimization::kSse2: - aec3::UpdateFrequencyResponse_SSE2(H_, H2); + aec3::ComputeFrequencyResponse_Sse2(current_size_partitions_, H_, H2); break; #endif #if defined(WEBRTC_HAS_NEON) case Aec3Optimization::kNeon: - aec3::UpdateFrequencyResponse_NEON(H_, H2); + aec3::ComputeFrequencyResponse_Neon(current_size_partitions_, H_, H2); break; #endif default: - aec3::UpdateFrequencyResponse(H_, H2); + aec3::ComputeFrequencyResponse(current_size_partitions_, H_, H2); } } @@ -537,16 +617,18 @@ void AdaptiveFirFilter::AdaptAndUpdateSize(const RenderBuffer& render_buffer, switch (optimization_) { #if defined(WEBRTC_ARCH_X86_FAMILY) case Aec3Optimization::kSse2: - aec3::AdaptPartitions_SSE2(render_buffer, G, H_); + aec3::AdaptPartitions_Sse2(render_buffer, G, current_size_partitions_, + &H_); break; #endif #if defined(WEBRTC_HAS_NEON) case Aec3Optimization::kNeon: - aec3::AdaptPartitions_NEON(render_buffer, G, H_); + aec3::AdaptPartitions_Neon(render_buffer, G, current_size_partitions_, + &H_); break; #endif default: - aec3::AdaptPartitions(render_buffer, G, H_); + aec3::AdaptPartitions(render_buffer, G, current_size_partitions_, &H_); } } @@ -557,62 +639,91 @@ void AdaptiveFirFilter::ConstrainAndUpdateImpulseResponse( std::vector* impulse_response) { RTC_DCHECK_EQ(GetTimeDomainLength(max_size_partitions_), impulse_response->capacity()); - impulse_response->resize(GetTimeDomainLength(current_size_partitions_)); std::array h; - fft_.Ifft(H_[partition_to_constrain_], &h); + impulse_response->resize(GetTimeDomainLength(current_size_partitions_)); + std::fill( + impulse_response->begin() + partition_to_constrain_ * kFftLengthBy2, + impulse_response->begin() + (partition_to_constrain_ + 1) * kFftLengthBy2, + 0.f); - static constexpr float kScale = 1.0f / kFftLengthBy2; - std::for_each(h.begin(), h.begin() + kFftLengthBy2, - [](float& a) { a *= kScale; }); - std::fill(h.begin() + kFftLengthBy2, h.end(), 0.f); + for (size_t ch = 0; ch < num_render_channels_; ++ch) { + fft_.Ifft(H_[partition_to_constrain_][ch], &h); - std::copy( - h.begin(), h.begin() + kFftLengthBy2, - impulse_response->begin() + partition_to_constrain_ * kFftLengthBy2); + static constexpr float kScale = 1.0f / kFftLengthBy2; + std::for_each(h.begin(), h.begin() + kFftLengthBy2, + [](float& a) { a *= kScale; }); + std::fill(h.begin() + kFftLengthBy2, h.end(), 0.f); - fft_.Fft(&h, &H_[partition_to_constrain_]); + if (ch == 0) { + std::copy( + h.begin(), h.begin() + kFftLengthBy2, + impulse_response->begin() + partition_to_constrain_ * kFftLengthBy2); + } else { + for (size_t k = 0, j = partition_to_constrain_ * kFftLengthBy2; + k < kFftLengthBy2; ++k, ++j) { + if (fabsf((*impulse_response)[j]) < fabsf(h[k])) { + (*impulse_response)[j] = h[k]; + } + } + } - partition_to_constrain_ = partition_to_constrain_ < (H_.size() - 1) - ? partition_to_constrain_ + 1 - : 0; + fft_.Fft(&h, &H_[partition_to_constrain_][ch]); + } + + partition_to_constrain_ = + partition_to_constrain_ < (current_size_partitions_ - 1) + ? partition_to_constrain_ + 1 + : 0; } // Constrains the a partiton of the frequency domain filter to be limited in // time via setting the relevant time-domain coefficients to zero. void AdaptiveFirFilter::Constrain() { std::array h; - fft_.Ifft(H_[partition_to_constrain_], &h); + for (size_t ch = 0; ch < num_render_channels_; ++ch) { + fft_.Ifft(H_[partition_to_constrain_][ch], &h); - static constexpr float kScale = 1.0f / kFftLengthBy2; - std::for_each(h.begin(), h.begin() + kFftLengthBy2, - [](float& a) { a *= kScale; }); - std::fill(h.begin() + kFftLengthBy2, h.end(), 0.f); + static constexpr float kScale = 1.0f / kFftLengthBy2; + std::for_each(h.begin(), h.begin() + kFftLengthBy2, + [](float& a) { a *= kScale; }); + std::fill(h.begin() + kFftLengthBy2, h.end(), 0.f); - fft_.Fft(&h, &H_[partition_to_constrain_]); + fft_.Fft(&h, &H_[partition_to_constrain_][ch]); + } - partition_to_constrain_ = partition_to_constrain_ < (H_.size() - 1) - ? partition_to_constrain_ + 1 - : 0; + partition_to_constrain_ = + partition_to_constrain_ < (current_size_partitions_ - 1) + ? partition_to_constrain_ + 1 + : 0; } void AdaptiveFirFilter::ScaleFilter(float factor) { - for (auto& H : H_) { - for (auto& re : H.re) { - re *= factor; - } - for (auto& im : H.im) { - im *= factor; + for (auto& H_p : H_) { + for (auto& H_p_ch : H_p) { + for (auto& re : H_p_ch.re) { + re *= factor; + } + for (auto& im : H_p_ch.im) { + im *= factor; + } } } } // Set the filter coefficients. -void AdaptiveFirFilter::SetFilter(const std::vector& H) { - const size_t num_partitions = std::min(H_.size(), H.size()); - for (size_t k = 0; k < num_partitions; ++k) { - std::copy(H[k].re.begin(), H[k].re.end(), H_[k].re.begin()); - std::copy(H[k].im.begin(), H[k].im.end(), H_[k].im.begin()); +void AdaptiveFirFilter::SetFilter(size_t num_partitions, + const std::vector>& H) { + const size_t min_num_partitions = + std::min(current_size_partitions_, num_partitions); + for (size_t p = 0; p < min_num_partitions; ++p) { + RTC_DCHECK_EQ(H_[p].size(), H[p].size()); + RTC_DCHECK_EQ(num_render_channels_, H_[p].size()); + + for (size_t ch = 0; ch < num_render_channels_; ++ch) { + std::copy(H[p][ch].re.begin(), H[p][ch].re.end(), H_[p][ch].re.begin()); + std::copy(H[p][ch].im.begin(), H[p][ch].im.end(), H_[p][ch].im.begin()); + } } } diff --git a/modules/audio_processing/aec3/adaptive_fir_filter.h b/modules/audio_processing/aec3/adaptive_fir_filter.h index aec83aabd4..2f6485340f 100644 --- a/modules/audio_processing/aec3/adaptive_fir_filter.h +++ b/modules/audio_processing/aec3/adaptive_fir_filter.h @@ -27,47 +27,56 @@ namespace webrtc { namespace aec3 { // Computes and stores the frequency response of the filter. -void UpdateFrequencyResponse( - rtc::ArrayView H, +void ComputeFrequencyResponse( + size_t num_partitions, + const std::vector>& H, std::vector>* H2); #if defined(WEBRTC_HAS_NEON) -void UpdateFrequencyResponse_NEON( - rtc::ArrayView H, +void ComputeFrequencyResponse_Neon( + size_t num_partitions, + const std::vector>& H, std::vector>* H2); #endif #if defined(WEBRTC_ARCH_X86_FAMILY) -void UpdateFrequencyResponse_SSE2( - rtc::ArrayView H, +void ComputeFrequencyResponse_Sse2( + size_t num_partitions, + const std::vector>& H, std::vector>* H2); #endif // Adapts the filter partitions. void AdaptPartitions(const RenderBuffer& render_buffer, const FftData& G, - rtc::ArrayView H); + size_t num_partitions, + std::vector>* H); #if defined(WEBRTC_HAS_NEON) -void AdaptPartitions_NEON(const RenderBuffer& render_buffer, +void AdaptPartitions_Neon(const RenderBuffer& render_buffer, const FftData& G, - rtc::ArrayView H); + size_t num_partitions, + std::vector>* H); #endif #if defined(WEBRTC_ARCH_X86_FAMILY) -void AdaptPartitions_SSE2(const RenderBuffer& render_buffer, +void AdaptPartitions_Sse2(const RenderBuffer& render_buffer, const FftData& G, - rtc::ArrayView H); + size_t num_partitions, + std::vector>* H); #endif // Produces the filter output. void ApplyFilter(const RenderBuffer& render_buffer, - rtc::ArrayView H, + size_t num_partitions, + const std::vector>& H, FftData* S); #if defined(WEBRTC_HAS_NEON) -void ApplyFilter_NEON(const RenderBuffer& render_buffer, - rtc::ArrayView H, +void ApplyFilter_Neon(const RenderBuffer& render_buffer, + size_t num_partitions, + const std::vector>& H, FftData* S); #endif #if defined(WEBRTC_ARCH_X86_FAMILY) -void ApplyFilter_SSE2(const RenderBuffer& render_buffer, - rtc::ArrayView H, +void ApplyFilter_Sse2(const RenderBuffer& render_buffer, + size_t num_partitions, + const std::vector>& H, FftData* S); #endif @@ -80,7 +89,6 @@ class AdaptiveFirFilter { size_t initial_size_partitions, size_t size_change_duration_blocks, size_t num_render_channels, - size_t num_capture_channels, Aec3Optimization optimization, ApmDataDumper* data_dumper); @@ -106,7 +114,7 @@ class AdaptiveFirFilter { void HandleEchoPathChange(); // Returns the filter size. - size_t SizePartitions() const { return H_.size(); } + size_t SizePartitions() const { return current_size_partitions_; } // Sets the filter size. void SetSizePartitions(size_t size, bool immediate_effect); @@ -119,23 +127,21 @@ class AdaptiveFirFilter { size_t max_filter_size_partitions() const { return max_size_partitions_; } void DumpFilter(const char* name_frequency_domain) { - size_t current_size = H_.size(); - H_.resize(max_size_partitions_); - for (auto& H : H_) { - data_dumper_->DumpRaw(name_frequency_domain, H.re); - data_dumper_->DumpRaw(name_frequency_domain, H.im); + for (size_t p = 0; p < max_size_partitions_; ++p) { + data_dumper_->DumpRaw(name_frequency_domain, H_[p][0].re); + data_dumper_->DumpRaw(name_frequency_domain, H_[p][0].im); } - H_.resize(current_size); } // Scale the filter impulse response and spectrum by a factor. void ScaleFilter(float factor); // Set the filter coefficients. - void SetFilter(const std::vector& H); + void SetFilter(size_t num_partitions, + const std::vector>& H); // Gets the filter coefficients. - const std::vector& GetFilter() const { return H_; } + const std::vector>& GetFilter() const { return H_; } private: // Adapts the filter and updates the filter size. @@ -147,15 +153,13 @@ class AdaptiveFirFilter { // values in the supplied impulse response. void ConstrainAndUpdateImpulseResponse(std::vector* impulse_response); - // Resets the filter buffers to use the current size. - void ResetFilterBuffersToCurrentSize(); - // Gradually Updates the current filter size towards the target size. void UpdateSize(); ApmDataDumper* const data_dumper_; const Aec3Fft fft_; const Aec3Optimization optimization_; + const size_t num_render_channels_; const size_t max_size_partitions_; const int size_change_duration_blocks_; float one_by_size_change_duration_blocks_; @@ -163,7 +167,7 @@ class AdaptiveFirFilter { size_t target_size_partitions_; size_t old_target_size_partitions_; int size_change_counter_ = 0; - std::vector H_; + std::vector> H_; size_t partition_to_constrain_ = 0; }; diff --git a/modules/audio_processing/aec3/adaptive_fir_filter_unittest.cc b/modules/audio_processing/aec3/adaptive_fir_filter_unittest.cc index 36e31ebe73..6f1635fa60 100644 --- a/modules/audio_processing/aec3/adaptive_fir_filter_unittest.cc +++ b/modules/audio_processing/aec3/adaptive_fir_filter_unittest.cc @@ -42,9 +42,10 @@ namespace webrtc { namespace aec3 { namespace { -std::string ProduceDebugText(size_t delay) { +std::string ProduceDebugText(size_t num_render_channels, size_t delay) { rtc::StringBuilder ss; - ss << ", Delay: " << delay; + ss << "delay: " << delay << ", "; + ss << "num_render_channels:" << num_render_channels; return ss.Release(); } @@ -54,163 +55,184 @@ std::string ProduceDebugText(size_t delay) { // Verifies that the optimized methods for filter adaptation are similar to // their reference counterparts. TEST(AdaptiveFirFilter, FilterAdaptationNeonOptimizations) { - constexpr size_t kNumRenderChannels = 1; - constexpr int kSampleRateHz = 48000; - constexpr size_t kNumBands = NumBandsForRate(kSampleRateHz); + for (size_t num_partitions : {2, 5, 12, 30, 50}) { + for (size_t num_render_channels : {1, 2, 4, 8}) { + constexpr int kSampleRateHz = 48000; + constexpr size_t kNumBands = NumBandsForRate(kSampleRateHz); - std::unique_ptr render_delay_buffer( - RenderDelayBuffer::Create(EchoCanceller3Config(), kSampleRateHz, - kNumRenderChannels)); - Random random_generator(42U); - std::vector>> x( - kNumBands, std::vector>( - kNumRenderChannels, std::vector(kBlockSize, 0.f))); - FftData S_C; - FftData S_NEON; - FftData G; - Aec3Fft fft; - std::vector H_C(10); - std::vector H_NEON(10); - for (auto& H_j : H_C) { - H_j.Clear(); - } - for (auto& H_j : H_NEON) { - H_j.Clear(); - } + std::unique_ptr render_delay_buffer( + RenderDelayBuffer::Create(EchoCanceller3Config(), kSampleRateHz, + num_render_channels)); + Random random_generator(42U); + std::vector>> x( + kNumBands, + std::vector>(num_render_channels, + std::vector(kBlockSize, 0.f))); + FftData S_C; + FftData S_Neon; + FftData G; + Aec3Fft fft; + std::vector> H_C( + num_partitions, std::vector(num_render_channels)); + std::vector> H_Neon( + num_partitions, std::vector(num_render_channels)); + for (size_t p = 0; p < num_partitions; ++p) { + for (size_t ch = 0; ch < num_render_channels; ++ch) { + H_C[p][ch].Clear(); + H_Neon[p][ch].Clear(); + } + } - for (size_t k = 0; k < 30; ++k) { - for (size_t band = 0; band < x.size(); ++band) { - for (size_t channel = 0; channel < x[band].size(); ++channel) { - RandomizeSampleVector(&random_generator, x[band][channel]); + for (size_t k = 0; k < 30; ++k) { + for (size_t band = 0; band < x.size(); ++band) { + for (size_t ch = 0; ch < x[band].size(); ++ch) { + RandomizeSampleVector(&random_generator, x[band][ch]); + } + } + render_delay_buffer->Insert(x); + if (k == 0) { + render_delay_buffer->Reset(); + } + render_delay_buffer->PrepareCaptureProcessing(); + } + auto* const render_buffer = render_delay_buffer->GetRenderBuffer(); + + for (size_t j = 0; j < G.re.size(); ++j) { + G.re[j] = j / 10001.f; + } + for (size_t j = 1; j < G.im.size() - 1; ++j) { + G.im[j] = j / 20001.f; + } + G.im[0] = 0.f; + G.im[G.im.size() - 1] = 0.f; + + AdaptPartitions_Neon(*render_buffer, G, num_partitions, &H_Neon); + AdaptPartitions(*render_buffer, G, num_partitions, &H_C); + AdaptPartitions_Neon(*render_buffer, G, num_partitions, &H_Neon); + AdaptPartitions(*render_buffer, G, num_partitions, &H_C); + + for (size_t p = 0; p < num_partitions; ++p) { + for (size_t ch = 0; ch < num_render_channels; ++ch) { + for (size_t j = 0; j < H_C[p][ch].re.size(); ++j) { + EXPECT_FLOAT_EQ(H_C[p][ch].re[j], H_Neon[p][ch].re[j]); + EXPECT_FLOAT_EQ(H_C[p][ch].im[j], H_Neon[p][ch].im[j]); + } + } + } + + ApplyFilter_Neon(*render_buffer, num_partitions, H_Neon, &S_Neon); + ApplyFilter(*render_buffer, num_partitions, H_C, &S_C); + for (size_t j = 0; j < S_C.re.size(); ++j) { + EXPECT_NEAR(S_C.re[j], S_Neon.re[j], fabs(S_C.re[j] * 0.00001f)); + EXPECT_NEAR(S_C.im[j], S_Neon.im[j], fabs(S_C.re[j] * 0.00001f)); } } - render_delay_buffer->Insert(x); - if (k == 0) { - render_delay_buffer->Reset(); - } - render_delay_buffer->PrepareCaptureProcessing(); - } - auto* const render_buffer = render_delay_buffer->GetRenderBuffer(); - - for (size_t j = 0; j < G.re.size(); ++j) { - G.re[j] = j / 10001.f; - } - for (size_t j = 1; j < G.im.size() - 1; ++j) { - G.im[j] = j / 20001.f; - } - G.im[0] = 0.f; - G.im[G.im.size() - 1] = 0.f; - - AdaptPartitions_NEON(*render_buffer, G, H_NEON); - AdaptPartitions(*render_buffer, G, H_C); - AdaptPartitions_NEON(*render_buffer, G, H_NEON); - AdaptPartitions(*render_buffer, G, H_C); - - for (size_t l = 0; l < H_C.size(); ++l) { - for (size_t j = 0; j < H_C[l].im.size(); ++j) { - EXPECT_NEAR(H_C[l].re[j], H_NEON[l].re[j], fabs(H_C[l].re[j] * 0.00001f)); - EXPECT_NEAR(H_C[l].im[j], H_NEON[l].im[j], fabs(H_C[l].im[j] * 0.00001f)); - } - } - - ApplyFilter_NEON(*render_buffer, H_NEON, &S_NEON); - ApplyFilter(*render_buffer, H_C, &S_C); - for (size_t j = 0; j < S_C.re.size(); ++j) { - EXPECT_NEAR(S_C.re[j], S_NEON.re[j], fabs(S_C.re[j] * 0.00001f)); - EXPECT_NEAR(S_C.im[j], S_NEON.im[j], fabs(S_C.re[j] * 0.00001f)); } } // Verifies that the optimized method for frequency response computation is // bitexact to the reference counterpart. -TEST(AdaptiveFirFilter, UpdateFrequencyResponseNeonOptimization) { - const size_t kNumPartitions = 12; - std::vector H(kNumPartitions); - std::vector> H2(kNumPartitions); - std::vector> H2_NEON(kNumPartitions); +TEST(AdaptiveFirFilter, ComputeFrequencyResponseNeonOptimization) { + for (size_t num_partitions : {2, 5, 12, 30, 50}) { + for (size_t num_render_channels : {1, 2, 4, 8}) { + std::vector> H( + num_partitions, std::vector(num_render_channels)); + std::vector> H2(num_partitions); + std::vector> H2_Neon( + num_partitions); - for (size_t j = 0; j < H.size(); ++j) { - for (size_t k = 0; k < H[j].re.size(); ++k) { - H[j].re[k] = k + j / 3.f; - H[j].im[k] = j + k / 7.f; - } - } + for (size_t p = 0; p < num_partitions; ++p) { + for (size_t ch = 0; ch < num_render_channels; ++ch) { + for (size_t k = 0; k < H[p][ch].re.size(); ++k) { + H[p][ch].re[k] = k + p / 3.f + ch; + H[p][ch].im[k] = p + k / 7.f - ch; + } + } + } - UpdateFrequencyResponse(H, &H2); - UpdateFrequencyResponse_NEON(H, &H2_NEON); + ComputeFrequencyResponse(num_partitions, H, &H2); + ComputeFrequencyResponse_Neon(num_partitions, H, &H2_Neon); - for (size_t j = 0; j < H2.size(); ++j) { - for (size_t k = 0; k < H[j].re.size(); ++k) { - EXPECT_FLOAT_EQ(H2[j][k], H2_NEON[j][k]); + for (size_t p = 0; p < num_partitions; ++p) { + for (size_t k = 0; k < H2[p].size(); ++k) { + EXPECT_FLOAT_EQ(H2[p][k], H2_Neon[p][k]); + } + } } } } - #endif #if defined(WEBRTC_ARCH_X86_FAMILY) // Verifies that the optimized methods for filter adaptation are bitexact to // their reference counterparts. TEST(AdaptiveFirFilter, FilterAdaptationSse2Optimizations) { - constexpr size_t kNumRenderChannels = 1; constexpr int kSampleRateHz = 48000; constexpr size_t kNumBands = NumBandsForRate(kSampleRateHz); bool use_sse2 = (WebRtc_GetCPUInfo(kSSE2) != 0); if (use_sse2) { - std::unique_ptr render_delay_buffer( - RenderDelayBuffer::Create(EchoCanceller3Config(), kSampleRateHz, - kNumRenderChannels)); - Random random_generator(42U); - std::vector>> x( - kNumBands, - std::vector>(kNumRenderChannels, - std::vector(kBlockSize, 0.f))); - FftData S_C; - FftData S_SSE2; - FftData G; - Aec3Fft fft; - std::vector H_C(10); - std::vector H_SSE2(10); - for (auto& H_j : H_C) { - H_j.Clear(); - } - for (auto& H_j : H_SSE2) { - H_j.Clear(); - } - - for (size_t k = 0; k < 500; ++k) { - for (size_t band = 0; band < x.size(); ++band) { - for (size_t channel = 0; channel < x[band].size(); ++channel) { - RandomizeSampleVector(&random_generator, x[band][channel]); + for (size_t num_partitions : {2, 5, 12, 30, 50}) { + for (size_t num_render_channels : {1, 2, 4, 8}) { + std::unique_ptr render_delay_buffer( + RenderDelayBuffer::Create(EchoCanceller3Config(), kSampleRateHz, + num_render_channels)); + Random random_generator(42U); + std::vector>> x( + kNumBands, + std::vector>( + num_render_channels, std::vector(kBlockSize, 0.f))); + FftData S_C; + FftData S_Sse2; + FftData G; + Aec3Fft fft; + std::vector> H_C( + num_partitions, std::vector(num_render_channels)); + std::vector> H_Sse2( + num_partitions, std::vector(num_render_channels)); + for (size_t p = 0; p < num_partitions; ++p) { + for (size_t ch = 0; ch < num_render_channels; ++ch) { + H_C[p][ch].Clear(); + H_Sse2[p][ch].Clear(); + } } - } - render_delay_buffer->Insert(x); - if (k == 0) { - render_delay_buffer->Reset(); - } - render_delay_buffer->PrepareCaptureProcessing(); - auto* const render_buffer = render_delay_buffer->GetRenderBuffer(); - ApplyFilter_SSE2(*render_buffer, H_SSE2, &S_SSE2); - ApplyFilter(*render_buffer, H_C, &S_C); - for (size_t j = 0; j < S_C.re.size(); ++j) { - EXPECT_FLOAT_EQ(S_C.re[j], S_SSE2.re[j]); - EXPECT_FLOAT_EQ(S_C.im[j], S_SSE2.im[j]); - } + for (size_t k = 0; k < 500; ++k) { + for (size_t band = 0; band < x.size(); ++band) { + for (size_t ch = 0; ch < x[band].size(); ++ch) { + RandomizeSampleVector(&random_generator, x[band][ch]); + } + } + render_delay_buffer->Insert(x); + if (k == 0) { + render_delay_buffer->Reset(); + } + render_delay_buffer->PrepareCaptureProcessing(); + auto* const render_buffer = render_delay_buffer->GetRenderBuffer(); - std::for_each(G.re.begin(), G.re.end(), - [&](float& a) { a = random_generator.Rand(); }); - std::for_each(G.im.begin(), G.im.end(), - [&](float& a) { a = random_generator.Rand(); }); + ApplyFilter_Sse2(*render_buffer, num_partitions, H_Sse2, &S_Sse2); + ApplyFilter(*render_buffer, num_partitions, H_C, &S_C); + for (size_t j = 0; j < S_C.re.size(); ++j) { + EXPECT_FLOAT_EQ(S_C.re[j], S_Sse2.re[j]); + EXPECT_FLOAT_EQ(S_C.im[j], S_Sse2.im[j]); + } - AdaptPartitions_SSE2(*render_buffer, G, H_SSE2); - AdaptPartitions(*render_buffer, G, H_C); + std::for_each(G.re.begin(), G.re.end(), + [&](float& a) { a = random_generator.Rand(); }); + std::for_each(G.im.begin(), G.im.end(), + [&](float& a) { a = random_generator.Rand(); }); - for (size_t k = 0; k < H_C.size(); ++k) { - for (size_t j = 0; j < H_C[k].re.size(); ++j) { - EXPECT_FLOAT_EQ(H_C[k].re[j], H_SSE2[k].re[j]); - EXPECT_FLOAT_EQ(H_C[k].im[j], H_SSE2[k].im[j]); + AdaptPartitions_Sse2(*render_buffer, G, num_partitions, &H_Sse2); + AdaptPartitions(*render_buffer, G, num_partitions, &H_C); + + for (size_t p = 0; p < num_partitions; ++p) { + for (size_t ch = 0; ch < num_render_channels; ++ch) { + for (size_t j = 0; j < H_C[p][ch].re.size(); ++j) { + EXPECT_FLOAT_EQ(H_C[p][ch].re[j], H_Sse2[p][ch].re[j]); + EXPECT_FLOAT_EQ(H_C[p][ch].im[j], H_Sse2[p][ch].im[j]); + } + } + } } } } @@ -219,27 +241,34 @@ TEST(AdaptiveFirFilter, FilterAdaptationSse2Optimizations) { // Verifies that the optimized method for frequency response computation is // bitexact to the reference counterpart. -TEST(AdaptiveFirFilter, UpdateFrequencyResponseSse2Optimization) { +TEST(AdaptiveFirFilter, ComputeFrequencyResponseSse2Optimization) { bool use_sse2 = (WebRtc_GetCPUInfo(kSSE2) != 0); if (use_sse2) { - const size_t kNumPartitions = 12; - std::vector H(kNumPartitions); - std::vector> H2(kNumPartitions); - std::vector> H2_SSE2(kNumPartitions); + for (size_t num_partitions : {2, 5, 12, 30, 50}) { + for (size_t num_render_channels : {1, 2, 4, 8}) { + std::vector> H( + num_partitions, std::vector(num_render_channels)); + std::vector> H2(num_partitions); + std::vector> H2_Sse2( + num_partitions); - for (size_t j = 0; j < H.size(); ++j) { - for (size_t k = 0; k < H[j].re.size(); ++k) { - H[j].re[k] = k + j / 3.f; - H[j].im[k] = j + k / 7.f; - } - } + for (size_t p = 0; p < num_partitions; ++p) { + for (size_t ch = 0; ch < num_render_channels; ++ch) { + for (size_t k = 0; k < H[p][ch].re.size(); ++k) { + H[p][ch].re[k] = k + p / 3.f + ch; + H[p][ch].im[k] = p + k / 7.f - ch; + } + } + } - UpdateFrequencyResponse(H, &H2); - UpdateFrequencyResponse_SSE2(H, &H2_SSE2); + ComputeFrequencyResponse(num_partitions, H, &H2); + ComputeFrequencyResponse_Sse2(num_partitions, H, &H2_Sse2); - for (size_t j = 0; j < H2.size(); ++j) { - for (size_t k = 0; k < H[j].re.size(); ++k) { - EXPECT_FLOAT_EQ(H2[j][k], H2_SSE2[j][k]); + for (size_t p = 0; p < num_partitions; ++p) { + for (size_t k = 0; k < H2[p].size(); ++k) { + EXPECT_FLOAT_EQ(H2[p][k], H2_Sse2[p][k]); + } + } } } } @@ -250,14 +279,14 @@ TEST(AdaptiveFirFilter, UpdateFrequencyResponseSse2Optimization) { #if RTC_DCHECK_IS_ON && GTEST_HAS_DEATH_TEST && !defined(WEBRTC_ANDROID) // Verifies that the check for non-null data dumper works. TEST(AdaptiveFirFilter, NullDataDumper) { - EXPECT_DEATH( - AdaptiveFirFilter(9, 9, 250, 1, 1, DetectOptimization(), nullptr), ""); + EXPECT_DEATH(AdaptiveFirFilter(9, 9, 250, 1, DetectOptimization(), nullptr), + ""); } // Verifies that the check for non-null filter output works. TEST(AdaptiveFirFilter, NullFilterOutput) { ApmDataDumper data_dumper(42); - AdaptiveFirFilter filter(9, 9, 250, 1, 1, DetectOptimization(), &data_dumper); + AdaptiveFirFilter filter(9, 9, 250, 1, DetectOptimization(), &data_dumper); std::unique_ptr render_delay_buffer( RenderDelayBuffer::Create(EchoCanceller3Config(), 48000, 1)); EXPECT_DEATH(filter.Filter(*render_delay_buffer->GetRenderBuffer(), nullptr), @@ -271,7 +300,7 @@ TEST(AdaptiveFirFilter, NullFilterOutput) { TEST(AdaptiveFirFilter, FilterStatisticsAccess) { ApmDataDumper data_dumper(42); Aec3Optimization optimization = DetectOptimization(); - AdaptiveFirFilter filter(9, 9, 250, 1, 1, optimization, &data_dumper); + AdaptiveFirFilter filter(9, 9, 250, 1, optimization, &data_dumper); std::vector> H2( filter.max_filter_size_partitions(), std::array()); @@ -288,7 +317,7 @@ TEST(AdaptiveFirFilter, FilterStatisticsAccess) { TEST(AdaptiveFirFilter, FilterSize) { ApmDataDumper data_dumper(42); for (size_t filter_size = 1; filter_size < 5; ++filter_size) { - AdaptiveFirFilter filter(filter_size, filter_size, 250, 1, 1, + AdaptiveFirFilter filter(filter_size, filter_size, 250, 1, DetectOptimization(), &data_dumper); EXPECT_EQ(filter_size, filter.SizePartitions()); } @@ -297,115 +326,146 @@ TEST(AdaptiveFirFilter, FilterSize) { // Verifies that the filter is being able to properly filter a signal and to // adapt its coefficients. TEST(AdaptiveFirFilter, FilterAndAdapt) { - constexpr size_t kNumRenderChannels = 1; - constexpr size_t kNumCaptureChannels = 1; constexpr int kSampleRateHz = 48000; constexpr size_t kNumBands = NumBandsForRate(kSampleRateHz); + constexpr size_t kNumBlocksToProcessPerRenderChannel = 1000; + constexpr size_t kNumCaptureChannels = 1; - constexpr size_t kNumBlocksToProcess = 1000; - ApmDataDumper data_dumper(42); - EchoCanceller3Config config; - AdaptiveFirFilter filter(config.filter.main.length_blocks, - config.filter.main.length_blocks, - config.filter.config_change_duration_blocks, 1, 1, - DetectOptimization(), &data_dumper); - std::vector> H2( - filter.max_filter_size_partitions(), - std::array()); - std::vector h(GetTimeDomainLength(filter.max_filter_size_partitions()), - 0.f); - Aec3Fft fft; - config.delay.default_delay = 1; - std::unique_ptr render_delay_buffer( - RenderDelayBuffer::Create(config, kSampleRateHz, kNumRenderChannels)); - ShadowFilterUpdateGain gain(config.filter.shadow, - config.filter.config_change_duration_blocks); - Random random_generator(42U); - std::vector>> x( - kNumBands, std::vector>( - kNumRenderChannels, std::vector(kBlockSize, 0.f))); - std::vector n(kBlockSize, 0.f); - std::vector y(kBlockSize, 0.f); - AecState aec_state(EchoCanceller3Config{}, kNumCaptureChannels); - RenderSignalAnalyzer render_signal_analyzer(config); - absl::optional delay_estimate; - std::vector e(kBlockSize, 0.f); - std::array s_scratch; - std::vector output(kNumCaptureChannels); - FftData S; - FftData G; - FftData E; - std::array Y2; - std::array E2_main; - std::array E2_shadow; - // [B,A] = butter(2,100/8000,'high') - constexpr CascadedBiQuadFilter::BiQuadCoefficients - kHighPassFilterCoefficients = {{0.97261f, -1.94523f, 0.97261f}, - {-1.94448f, 0.94598f}}; - Y2.fill(0.f); - E2_main.fill(0.f); - E2_shadow.fill(0.f); - for (auto& subtractor_output : output) { - subtractor_output.Reset(); - } + for (size_t num_render_channels : {1, 2, 3, 6, 8}) { + ApmDataDumper data_dumper(42); + EchoCanceller3Config config; - constexpr float kScale = 1.0f / kFftLengthBy2; - - for (size_t delay_samples : {0, 64, 150, 200, 301}) { - DelayBuffer delay_buffer(delay_samples); - CascadedBiQuadFilter x_hp_filter(kHighPassFilterCoefficients, 1); - CascadedBiQuadFilter y_hp_filter(kHighPassFilterCoefficients, 1); - - SCOPED_TRACE(ProduceDebugText(delay_samples)); - for (size_t j = 0; j < kNumBlocksToProcess; ++j) { - RandomizeSampleVector(&random_generator, x[0][0]); - delay_buffer.Delay(x[0][0], y); - - RandomizeSampleVector(&random_generator, n); - static constexpr float kNoiseScaling = 1.f / 100.f; - std::transform(y.begin(), y.end(), n.begin(), y.begin(), - [](float a, float b) { return a + b * kNoiseScaling; }); - - x_hp_filter.Process(x[0][0]); - y_hp_filter.Process(y); - - render_delay_buffer->Insert(x); - if (j == 0) { - render_delay_buffer->Reset(); - } - render_delay_buffer->PrepareCaptureProcessing(); - auto* const render_buffer = render_delay_buffer->GetRenderBuffer(); - - render_signal_analyzer.Update(*render_buffer, - aec_state.FilterDelayBlocks()); - - filter.Filter(*render_buffer, &S); - fft.Ifft(S, &s_scratch); - std::transform(y.begin(), y.end(), s_scratch.begin() + kFftLengthBy2, - e.begin(), - [&](float a, float b) { return a - b * kScale; }); - std::for_each(e.begin(), e.end(), - [](float& a) { a = rtc::SafeClamp(a, -32768.f, 32767.f); }); - fft.ZeroPaddedFft(e, Aec3Fft::Window::kRectangular, &E); - for (size_t k = 0; k < kBlockSize; ++k) { - output[0].s_main[k] = kScale * s_scratch[k + kFftLengthBy2]; - } - - std::array render_power; - render_buffer->SpectralSum(filter.SizePartitions(), &render_power); - gain.Compute(render_power, render_signal_analyzer, E, - filter.SizePartitions(), false, &G); - filter.Adapt(*render_buffer, G, &h); - aec_state.HandleEchoPathChange(EchoPathVariability( - false, EchoPathVariability::DelayAdjustment::kNone, false)); - - filter.ComputeFrequencyResponse(&H2); - aec_state.Update(delay_estimate, H2, h, *render_buffer, E2_main, Y2, - output); + if (num_render_channels == 33) { + config.filter.main = {13, 0.00005f, 0.0005f, 0.0001f, 2.f, 20075344.f}; + config.filter.shadow = {13, 0.1f, 20075344.f}; + config.filter.main_initial = {12, 0.005f, 0.5f, 0.001f, 2.f, 20075344.f}; + config.filter.shadow_initial = {12, 0.7f, 20075344.f}; + } + + AdaptiveFirFilter filter( + config.filter.main.length_blocks, config.filter.main.length_blocks, + config.filter.config_change_duration_blocks, num_render_channels, + DetectOptimization(), &data_dumper); + std::vector> H2( + filter.max_filter_size_partitions(), + std::array()); + std::vector h( + GetTimeDomainLength(filter.max_filter_size_partitions()), 0.f); + Aec3Fft fft; + config.delay.default_delay = 1; + std::unique_ptr render_delay_buffer( + RenderDelayBuffer::Create(config, kSampleRateHz, num_render_channels)); + ShadowFilterUpdateGain gain(config.filter.shadow, + config.filter.config_change_duration_blocks); + Random random_generator(42U); + std::vector>> x( + kNumBands, + std::vector>(num_render_channels, + std::vector(kBlockSize, 0.f))); + std::vector n(kBlockSize, 0.f); + std::vector y(kBlockSize, 0.f); + AecState aec_state(EchoCanceller3Config{}, kNumCaptureChannels); + RenderSignalAnalyzer render_signal_analyzer(config); + absl::optional delay_estimate; + std::vector e(kBlockSize, 0.f); + std::array s_scratch; + std::vector output(kNumCaptureChannels); + FftData S; + FftData G; + FftData E; + std::array Y2; + std::array E2_main; + std::array E2_shadow; + // [B,A] = butter(2,100/8000,'high') + constexpr CascadedBiQuadFilter::BiQuadCoefficients + kHighPassFilterCoefficients = {{0.97261f, -1.94523f, 0.97261f}, + {-1.94448f, 0.94598f}}; + Y2.fill(0.f); + E2_main.fill(0.f); + E2_shadow.fill(0.f); + for (auto& subtractor_output : output) { + subtractor_output.Reset(); + } + + constexpr float kScale = 1.0f / kFftLengthBy2; + + for (size_t delay_samples : {0, 64, 150, 200, 301}) { + std::vector> delay_buffer( + num_render_channels, DelayBuffer(delay_samples)); + std::vector> x_hp_filter( + num_render_channels); + for (size_t ch = 0; ch < num_render_channels; ++ch) { + x_hp_filter[ch] = std::make_unique( + kHighPassFilterCoefficients, 1); + } + CascadedBiQuadFilter y_hp_filter(kHighPassFilterCoefficients, 1); + + SCOPED_TRACE(ProduceDebugText(num_render_channels, delay_samples)); + const size_t num_blocks_to_process = + kNumBlocksToProcessPerRenderChannel * num_render_channels; + for (size_t j = 0; j < num_blocks_to_process; ++j) { + std::fill(y.begin(), y.end(), 0.f); + for (size_t ch = 0; ch < num_render_channels; ++ch) { + RandomizeSampleVector(&random_generator, x[0][ch]); + std::array y_channel; + delay_buffer[ch].Delay(x[0][ch], y_channel); + for (size_t k = 0; k < y.size(); ++k) { + y[k] += y_channel[k] / num_render_channels; + } + } + + RandomizeSampleVector(&random_generator, n); + const float noise_scaling = 1.f / 100.f / num_render_channels; + for (size_t k = 0; k < y.size(); ++k) { + y[k] += n[k] * noise_scaling; + } + + for (size_t ch = 0; ch < num_render_channels; ++ch) { + x_hp_filter[ch]->Process(x[0][ch]); + } + y_hp_filter.Process(y); + + render_delay_buffer->Insert(x); + if (j == 0) { + render_delay_buffer->Reset(); + } + render_delay_buffer->PrepareCaptureProcessing(); + auto* const render_buffer = render_delay_buffer->GetRenderBuffer(); + + render_signal_analyzer.Update(*render_buffer, + aec_state.FilterDelayBlocks()); + + filter.Filter(*render_buffer, &S); + fft.Ifft(S, &s_scratch); + std::transform(y.begin(), y.end(), s_scratch.begin() + kFftLengthBy2, + e.begin(), + [&](float a, float b) { return a - b * kScale; }); + std::for_each(e.begin(), e.end(), [](float& a) { + a = rtc::SafeClamp(a, -32768.f, 32767.f); + }); + fft.ZeroPaddedFft(e, Aec3Fft::Window::kRectangular, &E); + for (auto& o : output) { + for (size_t k = 0; k < kBlockSize; ++k) { + o.s_main[k] = kScale * s_scratch[k + kFftLengthBy2]; + } + } + + std::array render_power; + render_buffer->SpectralSum(filter.SizePartitions(), &render_power); + gain.Compute(render_power, render_signal_analyzer, E, + filter.SizePartitions(), false, &G); + filter.Adapt(*render_buffer, G, &h); + aec_state.HandleEchoPathChange(EchoPathVariability( + false, EchoPathVariability::DelayAdjustment::kNone, false)); + + filter.ComputeFrequencyResponse(&H2); + aec_state.Update(delay_estimate, H2, h, *render_buffer, E2_main, Y2, + output); + } + // Verify that the filter is able to perform well. + EXPECT_LT(1000 * std::inner_product(e.begin(), e.end(), e.begin(), 0.f), + std::inner_product(y.begin(), y.end(), y.begin(), 0.f)); } - // Verify that the filter is able to perform well. - EXPECT_LT(1000 * std::inner_product(e.begin(), e.end(), e.begin(), 0.f), - std::inner_product(y.begin(), y.end(), y.begin(), 0.f)); } } } // namespace aec3 diff --git a/modules/audio_processing/aec3/comfort_noise_generator_unittest.cc b/modules/audio_processing/aec3/comfort_noise_generator_unittest.cc index 94aa039f78..7abbb794b7 100644 --- a/modules/audio_processing/aec3/comfort_noise_generator_unittest.cc +++ b/modules/audio_processing/aec3/comfort_noise_generator_unittest.cc @@ -13,6 +13,7 @@ #include #include +#include "modules/audio_processing/aec3/aec_state.h" #include "rtc_base/random.h" #include "rtc_base/system/arch.h" #include "system_wrappers/include/cpu_features_wrapper.h" diff --git a/modules/audio_processing/aec3/echo_remover.cc b/modules/audio_processing/aec3/echo_remover.cc index 2df9cfda0c..c33b39c049 100644 --- a/modules/audio_processing/aec3/echo_remover.cc +++ b/modules/audio_processing/aec3/echo_remover.cc @@ -386,9 +386,9 @@ void EchoRemoverImpl::ProcessCapture( // Update the AEC state information. // TODO(bugs.webrtc.org/10913): Take all subtractors into account. - aec_state_.Update(external_delay, subtractor_.FilterFrequencyResponse(), - subtractor_.FilterImpulseResponse(), *render_buffer, E2[0], - Y2[0], subtractor_output); + aec_state_.Update(external_delay, subtractor_.FilterFrequencyResponse()[0], + subtractor_.FilterImpulseResponse()[0], *render_buffer, + E2[0], Y2[0], subtractor_output); // Choose the linear output. const auto& Y_fft = aec_state_.UseLinearFilterOutput() ? E : Y; diff --git a/modules/audio_processing/aec3/main_filter_update_gain.cc b/modules/audio_processing/aec3/main_filter_update_gain.cc index c2cfd2c447..43f37b0cf4 100644 --- a/modules/audio_processing/aec3/main_filter_update_gain.cc +++ b/modules/audio_processing/aec3/main_filter_update_gain.cc @@ -80,7 +80,7 @@ void MainFilterUpdateGain::Compute( const auto& E2_main = subtractor_output.E2_main; const auto& E2_shadow = subtractor_output.E2_shadow; FftData* G = gain_fft; - auto X2 = render_power; + const auto& X2 = render_power; ++call_counter_; @@ -100,43 +100,40 @@ void MainFilterUpdateGain::Compute( std::array mu; // mu = H_error / (0.5* H_error* X2 + n * E2). for (size_t k = 0; k < kFftLengthBy2Plus1; ++k) { - mu[k] = X2[k] > current_config_.noise_gate - ? H_error_[k] / (0.5f * H_error_[k] * X2[k] + - size_partitions * E2_main[k]) - : 0.f; + if (X2[k] >= current_config_.noise_gate) { + mu[k] = H_error_[k] / + (0.5f * H_error_[k] * X2[k] + size_partitions * E2_main[k]); + } else { + mu[k] = 0.f; + } } // Avoid updating the filter close to narrow bands in the render signals. render_signal_analyzer.MaskRegionsAroundNarrowBands(&mu); // H_error = H_error - 0.5 * mu * X2 * H_error. - for (size_t k = 0; k < H_error_.size(); ++k) { + for (size_t k = 0; k < kFftLengthBy2Plus1; ++k) { H_error_[k] -= 0.5f * mu[k] * X2[k] * H_error_[k]; } // G = mu * E. - std::transform(mu.begin(), mu.end(), E_main.re.begin(), G->re.begin(), - std::multiplies()); - std::transform(mu.begin(), mu.end(), E_main.im.begin(), G->im.begin(), - std::multiplies()); + for (size_t k = 0; k < kFftLengthBy2Plus1; ++k) { + G->re[k] = mu[k] * E_main.re[k]; + G->im[k] = mu[k] * E_main.im[k]; + } } // H_error = H_error + factor * erl. - std::array H_error_increase; - std::transform(E2_shadow.begin(), E2_shadow.end(), E2_main.begin(), - H_error_increase.begin(), [&](float a, float b) { - return a >= b ? current_config_.leakage_converged - : current_config_.leakage_diverged; - }); - std::transform(erl.begin(), erl.end(), H_error_increase.begin(), - H_error_increase.begin(), std::multiplies()); - std::transform(H_error_.begin(), H_error_.end(), H_error_increase.begin(), - H_error_.begin(), [&](float a, float b) { - float error = a + b; - error = std::max(error, current_config_.error_floor); - error = std::min(error, current_config_.error_ceil); - return error; - }); + for (size_t k = 0; k < kFftLengthBy2Plus1; ++k) { + if (E2_shadow[k] >= E2_main[k]) { + H_error_[k] += current_config_.leakage_converged * erl[k]; + } else { + H_error_[k] += current_config_.leakage_diverged * erl[k]; + } + + H_error_[k] = std::max(H_error_[k], current_config_.error_floor); + H_error_[k] = std::min(H_error_[k], current_config_.error_ceil); + } data_dumper_->DumpRaw("aec3_main_gain_H_error", H_error_); } diff --git a/modules/audio_processing/aec3/main_filter_update_gain_unittest.cc b/modules/audio_processing/aec3/main_filter_update_gain_unittest.cc index 20714cea93..1a9e7929e7 100644 --- a/modules/audio_processing/aec3/main_filter_update_gain_unittest.cc +++ b/modules/audio_processing/aec3/main_filter_update_gain_unittest.cc @@ -54,11 +54,11 @@ void RunFilterUpdateTest(int num_blocks_to_process, AdaptiveFirFilter main_filter(config.filter.main.length_blocks, config.filter.main.length_blocks, config.filter.config_change_duration_blocks, 1, - 1, optimization, &data_dumper); + optimization, &data_dumper); AdaptiveFirFilter shadow_filter(config.filter.shadow.length_blocks, config.filter.shadow.length_blocks, config.filter.config_change_duration_blocks, - 1, 1, optimization, &data_dumper); + 1, optimization, &data_dumper); std::vector> H2( main_filter.max_filter_size_partitions(), std::array()); diff --git a/modules/audio_processing/aec3/shadow_filter_update_gain.cc b/modules/audio_processing/aec3/shadow_filter_update_gain.cc index e27437aff2..51ead2e540 100644 --- a/modules/audio_processing/aec3/shadow_filter_update_gain.cc +++ b/modules/audio_processing/aec3/shadow_filter_update_gain.cc @@ -28,8 +28,6 @@ ShadowFilterUpdateGain::ShadowFilterUpdateGain( } void ShadowFilterUpdateGain::HandleEchoPathChange() { - // TODO(peah): Check whether this counter should instead be initialized to a - // large value. poor_signal_excitation_counter_ = 0; call_counter_ = 0; } @@ -60,19 +58,23 @@ void ShadowFilterUpdateGain::Compute( // Compute mu. std::array mu; - auto X2 = render_power; - std::transform(X2.begin(), X2.end(), mu.begin(), [&](float a) { - return a > current_config_.noise_gate ? current_config_.rate / a : 0.f; - }); + const auto& X2 = render_power; + for (size_t k = 0; k < kFftLengthBy2Plus1; ++k) { + if (X2[k] > current_config_.noise_gate) { + mu[k] = current_config_.rate / X2[k]; + } else { + mu[k] = 0.f; + } + } // Avoid updating the filter close to narrow bands in the render signals. render_signal_analyzer.MaskRegionsAroundNarrowBands(&mu); // G = mu * E * X2. - std::transform(mu.begin(), mu.end(), E_shadow.re.begin(), G->re.begin(), - std::multiplies()); - std::transform(mu.begin(), mu.end(), E_shadow.im.begin(), G->im.begin(), - std::multiplies()); + for (size_t k = 0; k < kFftLengthBy2Plus1; ++k) { + G->re[k] = mu[k] * E_shadow.re[k]; + G->im[k] = mu[k] * E_shadow.im[k]; + } } void ShadowFilterUpdateGain::UpdateCurrentConfig() { diff --git a/modules/audio_processing/aec3/shadow_filter_update_gain_unittest.cc b/modules/audio_processing/aec3/shadow_filter_update_gain_unittest.cc index 605f5701dd..a73a539c74 100644 --- a/modules/audio_processing/aec3/shadow_filter_update_gain_unittest.cc +++ b/modules/audio_processing/aec3/shadow_filter_update_gain_unittest.cc @@ -44,11 +44,11 @@ void RunFilterUpdateTest(int num_blocks_to_process, AdaptiveFirFilter main_filter(config.filter.main.length_blocks, config.filter.main.length_blocks, config.filter.config_change_duration_blocks, 1, - 1, DetectOptimization(), &data_dumper); + DetectOptimization(), &data_dumper); AdaptiveFirFilter shadow_filter(config.filter.shadow.length_blocks, config.filter.shadow.length_blocks, config.filter.config_change_duration_blocks, - 1, 1, DetectOptimization(), &data_dumper); + 1, DetectOptimization(), &data_dumper); Aec3Fft fft; constexpr int kSampleRateHz = 48000; diff --git a/modules/audio_processing/aec3/subtractor.cc b/modules/audio_processing/aec3/subtractor.cc index 0c52ed64a1..5e995656e8 100644 --- a/modules/audio_processing/aec3/subtractor.cc +++ b/modules/audio_processing/aec3/subtractor.cc @@ -89,13 +89,13 @@ Subtractor::Subtractor(const EchoCanceller3Config& config, config_.filter.main.length_blocks, config_.filter.main_initial.length_blocks, config.filter.config_change_duration_blocks, num_render_channels, - num_capture_channels, optimization, data_dumper_); + optimization, data_dumper_); shadow_filter_[ch] = std::make_unique( config_.filter.shadow.length_blocks, config_.filter.shadow_initial.length_blocks, config.filter.config_change_duration_blocks, num_render_channels, - num_capture_channels, optimization, data_dumper_); + optimization, data_dumper_); G_main_[ch] = std::make_unique( config_.filter.main_initial, config_.filter.config_change_duration_blocks); @@ -162,14 +162,12 @@ void Subtractor::Process(const RenderBuffer& render_buffer, RTC_DCHECK_EQ(num_capture_channels_, capture.size()); // Compute the render powers. + const bool same_filter_sizes = + main_filter_[0]->SizePartitions() == shadow_filter_[0]->SizePartitions(); std::array X2_main; std::array X2_shadow_data; - std::array& X2_shadow = - main_filter_[0]->SizePartitions() == shadow_filter_[0]->SizePartitions() - ? X2_main - : X2_shadow_data; - if (main_filter_[0]->SizePartitions() == - shadow_filter_[0]->SizePartitions()) { + auto& X2_shadow = same_filter_sizes ? X2_main : X2_shadow_data; + if (same_filter_sizes) { render_buffer.SpectralSum(main_filter_[0]->SizePartitions(), &X2_main); } else if (main_filter_[0]->SizePartitions() > shadow_filter_[0]->SizePartitions()) { @@ -256,7 +254,8 @@ void Subtractor::Process(const RenderBuffer& render_buffer, aec_state.SaturatedCapture(), &G); } else { poor_shadow_filter_counter_[ch] = 0; - shadow_filter_[ch]->SetFilter(main_filter_[ch]->GetFilter()); + shadow_filter_[ch]->SetFilter(main_filter_[ch]->SizePartitions(), + main_filter_[ch]->GetFilter()); G_shadow_[ch]->Compute(X2_shadow, render_signal_analyzer, E_main, shadow_filter_[ch]->SizePartitions(), aec_state.SaturatedCapture(), &G); diff --git a/modules/audio_processing/aec3/subtractor.h b/modules/audio_processing/aec3/subtractor.h index c5fb765e54..01d2eef403 100644 --- a/modules/audio_processing/aec3/subtractor.h +++ b/modules/audio_processing/aec3/subtractor.h @@ -59,26 +59,24 @@ class Subtractor { void ExitInitialState(); // Returns the block-wise frequency responses for the main adaptive filters. - // TODO(bugs.webrtc.org/10913): Return the frequency responses for all capture - // channels. - const std::vector>& + const std::vector>>& FilterFrequencyResponse() const { - return main_frequency_response_[0]; + return main_frequency_response_; } // Returns the estimates of the impulse responses for the main adaptive // filters. - // TODO(bugs.webrtc.org/10913): Return the impulse responses for all capture - // channels. - const std::vector& FilterImpulseResponse() const { - return main_impulse_response_[0]; + const std::vector>& FilterImpulseResponse() const { + return main_impulse_response_; } void DumpFilters() { - size_t current_size = main_impulse_response_[0].size(); - main_impulse_response_[0].resize(main_impulse_response_[0].capacity()); - data_dumper_->DumpRaw("aec3_subtractor_h_main", main_impulse_response_[0]); - main_impulse_response_[0].resize(current_size); + data_dumper_->DumpRaw( + "aec3_subtractor_h_main", + rtc::ArrayView( + main_impulse_response_[0].data(), + GetTimeDomainLength( + main_filter_[0]->max_filter_size_partitions()))); main_filter_[0]->DumpFilter("aec3_subtractor_H_main"); shadow_filter_[0]->DumpFilter("aec3_subtractor_H_shadow"); diff --git a/modules/audio_processing/aec3/subtractor_unittest.cc b/modules/audio_processing/aec3/subtractor_unittest.cc index b5635f4b84..23e7ead41d 100644 --- a/modules/audio_processing/aec3/subtractor_unittest.cc +++ b/modules/audio_processing/aec3/subtractor_unittest.cc @@ -11,12 +11,14 @@ #include "modules/audio_processing/aec3/subtractor.h" #include +#include #include #include #include "modules/audio_processing/aec3/aec_state.h" #include "modules/audio_processing/aec3/render_delay_buffer.h" #include "modules/audio_processing/test/echo_canceller_test_tools.h" +#include "modules/audio_processing/utility/cascaded_biquad_filter.h" #include "rtc_base/random.h" #include "rtc_base/strings/string_builder.h" #include "test/gtest.h" @@ -24,51 +26,104 @@ namespace webrtc { namespace { -float RunSubtractorTest(int num_blocks_to_process, - int delay_samples, - int main_filter_length_blocks, - int shadow_filter_length_blocks, - bool uncorrelated_inputs, - const std::vector& blocks_with_echo_path_changes) { +std::vector RunSubtractorTest( + size_t num_render_channels, + size_t num_capture_channels, + int num_blocks_to_process, + int delay_samples, + int main_filter_length_blocks, + int shadow_filter_length_blocks, + bool uncorrelated_inputs, + const std::vector& blocks_with_echo_path_changes) { ApmDataDumper data_dumper(42); - constexpr size_t kNumChannels = 1; constexpr int kSampleRateHz = 48000; constexpr size_t kNumBands = NumBandsForRate(kSampleRateHz); EchoCanceller3Config config; config.filter.main.length_blocks = main_filter_length_blocks; config.filter.shadow.length_blocks = shadow_filter_length_blocks; - Subtractor subtractor(config, 1, 1, &data_dumper, DetectOptimization()); + Subtractor subtractor(config, num_render_channels, num_capture_channels, + &data_dumper, DetectOptimization()); absl::optional delay_estimate; std::vector>> x( kNumBands, std::vector>( - kNumChannels, std::vector(kBlockSize, 0.f))); - std::vector> y(1, std::vector(kBlockSize, 0.f)); + num_render_channels, std::vector(kBlockSize, 0.f))); + std::vector> y(num_capture_channels, + std::vector(kBlockSize, 0.f)); std::array x_old; - std::array output; + std::vector output(num_capture_channels); config.delay.default_delay = 1; std::unique_ptr render_delay_buffer( - RenderDelayBuffer::Create(config, kSampleRateHz, kNumChannels)); + RenderDelayBuffer::Create(config, kSampleRateHz, num_render_channels)); RenderSignalAnalyzer render_signal_analyzer(config); Random random_generator(42U); Aec3Fft fft; std::array Y2; std::array E2_main; std::array E2_shadow; - AecState aec_state(config, kNumChannels); + AecState aec_state(config, num_capture_channels); x_old.fill(0.f); Y2.fill(0.f); E2_main.fill(0.f); E2_shadow.fill(0.f); - DelayBuffer delay_buffer(delay_samples); - for (int k = 0; k < num_blocks_to_process; ++k) { - RandomizeSampleVector(&random_generator, x[0][0]); - if (uncorrelated_inputs) { - RandomizeSampleVector(&random_generator, y[0]); - } else { - delay_buffer.Delay(x[0][0], y[0]); + std::vector>>> delay_buffer( + num_capture_channels); + for (size_t capture_ch = 0; capture_ch < num_capture_channels; ++capture_ch) { + delay_buffer[capture_ch].resize(num_render_channels); + for (size_t render_ch = 0; render_ch < num_render_channels; ++render_ch) { + delay_buffer[capture_ch][render_ch] = + std::make_unique>(delay_samples); } + } + + // [B,A] = butter(2,100/8000,'high') + constexpr CascadedBiQuadFilter::BiQuadCoefficients + kHighPassFilterCoefficients = {{0.97261f, -1.94523f, 0.97261f}, + {-1.94448f, 0.94598f}}; + std::vector> x_hp_filter( + num_render_channels); + for (size_t ch = 0; ch < num_render_channels; ++ch) { + x_hp_filter[ch] = + std::make_unique(kHighPassFilterCoefficients, 1); + } + std::vector> y_hp_filter( + num_capture_channels); + for (size_t ch = 0; ch < num_capture_channels; ++ch) { + y_hp_filter[ch] = + std::make_unique(kHighPassFilterCoefficients, 1); + } + + for (int k = 0; k < num_blocks_to_process; ++k) { + for (size_t render_ch = 0; render_ch < num_render_channels; ++render_ch) { + RandomizeSampleVector(&random_generator, x[0][render_ch]); + } + if (uncorrelated_inputs) { + for (size_t capture_ch = 0; capture_ch < num_capture_channels; + ++capture_ch) { + RandomizeSampleVector(&random_generator, y[capture_ch]); + } + } else { + for (size_t capture_ch = 0; capture_ch < num_capture_channels; + ++capture_ch) { + for (size_t render_ch = 0; render_ch < num_render_channels; + ++render_ch) { + std::array y_channel; + delay_buffer[capture_ch][render_ch]->Delay(x[0][render_ch], + y_channel); + for (size_t k = 0; k < y.size(); ++k) { + y[capture_ch][k] += y_channel[k] / num_render_channels; + } + } + } + } + for (size_t ch = 0; ch < num_render_channels; ++ch) { + x_hp_filter[ch]->Process(x[0][ch]); + } + for (size_t ch = 0; ch < num_capture_channels; ++ch) { + y_hp_filter[ch]->Process(y[ch]); + } + render_delay_buffer->Insert(x); if (k == 0) { render_delay_buffer->Reset(); @@ -90,28 +145,37 @@ float RunSubtractorTest(int num_blocks_to_process, aec_state.HandleEchoPathChange(EchoPathVariability( false, EchoPathVariability::DelayAdjustment::kNone, false)); - aec_state.Update(delay_estimate, subtractor.FilterFrequencyResponse(), - subtractor.FilterImpulseResponse(), + aec_state.Update(delay_estimate, subtractor.FilterFrequencyResponse()[0], + subtractor.FilterImpulseResponse()[0], *render_delay_buffer->GetRenderBuffer(), E2_main, Y2, output); } - const float output_power = - std::inner_product(output[0].e_main.begin(), output[0].e_main.end(), - output[0].e_main.begin(), 0.f); - const float y_power = - std::inner_product(y[0].begin(), y[0].end(), y[0].begin(), 0.f); - if (y_power == 0.f) { - ADD_FAILURE(); - return -1.0; + std::vector results(num_capture_channels); + for (size_t ch = 0; ch < num_capture_channels; ++ch) { + const float output_power = + std::inner_product(output[ch].e_main.begin(), output[ch].e_main.end(), + output[ch].e_main.begin(), 0.f); + const float y_power = + std::inner_product(y[ch].begin(), y[ch].end(), y[ch].begin(), 0.f); + if (y_power == 0.f) { + ADD_FAILURE(); + results[ch] = -1.f; + } + results[ch] = output_power / y_power; } - return output_power / y_power; + return results; } -std::string ProduceDebugText(size_t delay, int filter_length_blocks) { +std::string ProduceDebugText(size_t num_render_channels, + size_t num_capture_channels, + size_t delay, + int filter_length_blocks) { rtc::StringBuilder ss; - ss << "Delay: " << delay << ", "; - ss << "Length: " << filter_length_blocks; + ss << "delay: " << delay << ", "; + ss << "filter_length_blocks:" << filter_length_blocks << ", "; + ss << "num_render_channels:" << num_render_channels << ", "; + ss << "num_capture_channels:" << num_capture_channels; return ss.Release(); } @@ -150,17 +214,32 @@ TEST(Subtractor, Convergence) { std::vector blocks_with_echo_path_changes; for (size_t filter_length_blocks : {12, 20, 30}) { for (size_t delay_samples : {0, 64, 150, 200, 301}) { - SCOPED_TRACE(ProduceDebugText(delay_samples, filter_length_blocks)); + SCOPED_TRACE(ProduceDebugText(1, 1, delay_samples, filter_length_blocks)); + std::vector echo_to_nearend_powers = RunSubtractorTest( + 1, 1, 2500, delay_samples, filter_length_blocks, filter_length_blocks, + false, blocks_with_echo_path_changes); - float echo_to_nearend_power = RunSubtractorTest( - 400, delay_samples, filter_length_blocks, filter_length_blocks, false, - blocks_with_echo_path_changes); - - // Use different criteria to take overmodelling into account. - if (filter_length_blocks == 12) { + for (float echo_to_nearend_power : echo_to_nearend_powers) { + EXPECT_GT(0.1f, echo_to_nearend_power); + } + } + } +} + +// Verifies that the subtractor is able to converge on correlated data. +TEST(Subtractor, ConvergenceMultiChannel) { + std::vector blocks_with_echo_path_changes; + for (size_t num_render_channels : {1, 2, 4, 8}) { + for (size_t num_capture_channels : {1, 2, 4}) { + SCOPED_TRACE( + ProduceDebugText(num_render_channels, num_render_channels, 64, 20)); + size_t num_blocks_to_process = 2500 * num_render_channels; + std::vector echo_to_nearend_powers = RunSubtractorTest( + num_render_channels, num_capture_channels, num_blocks_to_process, 64, + 20, 20, false, blocks_with_echo_path_changes); + + for (float echo_to_nearend_power : echo_to_nearend_powers) { EXPECT_GT(0.1f, echo_to_nearend_power); - } else { - EXPECT_GT(1.f, echo_to_nearend_power); } } } @@ -170,18 +249,22 @@ TEST(Subtractor, Convergence) { // is longer than the shadow filter. TEST(Subtractor, MainFilterLongerThanShadowFilter) { std::vector blocks_with_echo_path_changes; - float echo_to_nearend_power = - RunSubtractorTest(400, 64, 20, 15, false, blocks_with_echo_path_changes); - EXPECT_GT(0.5f, echo_to_nearend_power); + std::vector echo_to_nearend_powers = RunSubtractorTest( + 1, 1, 400, 64, 20, 15, false, blocks_with_echo_path_changes); + for (float echo_to_nearend_power : echo_to_nearend_powers) { + EXPECT_GT(0.5f, echo_to_nearend_power); + } } // Verifies that the subtractor is able to handle the case when the shadow // filter is longer than the main filter. TEST(Subtractor, ShadowFilterLongerThanMainFilter) { std::vector blocks_with_echo_path_changes; - float echo_to_nearend_power = - RunSubtractorTest(400, 64, 15, 20, false, blocks_with_echo_path_changes); - EXPECT_GT(0.5f, echo_to_nearend_power); + std::vector echo_to_nearend_powers = RunSubtractorTest( + 1, 1, 400, 64, 15, 20, false, blocks_with_echo_path_changes); + for (float echo_to_nearend_power : echo_to_nearend_powers) { + EXPECT_GT(0.5f, echo_to_nearend_power); + } } // Verifies that the subtractor does not converge on uncorrelated signals. @@ -189,12 +272,33 @@ TEST(Subtractor, NonConvergenceOnUncorrelatedSignals) { std::vector blocks_with_echo_path_changes; for (size_t filter_length_blocks : {12, 20, 30}) { for (size_t delay_samples : {0, 64, 150, 200, 301}) { - SCOPED_TRACE(ProduceDebugText(delay_samples, filter_length_blocks)); + SCOPED_TRACE(ProduceDebugText(1, 1, delay_samples, filter_length_blocks)); - float echo_to_nearend_power = RunSubtractorTest( - 300, delay_samples, filter_length_blocks, filter_length_blocks, true, - blocks_with_echo_path_changes); - EXPECT_NEAR(1.f, echo_to_nearend_power, 0.1); + std::vector echo_to_nearend_powers = RunSubtractorTest( + 1, 1, 3000, delay_samples, filter_length_blocks, filter_length_blocks, + true, blocks_with_echo_path_changes); + for (float echo_to_nearend_power : echo_to_nearend_powers) { + EXPECT_NEAR(1.f, echo_to_nearend_power, 0.1); + } + } + } +} + +// Verifies that the subtractor does not converge on uncorrelated signals. +TEST(Subtractor, NonConvergenceOnUncorrelatedSignalsMultiChannel) { + std::vector blocks_with_echo_path_changes; + for (size_t num_render_channels : {1, 2, 4}) { + for (size_t num_capture_channels : {1, 2, 4}) { + SCOPED_TRACE( + ProduceDebugText(num_render_channels, num_render_channels, 64, 20)); + size_t num_blocks_to_process = 5000 * num_render_channels; + std::vector echo_to_nearend_powers = RunSubtractorTest( + num_render_channels, num_capture_channels, num_blocks_to_process, 64, + 20, 20, true, blocks_with_echo_path_changes); + for (float echo_to_nearend_power : echo_to_nearend_powers) { + EXPECT_LT(.8f, echo_to_nearend_power); + EXPECT_NEAR(1.f, echo_to_nearend_power, 0.25f); + } } } } diff --git a/modules/audio_processing/aec3/suppression_gain_unittest.cc b/modules/audio_processing/aec3/suppression_gain_unittest.cc index 490c7ec0cd..465227ccec 100644 --- a/modules/audio_processing/aec3/suppression_gain_unittest.cc +++ b/modules/audio_processing/aec3/suppression_gain_unittest.cc @@ -97,14 +97,14 @@ TEST(SuppressionGain, BasicGainComputation) { // Ensure that the gain is no longer forced to zero. for (int k = 0; k <= kNumBlocksPerSecond / 5 + 1; ++k) { - aec_state.Update(delay_estimate, subtractor.FilterFrequencyResponse(), - subtractor.FilterImpulseResponse(), + aec_state.Update(delay_estimate, subtractor.FilterFrequencyResponse()[0], + subtractor.FilterImpulseResponse()[0], *render_delay_buffer->GetRenderBuffer(), E2, Y2, output); } for (int k = 0; k < 100; ++k) { - aec_state.Update(delay_estimate, subtractor.FilterFrequencyResponse(), - subtractor.FilterImpulseResponse(), + aec_state.Update(delay_estimate, subtractor.FilterFrequencyResponse()[0], + subtractor.FilterImpulseResponse()[0], *render_delay_buffer->GetRenderBuffer(), E2, Y2, output); suppression_gain.GetGain(E2, S2, R2, N2, analyzer, aec_state, x, &high_bands_gain, &g); @@ -120,8 +120,8 @@ TEST(SuppressionGain, BasicGainComputation) { N2.fill(0.f); for (int k = 0; k < 100; ++k) { - aec_state.Update(delay_estimate, subtractor.FilterFrequencyResponse(), - subtractor.FilterImpulseResponse(), + aec_state.Update(delay_estimate, subtractor.FilterFrequencyResponse()[0], + subtractor.FilterImpulseResponse()[0], *render_delay_buffer->GetRenderBuffer(), E2, Y2, output); suppression_gain.GetGain(E2, S2, R2, N2, analyzer, aec_state, x, &high_bands_gain, &g); diff --git a/modules/audio_processing/test/echo_canceller_test_tools.h b/modules/audio_processing/test/echo_canceller_test_tools.h index bab7f273e9..0d70cd39c6 100644 --- a/modules/audio_processing/test/echo_canceller_test_tools.h +++ b/modules/audio_processing/test/echo_canceller_test_tools.h @@ -15,7 +15,6 @@ #include #include "api/array_view.h" -#include "rtc_base/constructor_magic.h" #include "rtc_base/random.h" namespace webrtc { @@ -41,7 +40,6 @@ class DelayBuffer { private: std::vector buffer_; size_t next_insert_index_ = 0; - RTC_DISALLOW_IMPLICIT_CONSTRUCTORS(DelayBuffer); }; } // namespace webrtc