RNN VAD: pitch search optimizations (part 2)
This CL brings a large improvement to the VAD by precomputing the energy for the sliding frame `y` in the pitch buffer instead of computing them twice in two different places. The realtime factor has improved by about +16x. There is room for additional improvement (TODOs added), but that will be done in a follow up CL since the change won't be bit-exact and careful testing is needed. Benchmarked as follows: ``` out/release/modules_unittests \ --gtest_filter=*RnnVadTest.DISABLED_RnnVadPerformance* \ --gtest_also_run_disabled_tests --logs ``` Results: | baseline | this CL ------+----------------------+------------------------ run 1 | 23.568 +/- 0.990788 | 22.8319 +/- 1.46554 | 377.207x | 389.367x ------+----------------------+------------------------ run 2 | 23.3714 +/- 0.857523 | 22.4286 +/- 0.726449 | 380.379x | 396.369x ------+----------------------+------------------------ run 2 | 23.709 +/- 1.04477 | 22.5688 +/- 0.831341 | 374.963x | 393.906x Bug: webrtc:10480 Change-Id: I599a4dda2bde16dc6c2f42cf89e96afbd4630311 Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/191484 Reviewed-by: Per Åhgren <peah@webrtc.org> Commit-Queue: Alessio Bazzica <alessiob@webrtc.org> Cr-Commit-Position: refs/heads/master@{#32571}
This commit is contained in:

committed by
Commit Bot

