RNN VAD: fix pitch gain type and change pitch period type
The pitch gain type in ComputePitchGainThreshold() is wrong (size_t instead of float). The pitch period is an unsigned integer type, but it is safer to switch to a signed type and add checks on the sign. Bug: webrtc:9076 Change-Id: If69d182071edab9750a320f0fbfac24aa8052ee0 Reviewed-on: https://webrtc-review.googlesource.com/c/117302 Reviewed-by: Alex Loiko <aleloi@webrtc.org> Commit-Queue: Alessio Bazzica <alessiob@webrtc.org> Cr-Commit-Position: refs/heads/master@{#26259}
This commit is contained in:

committed by
Commit Bot

parent
49ea47b90e
commit
c25fa89e9e
@ -18,8 +18,8 @@ namespace rnn_vad {
|
|||||||
// strength of the pitch (the higher, the stronger).
|
// strength of the pitch (the higher, the stronger).
|
||||||
struct PitchInfo {
|
struct PitchInfo {
|
||||||
PitchInfo() : period(0), gain(0.f) {}
|
PitchInfo() : period(0), gain(0.f) {}
|
||||||
PitchInfo(size_t p, float g) : period(p), gain(g) {}
|
PitchInfo(int p, float g) : period(p), gain(g) {}
|
||||||
size_t period;
|
int period;
|
||||||
float gain;
|
float gain;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -128,12 +128,12 @@ size_t PitchPseudoInterpolationInvLagAutoCorr(
|
|||||||
// sn = mex({n * i for i in S} | {1})
|
// sn = mex({n * i for i in S} | {1})
|
||||||
// S = S | {Fraction(1, n), Fraction(sn, n)}
|
// S = S | {Fraction(1, n), Fraction(sn, n)}
|
||||||
// print(sn, end=', ')
|
// print(sn, end=', ')
|
||||||
constexpr std::array<size_t, 14> kSubHarmonicMultipliers = {
|
constexpr std::array<int, 14> kSubHarmonicMultipliers = {
|
||||||
{3, 2, 3, 2, 5, 2, 3, 2, 3, 2, 5, 2, 3, 2}};
|
{3, 2, 3, 2, 5, 2, 3, 2, 3, 2, 5, 2, 3, 2}};
|
||||||
|
|
||||||
// Initial pitch period candidate thresholds for ComputePitchGainThreshold() for
|
// Initial pitch period candidate thresholds for ComputePitchGainThreshold() for
|
||||||
// a sample rate of 24 kHz. Computed as [5*k*k for k in range(16)].
|
// a sample rate of 24 kHz. Computed as [5*k*k for k in range(16)].
|
||||||
constexpr std::array<size_t, 14> kInitialPitchPeriodThresholds = {
|
constexpr std::array<int, 14> kInitialPitchPeriodThresholds = {
|
||||||
{20, 45, 80, 125, 180, 245, 320, 405, 500, 605, 720, 845, 980, 1125}};
|
{20, 45, 80, 125, 180, 245, 320, 405, 500, 605, 720, 845, 980, 1125}};
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
@ -147,31 +147,34 @@ void Decimate2x(rtc::ArrayView<const float, kBufSize24kHz> src,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
float ComputePitchGainThreshold(size_t candidate_pitch_period,
|
float ComputePitchGainThreshold(int candidate_pitch_period,
|
||||||
size_t pitch_period_ratio,
|
int pitch_period_ratio,
|
||||||
size_t initial_pitch_period,
|
int initial_pitch_period,
|
||||||
float initial_pitch_gain,
|
float initial_pitch_gain,
|
||||||
size_t prev_pitch_period,
|
int prev_pitch_period,
|
||||||
size_t prev_pitch_gain) {
|
float prev_pitch_gain) {
|
||||||
// Map arguments to more compact aliases.
|
// Map arguments to more compact aliases.
|
||||||
const size_t& t1 = candidate_pitch_period;
|
const int& t1 = candidate_pitch_period;
|
||||||
const size_t& k = pitch_period_ratio;
|
const int& k = pitch_period_ratio;
|
||||||
const size_t& t0 = initial_pitch_period;
|
const int& t0 = initial_pitch_period;
|
||||||
const float& g0 = initial_pitch_gain;
|
const float& g0 = initial_pitch_gain;
|
||||||
const size_t& t_prev = prev_pitch_period;
|
const int& t_prev = prev_pitch_period;
|
||||||
const size_t& g_prev = prev_pitch_gain;
|
const float& g_prev = prev_pitch_gain;
|
||||||
|
|
||||||
// Validate input.
|
// Validate input.
|
||||||
|
RTC_DCHECK_GE(t1, 0);
|
||||||
RTC_DCHECK_GE(k, 2);
|
RTC_DCHECK_GE(k, 2);
|
||||||
|
RTC_DCHECK_GE(t0, 0);
|
||||||
|
RTC_DCHECK_GE(t_prev, 0);
|
||||||
|
|
||||||
// Compute a term that lowers the threshold when |t1| is close to the last
|
// Compute a term that lowers the threshold when |t1| is close to the last
|
||||||
// estimated period |t_prev| - i.e., pitch tracking.
|
// estimated period |t_prev| - i.e., pitch tracking.
|
||||||
float lower_threshold_term = 0;
|
float lower_threshold_term = 0;
|
||||||
if (abs(static_cast<int>(t1) - static_cast<int>(t_prev)) <= 1) {
|
if (abs(t1 - t_prev) <= 1) {
|
||||||
// The candidate pitch period is within 1 sample from the previous one.
|
// The candidate pitch period is within 1 sample from the previous one.
|
||||||
// Make the candidate at |t1| very easy to be accepted.
|
// Make the candidate at |t1| very easy to be accepted.
|
||||||
lower_threshold_term = g_prev;
|
lower_threshold_term = g_prev;
|
||||||
} else if (abs(static_cast<int>(t1) - static_cast<int>(t_prev)) == 2 &&
|
} else if (abs(t1 - t_prev) == 2 &&
|
||||||
t0 > kInitialPitchPeriodThresholds[k - 2]) {
|
t0 > kInitialPitchPeriodThresholds[k - 2]) {
|
||||||
// The candidate pitch period is 2 samples far from the previous one and the
|
// The candidate pitch period is 2 samples far from the previous one and the
|
||||||
// period |t0| (from which |t1| has been derived) is greater than a
|
// period |t0| (from which |t1| has been derived) is greater than a
|
||||||
@ -182,9 +185,11 @@ float ComputePitchGainThreshold(size_t candidate_pitch_period,
|
|||||||
// reduce the chance of false positives caused by a bias towards high
|
// reduce the chance of false positives caused by a bias towards high
|
||||||
// frequencies (originating from short-term correlations).
|
// frequencies (originating from short-term correlations).
|
||||||
float threshold = std::max(0.3f, 0.7f * g0 - lower_threshold_term);
|
float threshold = std::max(0.3f, 0.7f * g0 - lower_threshold_term);
|
||||||
if (t1 < 3 * kMinPitch24kHz) { // High frequency.
|
if (static_cast<size_t>(t1) < 3 * kMinPitch24kHz) {
|
||||||
|
// High frequency.
|
||||||
threshold = std::max(0.4f, 0.85f * g0 - lower_threshold_term);
|
threshold = std::max(0.4f, 0.85f * g0 - lower_threshold_term);
|
||||||
} else if (t1 < 2 * kMinPitch24kHz) { // Even higher frequency.
|
} else if (static_cast<size_t>(t1) < 2 * kMinPitch24kHz) {
|
||||||
|
// Even higher frequency.
|
||||||
threshold = std::max(0.5f, 0.9f * g0 - lower_threshold_term);
|
threshold = std::max(0.5f, 0.9f * g0 - lower_threshold_term);
|
||||||
}
|
}
|
||||||
return threshold;
|
return threshold;
|
||||||
@ -350,16 +355,16 @@ size_t RefinePitchPeriod48kHz(
|
|||||||
|
|
||||||
PitchInfo CheckLowerPitchPeriodsAndComputePitchGain(
|
PitchInfo CheckLowerPitchPeriodsAndComputePitchGain(
|
||||||
rtc::ArrayView<const float, kBufSize24kHz> pitch_buf,
|
rtc::ArrayView<const float, kBufSize24kHz> pitch_buf,
|
||||||
size_t initial_pitch_period_48kHz,
|
int initial_pitch_period_48kHz,
|
||||||
PitchInfo prev_pitch_48kHz) {
|
PitchInfo prev_pitch_48kHz) {
|
||||||
RTC_DCHECK_LE(kMinPitch48kHz, initial_pitch_period_48kHz);
|
RTC_DCHECK_LE(kMinPitch48kHz, initial_pitch_period_48kHz);
|
||||||
RTC_DCHECK_LE(initial_pitch_period_48kHz, kMaxPitch48kHz);
|
RTC_DCHECK_LE(initial_pitch_period_48kHz, kMaxPitch48kHz);
|
||||||
// Stores information for a refined pitch candidate.
|
// Stores information for a refined pitch candidate.
|
||||||
struct RefinedPitchCandidate {
|
struct RefinedPitchCandidate {
|
||||||
RefinedPitchCandidate() {}
|
RefinedPitchCandidate() {}
|
||||||
RefinedPitchCandidate(size_t period_24kHz, float gain, float xy, float yy)
|
RefinedPitchCandidate(int period_24kHz, float gain, float xy, float yy)
|
||||||
: period_24kHz(period_24kHz), gain(gain), xy(xy), yy(yy) {}
|
: period_24kHz(period_24kHz), gain(gain), xy(xy), yy(yy) {}
|
||||||
size_t period_24kHz;
|
int period_24kHz;
|
||||||
// Pitch strength information.
|
// Pitch strength information.
|
||||||
float gain;
|
float gain;
|
||||||
// Additional pitch strength information used for the final estimation of
|
// Additional pitch strength information used for the final estimation of
|
||||||
@ -380,8 +385,8 @@ PitchInfo CheckLowerPitchPeriodsAndComputePitchGain(
|
|||||||
};
|
};
|
||||||
// Initial pitch candidate gain.
|
// Initial pitch candidate gain.
|
||||||
RefinedPitchCandidate best_pitch;
|
RefinedPitchCandidate best_pitch;
|
||||||
best_pitch.period_24kHz =
|
best_pitch.period_24kHz = std::min(initial_pitch_period_48kHz / 2,
|
||||||
std::min(initial_pitch_period_48kHz / 2, kMaxPitch24kHz - 1);
|
static_cast<int>(kMaxPitch24kHz - 1));
|
||||||
best_pitch.xy = ComputeAutoCorrelationCoeff(
|
best_pitch.xy = ComputeAutoCorrelationCoeff(
|
||||||
pitch_buf, GetInvertedLag(best_pitch.period_24kHz), kMaxPitch24kHz);
|
pitch_buf, GetInvertedLag(best_pitch.period_24kHz), kMaxPitch24kHz);
|
||||||
best_pitch.yy = yy_values[best_pitch.period_24kHz];
|
best_pitch.yy = yy_values[best_pitch.period_24kHz];
|
||||||
@ -392,24 +397,27 @@ PitchInfo CheckLowerPitchPeriodsAndComputePitchGain(
|
|||||||
const float initial_pitch_gain = best_pitch.gain;
|
const float initial_pitch_gain = best_pitch.gain;
|
||||||
|
|
||||||
// Given the initial pitch estimation, check lower periods (i.e., harmonics).
|
// Given the initial pitch estimation, check lower periods (i.e., harmonics).
|
||||||
const auto alternative_period = [](size_t period, size_t k,
|
const auto alternative_period = [](int period, int k, int n) -> int {
|
||||||
size_t n) -> size_t {
|
RTC_DCHECK_GT(k, 0);
|
||||||
RTC_DCHECK_LT(0, k);
|
|
||||||
return (2 * n * period + k) / (2 * k); // Same as round(n*period/k).
|
return (2 * n * period + k) / (2 * k); // Same as round(n*period/k).
|
||||||
};
|
};
|
||||||
for (size_t k = 2; k < kSubHarmonicMultipliers.size() + 2; ++k) {
|
for (int k = 2; k < static_cast<int>(kSubHarmonicMultipliers.size() + 2);
|
||||||
size_t candidate_pitch_period =
|
++k) {
|
||||||
alternative_period(initial_pitch_period, k, 1);
|
int candidate_pitch_period = alternative_period(initial_pitch_period, k, 1);
|
||||||
if (candidate_pitch_period < kMinPitch24kHz)
|
if (static_cast<size_t>(candidate_pitch_period) < kMinPitch24kHz) {
|
||||||
break;
|
break;
|
||||||
|
}
|
||||||
// When looking at |candidate_pitch_period|, we also look at one of its
|
// When looking at |candidate_pitch_period|, we also look at one of its
|
||||||
// sub-harmonics. |kSubHarmonicMultipliers| is used to know where to look.
|
// sub-harmonics. |kSubHarmonicMultipliers| is used to know where to look.
|
||||||
// |k| == 2 is a special case since |candidate_pitch_secondary_period| might
|
// |k| == 2 is a special case since |candidate_pitch_secondary_period| might
|
||||||
// be greater than the maximum pitch period.
|
// be greater than the maximum pitch period.
|
||||||
size_t candidate_pitch_secondary_period = alternative_period(
|
int candidate_pitch_secondary_period = alternative_period(
|
||||||
initial_pitch_period, k, kSubHarmonicMultipliers[k - 2]);
|
initial_pitch_period, k, kSubHarmonicMultipliers[k - 2]);
|
||||||
if (k == 2 && candidate_pitch_secondary_period > kMaxPitch24kHz)
|
RTC_DCHECK_GT(candidate_pitch_secondary_period, 0);
|
||||||
|
if (k == 2 &&
|
||||||
|
candidate_pitch_secondary_period > static_cast<int>(kMaxPitch24kHz)) {
|
||||||
candidate_pitch_secondary_period = initial_pitch_period;
|
candidate_pitch_secondary_period = initial_pitch_period;
|
||||||
|
}
|
||||||
RTC_DCHECK_NE(candidate_pitch_period, candidate_pitch_secondary_period)
|
RTC_DCHECK_NE(candidate_pitch_period, candidate_pitch_secondary_period)
|
||||||
<< "The lower pitch period and the additional sub-harmonic must not "
|
<< "The lower pitch period and the additional sub-harmonic must not "
|
||||||
<< "coincide.";
|
<< "coincide.";
|
||||||
@ -442,7 +450,7 @@ PitchInfo CheckLowerPitchPeriodsAndComputePitchGain(
|
|||||||
? 1.f
|
? 1.f
|
||||||
: best_pitch.xy / (best_pitch.yy + 1.f);
|
: best_pitch.xy / (best_pitch.yy + 1.f);
|
||||||
final_pitch_gain = std::min(best_pitch.gain, final_pitch_gain);
|
final_pitch_gain = std::min(best_pitch.gain, final_pitch_gain);
|
||||||
size_t final_pitch_period_48kHz = std::max(
|
int final_pitch_period_48kHz = std::max(
|
||||||
kMinPitch48kHz,
|
kMinPitch48kHz,
|
||||||
PitchPseudoInterpolationLagPitchBuf(best_pitch.period_24kHz, pitch_buf));
|
PitchPseudoInterpolationLagPitchBuf(best_pitch.period_24kHz, pitch_buf));
|
||||||
|
|
||||||
|
@ -41,12 +41,12 @@ void Decimate2x(rtc::ArrayView<const float, kBufSize24kHz> src,
|
|||||||
// Computes a gain threshold for a candidate pitch period given the initial and
|
// Computes a gain threshold for a candidate pitch period given the initial and
|
||||||
// the previous pitch period and gain estimates and the pitch period ratio used
|
// the previous pitch period and gain estimates and the pitch period ratio used
|
||||||
// to derive the candidate pitch period from the initial period.
|
// to derive the candidate pitch period from the initial period.
|
||||||
float ComputePitchGainThreshold(size_t candidate_pitch_period,
|
float ComputePitchGainThreshold(int candidate_pitch_period,
|
||||||
size_t pitch_period_ratio,
|
int pitch_period_ratio,
|
||||||
size_t initial_pitch_period,
|
int initial_pitch_period,
|
||||||
float initial_pitch_gain,
|
float initial_pitch_gain,
|
||||||
size_t prev_pitch_period,
|
int prev_pitch_period,
|
||||||
size_t prev_pitch_gain);
|
float prev_pitch_gain);
|
||||||
|
|
||||||
// Computes the sum of squared samples for every sliding frame in the pitch
|
// Computes the sum of squared samples for every sliding frame in the pitch
|
||||||
// buffer. |yy_values| indexes are lags.
|
// buffer. |yy_values| indexes are lags.
|
||||||
@ -99,7 +99,7 @@ size_t RefinePitchPeriod48kHz(
|
|||||||
// refined pitch estimation data at 48 kHz.
|
// refined pitch estimation data at 48 kHz.
|
||||||
PitchInfo CheckLowerPitchPeriodsAndComputePitchGain(
|
PitchInfo CheckLowerPitchPeriodsAndComputePitchGain(
|
||||||
rtc::ArrayView<const float, kBufSize24kHz> pitch_buf,
|
rtc::ArrayView<const float, kBufSize24kHz> pitch_buf,
|
||||||
size_t initial_pitch_period_48kHz,
|
int initial_pitch_period_48kHz,
|
||||||
PitchInfo prev_pitch_48kHz);
|
PitchInfo prev_pitch_48kHz);
|
||||||
|
|
||||||
} // namespace rnn_vad
|
} // namespace rnn_vad
|
||||||
|
@ -24,8 +24,9 @@ namespace rnn_vad {
|
|||||||
namespace test {
|
namespace test {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
constexpr std::array<size_t, 2> kTestPitchPeriods = {
|
constexpr std::array<int, 2> kTestPitchPeriods = {
|
||||||
3 * kMinPitch48kHz / 2, (3 * kMinPitch48kHz + kMaxPitch48kHz) / 2,
|
3 * kMinPitch48kHz / 2,
|
||||||
|
(3 * kMinPitch48kHz + kMaxPitch48kHz) / 2,
|
||||||
};
|
};
|
||||||
constexpr std::array<float, 2> kTestPitchGains = {0.35f, 0.75f};
|
constexpr std::array<float, 2> kTestPitchGains = {0.35f, 0.75f};
|
||||||
|
|
||||||
@ -197,14 +198,14 @@ TEST(RnnVadTest, RefinePitchPeriod48kHzBitExactness) {
|
|||||||
class CheckLowerPitchPeriodsAndComputePitchGainTest
|
class CheckLowerPitchPeriodsAndComputePitchGainTest
|
||||||
: public testing::Test,
|
: public testing::Test,
|
||||||
public ::testing::WithParamInterface<
|
public ::testing::WithParamInterface<
|
||||||
std::tuple<size_t, size_t, float, size_t, float>> {};
|
std::tuple<int, int, float, int, float>> {};
|
||||||
|
|
||||||
TEST_P(CheckLowerPitchPeriodsAndComputePitchGainTest, BitExactness) {
|
TEST_P(CheckLowerPitchPeriodsAndComputePitchGainTest, BitExactness) {
|
||||||
const auto params = GetParam();
|
const auto params = GetParam();
|
||||||
const size_t initial_pitch_period = std::get<0>(params);
|
const int initial_pitch_period = std::get<0>(params);
|
||||||
const size_t prev_pitch_period = std::get<1>(params);
|
const int prev_pitch_period = std::get<1>(params);
|
||||||
const float prev_pitch_gain = std::get<2>(params);
|
const float prev_pitch_gain = std::get<2>(params);
|
||||||
const size_t expected_pitch_period = std::get<3>(params);
|
const int expected_pitch_period = std::get<3>(params);
|
||||||
const float expected_pitch_gain = std::get<4>(params);
|
const float expected_pitch_gain = std::get<4>(params);
|
||||||
TestData test_data;
|
TestData test_data;
|
||||||
{
|
{
|
||||||
|
@ -39,7 +39,7 @@ TEST(RnnVadTest, PitchSearchBitExactness) {
|
|||||||
lp_residual_reader.first->ReadValue(&expected_pitch_period);
|
lp_residual_reader.first->ReadValue(&expected_pitch_period);
|
||||||
lp_residual_reader.first->ReadValue(&expected_pitch_gain);
|
lp_residual_reader.first->ReadValue(&expected_pitch_gain);
|
||||||
PitchInfo pitch_info = pitch_estimator.Estimate(lp_residual);
|
PitchInfo pitch_info = pitch_estimator.Estimate(lp_residual);
|
||||||
EXPECT_EQ(static_cast<size_t>(expected_pitch_period), pitch_info.period);
|
EXPECT_EQ(static_cast<int>(expected_pitch_period), pitch_info.period);
|
||||||
EXPECT_NEAR(expected_pitch_gain, pitch_info.gain, 1e-5f);
|
EXPECT_NEAR(expected_pitch_gain, pitch_info.gain, 1e-5f);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user