Files
platform-external-webrtc/modules/audio_processing/agc2/rnn_vad/pitch_search.cc
Alessio Bazzica fd5dadbea9 RNN VAD: use VectorMath::DotProduct() for pitch search
This CL brings a large improvement to the RNN VAD CPU performance
by finally using `VectorMath::DotProduct()` for pitch search.

The realtime factor improved from about 390x to 570x for SSE2
(+180x, 45% faster) and to 610x for AVX2 (+235x, 60% faster).

RNN VAD benchmark results:
```
+-----+-------+------+------+
| run | none* | SSE2 | AVX2 |
+-----+-------+------+------+
|   1 | 393x  | 572x | 618x |
|   2 | 388x  | 568x | 607x |
|   3 | 393x  | 564x | 599x |
+-----+-------+------+------+
```
*: baseline, no SIMD used for pitch search, but SSE2 used for the RNN

Results obtained as follows:
1. Force SSE2 in `DISABLED_RnnVadPerformance` for the RNN part in
   order to measure the baseline correctly:
```
RnnBasedVad rnn_vad({/*sse2=*/true, /*avx2=*/true, /*neon=*/false});
```
2. Run the test:
```
$ ./out/release/modules_unittests \
  --gtest_filter=*RnnVadTest*DISABLED_RnnVadPerformance* \
  --gtest_also_run_disabled_tests --logs
```

Bug: webrtc:10480
Change-Id: I89a2bd420265540026944b9c0f1fdd4bfda7f475
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/195001
Reviewed-by: Gustaf Ullberg <gustaf@webrtc.org>
Commit-Queue: Alessio Bazzica <alessiob@webrtc.org>
Cr-Commit-Position: refs/heads/master@{#32755}
2020-12-03 11:50:09 +00:00

71 lines
2.9 KiB
C++

/*
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#include "modules/audio_processing/agc2/rnn_vad/pitch_search.h"
#include <array>
#include <cstddef>
#include "rtc_base/checks.h"
namespace webrtc {
namespace rnn_vad {
PitchEstimator::PitchEstimator(const AvailableCpuFeatures& cpu_features)
: cpu_features_(cpu_features),
y_energy_24kHz_(kRefineNumLags24kHz, 0.f),
pitch_buffer_12kHz_(kBufSize12kHz),
auto_correlation_12kHz_(kNumLags12kHz) {}
PitchEstimator::~PitchEstimator() = default;
int PitchEstimator::Estimate(
rtc::ArrayView<const float, kBufSize24kHz> pitch_buffer) {
rtc::ArrayView<float, kBufSize12kHz> pitch_buffer_12kHz_view(
pitch_buffer_12kHz_.data(), kBufSize12kHz);
RTC_DCHECK_EQ(pitch_buffer_12kHz_.size(), pitch_buffer_12kHz_view.size());
rtc::ArrayView<float, kNumLags12kHz> auto_correlation_12kHz_view(
auto_correlation_12kHz_.data(), kNumLags12kHz);
RTC_DCHECK_EQ(auto_correlation_12kHz_.size(),
auto_correlation_12kHz_view.size());
// TODO(bugs.chromium.org/10480): Use `cpu_features_` to estimate pitch.
// Perform the initial pitch search at 12 kHz.
Decimate2x(pitch_buffer, pitch_buffer_12kHz_view);
auto_corr_calculator_.ComputeOnPitchBuffer(pitch_buffer_12kHz_view,
auto_correlation_12kHz_view);
CandidatePitchPeriods pitch_periods = ComputePitchPeriod12kHz(
pitch_buffer_12kHz_view, auto_correlation_12kHz_view, cpu_features_);
// The refinement is done using the pitch buffer that contains 24 kHz samples.
// Therefore, adapt the inverted lags in |pitch_candidates_inv_lags| from 12
// to 24 kHz.
pitch_periods.best *= 2;
pitch_periods.second_best *= 2;
// Refine the initial pitch period estimation from 12 kHz to 48 kHz.
// Pre-compute frame energies at 24 kHz.
rtc::ArrayView<float, kRefineNumLags24kHz> y_energy_24kHz_view(
y_energy_24kHz_.data(), kRefineNumLags24kHz);
RTC_DCHECK_EQ(y_energy_24kHz_.size(), y_energy_24kHz_view.size());
ComputeSlidingFrameSquareEnergies24kHz(pitch_buffer, y_energy_24kHz_view,
cpu_features_);
// Estimation at 48 kHz.
const int pitch_lag_48kHz = ComputePitchPeriod48kHz(
pitch_buffer, y_energy_24kHz_view, pitch_periods, cpu_features_);
last_pitch_48kHz_ = ComputeExtendedPitchPeriod48kHz(
pitch_buffer, y_energy_24kHz_view,
/*initial_pitch_period_48kHz=*/kMaxPitch48kHz - pitch_lag_48kHz,
last_pitch_48kHz_, cpu_features_);
return last_pitch_48kHz_.period;
}
} // namespace rnn_vad
} // namespace webrtc