parent
1f99551775
commit
2f7d1c62e2
@ -19,37 +19,46 @@ namespace webrtc {
|
||||
namespace rnn_vad {
|
||||
|
||||
PitchEstimator::PitchEstimator()
|
||||
: pitch_buf_decimated_(kBufSize12kHz),
|
||||
pitch_buf_decimated_view_(pitch_buf_decimated_.data(), kBufSize12kHz),
|
||||
auto_corr_(kNumLags12kHz),
|
||||
auto_corr_view_(auto_corr_.data(), kNumLags12kHz) {
|
||||
RTC_DCHECK_EQ(kBufSize12kHz, pitch_buf_decimated_.size());
|
||||
RTC_DCHECK_EQ(kNumLags12kHz, auto_corr_view_.size());
|
||||
}
|
||||
: y_energy_24kHz_(kRefineNumLags24kHz, 0.f),
|
||||
pitch_buffer_12kHz_(kBufSize12kHz),
|
||||
auto_correlation_12kHz_(kNumLags12kHz) {}
|
||||
|
||||
PitchEstimator::~PitchEstimator() = default;
|
||||
|
||||
int PitchEstimator::Estimate(
|
||||
rtc::ArrayView<const float, kBufSize24kHz> pitch_buffer) {
|
||||
rtc::ArrayView<float, kBufSize12kHz> pitch_buffer_12kHz_view(
|
||||
pitch_buffer_12kHz_.data(), kBufSize12kHz);
|
||||
RTC_DCHECK_EQ(pitch_buffer_12kHz_.size(), pitch_buffer_12kHz_view.size());
|
||||
rtc::ArrayView<float, kNumLags12kHz> auto_correlation_12kHz_view(
|
||||
auto_correlation_12kHz_.data(), kNumLags12kHz);
|
||||
RTC_DCHECK_EQ(auto_correlation_12kHz_.size(),
|
||||
auto_correlation_12kHz_view.size());
|
||||
|
||||
// Perform the initial pitch search at 12 kHz.
|
||||
Decimate2x(pitch_buffer, pitch_buf_decimated_view_);
|
||||
auto_corr_calculator_.ComputeOnPitchBuffer(pitch_buf_decimated_view_,
|
||||
auto_corr_view_);
|
||||
CandidatePitchPeriods pitch_candidates_inverted_lags =
|
||||
ComputePitchPeriod12kHz(pitch_buf_decimated_view_, auto_corr_view_);
|
||||
// Refine the pitch period estimation.
|
||||
Decimate2x(pitch_buffer, pitch_buffer_12kHz_view);
|
||||
auto_corr_calculator_.ComputeOnPitchBuffer(pitch_buffer_12kHz_view,
|
||||
auto_correlation_12kHz_view);
|
||||
CandidatePitchPeriods pitch_periods = ComputePitchPeriod12kHz(
|
||||
pitch_buffer_12kHz_view, auto_correlation_12kHz_view);
|
||||
// The refinement is done using the pitch buffer that contains 24 kHz samples.
|
||||
// Therefore, adapt the inverted lags in |pitch_candidates_inv_lags| from 12
|
||||
// to 24 kHz.
|
||||
pitch_candidates_inverted_lags.best *= 2;
|
||||
pitch_candidates_inverted_lags.second_best *= 2;
|
||||
const int pitch_inv_lag_48kHz =
|
||||
ComputePitchPeriod48kHz(pitch_buffer, pitch_candidates_inverted_lags);
|
||||
// Look for stronger harmonics to find the final pitch period and its gain.
|
||||
RTC_DCHECK_LT(pitch_inv_lag_48kHz, kMaxPitch48kHz);
|
||||
pitch_periods.best *= 2;
|
||||
pitch_periods.second_best *= 2;
|
||||
|
||||
// Refine the initial pitch period estimation from 12 kHz to 48 kHz.
|
||||
// Pre-compute frame energies at 24 kHz.
|
||||
rtc::ArrayView<float, kRefineNumLags24kHz> y_energy_24kHz_view(
|
||||
y_energy_24kHz_.data(), kRefineNumLags24kHz);
|
||||
RTC_DCHECK_EQ(y_energy_24kHz_.size(), y_energy_24kHz_view.size());
|
||||
ComputeSlidingFrameSquareEnergies24kHz(pitch_buffer, y_energy_24kHz_view);
|
||||
// Estimation at 48 kHz.
|
||||
const int pitch_lag_48kHz =
|
||||
ComputePitchPeriod48kHz(pitch_buffer, y_energy_24kHz_view, pitch_periods);
|
||||
last_pitch_48kHz_ = ComputeExtendedPitchPeriod48kHz(
|
||||
pitch_buffer,
|
||||
/*initial_pitch_period_48kHz=*/kMaxPitch48kHz - pitch_inv_lag_48kHz,
|
||||
pitch_buffer, y_energy_24kHz_view,
|
||||
/*initial_pitch_period_48kHz=*/kMaxPitch48kHz - pitch_lag_48kHz,
|
||||
last_pitch_48kHz_);
|
||||
return last_pitch_48kHz_.period;
|
||||
}
|
||||
|
@ -41,10 +41,9 @@ class PitchEstimator {
|
||||
|
||||
PitchInfo last_pitch_48kHz_{};
|
||||
AutoCorrelationCalculator auto_corr_calculator_;
|
||||
std::vector<float> pitch_buf_decimated_;
|
||||
rtc::ArrayView<float, kBufSize12kHz> pitch_buf_decimated_view_;
|
||||
std::vector<float> auto_corr_;
|
||||
rtc::ArrayView<float, kNumLags12kHz> auto_corr_view_;
|
||||
std::vector<float> y_energy_24kHz_;
|
||||
std::vector<float> pitch_buffer_12kHz_;
|
||||
std::vector<float> auto_correlation_12kHz_;
|
||||
};
|
||||
|
||||
} // namespace rnn_vad
|
||||
|
@ -153,17 +153,12 @@ void ComputeAutoCorrelation(
|
||||
}
|
||||
}
|
||||
|
||||
int FindBestPitchPeriods24kHz(
|
||||
int ComputePitchPeriod24kHz(
|
||||
rtc::ArrayView<const float, kBufSize24kHz> pitch_buffer,
|
||||
rtc::ArrayView<const float, kInitialNumLags24kHz> auto_correlation,
|
||||
rtc::ArrayView<const float, kBufSize24kHz> pitch_buffer) {
|
||||
rtc::ArrayView<const float, kRefineNumLags24kHz> y_energy) {
|
||||
static_assert(kMaxPitch24kHz > kInitialNumLags24kHz, "");
|
||||
static_assert(kMaxPitch24kHz < kBufSize24kHz, "");
|
||||
// Initialize the sliding 20 ms frame energy.
|
||||
// TODO(bugs.webrtc.org/9076): Maybe optimize using vectorization.
|
||||
float denominator = std::inner_product(
|
||||
pitch_buffer.begin(), pitch_buffer.begin() + kFrameSize20ms24kHz + 1,
|
||||
pitch_buffer.begin(), 1.f);
|
||||
// Search best pitch by looking at the scaled auto-correlation.
|
||||
int best_inverted_lag = 0; // Pitch period.
|
||||
float best_numerator = -1.f; // Pitch strength numerator.
|
||||
float best_denominator = 0.f; // Pitch strength denominator.
|
||||
@ -171,8 +166,10 @@ int FindBestPitchPeriods24kHz(
|
||||
++inverted_lag) {
|
||||
// A pitch candidate must have positive correlation.
|
||||
if (auto_correlation[inverted_lag] > 0.f) {
|
||||
// Auto-correlation energy normalized by frame energy.
|
||||
const float numerator =
|
||||
auto_correlation[inverted_lag] * auto_correlation[inverted_lag];
|
||||
const float denominator = y_energy[kMaxPitch24kHz - inverted_lag];
|
||||
// Compare numerator/denominator ratios without using divisions.
|
||||
if (numerator * best_denominator > best_numerator * denominator) {
|
||||
best_inverted_lag = inverted_lag;
|
||||
@ -180,14 +177,6 @@ int FindBestPitchPeriods24kHz(
|
||||
best_denominator = denominator;
|
||||
}
|
||||
}
|
||||
// Update |denominator| for the next inverted lag.
|
||||
static_assert(kInitialNumLags24kHz + kFrameSize20ms24kHz < kBufSize24kHz,
|
||||
"");
|
||||
const float y_old = pitch_buffer[inverted_lag];
|
||||
const float y_new = pitch_buffer[inverted_lag + kFrameSize20ms24kHz];
|
||||
denominator -= y_old * y_old;
|
||||
denominator += y_new * y_new;
|
||||
denominator = std::max(0.f, denominator);
|
||||
}
|
||||
return best_inverted_lag;
|
||||
}
|
||||
@ -338,6 +327,7 @@ CandidatePitchPeriods ComputePitchPeriod12kHz(
|
||||
|
||||
int ComputePitchPeriod48kHz(
|
||||
rtc::ArrayView<const float, kBufSize24kHz> pitch_buffer,
|
||||
rtc::ArrayView<const float, kRefineNumLags24kHz> y_energy,
|
||||
CandidatePitchPeriods pitch_candidates) {
|
||||
// Compute the auto-correlation terms only for neighbors of the given pitch
|
||||
// candidates (similar to what is done in ComputePitchAutoCorrelation(), but
|
||||
@ -362,7 +352,7 @@ int ComputePitchPeriod48kHz(
|
||||
}
|
||||
// Find best pitch at 24 kHz.
|
||||
const int pitch_candidate_24kHz =
|
||||
FindBestPitchPeriods24kHz(auto_correlation, pitch_buffer);
|
||||
ComputePitchPeriod24kHz(pitch_buffer, auto_correlation, y_energy);
|
||||
// Pseudo-interpolation.
|
||||
return PitchPseudoInterpolationInvLagAutoCorr(pitch_candidate_24kHz,
|
||||
auto_correlation);
|
||||
@ -370,6 +360,7 @@ int ComputePitchPeriod48kHz(
|
||||
|
||||
PitchInfo ComputeExtendedPitchPeriod48kHz(
|
||||
rtc::ArrayView<const float, kBufSize24kHz> pitch_buffer,
|
||||
rtc::ArrayView<const float, kRefineNumLags24kHz> y_energy,
|
||||
int initial_pitch_period_48kHz,
|
||||
PitchInfo last_pitch_48kHz) {
|
||||
RTC_DCHECK_LE(kMinPitch48kHz, initial_pitch_period_48kHz);
|
||||
@ -379,34 +370,30 @@ PitchInfo ComputeExtendedPitchPeriod48kHz(
|
||||
struct RefinedPitchCandidate {
|
||||
int period;
|
||||
float strength;
|
||||
// Additional strength data used for the final estimation of the strength.
|
||||
float xy; // Cross-correlation.
|
||||
float yy; // Auto-correlation.
|
||||
// Additional strength data used for the final pitch estimation.
|
||||
float xy; // Auto-correlation.
|
||||
float y_energy; // Energy of the sliding frame `y`.
|
||||
};
|
||||
|
||||
// Initialize.
|
||||
std::array<float, kRefineNumLags24kHz> yy_values;
|
||||
// TODO(bugs.webrtc.org/9076): Reuse values from FindBestPitchPeriods24kHz().
|
||||
ComputeSlidingFrameSquareEnergies24kHz(pitch_buffer, yy_values);
|
||||
const float xx = yy_values[0];
|
||||
const auto pitch_strength = [](float xy, float yy, float xx) {
|
||||
RTC_DCHECK_GE(xx * yy, 0.f);
|
||||
return xy / std::sqrt(1.f + xx * yy);
|
||||
const float x_energy = y_energy[0];
|
||||
const auto pitch_strength = [x_energy](float xy, float y_energy) {
|
||||
RTC_DCHECK_GE(x_energy * y_energy, 0.f);
|
||||
return xy / std::sqrt(1.f + x_energy * y_energy);
|
||||
};
|
||||
// Initial pitch candidate.
|
||||
|
||||
// Initialize the best pitch candidate with `initial_pitch_period_48kHz`.
|
||||
RefinedPitchCandidate best_pitch;
|
||||
best_pitch.period =
|
||||
std::min(initial_pitch_period_48kHz / 2, kMaxPitch24kHz - 1);
|
||||
best_pitch.xy =
|
||||
ComputeAutoCorrelation(kMaxPitch24kHz - best_pitch.period, pitch_buffer);
|
||||
best_pitch.yy = yy_values[best_pitch.period];
|
||||
best_pitch.strength = pitch_strength(best_pitch.xy, best_pitch.yy, xx);
|
||||
|
||||
// 24 kHz version of the last estimated pitch and copy of the initial
|
||||
// estimation.
|
||||
best_pitch.y_energy = y_energy[best_pitch.period];
|
||||
best_pitch.strength = pitch_strength(best_pitch.xy, best_pitch.y_energy);
|
||||
// Keep a copy of the initial pitch candidate.
|
||||
const PitchInfo initial_pitch{best_pitch.period, best_pitch.strength};
|
||||
// 24 kHz version of the last estimated pitch.
|
||||
const PitchInfo last_pitch{last_pitch_48kHz.period / 2,
|
||||
last_pitch_48kHz.strength};
|
||||
const PitchInfo initial_pitch{best_pitch.period, best_pitch.strength};
|
||||
|
||||
// Find `max_period_divisor` such that the result of
|
||||
// `GetAlternativePitchPeriod(initial_pitch_period, 1, max_period_divisor)`
|
||||
@ -436,14 +423,14 @@ PitchInfo ComputeExtendedPitchPeriod48kHz(
|
||||
// Compute an auto-correlation score for the primary pitch candidate
|
||||
// |alternative_pitch.period| by also looking at its possible sub-harmonic
|
||||
// |dual_alternative_period|.
|
||||
float xy_primary_period = ComputeAutoCorrelation(
|
||||
const float xy_primary_period = ComputeAutoCorrelation(
|
||||
kMaxPitch24kHz - alternative_pitch.period, pitch_buffer);
|
||||
float xy_secondary_period = ComputeAutoCorrelation(
|
||||
const float xy_secondary_period = ComputeAutoCorrelation(
|
||||
kMaxPitch24kHz - dual_alternative_period, pitch_buffer);
|
||||
float xy = 0.5f * (xy_primary_period + xy_secondary_period);
|
||||
float yy = 0.5f * (yy_values[alternative_pitch.period] +
|
||||
yy_values[dual_alternative_period]);
|
||||
alternative_pitch.strength = pitch_strength(xy, yy, xx);
|
||||
const float xy = 0.5f * (xy_primary_period + xy_secondary_period);
|
||||
const float yy = 0.5f * (y_energy[alternative_pitch.period] +
|
||||
y_energy[dual_alternative_period]);
|
||||
alternative_pitch.strength = pitch_strength(xy, yy);
|
||||
|
||||
// Maybe update best period.
|
||||
if (IsAlternativePitchStrongerThanInitial(
|
||||
@ -455,10 +442,11 @@ PitchInfo ComputeExtendedPitchPeriod48kHz(
|
||||
|
||||
// Final pitch strength and period.
|
||||
best_pitch.xy = std::max(0.f, best_pitch.xy);
|
||||
RTC_DCHECK_LE(0.f, best_pitch.yy);
|
||||
float final_pitch_strength = (best_pitch.yy <= best_pitch.xy)
|
||||
? 1.f
|
||||
: best_pitch.xy / (best_pitch.yy + 1.f);
|
||||
RTC_DCHECK_LE(0.f, best_pitch.y_energy);
|
||||
float final_pitch_strength =
|
||||
(best_pitch.y_energy <= best_pitch.xy)
|
||||
? 1.f
|
||||
: best_pitch.xy / (best_pitch.y_energy + 1.f);
|
||||
final_pitch_strength = std::min(best_pitch.strength, final_pitch_strength);
|
||||
int final_pitch_period_48kHz = std::max(
|
||||
kMinPitch48kHz,
|
||||
|
@ -80,10 +80,12 @@ CandidatePitchPeriods ComputePitchPeriod12kHz(
|
||||
rtc::ArrayView<const float, kBufSize12kHz> pitch_buffer,
|
||||
rtc::ArrayView<const float, kNumLags12kHz> auto_correlation);
|
||||
|
||||
// Computes the pitch period at 48 kHz given a view on the 24 kHz pitch buffer
|
||||
// and the pitch period candidates at 24 kHz (encoded as inverted lag).
|
||||
// Computes the pitch period at 48 kHz given a view on the 24 kHz pitch buffer,
|
||||
// the energies for the sliding frames `y` at 24 kHz and the pitch period
|
||||
// candidates at 24 kHz (encoded as inverted lag).
|
||||
int ComputePitchPeriod48kHz(
|
||||
rtc::ArrayView<const float, kBufSize24kHz> pitch_buffer,
|
||||
rtc::ArrayView<const float, kRefineNumLags24kHz> y_energy,
|
||||
CandidatePitchPeriods pitch_candidates_24kHz);
|
||||
|
||||
struct PitchInfo {
|
||||
@ -92,10 +94,12 @@ struct PitchInfo {
|
||||
};
|
||||
|
||||
// Computes the pitch period at 48 kHz searching in an extended pitch range
|
||||
// given a view on the 24 kHz pitch buffer, the initial 48 kHz estimation
|
||||
// (computed by `ComputePitchPeriod48kHz()`) and the last estimated pitch.
|
||||
// given a view on the 24 kHz pitch buffer, the energies for the sliding frames
|
||||
// `y` at 24 kHz, the initial 48 kHz estimation (computed by
|
||||
// `ComputePitchPeriod48kHz()`) and the last estimated pitch.
|
||||
PitchInfo ComputeExtendedPitchPeriod48kHz(
|
||||
rtc::ArrayView<const float, kBufSize24kHz> pitch_buffer,
|
||||
rtc::ArrayView<const float, kRefineNumLags24kHz> y_energy,
|
||||
int initial_pitch_period_48kHz,
|
||||
PitchInfo last_pitch_48kHz);
|
||||
|
||||
|
@ -63,12 +63,17 @@ TEST(RnnVadTest, ComputePitchPeriod12kHzBitExactness) {
|
||||
// Checks that the refined pitch period is bit-exact given test input data.
|
||||
TEST(RnnVadTest, ComputePitchPeriod48kHzBitExactness) {
|
||||
PitchTestData test_data;
|
||||
std::vector<float> y_energy(kMaxPitch24kHz + 1);
|
||||
rtc::ArrayView<float, kMaxPitch24kHz + 1> y_energy_view(y_energy.data(),
|
||||
kMaxPitch24kHz + 1);
|
||||
ComputeSlidingFrameSquareEnergies24kHz(test_data.GetPitchBufView(),
|
||||
y_energy_view);
|
||||
// TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
|
||||
// FloatingPointExceptionObserver fpe_observer;
|
||||
EXPECT_EQ(ComputePitchPeriod48kHz(test_data.GetPitchBufView(),
|
||||
EXPECT_EQ(ComputePitchPeriod48kHz(test_data.GetPitchBufView(), y_energy_view,
|
||||
/*pitch_candidates=*/{280, 284}),
|
||||
560);
|
||||
EXPECT_EQ(ComputePitchPeriod48kHz(test_data.GetPitchBufView(),
|
||||
EXPECT_EQ(ComputePitchPeriod48kHz(test_data.GetPitchBufView(), y_energy_view,
|
||||
/*pitch_candidates=*/{260, 284}),
|
||||
568);
|
||||
}
|
||||
@ -90,10 +95,15 @@ class ComputeExtendedPitchPeriod48kHzTest
|
||||
TEST_P(ComputeExtendedPitchPeriod48kHzTest,
|
||||
PeriodBitExactnessGainWithinTolerance) {
|
||||
PitchTestData test_data;
|
||||
std::vector<float> y_energy(kMaxPitch24kHz + 1);
|
||||
rtc::ArrayView<float, kMaxPitch24kHz + 1> y_energy_view(y_energy.data(),
|
||||
kMaxPitch24kHz + 1);
|
||||
ComputeSlidingFrameSquareEnergies24kHz(test_data.GetPitchBufView(),
|
||||
y_energy_view);
|
||||
// TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
|
||||
// FloatingPointExceptionObserver fpe_observer;
|
||||
const auto computed_output = ComputeExtendedPitchPeriod48kHz(
|
||||
test_data.GetPitchBufView(), GetInitialPitchPeriod(),
|
||||
test_data.GetPitchBufView(), y_energy_view, GetInitialPitchPeriod(),
|
||||
{GetLastPitchPeriod(), GetLastPitchStrength()});
|
||||
EXPECT_EQ(GetExpectedPitchPeriod(), computed_output.period);
|
||||
EXPECT_NEAR(GetExpectedPitchStrength(), computed_output.strength, 1e-6f);
|
||||
|
Reference in New Issue
Block a user