Reland "RNN VAD: pitch search optimizations (part 1)"

This reverts commit 1b6b958a4aa574b7852fe62efe5d4f96ce085d8b.

Reason for revert: Bug fix

Original change's description:
> Revert "RNN VAD: pitch search optimizations (part 1)"
>
> This reverts commit 9da3e177fd5c2236cc15fea0ee8933e1dd0d8f6d.
>
> Reason for revert: bug in ComputePitchPeriod48kHz()
>
> Original change's description:
> > RNN VAD: pitch search optimizations (part 1)
> >
> > TL;DR this CL improves efficiency and includes several code
> > readability improvements mainly triggered by the comments to
> > patch set #10.
> >
> > Highlights:
> > - Split `FindBestPitchPeriods()` into 12 and 24 kHz versions
> >   to hard-code the input size and simplify the 24 kHz version
> > - Loop in `ComputePitchPeriod48kHz()` (new name for
> >   `RefinePitchPeriod48kHz()`) removed since the lags for which
> >   we need to compute the auto correlation are a few
> > - `ComputePitchGainThreshold()` was only used in unit tests; it's been
> >   moved into the anon ns and the test removed
> >
> > This CL makes `ComputePitchPeriod48kHz()` is about 10% faster (measured
> > with https://webrtc-review.googlesource.com/c/src/+/191320/4/modules/audio_processing/agc2/rnn_vad/pitch_search_internal_unittest.cc).
> > The realtime factor has improved by about +14%.
> >
> > Benchmarked as follows:
> > ```
> > out/release/modules_unittests \
> >   --gtest_filter=*RnnVadTest.DISABLED_RnnVadPerformance* \
> >   --gtest_also_run_disabled_tests --logs
> > ```
> >
> > Results:
> >
> >       | baseline             | this CL
> > ------+----------------------+------------------------
> > run 1 | 24.0231 +/- 0.591016 | 23.568 +/- 0.990788
> >       | 370.06x              | 377.207x
> > ------+----------------------+------------------------
> > run 2 | 24.0485 +/- 0.957498 | 23.3714 +/- 0.857523
> >       | 369.67x              | 380.379x
> > ------+----------------------+------------------------
> > run 2 | 25.4091 +/- 2.6123   | 23.709 +/- 1.04477
> >       | 349.875x             | 374.963x
> >
> > Bug: webrtc:10480
> > Change-Id: I9a3e9164b2442114b928de506c92a547c273882f
> > Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/191320
> > Reviewed-by: Per Åhgren <peah@webrtc.org>
> > Commit-Queue: Alessio Bazzica <alessiob@webrtc.org>
> > Cr-Commit-Position: refs/heads/master@{#32568}
>
> TBR=alessiob@webrtc.org,peah@webrtc.org
>
> No-Presubmit: true
> No-Tree-Checks: true
> No-Try: true
> Bug: webrtc:10480
> Change-Id: I2a91f4f29566f872a7dfa220b31c6c625ed075db
> Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/192660
> Commit-Queue: Alessio Bazzica <alessiob@webrtc.org>
> Reviewed-by: Alessio Bazzica <alessiob@webrtc.org>
> Cr-Commit-Position: refs/heads/master@{#32581}

TBR=alessiob@webrtc.org,peah@webrtc.org

# Not skipping CQ checks because this is a reland.

Bug: webrtc:10480
Change-Id: I66e3e8d73ebc04a437c01a0396cd5613c42a8cf5
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/192780
Reviewed-by: Alessio Bazzica <alessiob@webrtc.org>
Reviewed-by: Per Åhgren <peah@webrtc.org>
Commit-Queue: Alessio Bazzica <alessiob@webrtc.org>
Cr-Commit-Position: refs/heads/master@{#32585}
This commit is contained in:
Alessio Bazzica
2020-11-11 12:06:09 +01:00
committed by Commit Bot
parent 01a36f32e7
commit c36f8623c0
13 changed files with 487 additions and 450 deletions

View File

@ -83,7 +83,6 @@ rtc_library("rnn_vad_lp_residual") {
rtc_library("rnn_vad_pitch") {
sources = [
"pitch_info.h",
"pitch_search.cc",
"pitch_search.h",
"pitch_search_internal.cc",
@ -94,6 +93,7 @@ rtc_library("rnn_vad_pitch") {
":rnn_vad_common",
"../../../../api:array_view",
"../../../../rtc_base:checks",
"../../../../rtc_base:gtest_prod",
"../../../../rtc_base:safe_compare",
"../../../../rtc_base:safe_conversions",
]

View File

@ -20,7 +20,7 @@ namespace {
constexpr int kAutoCorrelationFftOrder = 9; // Length-512 FFT.
static_assert(1 << kAutoCorrelationFftOrder >
kNumInvertedLags12kHz + kBufSize12kHz - kMaxPitch12kHz,
kNumLags12kHz + kBufSize12kHz - kMaxPitch12kHz,
"");
} // namespace
@ -45,7 +45,7 @@ AutoCorrelationCalculator::~AutoCorrelationCalculator() = default;
// pitch period.
void AutoCorrelationCalculator::ComputeOnPitchBuffer(
rtc::ArrayView<const float, kBufSize12kHz> pitch_buf,
rtc::ArrayView<float, kNumInvertedLags12kHz> auto_corr) {
rtc::ArrayView<float, kNumLags12kHz> auto_corr) {
RTC_DCHECK_LT(auto_corr.size(), kMaxPitch12kHz);
RTC_DCHECK_GT(pitch_buf.size(), kMaxPitch12kHz);
constexpr int kFftFrameSize = 1 << kAutoCorrelationFftOrder;
@ -53,7 +53,7 @@ void AutoCorrelationCalculator::ComputeOnPitchBuffer(
static_assert(kConvolutionLength == kFrameSize20ms12kHz,
"Mismatch between pitch buffer size, frame size and maximum "
"pitch period.");
static_assert(kFftFrameSize > kNumInvertedLags12kHz + kConvolutionLength,
static_assert(kFftFrameSize > kNumLags12kHz + kConvolutionLength,
"The FFT length is not sufficiently big to avoid cyclic "
"convolution errors.");
auto tmp = tmp_->GetView();
@ -67,13 +67,12 @@ void AutoCorrelationCalculator::ComputeOnPitchBuffer(
// Compute the FFT for the sliding frames chunk. The sliding frames are
// defined as pitch_buf[i:i+kConvolutionLength] where i in
// [0, kNumInvertedLags12kHz). The chunk includes all of them, hence it is
// defined as pitch_buf[:kNumInvertedLags12kHz+kConvolutionLength].
// [0, kNumLags12kHz). The chunk includes all of them, hence it is
// defined as pitch_buf[:kNumLags12kHz+kConvolutionLength].
std::copy(pitch_buf.begin(),
pitch_buf.begin() + kConvolutionLength + kNumInvertedLags12kHz,
pitch_buf.begin() + kConvolutionLength + kNumLags12kHz,
tmp.begin());
std::fill(tmp.begin() + kNumInvertedLags12kHz + kConvolutionLength, tmp.end(),
0.f);
std::fill(tmp.begin() + kNumLags12kHz + kConvolutionLength, tmp.end(), 0.f);
fft_.ForwardTransform(*tmp_, X_.get(), /*ordered=*/false);
// Convolve in the frequency domain.
@ -84,7 +83,7 @@ void AutoCorrelationCalculator::ComputeOnPitchBuffer(
// Extract the auto-correlation coefficients.
std::copy(tmp.begin() + kConvolutionLength - 1,
tmp.begin() + kConvolutionLength + kNumInvertedLags12kHz - 1,
tmp.begin() + kConvolutionLength + kNumLags12kHz - 1,
auto_corr.begin());
}

View File

@ -34,7 +34,7 @@ class AutoCorrelationCalculator {
// |auto_corr| indexes are inverted lags.
void ComputeOnPitchBuffer(
rtc::ArrayView<const float, kBufSize12kHz> pitch_buf,
rtc::ArrayView<float, kNumInvertedLags12kHz> auto_corr);
rtc::ArrayView<float, kNumLags12kHz> auto_corr);
private:
Pffft fft_;

View File

@ -36,7 +36,13 @@ constexpr int kInitialMinPitch24kHz = 3 * kMinPitch24kHz;
static_assert(kMinPitch24kHz < kInitialMinPitch24kHz, "");
static_assert(kInitialMinPitch24kHz < kMaxPitch24kHz, "");
static_assert(kMaxPitch24kHz > kInitialMinPitch24kHz, "");
constexpr int kNumInvertedLags24kHz = kMaxPitch24kHz - kInitialMinPitch24kHz;
// Number of (inverted) lags during the initial pitch search phase at 24 kHz.
constexpr int kInitialNumLags24kHz = kMaxPitch24kHz - kInitialMinPitch24kHz;
// Number of (inverted) lags during the pitch search refinement phase at 24 kHz.
constexpr int kRefineNumLags24kHz = kMaxPitch24kHz + 1;
static_assert(
kRefineNumLags24kHz > kInitialNumLags24kHz,
"The refinement step must search the pitch in an extended pitch range.");
// 12 kHz analysis.
constexpr int kSampleRate12kHz = 12000;
@ -47,8 +53,8 @@ constexpr int kInitialMinPitch12kHz = kInitialMinPitch24kHz / 2;
constexpr int kMaxPitch12kHz = kMaxPitch24kHz / 2;
static_assert(kMaxPitch12kHz > kInitialMinPitch12kHz, "");
// The inverted lags for the pitch interval [|kInitialMinPitch12kHz|,
// |kMaxPitch12kHz|] are in the range [0, |kNumInvertedLags12kHz|].
constexpr int kNumInvertedLags12kHz = kMaxPitch12kHz - kInitialMinPitch12kHz;
// |kMaxPitch12kHz|] are in the range [0, |kNumLags12kHz|].
constexpr int kNumLags12kHz = kMaxPitch12kHz - kInitialMinPitch12kHz;
// 48 kHz constants.
constexpr int kMinPitch48kHz = kMinPitch24kHz * 2;

View File

@ -67,13 +67,12 @@ bool FeaturesExtractor::CheckSilenceComputeFeatures(
ComputeLpResidual(lpc_coeffs, pitch_buf_24kHz_view_, lp_residual_view_);
// Estimate pitch on the LP-residual and write the normalized pitch period
// into the output vector (normalization based on training data stats).
pitch_info_48kHz_ = pitch_estimator_.Estimate(lp_residual_view_);
feature_vector[kFeatureVectorSize - 2] =
0.01f * (pitch_info_48kHz_.period - 300);
pitch_period_48kHz_ = pitch_estimator_.Estimate(lp_residual_view_);
feature_vector[kFeatureVectorSize - 2] = 0.01f * (pitch_period_48kHz_ - 300);
// Extract lagged frames (according to the estimated pitch period).
RTC_DCHECK_LE(pitch_info_48kHz_.period / 2, kMaxPitch24kHz);
RTC_DCHECK_LE(pitch_period_48kHz_ / 2, kMaxPitch24kHz);
auto lagged_frame = pitch_buf_24kHz_view_.subview(
kMaxPitch24kHz - pitch_info_48kHz_.period / 2, kFrameSize20ms24kHz);
kMaxPitch24kHz - pitch_period_48kHz_ / 2, kFrameSize20ms24kHz);
// Analyze reference and lagged frames checking if silence has been detected
// and write the feature vector.
return spectral_features_extractor_.CheckSilenceComputeFeatures(

View File

@ -16,7 +16,6 @@
#include "api/array_view.h"
#include "modules/audio_processing/agc2/biquad_filter.h"
#include "modules/audio_processing/agc2/rnn_vad/common.h"
#include "modules/audio_processing/agc2/rnn_vad/pitch_info.h"
#include "modules/audio_processing/agc2/rnn_vad/pitch_search.h"
#include "modules/audio_processing/agc2/rnn_vad/sequence_buffer.h"
#include "modules/audio_processing/agc2/rnn_vad/spectral_features.h"
@ -53,7 +52,7 @@ class FeaturesExtractor {
PitchEstimator pitch_estimator_;
rtc::ArrayView<const float, kFrameSize20ms24kHz> reference_frame_view_;
SpectralFeaturesExtractor spectral_features_extractor_;
PitchInfo pitch_info_48kHz_;
int pitch_period_48kHz_;
};
} // namespace rnn_vad

View File

@ -1,29 +0,0 @@
/*
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#ifndef MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_PITCH_INFO_H_
#define MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_PITCH_INFO_H_
namespace webrtc {
namespace rnn_vad {
// Stores pitch period and gain information. The pitch gain measures the
// strength of the pitch (the higher, the stronger).
struct PitchInfo {
PitchInfo() : period(0), gain(0.f) {}
PitchInfo(int p, float g) : period(p), gain(g) {}
int period;
float gain;
};
} // namespace rnn_vad
} // namespace webrtc
#endif // MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_PITCH_INFO_H_

View File

@ -21,22 +21,22 @@ namespace rnn_vad {
PitchEstimator::PitchEstimator()
: pitch_buf_decimated_(kBufSize12kHz),
pitch_buf_decimated_view_(pitch_buf_decimated_.data(), kBufSize12kHz),
auto_corr_(kNumInvertedLags12kHz),
auto_corr_view_(auto_corr_.data(), kNumInvertedLags12kHz) {
auto_corr_(kNumLags12kHz),
auto_corr_view_(auto_corr_.data(), kNumLags12kHz) {
RTC_DCHECK_EQ(kBufSize12kHz, pitch_buf_decimated_.size());
RTC_DCHECK_EQ(kNumInvertedLags12kHz, auto_corr_view_.size());
RTC_DCHECK_EQ(kNumLags12kHz, auto_corr_view_.size());
}
PitchEstimator::~PitchEstimator() = default;
PitchInfo PitchEstimator::Estimate(
rtc::ArrayView<const float, kBufSize24kHz> pitch_buf) {
int PitchEstimator::Estimate(
rtc::ArrayView<const float, kBufSize24kHz> pitch_buffer) {
// Perform the initial pitch search at 12 kHz.
Decimate2x(pitch_buf, pitch_buf_decimated_view_);
Decimate2x(pitch_buffer, pitch_buf_decimated_view_);
auto_corr_calculator_.ComputeOnPitchBuffer(pitch_buf_decimated_view_,
auto_corr_view_);
CandidatePitchPeriods pitch_candidates_inverted_lags = FindBestPitchPeriods(
auto_corr_view_, pitch_buf_decimated_view_, kMaxPitch12kHz);
CandidatePitchPeriods pitch_candidates_inverted_lags =
ComputePitchPeriod12kHz(pitch_buf_decimated_view_, auto_corr_view_);
// Refine the pitch period estimation.
// The refinement is done using the pitch buffer that contains 24 kHz samples.
// Therefore, adapt the inverted lags in |pitch_candidates_inv_lags| from 12
@ -44,12 +44,14 @@ PitchInfo PitchEstimator::Estimate(
pitch_candidates_inverted_lags.best *= 2;
pitch_candidates_inverted_lags.second_best *= 2;
const int pitch_inv_lag_48kHz =
RefinePitchPeriod48kHz(pitch_buf, pitch_candidates_inverted_lags);
ComputePitchPeriod48kHz(pitch_buffer, pitch_candidates_inverted_lags);
// Look for stronger harmonics to find the final pitch period and its gain.
RTC_DCHECK_LT(pitch_inv_lag_48kHz, kMaxPitch48kHz);
last_pitch_48kHz_ = CheckLowerPitchPeriodsAndComputePitchGain(
pitch_buf, kMaxPitch48kHz - pitch_inv_lag_48kHz, last_pitch_48kHz_);
return last_pitch_48kHz_;
last_pitch_48kHz_ = ComputeExtendedPitchPeriod48kHz(
pitch_buffer,
/*initial_pitch_period_48kHz=*/kMaxPitch48kHz - pitch_inv_lag_48kHz,
last_pitch_48kHz_);
return last_pitch_48kHz_.period;
}
} // namespace rnn_vad

View File

@ -17,8 +17,8 @@
#include "api/array_view.h"
#include "modules/audio_processing/agc2/rnn_vad/auto_correlation.h"
#include "modules/audio_processing/agc2/rnn_vad/common.h"
#include "modules/audio_processing/agc2/rnn_vad/pitch_info.h"
#include "modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h"
#include "rtc_base/gtest_prod_util.h"
namespace webrtc {
namespace rnn_vad {
@ -30,17 +30,21 @@ class PitchEstimator {
PitchEstimator(const PitchEstimator&) = delete;
PitchEstimator& operator=(const PitchEstimator&) = delete;
~PitchEstimator();
// Estimates the pitch period and gain. Returns the pitch estimation data for
// 48 kHz.
PitchInfo Estimate(rtc::ArrayView<const float, kBufSize24kHz> pitch_buf);
// Returns the estimated pitch period at 48 kHz.
int Estimate(rtc::ArrayView<const float, kBufSize24kHz> pitch_buffer);
private:
PitchInfo last_pitch_48kHz_;
FRIEND_TEST_ALL_PREFIXES(RnnVadTest, PitchSearchWithinTolerance);
float GetLastPitchStrengthForTesting() const {
return last_pitch_48kHz_.strength;
}
PitchInfo last_pitch_48kHz_{};
AutoCorrelationCalculator auto_corr_calculator_;
std::vector<float> pitch_buf_decimated_;
rtc::ArrayView<float, kBufSize12kHz> pitch_buf_decimated_view_;
std::vector<float> auto_corr_;
rtc::ArrayView<float, kNumInvertedLags12kHz> auto_corr_view_;
rtc::ArrayView<float, kNumLags12kHz> auto_corr_view_;
};
} // namespace rnn_vad

View File

@ -26,94 +26,88 @@ namespace webrtc {
namespace rnn_vad {
namespace {
// Converts a lag to an inverted lag (only for 24kHz).
int GetInvertedLag(int lag) {
RTC_DCHECK_LE(lag, kMaxPitch24kHz);
return kMaxPitch24kHz - lag;
}
float ComputeAutoCorrelationCoeff(rtc::ArrayView<const float> pitch_buf,
int inv_lag,
int max_pitch_period) {
RTC_DCHECK_LT(inv_lag, pitch_buf.size());
RTC_DCHECK_LT(max_pitch_period, pitch_buf.size());
RTC_DCHECK_LE(inv_lag, max_pitch_period);
float ComputeAutoCorrelation(
int inverted_lag,
rtc::ArrayView<const float, kBufSize24kHz> pitch_buffer) {
RTC_DCHECK_LT(inverted_lag, kBufSize24kHz);
RTC_DCHECK_LT(inverted_lag, kRefineNumLags24kHz);
static_assert(kMaxPitch24kHz < kBufSize24kHz, "");
// TODO(bugs.webrtc.org/9076): Maybe optimize using vectorization.
return std::inner_product(pitch_buf.begin() + max_pitch_period,
pitch_buf.end(), pitch_buf.begin() + inv_lag, 0.f);
return std::inner_product(pitch_buffer.begin() + kMaxPitch24kHz,
pitch_buffer.end(),
pitch_buffer.begin() + inverted_lag, 0.f);
}
// Given the auto-correlation coefficients for a lag and its neighbors, computes
// a pseudo-interpolation offset to be applied to the pitch period associated to
// the central auto-correlation coefficient |lag_auto_corr|. The output is a lag
// in {-1, 0, +1}.
// TODO(bugs.webrtc.org/9076): Consider removing pseudo-i since it
// is relevant only if the spectral analysis works at a sample rate that is
// twice as that of the pitch buffer (not so important instead for the estimated
// pitch period feature fed into the RNN).
int GetPitchPseudoInterpolationOffset(float prev_auto_corr,
float lag_auto_corr,
float next_auto_corr) {
const float& a = prev_auto_corr;
const float& b = lag_auto_corr;
const float& c = next_auto_corr;
int offset = 0;
if ((c - a) > 0.7f * (b - a)) {
offset = 1; // |c| is the largest auto-correlation coefficient.
} else if ((a - c) > 0.7f * (b - c)) {
offset = -1; // |a| is the largest auto-correlation coefficient.
// Given an auto-correlation coefficient `curr_auto_correlation` and its
// neighboring values `prev_auto_correlation` and `next_auto_correlation`
// computes a pseudo-interpolation offset to be applied to the pitch period
// associated to `curr`. The output is a lag in {-1, 0, +1}.
// TODO(bugs.webrtc.org/9076): Consider removing this method.
// `GetPitchPseudoInterpolationOffset()` it is relevant only if the spectral
// analysis works at a sample rate that is twice as that of the pitch buffer;
// In particular, it is not relevant for the estimated pitch period feature fed
// into the RNN.
int GetPitchPseudoInterpolationOffset(float prev_auto_correlation,
float curr_auto_correlation,
float next_auto_correlation) {
if ((next_auto_correlation - prev_auto_correlation) >
0.7f * (curr_auto_correlation - prev_auto_correlation)) {
return 1; // |next_auto_correlation| is the largest auto-correlation
// coefficient.
} else if ((prev_auto_correlation - next_auto_correlation) >
0.7f * (curr_auto_correlation - next_auto_correlation)) {
return -1; // |prev_auto_correlation| is the largest auto-correlation
// coefficient.
}
return offset;
return 0;
}
// Refines a pitch period |lag| encoded as lag with pseudo-interpolation. The
// output sample rate is twice as that of |lag|.
int PitchPseudoInterpolationLagPitchBuf(
int lag,
rtc::ArrayView<const float, kBufSize24kHz> pitch_buf) {
rtc::ArrayView<const float, kBufSize24kHz> pitch_buffer) {
int offset = 0;
// Cannot apply pseudo-interpolation at the boundaries.
if (lag > 0 && lag < kMaxPitch24kHz) {
const int inverted_lag = kMaxPitch24kHz - lag;
offset = GetPitchPseudoInterpolationOffset(
ComputeAutoCorrelationCoeff(pitch_buf, GetInvertedLag(lag - 1),
kMaxPitch24kHz),
ComputeAutoCorrelationCoeff(pitch_buf, GetInvertedLag(lag),
kMaxPitch24kHz),
ComputeAutoCorrelationCoeff(pitch_buf, GetInvertedLag(lag + 1),
kMaxPitch24kHz));
ComputeAutoCorrelation(inverted_lag + 1, pitch_buffer),
ComputeAutoCorrelation(inverted_lag, pitch_buffer),
ComputeAutoCorrelation(inverted_lag - 1, pitch_buffer));
}
return 2 * lag + offset;
}
// Refines a pitch period |inv_lag| encoded as inverted lag with
// Refines a pitch period |inverted_lag| encoded as inverted lag with
// pseudo-interpolation. The output sample rate is twice as that of
// |inv_lag|.
// |inverted_lag|.
int PitchPseudoInterpolationInvLagAutoCorr(
int inv_lag,
rtc::ArrayView<const float> auto_corr) {
int inverted_lag,
rtc::ArrayView<const float, kInitialNumLags24kHz> auto_correlation) {
int offset = 0;
// Cannot apply pseudo-interpolation at the boundaries.
if (inv_lag > 0 && inv_lag < rtc::dchecked_cast<int>(auto_corr.size()) - 1) {
if (inverted_lag > 0 && inverted_lag < kInitialNumLags24kHz - 1) {
offset = GetPitchPseudoInterpolationOffset(
auto_corr[inv_lag + 1], auto_corr[inv_lag], auto_corr[inv_lag - 1]);
auto_correlation[inverted_lag + 1], auto_correlation[inverted_lag],
auto_correlation[inverted_lag - 1]);
}
// TODO(bugs.webrtc.org/9076): When retraining, check if |offset| below should
// be subtracted since |inv_lag| is an inverted lag but offset is a lag.
return 2 * inv_lag + offset;
// be subtracted since |inverted_lag| is an inverted lag but offset is a lag.
return 2 * inverted_lag + offset;
}
// Integer multipliers used in CheckLowerPitchPeriodsAndComputePitchGain() when
// Integer multipliers used in ComputeExtendedPitchPeriod48kHz() when
// looking for sub-harmonics.
// The values have been chosen to serve the following algorithm. Given the
// initial pitch period T, we examine whether one of its harmonics is the true
// fundamental frequency. We consider T/k with k in {2, ..., 15}. For each of
// these harmonics, in addition to the pitch gain of itself, we choose one
// these harmonics, in addition to the pitch strength of itself, we choose one
// multiple of its pitch period, n*T/k, to validate it (by averaging their pitch
// gains). The multiplier n is chosen so that n*T/k is used only one time over
// all k. When for example k = 4, we should also expect a peak at 3*T/4. When
// k = 8 instead we don't want to look at 2*T/8, since we have already checked
// T/4 before. Instead, we look at T*3/8.
// strengths). The multiplier n is chosen so that n*T/k is used only one time
// over all k. When for example k = 4, we should also expect a peak at 3*T/4.
// When k = 8 instead we don't want to look at 2*T/8, since we have already
// checked T/4 before. Instead, we look at T*3/8.
// The array can be generate in Python as follows:
// from fractions import Fraction
// # Smallest positive integer not in X.
@ -130,92 +124,171 @@ int PitchPseudoInterpolationInvLagAutoCorr(
constexpr std::array<int, 14> kSubHarmonicMultipliers = {
{3, 2, 3, 2, 5, 2, 3, 2, 3, 2, 5, 2, 3, 2}};
// Initial pitch period candidate thresholds for ComputePitchGainThreshold() for
// a sample rate of 24 kHz. Computed as [5*k*k for k in range(16)].
constexpr std::array<int, 14> kInitialPitchPeriodThresholds = {
{20, 45, 80, 125, 180, 245, 320, 405, 500, 605, 720, 845, 980, 1125}};
struct Range {
int min;
int max;
};
// Creates a pitch period interval centered in `inverted_lag` with hard-coded
// radius. Clipping is applied so that the interval is always valid for a 24 kHz
// pitch buffer.
Range CreateInvertedLagRange(int inverted_lag) {
constexpr int kRadius = 2;
return {std::max(inverted_lag - kRadius, 0),
std::min(inverted_lag + kRadius, kInitialNumLags24kHz - 1)};
}
// Computes the auto correlation coefficients for the inverted lags in the
// closed interval `inverted_lags`.
void ComputeAutoCorrelation(
Range inverted_lags,
rtc::ArrayView<const float, kBufSize24kHz> pitch_buffer,
rtc::ArrayView<float, kInitialNumLags24kHz> auto_correlation) {
// Check valid range.
RTC_DCHECK_LE(inverted_lags.min, inverted_lags.max);
// Check valid `inverted_lag` indexes.
RTC_DCHECK_GE(inverted_lags.min, 0);
RTC_DCHECK_LT(inverted_lags.max, auto_correlation.size());
for (int inverted_lag = inverted_lags.min; inverted_lag <= inverted_lags.max;
++inverted_lag) {
auto_correlation[inverted_lag] =
ComputeAutoCorrelation(inverted_lag, pitch_buffer);
}
}
int FindBestPitchPeriods24kHz(
rtc::ArrayView<const float, kInitialNumLags24kHz> auto_correlation,
rtc::ArrayView<const float, kBufSize24kHz> pitch_buffer) {
static_assert(kMaxPitch24kHz > kInitialNumLags24kHz, "");
static_assert(kMaxPitch24kHz < kBufSize24kHz, "");
// Initialize the sliding 20 ms frame energy.
// TODO(bugs.webrtc.org/9076): Maybe optimize using vectorization.
float denominator = std::inner_product(
pitch_buffer.begin(), pitch_buffer.begin() + kFrameSize20ms24kHz + 1,
pitch_buffer.begin(), 1.f);
// Search best pitch by looking at the scaled auto-correlation.
int best_inverted_lag = 0; // Pitch period.
float best_numerator = -1.f; // Pitch strength numerator.
float best_denominator = 0.f; // Pitch strength denominator.
for (int inverted_lag = 0; inverted_lag < kInitialNumLags24kHz;
++inverted_lag) {
// A pitch candidate must have positive correlation.
if (auto_correlation[inverted_lag] > 0.f) {
const float numerator =
auto_correlation[inverted_lag] * auto_correlation[inverted_lag];
// Compare numerator/denominator ratios without using divisions.
if (numerator * best_denominator > best_numerator * denominator) {
best_inverted_lag = inverted_lag;
best_numerator = numerator;
best_denominator = denominator;
}
}
// Update |denominator| for the next inverted lag.
static_assert(kInitialNumLags24kHz + kFrameSize20ms24kHz < kBufSize24kHz,
"");
const float y_old = pitch_buffer[inverted_lag];
const float y_new = pitch_buffer[inverted_lag + kFrameSize20ms24kHz];
denominator -= y_old * y_old;
denominator += y_new * y_new;
denominator = std::max(0.f, denominator);
}
return best_inverted_lag;
}
// Returns an alternative pitch period for `pitch_period` given a `multiplier`
// and a `divisor` of the period.
constexpr int GetAlternativePitchPeriod(int pitch_period,
int multiplier,
int divisor) {
RTC_DCHECK_GT(divisor, 0);
// Same as `round(multiplier * pitch_period / divisor)`.
return (2 * multiplier * pitch_period + divisor) / (2 * divisor);
}
// Returns true if the alternative pitch period is stronger than the initial one
// given the last estimated pitch and the value of `period_divisor` used to
// compute the alternative pitch period via `GetAlternativePitchPeriod()`.
bool IsAlternativePitchStrongerThanInitial(PitchInfo last,
PitchInfo initial,
PitchInfo alternative,
int period_divisor) {
// Initial pitch period candidate thresholds for a sample rate of 24 kHz.
// Computed as [5*k*k for k in range(16)].
constexpr std::array<int, 14> kInitialPitchPeriodThresholds = {
{20, 45, 80, 125, 180, 245, 320, 405, 500, 605, 720, 845, 980, 1125}};
static_assert(
kInitialPitchPeriodThresholds.size() == kSubHarmonicMultipliers.size(),
"");
RTC_DCHECK_GE(last.period, 0);
RTC_DCHECK_GE(initial.period, 0);
RTC_DCHECK_GE(alternative.period, 0);
RTC_DCHECK_GE(period_divisor, 2);
// Compute a term that lowers the threshold when |alternative.period| is close
// to the last estimated period |last.period| - i.e., pitch tracking.
float lower_threshold_term = 0.f;
if (std::abs(alternative.period - last.period) <= 1) {
// The candidate pitch period is within 1 sample from the last one.
// Make the candidate at |alternative.period| very easy to be accepted.
lower_threshold_term = last.strength;
} else if (std::abs(alternative.period - last.period) == 2 &&
initial.period >
kInitialPitchPeriodThresholds[period_divisor - 2]) {
// The candidate pitch period is 2 samples far from the last one and the
// period |initial.period| (from which |alternative.period| has been
// derived) is greater than a threshold. Make |alternative.period| easy to
// be accepted.
lower_threshold_term = 0.5f * last.strength;
}
// Set the threshold based on the strength of the initial estimate
// |initial.period|. Also reduce the chance of false positives caused by a
// bias towards high frequencies (originating from short-term correlations).
float threshold =
std::max(0.3f, 0.7f * initial.strength - lower_threshold_term);
if (alternative.period < 3 * kMinPitch24kHz) {
// High frequency.
threshold = std::max(0.4f, 0.85f * initial.strength - lower_threshold_term);
} else if (alternative.period < 2 * kMinPitch24kHz) {
// Even higher frequency.
threshold = std::max(0.5f, 0.9f * initial.strength - lower_threshold_term);
}
return alternative.strength > threshold;
}
} // namespace
void Decimate2x(rtc::ArrayView<const float, kBufSize24kHz> src,
rtc::ArrayView<float, kBufSize12kHz> dst) {
// TODO(bugs.webrtc.org/9076): Consider adding anti-aliasing filter.
static_assert(2 * dst.size() == src.size(), "");
for (int i = 0; rtc::SafeLt(i, dst.size()); ++i) {
static_assert(2 * kBufSize12kHz == kBufSize24kHz, "");
for (int i = 0; i < kBufSize12kHz; ++i) {
dst[i] = src[2 * i];
}
}
float ComputePitchGainThreshold(int candidate_pitch_period,
int pitch_period_ratio,
int initial_pitch_period,
float initial_pitch_gain,
int prev_pitch_period,
float prev_pitch_gain) {
// Map arguments to more compact aliases.
const int& t1 = candidate_pitch_period;
const int& k = pitch_period_ratio;
const int& t0 = initial_pitch_period;
const float& g0 = initial_pitch_gain;
const int& t_prev = prev_pitch_period;
const float& g_prev = prev_pitch_gain;
// Validate input.
RTC_DCHECK_GE(t1, 0);
RTC_DCHECK_GE(k, 2);
RTC_DCHECK_GE(t0, 0);
RTC_DCHECK_GE(t_prev, 0);
// Compute a term that lowers the threshold when |t1| is close to the last
// estimated period |t_prev| - i.e., pitch tracking.
float lower_threshold_term = 0;
if (abs(t1 - t_prev) <= 1) {
// The candidate pitch period is within 1 sample from the previous one.
// Make the candidate at |t1| very easy to be accepted.
lower_threshold_term = g_prev;
} else if (abs(t1 - t_prev) == 2 &&
t0 > kInitialPitchPeriodThresholds[k - 2]) {
// The candidate pitch period is 2 samples far from the previous one and the
// period |t0| (from which |t1| has been derived) is greater than a
// threshold. Make |t1| easy to be accepted.
lower_threshold_term = 0.5f * g_prev;
}
// Set the threshold based on the gain of the initial estimate |t0|. Also
// reduce the chance of false positives caused by a bias towards high
// frequencies (originating from short-term correlations).
float threshold = std::max(0.3f, 0.7f * g0 - lower_threshold_term);
if (t1 < 3 * kMinPitch24kHz) {
// High frequency.
threshold = std::max(0.4f, 0.85f * g0 - lower_threshold_term);
} else if (t1 < 2 * kMinPitch24kHz) {
// Even higher frequency.
threshold = std::max(0.5f, 0.9f * g0 - lower_threshold_term);
}
return threshold;
}
void ComputeSlidingFrameSquareEnergies(
rtc::ArrayView<const float, kBufSize24kHz> pitch_buf,
rtc::ArrayView<float, kMaxPitch24kHz + 1> yy_values) {
float yy =
ComputeAutoCorrelationCoeff(pitch_buf, kMaxPitch24kHz, kMaxPitch24kHz);
void ComputeSlidingFrameSquareEnergies24kHz(
rtc::ArrayView<const float, kBufSize24kHz> pitch_buffer,
rtc::ArrayView<float, kRefineNumLags24kHz> yy_values) {
float yy = ComputeAutoCorrelation(kMaxPitch24kHz, pitch_buffer);
yy_values[0] = yy;
for (int i = 1; rtc::SafeLt(i, yy_values.size()); ++i) {
RTC_DCHECK_LE(i, kMaxPitch24kHz + kFrameSize20ms24kHz);
RTC_DCHECK_LE(i, kMaxPitch24kHz);
const float old_coeff = pitch_buf[kMaxPitch24kHz + kFrameSize20ms24kHz - i];
const float new_coeff = pitch_buf[kMaxPitch24kHz - i];
yy -= old_coeff * old_coeff;
yy += new_coeff * new_coeff;
static_assert(kMaxPitch24kHz - (kRefineNumLags24kHz - 1) >= 0, "");
static_assert(kMaxPitch24kHz - 1 + kFrameSize20ms24kHz < kBufSize24kHz, "");
for (int lag = 1; lag < kRefineNumLags24kHz; ++lag) {
const int inverted_lag = kMaxPitch24kHz - lag;
const float y_old = pitch_buffer[inverted_lag + kFrameSize20ms24kHz];
const float y_new = pitch_buffer[inverted_lag];
yy -= y_old * y_old;
yy += y_new * y_new;
yy = std::max(0.f, yy);
yy_values[i] = yy;
yy_values[lag] = yy;
}
}
CandidatePitchPeriods FindBestPitchPeriods(
rtc::ArrayView<const float> auto_corr,
rtc::ArrayView<const float> pitch_buf,
int max_pitch_period) {
CandidatePitchPeriods ComputePitchPeriod12kHz(
rtc::ArrayView<const float, kBufSize12kHz> pitch_buffer,
rtc::ArrayView<const float, kNumLags12kHz> auto_correlation) {
static_assert(kMaxPitch12kHz > kNumLags12kHz, "");
static_assert(kMaxPitch12kHz < kBufSize12kHz, "");
// Stores a pitch candidate period and strength information.
struct PitchCandidate {
// Pitch period encoded as inverted lag.
@ -231,28 +304,22 @@ CandidatePitchPeriods FindBestPitchPeriods(
}
};
RTC_DCHECK_GT(max_pitch_period, auto_corr.size());
RTC_DCHECK_LT(max_pitch_period, pitch_buf.size());
const int frame_size =
rtc::dchecked_cast<int>(pitch_buf.size()) - max_pitch_period;
RTC_DCHECK_GT(frame_size, 0);
// TODO(bugs.webrtc.org/9076): Maybe optimize using vectorization.
float yy =
std::inner_product(pitch_buf.begin(), pitch_buf.begin() + frame_size + 1,
pitch_buf.begin(), 1.f);
float denominator = std::inner_product(
pitch_buffer.begin(), pitch_buffer.begin() + kFrameSize20ms12kHz + 1,
pitch_buffer.begin(), 1.f);
// Search best and second best pitches by looking at the scaled
// auto-correlation.
PitchCandidate candidate;
PitchCandidate best;
PitchCandidate second_best;
second_best.period_inverted_lag = 1;
for (int inv_lag = 0; inv_lag < rtc::dchecked_cast<int>(auto_corr.size());
++inv_lag) {
for (int inverted_lag = 0; inverted_lag < kNumLags12kHz; ++inverted_lag) {
// A pitch candidate must have positive correlation.
if (auto_corr[inv_lag] > 0) {
candidate.period_inverted_lag = inv_lag;
candidate.strength_numerator = auto_corr[inv_lag] * auto_corr[inv_lag];
candidate.strength_denominator = yy;
if (auto_correlation[inverted_lag] > 0.f) {
PitchCandidate candidate{
inverted_lag,
auto_correlation[inverted_lag] * auto_correlation[inverted_lag],
denominator};
if (candidate.HasStrongerPitchThan(second_best)) {
if (candidate.HasStrongerPitchThan(best)) {
second_best = best;
@ -263,144 +330,148 @@ CandidatePitchPeriods FindBestPitchPeriods(
}
}
// Update |squared_energy_y| for the next inverted lag.
const float old_coeff = pitch_buf[inv_lag];
const float new_coeff = pitch_buf[inv_lag + frame_size];
yy -= old_coeff * old_coeff;
yy += new_coeff * new_coeff;
yy = std::max(0.f, yy);
const float y_old = pitch_buffer[inverted_lag];
const float y_new = pitch_buffer[inverted_lag + kFrameSize20ms12kHz];
denominator -= y_old * y_old;
denominator += y_new * y_new;
denominator = std::max(0.f, denominator);
}
return {best.period_inverted_lag, second_best.period_inverted_lag};
}
int RefinePitchPeriod48kHz(
rtc::ArrayView<const float, kBufSize24kHz> pitch_buf,
CandidatePitchPeriods pitch_candidates_inverted_lags) {
int ComputePitchPeriod48kHz(
rtc::ArrayView<const float, kBufSize24kHz> pitch_buffer,
CandidatePitchPeriods pitch_candidates) {
// Compute the auto-correlation terms only for neighbors of the given pitch
// candidates (similar to what is done in ComputePitchAutoCorrelation(), but
// for a few lag values).
std::array<float, kNumInvertedLags24kHz> auto_correlation;
auto_correlation.fill(
0.f); // Zeros become ignored lags in FindBestPitchPeriods().
auto is_neighbor = [](int i, int j) {
return ((i > j) ? (i - j) : (j - i)) <= 2;
};
// TODO(https://crbug.com/webrtc/10480): Optimize by removing the loop.
for (int inverted_lag = 0; rtc::SafeLt(inverted_lag, auto_correlation.size());
++inverted_lag) {
if (is_neighbor(inverted_lag, pitch_candidates_inverted_lags.best) ||
is_neighbor(inverted_lag, pitch_candidates_inverted_lags.second_best))
auto_correlation[inverted_lag] =
ComputeAutoCorrelationCoeff(pitch_buf, inverted_lag, kMaxPitch24kHz);
std::array<float, kInitialNumLags24kHz> auto_correlation{};
// Create two inverted lag ranges so that `r1` precedes `r2`.
const bool swap_candidates =
pitch_candidates.best > pitch_candidates.second_best;
const Range r1 = CreateInvertedLagRange(
swap_candidates ? pitch_candidates.second_best : pitch_candidates.best);
const Range r2 = CreateInvertedLagRange(
swap_candidates ? pitch_candidates.best : pitch_candidates.second_best);
// Check valid ranges.
RTC_DCHECK_LE(r1.min, r1.max);
RTC_DCHECK_LE(r2.min, r2.max);
// Check `r1` precedes `r2`.
RTC_DCHECK_LE(r1.min, r2.min);
RTC_DCHECK_LE(r1.max, r2.max);
if (r1.max + 1 >= r2.min) {
// Overlapping or adjacent ranges.
ComputeAutoCorrelation({r1.min, r2.max}, pitch_buffer, auto_correlation);
} else {
// Disjoint ranges.
ComputeAutoCorrelation(r1, pitch_buffer, auto_correlation);
ComputeAutoCorrelation(r2, pitch_buffer, auto_correlation);
}
// Find best pitch at 24 kHz.
const CandidatePitchPeriods pitch_candidates_24kHz =
FindBestPitchPeriods(auto_correlation, pitch_buf, kMaxPitch24kHz);
const int pitch_candidate_24kHz =
FindBestPitchPeriods24kHz(auto_correlation, pitch_buffer);
// Pseudo-interpolation.
return PitchPseudoInterpolationInvLagAutoCorr(pitch_candidates_24kHz.best,
return PitchPseudoInterpolationInvLagAutoCorr(pitch_candidate_24kHz,
auto_correlation);
}
PitchInfo CheckLowerPitchPeriodsAndComputePitchGain(
rtc::ArrayView<const float, kBufSize24kHz> pitch_buf,
PitchInfo ComputeExtendedPitchPeriod48kHz(
rtc::ArrayView<const float, kBufSize24kHz> pitch_buffer,
int initial_pitch_period_48kHz,
PitchInfo prev_pitch_48kHz) {
PitchInfo last_pitch_48kHz) {
RTC_DCHECK_LE(kMinPitch48kHz, initial_pitch_period_48kHz);
RTC_DCHECK_LE(initial_pitch_period_48kHz, kMaxPitch48kHz);
// Stores information for a refined pitch candidate.
struct RefinedPitchCandidate {
RefinedPitchCandidate() {}
RefinedPitchCandidate(int period_24kHz, float gain, float xy, float yy)
: period_24kHz(period_24kHz), gain(gain), xy(xy), yy(yy) {}
int period_24kHz;
// Pitch strength information.
float gain;
// Additional pitch strength information used for the final estimation of
// pitch gain.
int period;
float strength;
// Additional strength data used for the final estimation of the strength.
float xy; // Cross-correlation.
float yy; // Auto-correlation.
};
// Initialize.
std::array<float, kMaxPitch24kHz + 1> yy_values;
ComputeSlidingFrameSquareEnergies(pitch_buf,
{yy_values.data(), yy_values.size()});
std::array<float, kRefineNumLags24kHz> yy_values;
// TODO(bugs.webrtc.org/9076): Reuse values from FindBestPitchPeriods24kHz().
ComputeSlidingFrameSquareEnergies24kHz(pitch_buffer, yy_values);
const float xx = yy_values[0];
// Helper lambdas.
const auto pitch_gain = [](float xy, float yy, float xx) {
RTC_DCHECK_LE(0.f, xx * yy);
const auto pitch_strength = [](float xy, float yy, float xx) {
RTC_DCHECK_GE(xx * yy, 0.f);
return xy / std::sqrt(1.f + xx * yy);
};
// Initial pitch candidate gain.
// Initial pitch candidate.
RefinedPitchCandidate best_pitch;
best_pitch.period_24kHz =
best_pitch.period =
std::min(initial_pitch_period_48kHz / 2, kMaxPitch24kHz - 1);
best_pitch.xy = ComputeAutoCorrelationCoeff(
pitch_buf, GetInvertedLag(best_pitch.period_24kHz), kMaxPitch24kHz);
best_pitch.yy = yy_values[best_pitch.period_24kHz];
best_pitch.gain = pitch_gain(best_pitch.xy, best_pitch.yy, xx);
best_pitch.xy =
ComputeAutoCorrelation(kMaxPitch24kHz - best_pitch.period, pitch_buffer);
best_pitch.yy = yy_values[best_pitch.period];
best_pitch.strength = pitch_strength(best_pitch.xy, best_pitch.yy, xx);
// Store the initial pitch period information.
const int initial_pitch_period = best_pitch.period_24kHz;
const float initial_pitch_gain = best_pitch.gain;
// 24 kHz version of the last estimated pitch and copy of the initial
// estimation.
const PitchInfo last_pitch{last_pitch_48kHz.period / 2,
last_pitch_48kHz.strength};
const PitchInfo initial_pitch{best_pitch.period, best_pitch.strength};
// Given the initial pitch estimation, check lower periods (i.e., harmonics).
const auto alternative_period = [](int period, int k, int n) -> int {
RTC_DCHECK_GT(k, 0);
return (2 * n * period + k) / (2 * k); // Same as round(n*period/k).
};
// |max_k| such that alternative_period(initial_pitch_period, max_k, 1) equals
// kMinPitch24kHz.
const int max_k = (2 * initial_pitch_period) / (2 * kMinPitch24kHz - 1);
for (int k = 2; k <= max_k; ++k) {
int candidate_pitch_period = alternative_period(initial_pitch_period, k, 1);
RTC_DCHECK_GE(candidate_pitch_period, kMinPitch24kHz);
// When looking at |candidate_pitch_period|, we also look at one of its
// Find `max_period_divisor` such that the result of
// `GetAlternativePitchPeriod(initial_pitch_period, 1, max_period_divisor)`
// equals `kMinPitch24kHz`.
const int max_period_divisor =
(2 * initial_pitch.period) / (2 * kMinPitch24kHz - 1);
for (int period_divisor = 2; period_divisor <= max_period_divisor;
++period_divisor) {
PitchInfo alternative_pitch;
alternative_pitch.period = GetAlternativePitchPeriod(
initial_pitch.period, /*multiplier=*/1, period_divisor);
RTC_DCHECK_GE(alternative_pitch.period, kMinPitch24kHz);
// When looking at |alternative_pitch.period|, we also look at one of its
// sub-harmonics. |kSubHarmonicMultipliers| is used to know where to look.
// |k| == 2 is a special case since |candidate_pitch_secondary_period| might
// be greater than the maximum pitch period.
int candidate_pitch_secondary_period = alternative_period(
initial_pitch_period, k, kSubHarmonicMultipliers[k - 2]);
RTC_DCHECK_GT(candidate_pitch_secondary_period, 0);
if (k == 2 && candidate_pitch_secondary_period > kMaxPitch24kHz) {
candidate_pitch_secondary_period = initial_pitch_period;
// |period_divisor| == 2 is a special case since |dual_alternative_period|
// might be greater than the maximum pitch period.
int dual_alternative_period = GetAlternativePitchPeriod(
initial_pitch.period, kSubHarmonicMultipliers[period_divisor - 2],
period_divisor);
RTC_DCHECK_GT(dual_alternative_period, 0);
if (period_divisor == 2 && dual_alternative_period > kMaxPitch24kHz) {
dual_alternative_period = initial_pitch.period;
}
RTC_DCHECK_NE(candidate_pitch_period, candidate_pitch_secondary_period)
RTC_DCHECK_NE(alternative_pitch.period, dual_alternative_period)
<< "The lower pitch period and the additional sub-harmonic must not "
"coincide.";
// Compute an auto-correlation score for the primary pitch candidate
// |candidate_pitch_period| by also looking at its possible sub-harmonic
// |candidate_pitch_secondary_period|.
float xy_primary_period = ComputeAutoCorrelationCoeff(
pitch_buf, GetInvertedLag(candidate_pitch_period), kMaxPitch24kHz);
float xy_secondary_period = ComputeAutoCorrelationCoeff(
pitch_buf, GetInvertedLag(candidate_pitch_secondary_period),
kMaxPitch24kHz);
// |alternative_pitch.period| by also looking at its possible sub-harmonic
// |dual_alternative_period|.
float xy_primary_period = ComputeAutoCorrelation(
kMaxPitch24kHz - alternative_pitch.period, pitch_buffer);
float xy_secondary_period = ComputeAutoCorrelation(
kMaxPitch24kHz - dual_alternative_period, pitch_buffer);
float xy = 0.5f * (xy_primary_period + xy_secondary_period);
float yy = 0.5f * (yy_values[candidate_pitch_period] +
yy_values[candidate_pitch_secondary_period]);
float candidate_pitch_gain = pitch_gain(xy, yy, xx);
float yy = 0.5f * (yy_values[alternative_pitch.period] +
yy_values[dual_alternative_period]);
alternative_pitch.strength = pitch_strength(xy, yy, xx);
// Maybe update best period.
float threshold = ComputePitchGainThreshold(
candidate_pitch_period, k, initial_pitch_period, initial_pitch_gain,
prev_pitch_48kHz.period / 2, prev_pitch_48kHz.gain);
if (candidate_pitch_gain > threshold) {
best_pitch = {candidate_pitch_period, candidate_pitch_gain, xy, yy};
if (IsAlternativePitchStrongerThanInitial(
last_pitch, initial_pitch, alternative_pitch, period_divisor)) {
best_pitch = {alternative_pitch.period, alternative_pitch.strength, xy,
yy};
}
}
// Final pitch gain and period.
// Final pitch strength and period.
best_pitch.xy = std::max(0.f, best_pitch.xy);
RTC_DCHECK_LE(0.f, best_pitch.yy);
float final_pitch_gain = (best_pitch.yy <= best_pitch.xy)
? 1.f
: best_pitch.xy / (best_pitch.yy + 1.f);
final_pitch_gain = std::min(best_pitch.gain, final_pitch_gain);
float final_pitch_strength = (best_pitch.yy <= best_pitch.xy)
? 1.f
: best_pitch.xy / (best_pitch.yy + 1.f);
final_pitch_strength = std::min(best_pitch.strength, final_pitch_strength);
int final_pitch_period_48kHz = std::max(
kMinPitch48kHz,
PitchPseudoInterpolationLagPitchBuf(best_pitch.period_24kHz, pitch_buf));
PitchPseudoInterpolationLagPitchBuf(best_pitch.period, pitch_buffer));
return {final_pitch_period_48kHz, final_pitch_gain};
return {final_pitch_period_48kHz, final_pitch_strength};
}
} // namespace rnn_vad

View File

@ -18,7 +18,6 @@
#include "api/array_view.h"
#include "modules/audio_processing/agc2/rnn_vad/common.h"
#include "modules/audio_processing/agc2/rnn_vad/pitch_info.h"
namespace webrtc {
namespace rnn_vad {
@ -27,56 +26,78 @@ namespace rnn_vad {
void Decimate2x(rtc::ArrayView<const float, kBufSize24kHz> src,
rtc::ArrayView<float, kBufSize12kHz> dst);
// Computes a gain threshold for a candidate pitch period given the initial and
// the previous pitch period and gain estimates and the pitch period ratio used
// to derive the candidate pitch period from the initial period.
float ComputePitchGainThreshold(int candidate_pitch_period,
int pitch_period_ratio,
int initial_pitch_period,
float initial_pitch_gain,
int prev_pitch_period,
float prev_pitch_gain);
// Computes the sum of squared samples for every sliding frame in the pitch
// buffer. |yy_values| indexes are lags.
// Key concepts and keywords used below in this file.
//
// The pitch buffer is structured as depicted below:
// |.........|...........|
// a b
// The part on the left, named "a" contains the oldest samples, whereas "b" the
// most recent ones. The size of "a" corresponds to the maximum pitch period,
// that of "b" to the frame size (e.g., 16 ms and 20 ms respectively).
void ComputeSlidingFrameSquareEnergies(
rtc::ArrayView<const float, kBufSize24kHz> pitch_buf,
rtc::ArrayView<float, kMaxPitch24kHz + 1> yy_values);
// The pitch estimation relies on a pitch buffer, which is an array-like data
// structured designed as follows:
//
// |....A....|.....B.....|
//
// The part on the left, named `A` contains the oldest samples, whereas `B`
// contains the most recent ones. The size of `A` corresponds to the maximum
// pitch period, that of `B` to the analysis frame size (e.g., 16 ms and 20 ms
// respectively).
//
// Pitch estimation is essentially based on the analysis of two 20 ms frames
// extracted from the pitch buffer. One frame, called `x`, is kept fixed and
// corresponds to `B` - i.e., the most recent 20 ms. The other frame, called
// `y`, is extracted from different parts of the buffer instead.
//
// The offset between `x` and `y` corresponds to a specific pitch period.
// For instance, if `y` is positioned at the beginning of the pitch buffer, then
// the cross-correlation between `x` and `y` can be used as an indication of the
// strength for the maximum pitch.
//
// Such an offset can be encoded in two ways:
// - As a lag, which is the index in the pitch buffer for the first item in `y`
// - As an inverted lag, which is the number of samples from the beginning of
// `x` and the end of `y`
//
// |---->| lag
// |....A....|.....B.....|
// |<--| inverted lag
// |.....y.....| `y` 20 ms frame
//
// The inverted lag has the advantage of being directly proportional to the
// corresponding pitch period.
// Top-2 pitch period candidates.
// Computes the sum of squared samples for every sliding frame `y` in the pitch
// buffer. The indexes of `yy_values` are lags.
void ComputeSlidingFrameSquareEnergies24kHz(
rtc::ArrayView<const float, kBufSize24kHz> pitch_buffer,
rtc::ArrayView<float, kRefineNumLags24kHz> yy_values);
// Top-2 pitch period candidates. Unit: number of samples - i.e., inverted lags.
struct CandidatePitchPeriods {
int best;
int second_best;
};
// Computes the candidate pitch periods given the auto-correlation coefficients
// stored according to ComputePitchAutoCorrelation() (i.e., using inverted
// lags). The return periods are inverted lags.
CandidatePitchPeriods FindBestPitchPeriods(
rtc::ArrayView<const float> auto_corr,
rtc::ArrayView<const float> pitch_buf,
int max_pitch_period);
// Computes the candidate pitch periods at 12 kHz given a view on the 12 kHz
// pitch buffer and the auto-correlation values (having inverted lags as
// indexes).
CandidatePitchPeriods ComputePitchPeriod12kHz(
rtc::ArrayView<const float, kBufSize12kHz> pitch_buffer,
rtc::ArrayView<const float, kNumLags12kHz> auto_correlation);
// Refines the pitch period estimation given the pitch buffer |pitch_buf| and
// the initial pitch period estimation |pitch_candidates_inverted_lags|.
// Returns an inverted lag at 48 kHz.
int RefinePitchPeriod48kHz(
rtc::ArrayView<const float, kBufSize24kHz> pitch_buf,
CandidatePitchPeriods pitch_candidates_inverted_lags);
// Computes the pitch period at 48 kHz given a view on the 24 kHz pitch buffer
// and the pitch period candidates at 24 kHz (encoded as inverted lag).
int ComputePitchPeriod48kHz(
rtc::ArrayView<const float, kBufSize24kHz> pitch_buffer,
CandidatePitchPeriods pitch_candidates_24kHz);
// Refines the pitch period estimation and compute the pitch gain. Returns the
// refined pitch estimation data at 48 kHz.
PitchInfo CheckLowerPitchPeriodsAndComputePitchGain(
rtc::ArrayView<const float, kBufSize24kHz> pitch_buf,
struct PitchInfo {
int period;
float strength;
};
// Computes the pitch period at 48 kHz searching in an extended pitch range
// given a view on the 24 kHz pitch buffer, the initial 48 kHz estimation
// (computed by `ComputePitchPeriod48kHz()`) and the last estimated pitch.
PitchInfo ComputeExtendedPitchPeriod48kHz(
rtc::ArrayView<const float, kBufSize24kHz> pitch_buffer,
int initial_pitch_period_48kHz,
PitchInfo prev_pitch_48kHz);
PitchInfo last_pitch_48kHz);
} // namespace rnn_vad
} // namespace webrtc

View File

@ -31,138 +31,105 @@ constexpr float kTestPitchGainsHigh = 0.75f;
} // namespace
class ComputePitchGainThresholdTest
: public ::testing::Test,
public ::testing::WithParamInterface<std::tuple<
/*candidate_pitch_period=*/int,
/*pitch_period_ratio=*/int,
/*initial_pitch_period=*/int,
/*initial_pitch_gain=*/float,
/*prev_pitch_period=*/int,
/*prev_pitch_gain=*/float,
/*threshold=*/float>> {};
// Checks that the computed pitch gain is within tolerance given test input
// data.
TEST_P(ComputePitchGainThresholdTest, WithinTolerance) {
const auto params = GetParam();
const int candidate_pitch_period = std::get<0>(params);
const int pitch_period_ratio = std::get<1>(params);
const int initial_pitch_period = std::get<2>(params);
const float initial_pitch_gain = std::get<3>(params);
const int prev_pitch_period = std::get<4>(params);
const float prev_pitch_gain = std::get<5>(params);
const float threshold = std::get<6>(params);
{
// TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
// FloatingPointExceptionObserver fpe_observer;
EXPECT_NEAR(
threshold,
ComputePitchGainThreshold(candidate_pitch_period, pitch_period_ratio,
initial_pitch_period, initial_pitch_gain,
prev_pitch_period, prev_pitch_gain),
5e-7f);
}
}
INSTANTIATE_TEST_SUITE_P(
RnnVadTest,
ComputePitchGainThresholdTest,
::testing::Values(
std::make_tuple(31, 7, 219, 0.45649201f, 199, 0.604747f, 0.40000001f),
std::make_tuple(113,
2,
226,
0.20967799f,
219,
0.40392199f,
0.30000001f),
std::make_tuple(63, 2, 126, 0.210788f, 364, 0.098519f, 0.40000001f),
std::make_tuple(30, 5, 152, 0.82356697f, 149, 0.55535901f, 0.700032f),
std::make_tuple(76, 2, 151, 0.79522997f, 151, 0.82356697f, 0.675946f),
std::make_tuple(31, 5, 153, 0.85069299f, 150, 0.79073799f, 0.72308898f),
std::make_tuple(78, 2, 156, 0.72750503f, 153, 0.85069299f, 0.618379f)));
// Checks that the frame-wise sliding square energy function produces output
// within tolerance given test input data.
TEST(RnnVadTest, ComputeSlidingFrameSquareEnergiesWithinTolerance) {
TEST(RnnVadTest, ComputeSlidingFrameSquareEnergies24kHzWithinTolerance) {
PitchTestData test_data;
std::array<float, kNumPitchBufSquareEnergies> computed_output;
{
// TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
// FloatingPointExceptionObserver fpe_observer;
ComputeSlidingFrameSquareEnergies(test_data.GetPitchBufView(),
computed_output);
}
// TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
// FloatingPointExceptionObserver fpe_observer;
ComputeSlidingFrameSquareEnergies24kHz(test_data.GetPitchBufView(),
computed_output);
auto square_energies_view = test_data.GetPitchBufSquareEnergiesView();
ExpectNearAbsolute({square_energies_view.data(), square_energies_view.size()},
computed_output, 3e-2f);
}
// Checks that the estimated pitch period is bit-exact given test input data.
TEST(RnnVadTest, FindBestPitchPeriodsBitExactness) {
TEST(RnnVadTest, ComputePitchPeriod12kHzBitExactness) {
PitchTestData test_data;
std::array<float, kBufSize12kHz> pitch_buf_decimated;
Decimate2x(test_data.GetPitchBufView(), pitch_buf_decimated);
CandidatePitchPeriods pitch_candidates;
{
// TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
// FloatingPointExceptionObserver fpe_observer;
auto auto_corr_view = test_data.GetPitchBufAutoCorrCoeffsView();
pitch_candidates = FindBestPitchPeriods(auto_corr_view, pitch_buf_decimated,
kMaxPitch12kHz);
}
// TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
// FloatingPointExceptionObserver fpe_observer;
auto auto_corr_view = test_data.GetPitchBufAutoCorrCoeffsView();
pitch_candidates =
ComputePitchPeriod12kHz(pitch_buf_decimated, auto_corr_view);
EXPECT_EQ(pitch_candidates.best, 140);
EXPECT_EQ(pitch_candidates.second_best, 142);
}
// Checks that the refined pitch period is bit-exact given test input data.
TEST(RnnVadTest, RefinePitchPeriod48kHzBitExactness) {
TEST(RnnVadTest, ComputePitchPeriod48kHzBitExactness) {
PitchTestData test_data;
// TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
// FloatingPointExceptionObserver fpe_observer;
EXPECT_EQ(RefinePitchPeriod48kHz(test_data.GetPitchBufView(),
/*pitch_candidates=*/{280, 284}),
EXPECT_EQ(ComputePitchPeriod48kHz(test_data.GetPitchBufView(),
/*pitch_candidates=*/{280, 284}),
560);
EXPECT_EQ(RefinePitchPeriod48kHz(test_data.GetPitchBufView(),
/*pitch_candidates=*/{260, 284}),
EXPECT_EQ(ComputePitchPeriod48kHz(test_data.GetPitchBufView(),
/*pitch_candidates=*/{260, 284}),
568);
}
class CheckLowerPitchPeriodsAndComputePitchGainTest
: public ::testing::Test,
public ::testing::WithParamInterface<std::tuple<
/*initial_pitch_period=*/int,
/*prev_pitch_period=*/int,
/*prev_pitch_gain=*/float,
/*expected_pitch_period=*/int,
/*expected_pitch_gain=*/float>> {};
class PitchCandidatesParametrization
: public ::testing::TestWithParam<CandidatePitchPeriods> {
protected:
CandidatePitchPeriods GetPitchCandidates() const { return GetParam(); }
CandidatePitchPeriods GetSwappedPitchCandidates() const {
CandidatePitchPeriods candidate = GetParam();
return {candidate.second_best, candidate.best};
}
};
// Checks that the result of `ComputePitchPeriod48kHz()` does not depend on the
// order of the input pitch candidates.
TEST_P(PitchCandidatesParametrization,
ComputePitchPeriod48kHzOrderDoesNotMatter) {
PitchTestData test_data;
EXPECT_EQ(ComputePitchPeriod48kHz(test_data.GetPitchBufView(),
GetPitchCandidates()),
ComputePitchPeriod48kHz(test_data.GetPitchBufView(),
GetSwappedPitchCandidates()));
}
INSTANTIATE_TEST_SUITE_P(RnnVadTest,
PitchCandidatesParametrization,
::testing::Values(CandidatePitchPeriods{0, 2},
CandidatePitchPeriods{260, 284},
CandidatePitchPeriods{280, 284},
CandidatePitchPeriods{
kInitialNumLags24kHz - 2,
kInitialNumLags24kHz - 1}));
class ExtendedPitchPeriodSearchParametrizaion
: public ::testing::TestWithParam<std::tuple<int, int, float, int, float>> {
protected:
int GetInitialPitchPeriod() const { return std::get<0>(GetParam()); }
int GetLastPitchPeriod() const { return std::get<1>(GetParam()); }
float GetLastPitchStrength() const { return std::get<2>(GetParam()); }
int GetExpectedPitchPeriod() const { return std::get<3>(GetParam()); }
float GetExpectedPitchStrength() const { return std::get<4>(GetParam()); }
};
// Checks that the computed pitch period is bit-exact and that the computed
// pitch gain is within tolerance given test input data.
TEST_P(CheckLowerPitchPeriodsAndComputePitchGainTest,
// pitch strength is within tolerance given test input data.
TEST_P(ExtendedPitchPeriodSearchParametrizaion,
PeriodBitExactnessGainWithinTolerance) {
const auto params = GetParam();
const int initial_pitch_period = std::get<0>(params);
const int prev_pitch_period = std::get<1>(params);
const float prev_pitch_gain = std::get<2>(params);
const int expected_pitch_period = std::get<3>(params);
const float expected_pitch_gain = std::get<4>(params);
PitchTestData test_data;
{
// TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
// FloatingPointExceptionObserver fpe_observer;
const auto computed_output = CheckLowerPitchPeriodsAndComputePitchGain(
test_data.GetPitchBufView(), initial_pitch_period,
{prev_pitch_period, prev_pitch_gain});
EXPECT_EQ(expected_pitch_period, computed_output.period);
EXPECT_NEAR(expected_pitch_gain, computed_output.gain, 1e-6f);
}
// TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
// FloatingPointExceptionObserver fpe_observer;
const auto computed_output = ComputeExtendedPitchPeriod48kHz(
test_data.GetPitchBufView(), GetInitialPitchPeriod(),
{GetLastPitchPeriod(), GetLastPitchStrength()});
EXPECT_EQ(GetExpectedPitchPeriod(), computed_output.period);
EXPECT_NEAR(GetExpectedPitchStrength(), computed_output.strength, 1e-6f);
}
INSTANTIATE_TEST_SUITE_P(
RnnVadTest,
CheckLowerPitchPeriodsAndComputePitchGainTest,
ExtendedPitchPeriodSearchParametrizaion,
::testing::Values(std::make_tuple(kTestPitchPeriodsLow,
kTestPitchPeriodsLow,
kTestPitchGainsLow,

View File

@ -13,7 +13,6 @@
#include <algorithm>
#include <vector>
#include "modules/audio_processing/agc2/rnn_vad/pitch_info.h"
#include "modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h"
#include "modules/audio_processing/agc2/rnn_vad/test_utils.h"
// TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
@ -22,15 +21,14 @@
namespace webrtc {
namespace rnn_vad {
namespace test {
// Checks that the computed pitch period is bit-exact and that the computed
// pitch gain is within tolerance given test input data.
TEST(RnnVadTest, PitchSearchWithinTolerance) {
auto lp_residual_reader = CreateLpResidualAndPitchPeriodGainReader();
auto lp_residual_reader = test::CreateLpResidualAndPitchPeriodGainReader();
const int num_frames = std::min(lp_residual_reader.second, 300); // Max 3 s.
std::vector<float> lp_residual(kBufSize24kHz);
float expected_pitch_period, expected_pitch_gain;
float expected_pitch_period, expected_pitch_strength;
PitchEstimator pitch_estimator;
{
// TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
@ -39,15 +37,15 @@ TEST(RnnVadTest, PitchSearchWithinTolerance) {
SCOPED_TRACE(i);
lp_residual_reader.first->ReadChunk(lp_residual);
lp_residual_reader.first->ReadValue(&expected_pitch_period);
lp_residual_reader.first->ReadValue(&expected_pitch_gain);
PitchInfo pitch_info =
lp_residual_reader.first->ReadValue(&expected_pitch_strength);
int pitch_period =
pitch_estimator.Estimate({lp_residual.data(), kBufSize24kHz});
EXPECT_EQ(expected_pitch_period, pitch_info.period);
EXPECT_NEAR(expected_pitch_gain, pitch_info.gain, 1e-5f);
EXPECT_EQ(expected_pitch_period, pitch_period);
EXPECT_NEAR(expected_pitch_strength,
pitch_estimator.GetLastPitchStrengthForTesting(), 1e-5f);
}
}
}
} // namespace test
} // namespace rnn_vad
} // namespace webrtc