RNN VAD: fix pitch gain type and change pitch period type

The pitch gain type in ComputePitchGainThreshold() is wrong
(size_t instead of float).
The pitch period is an unsigned integer type, but it is safer to
switch to a signed type and add checks on the sign.

Bug: webrtc:9076
Change-Id: If69d182071edab9750a320f0fbfac24aa8052ee0
Reviewed-on: https://webrtc-review.googlesource.com/c/117302
Reviewed-by: Alex Loiko <aleloi@webrtc.org>
Commit-Queue: Alessio Bazzica <alessiob@webrtc.org>
Cr-Commit-Position: refs/heads/master@{#26259}
This commit is contained in:
Alessio Bazzica
2019-01-14 13:54:57 +01:00
committed by Commit Bot
parent 49ea47b90e
commit c25fa89e9e
5 changed files with 55 additions and 46 deletions

View File

@ -18,8 +18,8 @@ namespace rnn_vad {
// strength of the pitch (the higher, the stronger).
struct PitchInfo {
PitchInfo() : period(0), gain(0.f) {}
PitchInfo(size_t p, float g) : period(p), gain(g) {}
size_t period;
PitchInfo(int p, float g) : period(p), gain(g) {}
int period;
float gain;
};

View File

@ -128,12 +128,12 @@ size_t PitchPseudoInterpolationInvLagAutoCorr(
// sn = mex({n * i for i in S} | {1})
// S = S | {Fraction(1, n), Fraction(sn, n)}
// print(sn, end=', ')
constexpr std::array<size_t, 14> kSubHarmonicMultipliers = {
constexpr std::array<int, 14> kSubHarmonicMultipliers = {
{3, 2, 3, 2, 5, 2, 3, 2, 3, 2, 5, 2, 3, 2}};
// Initial pitch period candidate thresholds for ComputePitchGainThreshold() for
// a sample rate of 24 kHz. Computed as [5*k*k for k in range(16)].
constexpr std::array<size_t, 14> kInitialPitchPeriodThresholds = {
constexpr std::array<int, 14> kInitialPitchPeriodThresholds = {
{20, 45, 80, 125, 180, 245, 320, 405, 500, 605, 720, 845, 980, 1125}};
} // namespace
@ -147,31 +147,34 @@ void Decimate2x(rtc::ArrayView<const float, kBufSize24kHz> src,
}
}
float ComputePitchGainThreshold(size_t candidate_pitch_period,
size_t pitch_period_ratio,
size_t initial_pitch_period,
float ComputePitchGainThreshold(int candidate_pitch_period,
int pitch_period_ratio,
int initial_pitch_period,
float initial_pitch_gain,
size_t prev_pitch_period,
size_t prev_pitch_gain) {
int prev_pitch_period,
float prev_pitch_gain) {
// Map arguments to more compact aliases.
const size_t& t1 = candidate_pitch_period;
const size_t& k = pitch_period_ratio;
const size_t& t0 = initial_pitch_period;
const int& t1 = candidate_pitch_period;
const int& k = pitch_period_ratio;
const int& t0 = initial_pitch_period;
const float& g0 = initial_pitch_gain;
const size_t& t_prev = prev_pitch_period;
const size_t& g_prev = prev_pitch_gain;
const int& t_prev = prev_pitch_period;
const float& g_prev = prev_pitch_gain;
// Validate input.
RTC_DCHECK_GE(t1, 0);
RTC_DCHECK_GE(k, 2);
RTC_DCHECK_GE(t0, 0);
RTC_DCHECK_GE(t_prev, 0);
// Compute a term that lowers the threshold when |t1| is close to the last
// estimated period |t_prev| - i.e., pitch tracking.
float lower_threshold_term = 0;
if (abs(static_cast<int>(t1) - static_cast<int>(t_prev)) <= 1) {
if (abs(t1 - t_prev) <= 1) {
// The candidate pitch period is within 1 sample from the previous one.
// Make the candidate at |t1| very easy to be accepted.
lower_threshold_term = g_prev;
} else if (abs(static_cast<int>(t1) - static_cast<int>(t_prev)) == 2 &&
} else if (abs(t1 - t_prev) == 2 &&
t0 > kInitialPitchPeriodThresholds[k - 2]) {
// The candidate pitch period is 2 samples far from the previous one and the
// period |t0| (from which |t1| has been derived) is greater than a
@ -182,9 +185,11 @@ float ComputePitchGainThreshold(size_t candidate_pitch_period,
// reduce the chance of false positives caused by a bias towards high
// frequencies (originating from short-term correlations).
float threshold = std::max(0.3f, 0.7f * g0 - lower_threshold_term);
if (t1 < 3 * kMinPitch24kHz) { // High frequency.
if (static_cast<size_t>(t1) < 3 * kMinPitch24kHz) {
// High frequency.
threshold = std::max(0.4f, 0.85f * g0 - lower_threshold_term);
} else if (t1 < 2 * kMinPitch24kHz) { // Even higher frequency.
} else if (static_cast<size_t>(t1) < 2 * kMinPitch24kHz) {
// Even higher frequency.
threshold = std::max(0.5f, 0.9f * g0 - lower_threshold_term);
}
return threshold;
@ -350,16 +355,16 @@ size_t RefinePitchPeriod48kHz(
PitchInfo CheckLowerPitchPeriodsAndComputePitchGain(
rtc::ArrayView<const float, kBufSize24kHz> pitch_buf,
size_t initial_pitch_period_48kHz,
int initial_pitch_period_48kHz,
PitchInfo prev_pitch_48kHz) {
RTC_DCHECK_LE(kMinPitch48kHz, initial_pitch_period_48kHz);
RTC_DCHECK_LE(initial_pitch_period_48kHz, kMaxPitch48kHz);
// Stores information for a refined pitch candidate.
struct RefinedPitchCandidate {
RefinedPitchCandidate() {}
RefinedPitchCandidate(size_t period_24kHz, float gain, float xy, float yy)
RefinedPitchCandidate(int period_24kHz, float gain, float xy, float yy)
: period_24kHz(period_24kHz), gain(gain), xy(xy), yy(yy) {}
size_t period_24kHz;
int period_24kHz;
// Pitch strength information.
float gain;
// Additional pitch strength information used for the final estimation of
@ -380,8 +385,8 @@ PitchInfo CheckLowerPitchPeriodsAndComputePitchGain(
};
// Initial pitch candidate gain.
RefinedPitchCandidate best_pitch;
best_pitch.period_24kHz =
std::min(initial_pitch_period_48kHz / 2, kMaxPitch24kHz - 1);
best_pitch.period_24kHz = std::min(initial_pitch_period_48kHz / 2,
static_cast<int>(kMaxPitch24kHz - 1));
best_pitch.xy = ComputeAutoCorrelationCoeff(
pitch_buf, GetInvertedLag(best_pitch.period_24kHz), kMaxPitch24kHz);
best_pitch.yy = yy_values[best_pitch.period_24kHz];
@ -392,24 +397,27 @@ PitchInfo CheckLowerPitchPeriodsAndComputePitchGain(
const float initial_pitch_gain = best_pitch.gain;
// Given the initial pitch estimation, check lower periods (i.e., harmonics).
const auto alternative_period = [](size_t period, size_t k,
size_t n) -> size_t {
RTC_DCHECK_LT(0, k);
const auto alternative_period = [](int period, int k, int n) -> int {
RTC_DCHECK_GT(k, 0);
return (2 * n * period + k) / (2 * k); // Same as round(n*period/k).
};
for (size_t k = 2; k < kSubHarmonicMultipliers.size() + 2; ++k) {
size_t candidate_pitch_period =
alternative_period(initial_pitch_period, k, 1);
if (candidate_pitch_period < kMinPitch24kHz)
for (int k = 2; k < static_cast<int>(kSubHarmonicMultipliers.size() + 2);
++k) {
int candidate_pitch_period = alternative_period(initial_pitch_period, k, 1);
if (static_cast<size_t>(candidate_pitch_period) < kMinPitch24kHz) {
break;
}
// When looking at |candidate_pitch_period|, we also look at one of its
// sub-harmonics. |kSubHarmonicMultipliers| is used to know where to look.
// |k| == 2 is a special case since |candidate_pitch_secondary_period| might
// be greater than the maximum pitch period.
size_t candidate_pitch_secondary_period = alternative_period(
int candidate_pitch_secondary_period = alternative_period(
initial_pitch_period, k, kSubHarmonicMultipliers[k - 2]);
if (k == 2 && candidate_pitch_secondary_period > kMaxPitch24kHz)
RTC_DCHECK_GT(candidate_pitch_secondary_period, 0);
if (k == 2 &&
candidate_pitch_secondary_period > static_cast<int>(kMaxPitch24kHz)) {
candidate_pitch_secondary_period = initial_pitch_period;
}
RTC_DCHECK_NE(candidate_pitch_period, candidate_pitch_secondary_period)
<< "The lower pitch period and the additional sub-harmonic must not "
<< "coincide.";
@ -442,7 +450,7 @@ PitchInfo CheckLowerPitchPeriodsAndComputePitchGain(
? 1.f
: best_pitch.xy / (best_pitch.yy + 1.f);
final_pitch_gain = std::min(best_pitch.gain, final_pitch_gain);
size_t final_pitch_period_48kHz = std::max(
int final_pitch_period_48kHz = std::max(
kMinPitch48kHz,
PitchPseudoInterpolationLagPitchBuf(best_pitch.period_24kHz, pitch_buf));

View File

@ -41,12 +41,12 @@ void Decimate2x(rtc::ArrayView<const float, kBufSize24kHz> src,
// Computes a gain threshold for a candidate pitch period given the initial and
// the previous pitch period and gain estimates and the pitch period ratio used
// to derive the candidate pitch period from the initial period.
float ComputePitchGainThreshold(size_t candidate_pitch_period,
size_t pitch_period_ratio,
size_t initial_pitch_period,
float ComputePitchGainThreshold(int candidate_pitch_period,
int pitch_period_ratio,
int initial_pitch_period,
float initial_pitch_gain,
size_t prev_pitch_period,
size_t prev_pitch_gain);
int prev_pitch_period,
float prev_pitch_gain);
// Computes the sum of squared samples for every sliding frame in the pitch
// buffer. |yy_values| indexes are lags.
@ -99,7 +99,7 @@ size_t RefinePitchPeriod48kHz(
// refined pitch estimation data at 48 kHz.
PitchInfo CheckLowerPitchPeriodsAndComputePitchGain(
rtc::ArrayView<const float, kBufSize24kHz> pitch_buf,
size_t initial_pitch_period_48kHz,
int initial_pitch_period_48kHz,
PitchInfo prev_pitch_48kHz);
} // namespace rnn_vad

View File

@ -24,8 +24,9 @@ namespace rnn_vad {
namespace test {
namespace {
constexpr std::array<size_t, 2> kTestPitchPeriods = {
3 * kMinPitch48kHz / 2, (3 * kMinPitch48kHz + kMaxPitch48kHz) / 2,
constexpr std::array<int, 2> kTestPitchPeriods = {
3 * kMinPitch48kHz / 2,
(3 * kMinPitch48kHz + kMaxPitch48kHz) / 2,
};
constexpr std::array<float, 2> kTestPitchGains = {0.35f, 0.75f};
@ -197,14 +198,14 @@ TEST(RnnVadTest, RefinePitchPeriod48kHzBitExactness) {
class CheckLowerPitchPeriodsAndComputePitchGainTest
: public testing::Test,
public ::testing::WithParamInterface<
std::tuple<size_t, size_t, float, size_t, float>> {};
std::tuple<int, int, float, int, float>> {};
TEST_P(CheckLowerPitchPeriodsAndComputePitchGainTest, BitExactness) {
const auto params = GetParam();
const size_t initial_pitch_period = std::get<0>(params);
const size_t prev_pitch_period = std::get<1>(params);
const int initial_pitch_period = std::get<0>(params);
const int prev_pitch_period = std::get<1>(params);
const float prev_pitch_gain = std::get<2>(params);
const size_t expected_pitch_period = std::get<3>(params);
const int expected_pitch_period = std::get<3>(params);
const float expected_pitch_gain = std::get<4>(params);
TestData test_data;
{

View File

@ -39,7 +39,7 @@ TEST(RnnVadTest, PitchSearchBitExactness) {
lp_residual_reader.first->ReadValue(&expected_pitch_period);
lp_residual_reader.first->ReadValue(&expected_pitch_gain);
PitchInfo pitch_info = pitch_estimator.Estimate(lp_residual);
EXPECT_EQ(static_cast<size_t>(expected_pitch_period), pitch_info.period);
EXPECT_EQ(static_cast<int>(expected_pitch_period), pitch_info.period);
EXPECT_NEAR(expected_pitch_gain, pitch_info.gain, 1e-5f);
}
}