RNN VAD: fix pitch gain type and change pitch period type

The pitch gain type in ComputePitchGainThreshold() is wrong
(size_t instead of float).
The pitch period is an unsigned integer type, but it is safer to
switch to a signed type and add checks on the sign.

Bug: webrtc:9076
Change-Id: If69d182071edab9750a320f0fbfac24aa8052ee0
Reviewed-on: https://webrtc-review.googlesource.com/c/117302
Reviewed-by: Alex Loiko <aleloi@webrtc.org>
Commit-Queue: Alessio Bazzica <alessiob@webrtc.org>
Cr-Commit-Position: refs/heads/master@{#26259}
This commit is contained in:
Alessio Bazzica
2019-01-14 13:54:57 +01:00
committed by Commit Bot
parent 49ea47b90e
commit c25fa89e9e
5 changed files with 55 additions and 46 deletions

View File

@ -18,8 +18,8 @@ namespace rnn_vad {
// strength of the pitch (the higher, the stronger). // strength of the pitch (the higher, the stronger).
struct PitchInfo { struct PitchInfo {
PitchInfo() : period(0), gain(0.f) {} PitchInfo() : period(0), gain(0.f) {}
PitchInfo(size_t p, float g) : period(p), gain(g) {} PitchInfo(int p, float g) : period(p), gain(g) {}
size_t period; int period;
float gain; float gain;
}; };

View File

@ -128,12 +128,12 @@ size_t PitchPseudoInterpolationInvLagAutoCorr(
// sn = mex({n * i for i in S} | {1}) // sn = mex({n * i for i in S} | {1})
// S = S | {Fraction(1, n), Fraction(sn, n)} // S = S | {Fraction(1, n), Fraction(sn, n)}
// print(sn, end=', ') // print(sn, end=', ')
constexpr std::array<size_t, 14> kSubHarmonicMultipliers = { constexpr std::array<int, 14> kSubHarmonicMultipliers = {
{3, 2, 3, 2, 5, 2, 3, 2, 3, 2, 5, 2, 3, 2}}; {3, 2, 3, 2, 5, 2, 3, 2, 3, 2, 5, 2, 3, 2}};
// Initial pitch period candidate thresholds for ComputePitchGainThreshold() for // Initial pitch period candidate thresholds for ComputePitchGainThreshold() for
// a sample rate of 24 kHz. Computed as [5*k*k for k in range(16)]. // a sample rate of 24 kHz. Computed as [5*k*k for k in range(16)].
constexpr std::array<size_t, 14> kInitialPitchPeriodThresholds = { constexpr std::array<int, 14> kInitialPitchPeriodThresholds = {
{20, 45, 80, 125, 180, 245, 320, 405, 500, 605, 720, 845, 980, 1125}}; {20, 45, 80, 125, 180, 245, 320, 405, 500, 605, 720, 845, 980, 1125}};
} // namespace } // namespace
@ -147,31 +147,34 @@ void Decimate2x(rtc::ArrayView<const float, kBufSize24kHz> src,
} }
} }
float ComputePitchGainThreshold(size_t candidate_pitch_period, float ComputePitchGainThreshold(int candidate_pitch_period,
size_t pitch_period_ratio, int pitch_period_ratio,
size_t initial_pitch_period, int initial_pitch_period,
float initial_pitch_gain, float initial_pitch_gain,
size_t prev_pitch_period, int prev_pitch_period,
size_t prev_pitch_gain) { float prev_pitch_gain) {
// Map arguments to more compact aliases. // Map arguments to more compact aliases.
const size_t& t1 = candidate_pitch_period; const int& t1 = candidate_pitch_period;
const size_t& k = pitch_period_ratio; const int& k = pitch_period_ratio;
const size_t& t0 = initial_pitch_period; const int& t0 = initial_pitch_period;
const float& g0 = initial_pitch_gain; const float& g0 = initial_pitch_gain;
const size_t& t_prev = prev_pitch_period; const int& t_prev = prev_pitch_period;
const size_t& g_prev = prev_pitch_gain; const float& g_prev = prev_pitch_gain;
// Validate input. // Validate input.
RTC_DCHECK_GE(t1, 0);
RTC_DCHECK_GE(k, 2); RTC_DCHECK_GE(k, 2);
RTC_DCHECK_GE(t0, 0);
RTC_DCHECK_GE(t_prev, 0);
// Compute a term that lowers the threshold when |t1| is close to the last // Compute a term that lowers the threshold when |t1| is close to the last
// estimated period |t_prev| - i.e., pitch tracking. // estimated period |t_prev| - i.e., pitch tracking.
float lower_threshold_term = 0; float lower_threshold_term = 0;
if (abs(static_cast<int>(t1) - static_cast<int>(t_prev)) <= 1) { if (abs(t1 - t_prev) <= 1) {
// The candidate pitch period is within 1 sample from the previous one. // The candidate pitch period is within 1 sample from the previous one.
// Make the candidate at |t1| very easy to be accepted. // Make the candidate at |t1| very easy to be accepted.
lower_threshold_term = g_prev; lower_threshold_term = g_prev;
} else if (abs(static_cast<int>(t1) - static_cast<int>(t_prev)) == 2 && } else if (abs(t1 - t_prev) == 2 &&
t0 > kInitialPitchPeriodThresholds[k - 2]) { t0 > kInitialPitchPeriodThresholds[k - 2]) {
// The candidate pitch period is 2 samples far from the previous one and the // The candidate pitch period is 2 samples far from the previous one and the
// period |t0| (from which |t1| has been derived) is greater than a // period |t0| (from which |t1| has been derived) is greater than a
@ -182,9 +185,11 @@ float ComputePitchGainThreshold(size_t candidate_pitch_period,
// reduce the chance of false positives caused by a bias towards high // reduce the chance of false positives caused by a bias towards high
// frequencies (originating from short-term correlations). // frequencies (originating from short-term correlations).
float threshold = std::max(0.3f, 0.7f * g0 - lower_threshold_term); float threshold = std::max(0.3f, 0.7f * g0 - lower_threshold_term);
if (t1 < 3 * kMinPitch24kHz) { // High frequency. if (static_cast<size_t>(t1) < 3 * kMinPitch24kHz) {
// High frequency.
threshold = std::max(0.4f, 0.85f * g0 - lower_threshold_term); threshold = std::max(0.4f, 0.85f * g0 - lower_threshold_term);
} else if (t1 < 2 * kMinPitch24kHz) { // Even higher frequency. } else if (static_cast<size_t>(t1) < 2 * kMinPitch24kHz) {
// Even higher frequency.
threshold = std::max(0.5f, 0.9f * g0 - lower_threshold_term); threshold = std::max(0.5f, 0.9f * g0 - lower_threshold_term);
} }
return threshold; return threshold;
@ -350,16 +355,16 @@ size_t RefinePitchPeriod48kHz(
PitchInfo CheckLowerPitchPeriodsAndComputePitchGain( PitchInfo CheckLowerPitchPeriodsAndComputePitchGain(
rtc::ArrayView<const float, kBufSize24kHz> pitch_buf, rtc::ArrayView<const float, kBufSize24kHz> pitch_buf,
size_t initial_pitch_period_48kHz, int initial_pitch_period_48kHz,
PitchInfo prev_pitch_48kHz) { PitchInfo prev_pitch_48kHz) {
RTC_DCHECK_LE(kMinPitch48kHz, initial_pitch_period_48kHz); RTC_DCHECK_LE(kMinPitch48kHz, initial_pitch_period_48kHz);
RTC_DCHECK_LE(initial_pitch_period_48kHz, kMaxPitch48kHz); RTC_DCHECK_LE(initial_pitch_period_48kHz, kMaxPitch48kHz);
// Stores information for a refined pitch candidate. // Stores information for a refined pitch candidate.
struct RefinedPitchCandidate { struct RefinedPitchCandidate {
RefinedPitchCandidate() {} RefinedPitchCandidate() {}
RefinedPitchCandidate(size_t period_24kHz, float gain, float xy, float yy) RefinedPitchCandidate(int period_24kHz, float gain, float xy, float yy)
: period_24kHz(period_24kHz), gain(gain), xy(xy), yy(yy) {} : period_24kHz(period_24kHz), gain(gain), xy(xy), yy(yy) {}
size_t period_24kHz; int period_24kHz;
// Pitch strength information. // Pitch strength information.
float gain; float gain;
// Additional pitch strength information used for the final estimation of // Additional pitch strength information used for the final estimation of
@ -380,8 +385,8 @@ PitchInfo CheckLowerPitchPeriodsAndComputePitchGain(
}; };
// Initial pitch candidate gain. // Initial pitch candidate gain.
RefinedPitchCandidate best_pitch; RefinedPitchCandidate best_pitch;
best_pitch.period_24kHz = best_pitch.period_24kHz = std::min(initial_pitch_period_48kHz / 2,
std::min(initial_pitch_period_48kHz / 2, kMaxPitch24kHz - 1); static_cast<int>(kMaxPitch24kHz - 1));
best_pitch.xy = ComputeAutoCorrelationCoeff( best_pitch.xy = ComputeAutoCorrelationCoeff(
pitch_buf, GetInvertedLag(best_pitch.period_24kHz), kMaxPitch24kHz); pitch_buf, GetInvertedLag(best_pitch.period_24kHz), kMaxPitch24kHz);
best_pitch.yy = yy_values[best_pitch.period_24kHz]; best_pitch.yy = yy_values[best_pitch.period_24kHz];
@ -392,24 +397,27 @@ PitchInfo CheckLowerPitchPeriodsAndComputePitchGain(
const float initial_pitch_gain = best_pitch.gain; const float initial_pitch_gain = best_pitch.gain;
// Given the initial pitch estimation, check lower periods (i.e., harmonics). // Given the initial pitch estimation, check lower periods (i.e., harmonics).
const auto alternative_period = [](size_t period, size_t k, const auto alternative_period = [](int period, int k, int n) -> int {
size_t n) -> size_t { RTC_DCHECK_GT(k, 0);
RTC_DCHECK_LT(0, k);
return (2 * n * period + k) / (2 * k); // Same as round(n*period/k). return (2 * n * period + k) / (2 * k); // Same as round(n*period/k).
}; };
for (size_t k = 2; k < kSubHarmonicMultipliers.size() + 2; ++k) { for (int k = 2; k < static_cast<int>(kSubHarmonicMultipliers.size() + 2);
size_t candidate_pitch_period = ++k) {
alternative_period(initial_pitch_period, k, 1); int candidate_pitch_period = alternative_period(initial_pitch_period, k, 1);
if (candidate_pitch_period < kMinPitch24kHz) if (static_cast<size_t>(candidate_pitch_period) < kMinPitch24kHz) {
break; break;
}
// When looking at |candidate_pitch_period|, we also look at one of its // When looking at |candidate_pitch_period|, we also look at one of its
// sub-harmonics. |kSubHarmonicMultipliers| is used to know where to look. // sub-harmonics. |kSubHarmonicMultipliers| is used to know where to look.
// |k| == 2 is a special case since |candidate_pitch_secondary_period| might // |k| == 2 is a special case since |candidate_pitch_secondary_period| might
// be greater than the maximum pitch period. // be greater than the maximum pitch period.
size_t candidate_pitch_secondary_period = alternative_period( int candidate_pitch_secondary_period = alternative_period(
initial_pitch_period, k, kSubHarmonicMultipliers[k - 2]); initial_pitch_period, k, kSubHarmonicMultipliers[k - 2]);
if (k == 2 && candidate_pitch_secondary_period > kMaxPitch24kHz) RTC_DCHECK_GT(candidate_pitch_secondary_period, 0);
if (k == 2 &&
candidate_pitch_secondary_period > static_cast<int>(kMaxPitch24kHz)) {
candidate_pitch_secondary_period = initial_pitch_period; candidate_pitch_secondary_period = initial_pitch_period;
}
RTC_DCHECK_NE(candidate_pitch_period, candidate_pitch_secondary_period) RTC_DCHECK_NE(candidate_pitch_period, candidate_pitch_secondary_period)
<< "The lower pitch period and the additional sub-harmonic must not " << "The lower pitch period and the additional sub-harmonic must not "
<< "coincide."; << "coincide.";
@ -442,7 +450,7 @@ PitchInfo CheckLowerPitchPeriodsAndComputePitchGain(
? 1.f ? 1.f
: best_pitch.xy / (best_pitch.yy + 1.f); : best_pitch.xy / (best_pitch.yy + 1.f);
final_pitch_gain = std::min(best_pitch.gain, final_pitch_gain); final_pitch_gain = std::min(best_pitch.gain, final_pitch_gain);
size_t final_pitch_period_48kHz = std::max( int final_pitch_period_48kHz = std::max(
kMinPitch48kHz, kMinPitch48kHz,
PitchPseudoInterpolationLagPitchBuf(best_pitch.period_24kHz, pitch_buf)); PitchPseudoInterpolationLagPitchBuf(best_pitch.period_24kHz, pitch_buf));

View File

@ -41,12 +41,12 @@ void Decimate2x(rtc::ArrayView<const float, kBufSize24kHz> src,
// Computes a gain threshold for a candidate pitch period given the initial and // Computes a gain threshold for a candidate pitch period given the initial and
// the previous pitch period and gain estimates and the pitch period ratio used // the previous pitch period and gain estimates and the pitch period ratio used
// to derive the candidate pitch period from the initial period. // to derive the candidate pitch period from the initial period.
float ComputePitchGainThreshold(size_t candidate_pitch_period, float ComputePitchGainThreshold(int candidate_pitch_period,
size_t pitch_period_ratio, int pitch_period_ratio,
size_t initial_pitch_period, int initial_pitch_period,
float initial_pitch_gain, float initial_pitch_gain,
size_t prev_pitch_period, int prev_pitch_period,
size_t prev_pitch_gain); float prev_pitch_gain);
// Computes the sum of squared samples for every sliding frame in the pitch // Computes the sum of squared samples for every sliding frame in the pitch
// buffer. |yy_values| indexes are lags. // buffer. |yy_values| indexes are lags.
@ -99,7 +99,7 @@ size_t RefinePitchPeriod48kHz(
// refined pitch estimation data at 48 kHz. // refined pitch estimation data at 48 kHz.
PitchInfo CheckLowerPitchPeriodsAndComputePitchGain( PitchInfo CheckLowerPitchPeriodsAndComputePitchGain(
rtc::ArrayView<const float, kBufSize24kHz> pitch_buf, rtc::ArrayView<const float, kBufSize24kHz> pitch_buf,
size_t initial_pitch_period_48kHz, int initial_pitch_period_48kHz,
PitchInfo prev_pitch_48kHz); PitchInfo prev_pitch_48kHz);
} // namespace rnn_vad } // namespace rnn_vad

View File

@ -24,8 +24,9 @@ namespace rnn_vad {
namespace test { namespace test {
namespace { namespace {
constexpr std::array<size_t, 2> kTestPitchPeriods = { constexpr std::array<int, 2> kTestPitchPeriods = {
3 * kMinPitch48kHz / 2, (3 * kMinPitch48kHz + kMaxPitch48kHz) / 2, 3 * kMinPitch48kHz / 2,
(3 * kMinPitch48kHz + kMaxPitch48kHz) / 2,
}; };
constexpr std::array<float, 2> kTestPitchGains = {0.35f, 0.75f}; constexpr std::array<float, 2> kTestPitchGains = {0.35f, 0.75f};
@ -197,14 +198,14 @@ TEST(RnnVadTest, RefinePitchPeriod48kHzBitExactness) {
class CheckLowerPitchPeriodsAndComputePitchGainTest class CheckLowerPitchPeriodsAndComputePitchGainTest
: public testing::Test, : public testing::Test,
public ::testing::WithParamInterface< public ::testing::WithParamInterface<
std::tuple<size_t, size_t, float, size_t, float>> {}; std::tuple<int, int, float, int, float>> {};
TEST_P(CheckLowerPitchPeriodsAndComputePitchGainTest, BitExactness) { TEST_P(CheckLowerPitchPeriodsAndComputePitchGainTest, BitExactness) {
const auto params = GetParam(); const auto params = GetParam();
const size_t initial_pitch_period = std::get<0>(params); const int initial_pitch_period = std::get<0>(params);
const size_t prev_pitch_period = std::get<1>(params); const int prev_pitch_period = std::get<1>(params);
const float prev_pitch_gain = std::get<2>(params); const float prev_pitch_gain = std::get<2>(params);
const size_t expected_pitch_period = std::get<3>(params); const int expected_pitch_period = std::get<3>(params);
const float expected_pitch_gain = std::get<4>(params); const float expected_pitch_gain = std::get<4>(params);
TestData test_data; TestData test_data;
{ {

View File

@ -39,7 +39,7 @@ TEST(RnnVadTest, PitchSearchBitExactness) {
lp_residual_reader.first->ReadValue(&expected_pitch_period); lp_residual_reader.first->ReadValue(&expected_pitch_period);
lp_residual_reader.first->ReadValue(&expected_pitch_gain); lp_residual_reader.first->ReadValue(&expected_pitch_gain);
PitchInfo pitch_info = pitch_estimator.Estimate(lp_residual); PitchInfo pitch_info = pitch_estimator.Estimate(lp_residual);
EXPECT_EQ(static_cast<size_t>(expected_pitch_period), pitch_info.period); EXPECT_EQ(static_cast<int>(expected_pitch_period), pitch_info.period);
EXPECT_NEAR(expected_pitch_gain, pitch_info.gain, 1e-5f); EXPECT_NEAR(expected_pitch_gain, pitch_info.gain, 1e-5f);
} }
} }