diff --git a/modules/audio_processing/agc2/rnn_vad/pitch_info.h b/modules/audio_processing/agc2/rnn_vad/pitch_info.h index f0998d1fad..c9fdd182b0 100644 --- a/modules/audio_processing/agc2/rnn_vad/pitch_info.h +++ b/modules/audio_processing/agc2/rnn_vad/pitch_info.h @@ -18,8 +18,8 @@ namespace rnn_vad { // strength of the pitch (the higher, the stronger). struct PitchInfo { PitchInfo() : period(0), gain(0.f) {} - PitchInfo(size_t p, float g) : period(p), gain(g) {} - size_t period; + PitchInfo(int p, float g) : period(p), gain(g) {} + int period; float gain; }; diff --git a/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.cc b/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.cc index 32ee8c00df..7c17dfb0bc 100644 --- a/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.cc +++ b/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.cc @@ -128,12 +128,12 @@ size_t PitchPseudoInterpolationInvLagAutoCorr( // sn = mex({n * i for i in S} | {1}) // S = S | {Fraction(1, n), Fraction(sn, n)} // print(sn, end=', ') -constexpr std::array kSubHarmonicMultipliers = { +constexpr std::array kSubHarmonicMultipliers = { {3, 2, 3, 2, 5, 2, 3, 2, 3, 2, 5, 2, 3, 2}}; // Initial pitch period candidate thresholds for ComputePitchGainThreshold() for // a sample rate of 24 kHz. Computed as [5*k*k for k in range(16)]. -constexpr std::array kInitialPitchPeriodThresholds = { +constexpr std::array kInitialPitchPeriodThresholds = { {20, 45, 80, 125, 180, 245, 320, 405, 500, 605, 720, 845, 980, 1125}}; } // namespace @@ -147,31 +147,34 @@ void Decimate2x(rtc::ArrayView src, } } -float ComputePitchGainThreshold(size_t candidate_pitch_period, - size_t pitch_period_ratio, - size_t initial_pitch_period, +float ComputePitchGainThreshold(int candidate_pitch_period, + int pitch_period_ratio, + int initial_pitch_period, float initial_pitch_gain, - size_t prev_pitch_period, - size_t prev_pitch_gain) { + int prev_pitch_period, + float prev_pitch_gain) { // Map arguments to more compact aliases. - const size_t& t1 = candidate_pitch_period; - const size_t& k = pitch_period_ratio; - const size_t& t0 = initial_pitch_period; + const int& t1 = candidate_pitch_period; + const int& k = pitch_period_ratio; + const int& t0 = initial_pitch_period; const float& g0 = initial_pitch_gain; - const size_t& t_prev = prev_pitch_period; - const size_t& g_prev = prev_pitch_gain; + const int& t_prev = prev_pitch_period; + const float& g_prev = prev_pitch_gain; // Validate input. + RTC_DCHECK_GE(t1, 0); RTC_DCHECK_GE(k, 2); + RTC_DCHECK_GE(t0, 0); + RTC_DCHECK_GE(t_prev, 0); // Compute a term that lowers the threshold when |t1| is close to the last // estimated period |t_prev| - i.e., pitch tracking. float lower_threshold_term = 0; - if (abs(static_cast(t1) - static_cast(t_prev)) <= 1) { + if (abs(t1 - t_prev) <= 1) { // The candidate pitch period is within 1 sample from the previous one. // Make the candidate at |t1| very easy to be accepted. lower_threshold_term = g_prev; - } else if (abs(static_cast(t1) - static_cast(t_prev)) == 2 && + } else if (abs(t1 - t_prev) == 2 && t0 > kInitialPitchPeriodThresholds[k - 2]) { // The candidate pitch period is 2 samples far from the previous one and the // period |t0| (from which |t1| has been derived) is greater than a @@ -182,9 +185,11 @@ float ComputePitchGainThreshold(size_t candidate_pitch_period, // reduce the chance of false positives caused by a bias towards high // frequencies (originating from short-term correlations). float threshold = std::max(0.3f, 0.7f * g0 - lower_threshold_term); - if (t1 < 3 * kMinPitch24kHz) { // High frequency. + if (static_cast(t1) < 3 * kMinPitch24kHz) { + // High frequency. threshold = std::max(0.4f, 0.85f * g0 - lower_threshold_term); - } else if (t1 < 2 * kMinPitch24kHz) { // Even higher frequency. + } else if (static_cast(t1) < 2 * kMinPitch24kHz) { + // Even higher frequency. threshold = std::max(0.5f, 0.9f * g0 - lower_threshold_term); } return threshold; @@ -350,16 +355,16 @@ size_t RefinePitchPeriod48kHz( PitchInfo CheckLowerPitchPeriodsAndComputePitchGain( rtc::ArrayView pitch_buf, - size_t initial_pitch_period_48kHz, + int initial_pitch_period_48kHz, PitchInfo prev_pitch_48kHz) { RTC_DCHECK_LE(kMinPitch48kHz, initial_pitch_period_48kHz); RTC_DCHECK_LE(initial_pitch_period_48kHz, kMaxPitch48kHz); // Stores information for a refined pitch candidate. struct RefinedPitchCandidate { RefinedPitchCandidate() {} - RefinedPitchCandidate(size_t period_24kHz, float gain, float xy, float yy) + RefinedPitchCandidate(int period_24kHz, float gain, float xy, float yy) : period_24kHz(period_24kHz), gain(gain), xy(xy), yy(yy) {} - size_t period_24kHz; + int period_24kHz; // Pitch strength information. float gain; // Additional pitch strength information used for the final estimation of @@ -380,8 +385,8 @@ PitchInfo CheckLowerPitchPeriodsAndComputePitchGain( }; // Initial pitch candidate gain. RefinedPitchCandidate best_pitch; - best_pitch.period_24kHz = - std::min(initial_pitch_period_48kHz / 2, kMaxPitch24kHz - 1); + best_pitch.period_24kHz = std::min(initial_pitch_period_48kHz / 2, + static_cast(kMaxPitch24kHz - 1)); best_pitch.xy = ComputeAutoCorrelationCoeff( pitch_buf, GetInvertedLag(best_pitch.period_24kHz), kMaxPitch24kHz); best_pitch.yy = yy_values[best_pitch.period_24kHz]; @@ -392,24 +397,27 @@ PitchInfo CheckLowerPitchPeriodsAndComputePitchGain( const float initial_pitch_gain = best_pitch.gain; // Given the initial pitch estimation, check lower periods (i.e., harmonics). - const auto alternative_period = [](size_t period, size_t k, - size_t n) -> size_t { - RTC_DCHECK_LT(0, k); + const auto alternative_period = [](int period, int k, int n) -> int { + RTC_DCHECK_GT(k, 0); return (2 * n * period + k) / (2 * k); // Same as round(n*period/k). }; - for (size_t k = 2; k < kSubHarmonicMultipliers.size() + 2; ++k) { - size_t candidate_pitch_period = - alternative_period(initial_pitch_period, k, 1); - if (candidate_pitch_period < kMinPitch24kHz) + for (int k = 2; k < static_cast(kSubHarmonicMultipliers.size() + 2); + ++k) { + int candidate_pitch_period = alternative_period(initial_pitch_period, k, 1); + if (static_cast(candidate_pitch_period) < kMinPitch24kHz) { break; + } // When looking at |candidate_pitch_period|, we also look at one of its // sub-harmonics. |kSubHarmonicMultipliers| is used to know where to look. // |k| == 2 is a special case since |candidate_pitch_secondary_period| might // be greater than the maximum pitch period. - size_t candidate_pitch_secondary_period = alternative_period( + int candidate_pitch_secondary_period = alternative_period( initial_pitch_period, k, kSubHarmonicMultipliers[k - 2]); - if (k == 2 && candidate_pitch_secondary_period > kMaxPitch24kHz) + RTC_DCHECK_GT(candidate_pitch_secondary_period, 0); + if (k == 2 && + candidate_pitch_secondary_period > static_cast(kMaxPitch24kHz)) { candidate_pitch_secondary_period = initial_pitch_period; + } RTC_DCHECK_NE(candidate_pitch_period, candidate_pitch_secondary_period) << "The lower pitch period and the additional sub-harmonic must not " << "coincide."; @@ -442,7 +450,7 @@ PitchInfo CheckLowerPitchPeriodsAndComputePitchGain( ? 1.f : best_pitch.xy / (best_pitch.yy + 1.f); final_pitch_gain = std::min(best_pitch.gain, final_pitch_gain); - size_t final_pitch_period_48kHz = std::max( + int final_pitch_period_48kHz = std::max( kMinPitch48kHz, PitchPseudoInterpolationLagPitchBuf(best_pitch.period_24kHz, pitch_buf)); diff --git a/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h b/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h index bb747bb03e..aabf713fce 100644 --- a/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h +++ b/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h @@ -41,12 +41,12 @@ void Decimate2x(rtc::ArrayView src, // Computes a gain threshold for a candidate pitch period given the initial and // the previous pitch period and gain estimates and the pitch period ratio used // to derive the candidate pitch period from the initial period. -float ComputePitchGainThreshold(size_t candidate_pitch_period, - size_t pitch_period_ratio, - size_t initial_pitch_period, +float ComputePitchGainThreshold(int candidate_pitch_period, + int pitch_period_ratio, + int initial_pitch_period, float initial_pitch_gain, - size_t prev_pitch_period, - size_t prev_pitch_gain); + int prev_pitch_period, + float prev_pitch_gain); // Computes the sum of squared samples for every sliding frame in the pitch // buffer. |yy_values| indexes are lags. @@ -99,7 +99,7 @@ size_t RefinePitchPeriod48kHz( // refined pitch estimation data at 48 kHz. PitchInfo CheckLowerPitchPeriodsAndComputePitchGain( rtc::ArrayView pitch_buf, - size_t initial_pitch_period_48kHz, + int initial_pitch_period_48kHz, PitchInfo prev_pitch_48kHz); } // namespace rnn_vad diff --git a/modules/audio_processing/agc2/rnn_vad/pitch_search_internal_unittest.cc b/modules/audio_processing/agc2/rnn_vad/pitch_search_internal_unittest.cc index 82b48101ff..033ea3e77f 100644 --- a/modules/audio_processing/agc2/rnn_vad/pitch_search_internal_unittest.cc +++ b/modules/audio_processing/agc2/rnn_vad/pitch_search_internal_unittest.cc @@ -24,8 +24,9 @@ namespace rnn_vad { namespace test { namespace { -constexpr std::array kTestPitchPeriods = { - 3 * kMinPitch48kHz / 2, (3 * kMinPitch48kHz + kMaxPitch48kHz) / 2, +constexpr std::array kTestPitchPeriods = { + 3 * kMinPitch48kHz / 2, + (3 * kMinPitch48kHz + kMaxPitch48kHz) / 2, }; constexpr std::array kTestPitchGains = {0.35f, 0.75f}; @@ -197,14 +198,14 @@ TEST(RnnVadTest, RefinePitchPeriod48kHzBitExactness) { class CheckLowerPitchPeriodsAndComputePitchGainTest : public testing::Test, public ::testing::WithParamInterface< - std::tuple> {}; + std::tuple> {}; TEST_P(CheckLowerPitchPeriodsAndComputePitchGainTest, BitExactness) { const auto params = GetParam(); - const size_t initial_pitch_period = std::get<0>(params); - const size_t prev_pitch_period = std::get<1>(params); + const int initial_pitch_period = std::get<0>(params); + const int prev_pitch_period = std::get<1>(params); const float prev_pitch_gain = std::get<2>(params); - const size_t expected_pitch_period = std::get<3>(params); + const int expected_pitch_period = std::get<3>(params); const float expected_pitch_gain = std::get<4>(params); TestData test_data; { diff --git a/modules/audio_processing/agc2/rnn_vad/pitch_search_unittest.cc b/modules/audio_processing/agc2/rnn_vad/pitch_search_unittest.cc index 4c56238dd9..eac332edbf 100644 --- a/modules/audio_processing/agc2/rnn_vad/pitch_search_unittest.cc +++ b/modules/audio_processing/agc2/rnn_vad/pitch_search_unittest.cc @@ -39,7 +39,7 @@ TEST(RnnVadTest, PitchSearchBitExactness) { lp_residual_reader.first->ReadValue(&expected_pitch_period); lp_residual_reader.first->ReadValue(&expected_pitch_gain); PitchInfo pitch_info = pitch_estimator.Estimate(lp_residual); - EXPECT_EQ(static_cast(expected_pitch_period), pitch_info.period); + EXPECT_EQ(static_cast(expected_pitch_period), pitch_info.period); EXPECT_NEAR(expected_pitch_gain, pitch_info.gain, 1e-5f); } }