RNN VAD: pitch search optimizations (part 2)

This CL brings a large improvement to the VAD by precomputing the
energy for the sliding frame `y` in the pitch buffer instead of
computing them twice in two different places. The realtime factor
has improved by about +16x.

There is room for additional improvement (TODOs added), but that will
be done in a follow up CL since the change won't be bit-exact and
careful testing is needed.

Benchmarked as follows:
```
out/release/modules_unittests \
  --gtest_filter=*RnnVadTest.DISABLED_RnnVadPerformance* \
  --gtest_also_run_disabled_tests --logs
```

Results:

      | baseline             | this CL
------+----------------------+------------------------
run 1 | 23.568 +/- 0.990788  | 22.8319 +/- 1.46554
      | 377.207x             | 389.367x
------+----------------------+------------------------
run 2 | 23.3714 +/- 0.857523 | 22.4286 +/- 0.726449
      | 380.379x             | 396.369x
------+----------------------+------------------------
run 2 | 23.709 +/- 1.04477   | 22.5688 +/- 0.831341
      | 374.963x             | 393.906x

Bug: webrtc:10480
Change-Id: I599a4dda2bde16dc6c2f42cf89e96afbd4630311
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/191484
Reviewed-by: Per Åhgren <peah@webrtc.org>
Commit-Queue: Alessio Bazzica <alessiob@webrtc.org>
Cr-Commit-Position: refs/heads/master@{#32571}
This commit is contained in:
Alessio Bazzica
2020-11-09 15:40:14 +01:00
committed by Commit Bot
parent 1f99551775
commit 2f7d1c62e2
5 changed files with 87 additions and 77 deletions

View File

@ -19,37 +19,46 @@ namespace webrtc {
namespace rnn_vad {
PitchEstimator::PitchEstimator()
: pitch_buf_decimated_(kBufSize12kHz),
pitch_buf_decimated_view_(pitch_buf_decimated_.data(), kBufSize12kHz),
auto_corr_(kNumLags12kHz),
auto_corr_view_(auto_corr_.data(), kNumLags12kHz) {
RTC_DCHECK_EQ(kBufSize12kHz, pitch_buf_decimated_.size());
RTC_DCHECK_EQ(kNumLags12kHz, auto_corr_view_.size());
}
: y_energy_24kHz_(kRefineNumLags24kHz, 0.f),
pitch_buffer_12kHz_(kBufSize12kHz),
auto_correlation_12kHz_(kNumLags12kHz) {}
PitchEstimator::~PitchEstimator() = default;
int PitchEstimator::Estimate(
rtc::ArrayView<const float, kBufSize24kHz> pitch_buffer) {
rtc::ArrayView<float, kBufSize12kHz> pitch_buffer_12kHz_view(
pitch_buffer_12kHz_.data(), kBufSize12kHz);
RTC_DCHECK_EQ(pitch_buffer_12kHz_.size(), pitch_buffer_12kHz_view.size());
rtc::ArrayView<float, kNumLags12kHz> auto_correlation_12kHz_view(
auto_correlation_12kHz_.data(), kNumLags12kHz);
RTC_DCHECK_EQ(auto_correlation_12kHz_.size(),
auto_correlation_12kHz_view.size());
// Perform the initial pitch search at 12 kHz.
Decimate2x(pitch_buffer, pitch_buf_decimated_view_);
auto_corr_calculator_.ComputeOnPitchBuffer(pitch_buf_decimated_view_,
auto_corr_view_);
CandidatePitchPeriods pitch_candidates_inverted_lags =
ComputePitchPeriod12kHz(pitch_buf_decimated_view_, auto_corr_view_);
// Refine the pitch period estimation.
Decimate2x(pitch_buffer, pitch_buffer_12kHz_view);
auto_corr_calculator_.ComputeOnPitchBuffer(pitch_buffer_12kHz_view,
auto_correlation_12kHz_view);
CandidatePitchPeriods pitch_periods = ComputePitchPeriod12kHz(
pitch_buffer_12kHz_view, auto_correlation_12kHz_view);
// The refinement is done using the pitch buffer that contains 24 kHz samples.
// Therefore, adapt the inverted lags in |pitch_candidates_inv_lags| from 12
// to 24 kHz.
pitch_candidates_inverted_lags.best *= 2;
pitch_candidates_inverted_lags.second_best *= 2;
const int pitch_inv_lag_48kHz =
ComputePitchPeriod48kHz(pitch_buffer, pitch_candidates_inverted_lags);
// Look for stronger harmonics to find the final pitch period and its gain.
RTC_DCHECK_LT(pitch_inv_lag_48kHz, kMaxPitch48kHz);
pitch_periods.best *= 2;
pitch_periods.second_best *= 2;
// Refine the initial pitch period estimation from 12 kHz to 48 kHz.
// Pre-compute frame energies at 24 kHz.
rtc::ArrayView<float, kRefineNumLags24kHz> y_energy_24kHz_view(
y_energy_24kHz_.data(), kRefineNumLags24kHz);
RTC_DCHECK_EQ(y_energy_24kHz_.size(), y_energy_24kHz_view.size());
ComputeSlidingFrameSquareEnergies24kHz(pitch_buffer, y_energy_24kHz_view);
// Estimation at 48 kHz.
const int pitch_lag_48kHz =
ComputePitchPeriod48kHz(pitch_buffer, y_energy_24kHz_view, pitch_periods);
last_pitch_48kHz_ = ComputeExtendedPitchPeriod48kHz(
pitch_buffer,
/*initial_pitch_period_48kHz=*/kMaxPitch48kHz - pitch_inv_lag_48kHz,
pitch_buffer, y_energy_24kHz_view,
/*initial_pitch_period_48kHz=*/kMaxPitch48kHz - pitch_lag_48kHz,
last_pitch_48kHz_);
return last_pitch_48kHz_.period;
}

View File

@ -41,10 +41,9 @@ class PitchEstimator {
PitchInfo last_pitch_48kHz_{};
AutoCorrelationCalculator auto_corr_calculator_;
std::vector<float> pitch_buf_decimated_;
rtc::ArrayView<float, kBufSize12kHz> pitch_buf_decimated_view_;
std::vector<float> auto_corr_;
rtc::ArrayView<float, kNumLags12kHz> auto_corr_view_;
std::vector<float> y_energy_24kHz_;
std::vector<float> pitch_buffer_12kHz_;
std::vector<float> auto_correlation_12kHz_;
};
} // namespace rnn_vad

View File

@ -153,17 +153,12 @@ void ComputeAutoCorrelation(
}
}
int FindBestPitchPeriods24kHz(
int ComputePitchPeriod24kHz(
rtc::ArrayView<const float, kBufSize24kHz> pitch_buffer,
rtc::ArrayView<const float, kInitialNumLags24kHz> auto_correlation,
rtc::ArrayView<const float, kBufSize24kHz> pitch_buffer) {
rtc::ArrayView<const float, kRefineNumLags24kHz> y_energy) {
static_assert(kMaxPitch24kHz > kInitialNumLags24kHz, "");
static_assert(kMaxPitch24kHz < kBufSize24kHz, "");
// Initialize the sliding 20 ms frame energy.
// TODO(bugs.webrtc.org/9076): Maybe optimize using vectorization.
float denominator = std::inner_product(
pitch_buffer.begin(), pitch_buffer.begin() + kFrameSize20ms24kHz + 1,
pitch_buffer.begin(), 1.f);
// Search best pitch by looking at the scaled auto-correlation.
int best_inverted_lag = 0; // Pitch period.
float best_numerator = -1.f; // Pitch strength numerator.
float best_denominator = 0.f; // Pitch strength denominator.
@ -171,8 +166,10 @@ int FindBestPitchPeriods24kHz(
++inverted_lag) {
// A pitch candidate must have positive correlation.
if (auto_correlation[inverted_lag] > 0.f) {
// Auto-correlation energy normalized by frame energy.
const float numerator =
auto_correlation[inverted_lag] * auto_correlation[inverted_lag];
const float denominator = y_energy[kMaxPitch24kHz - inverted_lag];
// Compare numerator/denominator ratios without using divisions.
if (numerator * best_denominator > best_numerator * denominator) {
best_inverted_lag = inverted_lag;
@ -180,14 +177,6 @@ int FindBestPitchPeriods24kHz(
best_denominator = denominator;
}
}
// Update |denominator| for the next inverted lag.
static_assert(kInitialNumLags24kHz + kFrameSize20ms24kHz < kBufSize24kHz,
"");
const float y_old = pitch_buffer[inverted_lag];
const float y_new = pitch_buffer[inverted_lag + kFrameSize20ms24kHz];
denominator -= y_old * y_old;
denominator += y_new * y_new;
denominator = std::max(0.f, denominator);
}
return best_inverted_lag;
}
@ -338,6 +327,7 @@ CandidatePitchPeriods ComputePitchPeriod12kHz(
int ComputePitchPeriod48kHz(
rtc::ArrayView<const float, kBufSize24kHz> pitch_buffer,
rtc::ArrayView<const float, kRefineNumLags24kHz> y_energy,
CandidatePitchPeriods pitch_candidates) {
// Compute the auto-correlation terms only for neighbors of the given pitch
// candidates (similar to what is done in ComputePitchAutoCorrelation(), but
@ -362,7 +352,7 @@ int ComputePitchPeriod48kHz(
}
// Find best pitch at 24 kHz.
const int pitch_candidate_24kHz =
FindBestPitchPeriods24kHz(auto_correlation, pitch_buffer);
ComputePitchPeriod24kHz(pitch_buffer, auto_correlation, y_energy);
// Pseudo-interpolation.
return PitchPseudoInterpolationInvLagAutoCorr(pitch_candidate_24kHz,
auto_correlation);
@ -370,6 +360,7 @@ int ComputePitchPeriod48kHz(
PitchInfo ComputeExtendedPitchPeriod48kHz(
rtc::ArrayView<const float, kBufSize24kHz> pitch_buffer,
rtc::ArrayView<const float, kRefineNumLags24kHz> y_energy,
int initial_pitch_period_48kHz,
PitchInfo last_pitch_48kHz) {
RTC_DCHECK_LE(kMinPitch48kHz, initial_pitch_period_48kHz);
@ -379,34 +370,30 @@ PitchInfo ComputeExtendedPitchPeriod48kHz(
struct RefinedPitchCandidate {
int period;
float strength;
// Additional strength data used for the final estimation of the strength.
float xy; // Cross-correlation.
float yy; // Auto-correlation.
// Additional strength data used for the final pitch estimation.
float xy; // Auto-correlation.
float y_energy; // Energy of the sliding frame `y`.
};
// Initialize.
std::array<float, kRefineNumLags24kHz> yy_values;
// TODO(bugs.webrtc.org/9076): Reuse values from FindBestPitchPeriods24kHz().
ComputeSlidingFrameSquareEnergies24kHz(pitch_buffer, yy_values);
const float xx = yy_values[0];
const auto pitch_strength = [](float xy, float yy, float xx) {
RTC_DCHECK_GE(xx * yy, 0.f);
return xy / std::sqrt(1.f + xx * yy);
const float x_energy = y_energy[0];
const auto pitch_strength = [x_energy](float xy, float y_energy) {
RTC_DCHECK_GE(x_energy * y_energy, 0.f);
return xy / std::sqrt(1.f + x_energy * y_energy);
};
// Initial pitch candidate.
// Initialize the best pitch candidate with `initial_pitch_period_48kHz`.
RefinedPitchCandidate best_pitch;
best_pitch.period =
std::min(initial_pitch_period_48kHz / 2, kMaxPitch24kHz - 1);
best_pitch.xy =
ComputeAutoCorrelation(kMaxPitch24kHz - best_pitch.period, pitch_buffer);
best_pitch.yy = yy_values[best_pitch.period];
best_pitch.strength = pitch_strength(best_pitch.xy, best_pitch.yy, xx);
// 24 kHz version of the last estimated pitch and copy of the initial
// estimation.
best_pitch.y_energy = y_energy[best_pitch.period];
best_pitch.strength = pitch_strength(best_pitch.xy, best_pitch.y_energy);
// Keep a copy of the initial pitch candidate.
const PitchInfo initial_pitch{best_pitch.period, best_pitch.strength};
// 24 kHz version of the last estimated pitch.
const PitchInfo last_pitch{last_pitch_48kHz.period / 2,
last_pitch_48kHz.strength};
const PitchInfo initial_pitch{best_pitch.period, best_pitch.strength};
// Find `max_period_divisor` such that the result of
// `GetAlternativePitchPeriod(initial_pitch_period, 1, max_period_divisor)`
@ -436,14 +423,14 @@ PitchInfo ComputeExtendedPitchPeriod48kHz(
// Compute an auto-correlation score for the primary pitch candidate
// |alternative_pitch.period| by also looking at its possible sub-harmonic
// |dual_alternative_period|.
float xy_primary_period = ComputeAutoCorrelation(
const float xy_primary_period = ComputeAutoCorrelation(
kMaxPitch24kHz - alternative_pitch.period, pitch_buffer);
float xy_secondary_period = ComputeAutoCorrelation(
const float xy_secondary_period = ComputeAutoCorrelation(
kMaxPitch24kHz - dual_alternative_period, pitch_buffer);
float xy = 0.5f * (xy_primary_period + xy_secondary_period);
float yy = 0.5f * (yy_values[alternative_pitch.period] +
yy_values[dual_alternative_period]);
alternative_pitch.strength = pitch_strength(xy, yy, xx);
const float xy = 0.5f * (xy_primary_period + xy_secondary_period);
const float yy = 0.5f * (y_energy[alternative_pitch.period] +
y_energy[dual_alternative_period]);
alternative_pitch.strength = pitch_strength(xy, yy);
// Maybe update best period.
if (IsAlternativePitchStrongerThanInitial(
@ -455,10 +442,11 @@ PitchInfo ComputeExtendedPitchPeriod48kHz(
// Final pitch strength and period.
best_pitch.xy = std::max(0.f, best_pitch.xy);
RTC_DCHECK_LE(0.f, best_pitch.yy);
float final_pitch_strength = (best_pitch.yy <= best_pitch.xy)
? 1.f
: best_pitch.xy / (best_pitch.yy + 1.f);
RTC_DCHECK_LE(0.f, best_pitch.y_energy);
float final_pitch_strength =
(best_pitch.y_energy <= best_pitch.xy)
? 1.f
: best_pitch.xy / (best_pitch.y_energy + 1.f);
final_pitch_strength = std::min(best_pitch.strength, final_pitch_strength);
int final_pitch_period_48kHz = std::max(
kMinPitch48kHz,

View File

@ -80,10 +80,12 @@ CandidatePitchPeriods ComputePitchPeriod12kHz(
rtc::ArrayView<const float, kBufSize12kHz> pitch_buffer,
rtc::ArrayView<const float, kNumLags12kHz> auto_correlation);
// Computes the pitch period at 48 kHz given a view on the 24 kHz pitch buffer
// and the pitch period candidates at 24 kHz (encoded as inverted lag).
// Computes the pitch period at 48 kHz given a view on the 24 kHz pitch buffer,
// the energies for the sliding frames `y` at 24 kHz and the pitch period
// candidates at 24 kHz (encoded as inverted lag).
int ComputePitchPeriod48kHz(
rtc::ArrayView<const float, kBufSize24kHz> pitch_buffer,
rtc::ArrayView<const float, kRefineNumLags24kHz> y_energy,
CandidatePitchPeriods pitch_candidates_24kHz);
struct PitchInfo {
@ -92,10 +94,12 @@ struct PitchInfo {
};
// Computes the pitch period at 48 kHz searching in an extended pitch range
// given a view on the 24 kHz pitch buffer, the initial 48 kHz estimation
// (computed by `ComputePitchPeriod48kHz()`) and the last estimated pitch.
// given a view on the 24 kHz pitch buffer, the energies for the sliding frames
// `y` at 24 kHz, the initial 48 kHz estimation (computed by
// `ComputePitchPeriod48kHz()`) and the last estimated pitch.
PitchInfo ComputeExtendedPitchPeriod48kHz(
rtc::ArrayView<const float, kBufSize24kHz> pitch_buffer,
rtc::ArrayView<const float, kRefineNumLags24kHz> y_energy,
int initial_pitch_period_48kHz,
PitchInfo last_pitch_48kHz);

View File

@ -63,12 +63,17 @@ TEST(RnnVadTest, ComputePitchPeriod12kHzBitExactness) {
// Checks that the refined pitch period is bit-exact given test input data.
TEST(RnnVadTest, ComputePitchPeriod48kHzBitExactness) {
PitchTestData test_data;
std::vector<float> y_energy(kMaxPitch24kHz + 1);
rtc::ArrayView<float, kMaxPitch24kHz + 1> y_energy_view(y_energy.data(),
kMaxPitch24kHz + 1);
ComputeSlidingFrameSquareEnergies24kHz(test_data.GetPitchBufView(),
y_energy_view);
// TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
// FloatingPointExceptionObserver fpe_observer;
EXPECT_EQ(ComputePitchPeriod48kHz(test_data.GetPitchBufView(),
EXPECT_EQ(ComputePitchPeriod48kHz(test_data.GetPitchBufView(), y_energy_view,
/*pitch_candidates=*/{280, 284}),
560);
EXPECT_EQ(ComputePitchPeriod48kHz(test_data.GetPitchBufView(),
EXPECT_EQ(ComputePitchPeriod48kHz(test_data.GetPitchBufView(), y_energy_view,
/*pitch_candidates=*/{260, 284}),
568);
}
@ -90,10 +95,15 @@ class ComputeExtendedPitchPeriod48kHzTest
TEST_P(ComputeExtendedPitchPeriod48kHzTest,
PeriodBitExactnessGainWithinTolerance) {
PitchTestData test_data;
std::vector<float> y_energy(kMaxPitch24kHz + 1);
rtc::ArrayView<float, kMaxPitch24kHz + 1> y_energy_view(y_energy.data(),
kMaxPitch24kHz + 1);
ComputeSlidingFrameSquareEnergies24kHz(test_data.GetPitchBufView(),
y_energy_view);
// TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
// FloatingPointExceptionObserver fpe_observer;
const auto computed_output = ComputeExtendedPitchPeriod48kHz(
test_data.GetPitchBufView(), GetInitialPitchPeriod(),
test_data.GetPitchBufView(), y_energy_view, GetInitialPitchPeriod(),
{GetLastPitchPeriod(), GetLastPitchStrength()});
EXPECT_EQ(GetExpectedPitchPeriod(), computed_output.period);
EXPECT_NEAR(GetExpectedPitchStrength(), computed_output.strength, 1e-6f);