diff --git a/modules/audio_processing/agc2/rnn_vad/BUILD.gn b/modules/audio_processing/agc2/rnn_vad/BUILD.gn index 7379d4195b..237c80972d 100644 --- a/modules/audio_processing/agc2/rnn_vad/BUILD.gn +++ b/modules/audio_processing/agc2/rnn_vad/BUILD.gn @@ -11,6 +11,8 @@ import("../../../../webrtc.gni") rtc_source_set("rnn_vad") { visibility = [ "../*" ] sources = [ + "auto_correlation.cc", + "auto_correlation.h", "common.h", "features_extraction.cc", "features_extraction.h", @@ -37,9 +39,9 @@ rtc_source_set("rnn_vad") { "..:biquad_filter", "../../../../api:array_view", "../../../../api:function_view", - "../../../../common_audio/", "../../../../rtc_base:checks", "../../../../rtc_base:rtc_base_approved", + "../../utility:pffft_wrapper", "//third_party/rnnoise:kiss_fft", "//third_party/rnnoise:rnn_vad", ] @@ -53,6 +55,7 @@ if (rtc_include_tests) { "test_utils.h", ] deps = [ + ":rnn_vad", "../../../../api:array_view", "../../../../api:scoped_refptr", "../../../../rtc_base:checks", @@ -86,6 +89,7 @@ if (rtc_include_tests) { rtc_source_set("unittests") { testonly = true sources = [ + "auto_correlation_unittest.cc", "features_extraction_unittest.cc", "fft_util_unittest.cc", "lp_residual_unittest.cc", diff --git a/modules/audio_processing/agc2/rnn_vad/auto_correlation.cc b/modules/audio_processing/agc2/rnn_vad/auto_correlation.cc new file mode 100644 index 0000000000..d932c78063 --- /dev/null +++ b/modules/audio_processing/agc2/rnn_vad/auto_correlation.cc @@ -0,0 +1,92 @@ +/* + * Copyright (c) 2019 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/agc2/rnn_vad/auto_correlation.h" + +#include + +#include "rtc_base/checks.h" + +namespace webrtc { +namespace rnn_vad { +namespace { + +constexpr int kAutoCorrelationFftOrder = 9; // Length-512 FFT. +static_assert(1 << kAutoCorrelationFftOrder > + kNumInvertedLags12kHz + kBufSize12kHz - kMaxPitch12kHz, + ""); + +} // namespace + +AutoCorrelationCalculator::AutoCorrelationCalculator() + : fft_(1 << kAutoCorrelationFftOrder, Pffft::FftType::kReal), + tmp_(fft_.CreateBuffer()), + X_(fft_.CreateBuffer()), + H_(fft_.CreateBuffer()) {} + +AutoCorrelationCalculator::~AutoCorrelationCalculator() = default; + +// The auto-correlations coefficients are computed as follows: +// |.........|...........| <- pitch buffer +// [ x (fixed) ] +// [ y_0 ] +// [ y_{m-1} ] +// x and y are sub-array of equal length; x is never moved, whereas y slides. +// The cross-correlation between y_0 and x corresponds to the auto-correlation +// for the maximum pitch period. Hence, the first value in |auto_corr| has an +// inverted lag equal to 0 that corresponds to a lag equal to the maximum +// pitch period. +void AutoCorrelationCalculator::ComputeOnPitchBuffer( + rtc::ArrayView pitch_buf, + rtc::ArrayView auto_corr) { + RTC_DCHECK_LT(auto_corr.size(), kMaxPitch12kHz); + RTC_DCHECK_GT(pitch_buf.size(), kMaxPitch12kHz); + constexpr size_t kFftFrameSize = 1 << kAutoCorrelationFftOrder; + constexpr size_t kConvolutionLength = kBufSize12kHz - kMaxPitch12kHz; + static_assert(kConvolutionLength == kFrameSize20ms12kHz, + "Mismatch between pitch buffer size, frame size and maximum " + "pitch period."); + static_assert(kFftFrameSize > kNumInvertedLags12kHz + kConvolutionLength, + "The FFT length is not sufficiently big to avoid cyclic " + "convolution errors."); + auto tmp = tmp_->GetView(); + + // Compute the FFT for the reversed reference frame - i.e., + // pitch_buf[-kConvolutionLength:]. + std::reverse_copy(pitch_buf.end() - kConvolutionLength, pitch_buf.end(), + tmp.begin()); + std::fill(tmp.begin() + kConvolutionLength, tmp.end(), 0.f); + fft_.ForwardTransform(*tmp_, H_.get(), /*ordered=*/false); + + // Compute the FFT for the sliding frames chunk. The sliding frames are + // defined as pitch_buf[i:i+kConvolutionLength] where i in + // [0, kNumInvertedLags12kHz). The chunk includes all of them, hence it is + // defined as pitch_buf[:kNumInvertedLags12kHz+kConvolutionLength]. + std::copy(pitch_buf.begin(), + pitch_buf.begin() + kConvolutionLength + kNumInvertedLags12kHz, + tmp.begin()); + std::fill(tmp.begin() + kNumInvertedLags12kHz + kConvolutionLength, tmp.end(), + 0.f); + fft_.ForwardTransform(*tmp_, X_.get(), /*ordered=*/false); + + // Convolve in the frequency domain. + constexpr float kScalingFactor = 1.f / static_cast(kFftFrameSize); + std::fill(tmp.begin(), tmp.end(), 0.f); + fft_.FrequencyDomainConvolve(*X_, *H_, tmp_.get(), kScalingFactor); + fft_.BackwardTransform(*tmp_, tmp_.get(), /*ordered=*/false); + + // Extract the auto-correlation coefficients. + std::copy(tmp.begin() + kConvolutionLength - 1, + tmp.begin() + kConvolutionLength + kNumInvertedLags12kHz - 1, + auto_corr.begin()); +} + +} // namespace rnn_vad +} // namespace webrtc diff --git a/modules/audio_processing/agc2/rnn_vad/auto_correlation.h b/modules/audio_processing/agc2/rnn_vad/auto_correlation.h new file mode 100644 index 0000000000..de7f453bc7 --- /dev/null +++ b/modules/audio_processing/agc2/rnn_vad/auto_correlation.h @@ -0,0 +1,49 @@ +/* + * Copyright (c) 2019 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_AUTO_CORRELATION_H_ +#define MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_AUTO_CORRELATION_H_ + +#include + +#include "api/array_view.h" +#include "modules/audio_processing/agc2/rnn_vad/common.h" +#include "modules/audio_processing/utility/pffft_wrapper.h" + +namespace webrtc { +namespace rnn_vad { + +// Class to compute the auto correlation on the pitch buffer for a target pitch +// interval. +class AutoCorrelationCalculator { + public: + AutoCorrelationCalculator(); + AutoCorrelationCalculator(const AutoCorrelationCalculator&) = delete; + AutoCorrelationCalculator& operator=(const AutoCorrelationCalculator&) = + delete; + ~AutoCorrelationCalculator(); + + // Computes the auto-correlation coefficients for a target pitch interval. + // |auto_corr| indexes are inverted lags. + void ComputeOnPitchBuffer( + rtc::ArrayView pitch_buf, + rtc::ArrayView auto_corr); + + private: + Pffft fft_; + std::unique_ptr tmp_; + std::unique_ptr X_; + std::unique_ptr H_; +}; + +} // namespace rnn_vad +} // namespace webrtc + +#endif // MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_AUTO_CORRELATION_H_ diff --git a/modules/audio_processing/agc2/rnn_vad/auto_correlation_unittest.cc b/modules/audio_processing/agc2/rnn_vad/auto_correlation_unittest.cc new file mode 100644 index 0000000000..a5e456a4de --- /dev/null +++ b/modules/audio_processing/agc2/rnn_vad/auto_correlation_unittest.cc @@ -0,0 +1,62 @@ +/* + * Copyright (c) 2019 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "modules/audio_processing/agc2/rnn_vad/auto_correlation.h" + +#include "modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h" +#include "modules/audio_processing/agc2/rnn_vad/test_utils.h" +#include "test/gtest.h" + +namespace webrtc { +namespace rnn_vad { +namespace test { + +TEST(RnnVadTest, PitchBufferAutoCorrelationWithinTolerance) { + PitchTestData test_data; + std::array pitch_buf_decimated; + Decimate2x(test_data.GetPitchBufView(), pitch_buf_decimated); + std::array computed_output; + { + // TODO(bugs.webrtc.org/8948): Add when the issue is fixed. + // FloatingPointExceptionObserver fpe_observer; + AutoCorrelationCalculator auto_corr_calculator; + auto_corr_calculator.ComputeOnPitchBuffer(pitch_buf_decimated, + computed_output); + } + auto auto_corr_view = test_data.GetPitchBufAutoCorrCoeffsView(); + ExpectNearAbsolute({auto_corr_view.data(), auto_corr_view.size()}, + computed_output, 3e-3f); +} + +// Check that the auto correlation function computes the right thing for a +// simple use case. +TEST(RnnVadTest, CheckAutoCorrelationOnConstantPitchBuffer) { + // Create constant signal with no pitch. + std::array pitch_buf_decimated; + std::fill(pitch_buf_decimated.begin(), pitch_buf_decimated.end(), 1.f); + std::array computed_output; + { + // TODO(bugs.webrtc.org/8948): Add when the issue is fixed. + // FloatingPointExceptionObserver fpe_observer; + AutoCorrelationCalculator auto_corr_calculator; + auto_corr_calculator.ComputeOnPitchBuffer(pitch_buf_decimated, + computed_output); + } + // The expected output is constantly the length of the fixed 'x' + // array in ComputePitchAutoCorrelation. + std::array expected_output; + std::fill(expected_output.begin(), expected_output.end(), + kBufSize12kHz - kMaxPitch12kHz); + ExpectNearAbsolute(expected_output, computed_output, 4e-5f); +} + +} // namespace test +} // namespace rnn_vad +} // namespace webrtc diff --git a/modules/audio_processing/agc2/rnn_vad/common.h b/modules/audio_processing/agc2/rnn_vad/common.h index b98438db68..2f16cd41e9 100644 --- a/modules/audio_processing/agc2/rnn_vad/common.h +++ b/modules/audio_processing/agc2/rnn_vad/common.h @@ -20,18 +20,21 @@ constexpr size_t kSampleRate24kHz = 24000; constexpr size_t kFrameSize10ms24kHz = kSampleRate24kHz / 100; constexpr size_t kFrameSize20ms24kHz = kFrameSize10ms24kHz * 2; -// Pitch analysis params. +// Pitch buffer. constexpr size_t kMinPitch24kHz = kSampleRate24kHz / 800; // 0.00125 s. constexpr size_t kMaxPitch24kHz = kSampleRate24kHz / 62.5; // 0.016 s. constexpr size_t kBufSize24kHz = kMaxPitch24kHz + kFrameSize20ms24kHz; static_assert((kBufSize24kHz & 1) == 0, "The buffer size must be even."); +// 24 kHz analysis. // Define a higher minimum pitch period for the initial search. This is used to // avoid searching for very short periods, for which a refinement step is // responsible. constexpr size_t kInitialMinPitch24kHz = 3 * kMinPitch24kHz; static_assert(kMinPitch24kHz < kInitialMinPitch24kHz, ""); static_assert(kInitialMinPitch24kHz < kMaxPitch24kHz, ""); +static_assert(kMaxPitch24kHz > kInitialMinPitch24kHz, ""); +constexpr size_t kNumInvertedLags24kHz = kMaxPitch24kHz - kInitialMinPitch24kHz; // 12 kHz analysis. constexpr size_t kSampleRate12kHz = 12000; @@ -40,6 +43,10 @@ constexpr size_t kFrameSize20ms12kHz = kFrameSize10ms12kHz * 2; constexpr size_t kBufSize12kHz = kBufSize24kHz / 2; constexpr size_t kInitialMinPitch12kHz = kInitialMinPitch24kHz / 2; constexpr size_t kMaxPitch12kHz = kMaxPitch24kHz / 2; +static_assert(kMaxPitch12kHz > kInitialMinPitch12kHz, ""); +// The inverted lags for the pitch interval [|kInitialMinPitch12kHz|, +// |kMaxPitch12kHz|] are in the range [0, |kNumInvertedLags12kHz|]. +constexpr size_t kNumInvertedLags12kHz = kMaxPitch12kHz - kInitialMinPitch12kHz; // 48 kHz constants. constexpr size_t kMinPitch48kHz = kMinPitch24kHz * 2; diff --git a/modules/audio_processing/agc2/rnn_vad/pitch_search.cc b/modules/audio_processing/agc2/rnn_vad/pitch_search.cc index aa0b751d28..1b3b459c5f 100644 --- a/modules/audio_processing/agc2/rnn_vad/pitch_search.cc +++ b/modules/audio_processing/agc2/rnn_vad/pitch_search.cc @@ -19,8 +19,7 @@ namespace webrtc { namespace rnn_vad { PitchEstimator::PitchEstimator() - : fft_(RealFourier::Create(kAutoCorrelationFftOrder)), - pitch_buf_decimated_(kBufSize12kHz), + : pitch_buf_decimated_(kBufSize12kHz), pitch_buf_decimated_view_(pitch_buf_decimated_.data(), kBufSize12kHz), auto_corr_(kNumInvertedLags12kHz), auto_corr_view_(auto_corr_.data(), kNumInvertedLags12kHz) { @@ -34,20 +33,16 @@ PitchInfo PitchEstimator::Estimate( rtc::ArrayView pitch_buf) { // Perform the initial pitch search at 12 kHz. Decimate2x(pitch_buf, pitch_buf_decimated_view_); - // Compute auto-correlation terms. - ComputePitchAutoCorrelation(pitch_buf_decimated_view_, kMaxPitch12kHz, - auto_corr_view_, fft_.get()); - - // Search for pitch at 12 kHz. + auto_corr_calculator_.ComputeOnPitchBuffer(pitch_buf_decimated_view_, + auto_corr_view_); std::array pitch_candidates_inv_lags = FindBestPitchPeriods( auto_corr_view_, pitch_buf_decimated_view_, kMaxPitch12kHz); - // Refine the pitch period estimation. // The refinement is done using the pitch buffer that contains 24 kHz samples. // Therefore, adapt the inverted lags in |pitch_candidates_inv_lags| from 12 // to 24 kHz. - for (size_t i = 0; i < pitch_candidates_inv_lags.size(); ++i) - pitch_candidates_inv_lags[i] *= 2; + pitch_candidates_inv_lags[0] *= 2; + pitch_candidates_inv_lags[1] *= 2; size_t pitch_inv_lag_48kHz = RefinePitchPeriod48kHz(pitch_buf, pitch_candidates_inv_lags); // Look for stronger harmonics to find the final pitch period and its gain. diff --git a/modules/audio_processing/agc2/rnn_vad/pitch_search.h b/modules/audio_processing/agc2/rnn_vad/pitch_search.h index 59145353c1..74133d0738 100644 --- a/modules/audio_processing/agc2/rnn_vad/pitch_search.h +++ b/modules/audio_processing/agc2/rnn_vad/pitch_search.h @@ -15,7 +15,7 @@ #include #include "api/array_view.h" -#include "common_audio/real_fourier.h" +#include "modules/audio_processing/agc2/rnn_vad/auto_correlation.h" #include "modules/audio_processing/agc2/rnn_vad/common.h" #include "modules/audio_processing/agc2/rnn_vad/pitch_info.h" #include "modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h" @@ -36,7 +36,7 @@ class PitchEstimator { private: PitchInfo last_pitch_48kHz_; - std::unique_ptr fft_; + AutoCorrelationCalculator auto_corr_calculator_; std::vector pitch_buf_decimated_; rtc::ArrayView pitch_buf_decimated_view_; std::vector auto_corr_; diff --git a/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.cc b/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.cc index 7c17dfb0bc..0561c3715f 100644 --- a/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.cc +++ b/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.cc @@ -13,7 +13,6 @@ #include #include #include -#include #include #include @@ -213,64 +212,6 @@ void ComputeSlidingFrameSquareEnergies( } } -void ComputePitchAutoCorrelation( - rtc::ArrayView pitch_buf, - size_t max_pitch_period, - rtc::ArrayView auto_corr, - webrtc::RealFourier* fft) { - RTC_DCHECK_GT(max_pitch_period, auto_corr.size()); - RTC_DCHECK_LT(max_pitch_period, pitch_buf.size()); - RTC_DCHECK(fft); - - constexpr size_t time_domain_fft_length = 1 << kAutoCorrelationFftOrder; - constexpr size_t freq_domain_fft_length = time_domain_fft_length / 2 + 1; - - RTC_DCHECK_EQ(RealFourier::FftLength(fft->order()), time_domain_fft_length); - RTC_DCHECK_EQ(RealFourier::ComplexLength(fft->order()), - freq_domain_fft_length); - - // Cross-correlation of y_i=pitch_buf[i:i+convolution_length] and - // x=pitch_buf[-convolution_length:] is equivalent to convolution of - // y_i and reversed(x). New notation: h=reversed(x), x=y. - std::array h{}; - std::array x{}; - - const size_t convolution_length = kBufSize12kHz - max_pitch_period; - // Check that the FFT-length is big enough to avoid cyclic - // convolution errors. - RTC_DCHECK_GT(time_domain_fft_length, - kNumInvertedLags12kHz + convolution_length); - - // h[0:convolution_length] is reversed pitch_buf[-convolution_length:]. - std::reverse_copy(pitch_buf.end() - convolution_length, pitch_buf.end(), - h.begin()); - - // x is pitch_buf[:kNumInvertedLags12kHz + convolution_length]. - std::copy(pitch_buf.begin(), - pitch_buf.begin() + kNumInvertedLags12kHz + convolution_length, - x.begin()); - - // Shift to frequency domain. - std::array, freq_domain_fft_length> X{}; - std::array, freq_domain_fft_length> H{}; - fft->Forward(&x[0], &X[0]); - fft->Forward(&h[0], &H[0]); - - // Convolve in frequency domain. - for (size_t i = 0; i < X.size(); ++i) { - X[i] *= H[i]; - } - - // Shift back to time domain. - std::array x_conv_h; - fft->Inverse(&X[0], &x_conv_h[0]); - - // Collect the result. - std::copy(x_conv_h.begin() + convolution_length - 1, - x_conv_h.begin() + convolution_length + kNumInvertedLags12kHz - 1, - auto_corr.begin()); -} - std::array FindBestPitchPeriods( rtc::ArrayView auto_corr, rtc::ArrayView pitch_buf, diff --git a/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h b/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h index aabf713fce..6ccd165010 100644 --- a/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h +++ b/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h @@ -15,25 +15,12 @@ #include #include "api/array_view.h" -#include "common_audio/real_fourier.h" #include "modules/audio_processing/agc2/rnn_vad/common.h" #include "modules/audio_processing/agc2/rnn_vad/pitch_info.h" namespace webrtc { namespace rnn_vad { -// The inverted lags for the pitch interval [|kInitialMinPitch12kHz|, -// |kMaxPitch12kHz|] are in the range [0, |kNumInvertedLags|]. -static_assert(kMaxPitch12kHz > kInitialMinPitch12kHz, ""); -static_assert(kMaxPitch24kHz > kInitialMinPitch24kHz, ""); -constexpr size_t kNumInvertedLags12kHz = kMaxPitch12kHz - kInitialMinPitch12kHz; -constexpr size_t kNumInvertedLags24kHz = kMaxPitch24kHz - kInitialMinPitch24kHz; -constexpr int kAutoCorrelationFftOrder = 9; // Length-512 FFT. - -static_assert(1 << kAutoCorrelationFftOrder > - kNumInvertedLags12kHz + kBufSize12kHz - kMaxPitch12kHz, - ""); - // Performs 2x decimation without any anti-aliasing filter. void Decimate2x(rtc::ArrayView src, rtc::ArrayView dst); @@ -61,25 +48,6 @@ void ComputeSlidingFrameSquareEnergies( rtc::ArrayView pitch_buf, rtc::ArrayView yy_values); -// Computes the auto-correlation coefficients for a given pitch interval. -// |auto_corr| indexes are inverted lags. -// -// The auto-correlations coefficients are computed as follows: -// |.........|...........| <- pitch buffer -// [ x (fixed) ] -// [ y_0 ] -// [ y_{m-1} ] -// x and y are sub-array of equal length; x is never moved, whereas y slides. -// The cross-correlation between y_0 and x corresponds to the auto-correlation -// for the maximum pitch period. Hence, the first value in |auto_corr| has an -// inverted lag equal to 0 that corresponds to a lag equal to the maximum pitch -// period. -void ComputePitchAutoCorrelation( - rtc::ArrayView pitch_buf, - size_t max_pitch_period, - rtc::ArrayView auto_corr, - webrtc::RealFourier* fft); - // Given the auto-correlation coefficients stored according to // ComputePitchAutoCorrelation() (i.e., using inverted lags), returns the best // and the second best pitch periods. diff --git a/modules/audio_processing/agc2/rnn_vad/pitch_search_internal_unittest.cc b/modules/audio_processing/agc2/rnn_vad/pitch_search_internal_unittest.cc index 8ff6ac1c12..bd2ea24961 100644 --- a/modules/audio_processing/agc2/rnn_vad/pitch_search_internal_unittest.cc +++ b/modules/audio_processing/agc2/rnn_vad/pitch_search_internal_unittest.cc @@ -9,7 +9,6 @@ */ #include "modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h" -#include "common_audio/real_fourier.h" #include #include @@ -30,34 +29,6 @@ constexpr std::array kTestPitchPeriods = { }; constexpr std::array kTestPitchGains = {0.35f, 0.75f}; -constexpr size_t kNumPitchBufSquareEnergies = 385; -constexpr size_t kNumPitchBufAutoCorrCoeffs = 147; -constexpr size_t kTestDataSize = - kBufSize24kHz + kNumPitchBufSquareEnergies + kNumPitchBufAutoCorrCoeffs; - -class TestData { - public: - TestData() { - auto test_data_reader = CreatePitchSearchTestDataReader(); - test_data_reader->ReadChunk(test_data_); - } - rtc::ArrayView GetPitchBufView() { - return {test_data_.data(), kBufSize24kHz}; - } - rtc::ArrayView - GetPitchBufSquareEnergiesView() { - return {test_data_.data() + kBufSize24kHz, kNumPitchBufSquareEnergies}; - } - rtc::ArrayView - GetPitchBufAutoCorrCoeffsView() { - return {test_data_.data() + kBufSize24kHz + kNumPitchBufSquareEnergies, - kNumPitchBufAutoCorrCoeffs}; - } - - private: - std::array test_data_; -}; - } // namespace class ComputePitchGainThresholdTest @@ -107,7 +78,7 @@ INSTANTIATE_TEST_SUITE_P( std::make_tuple(78, 2, 156, 0.72750503f, 153, 0.85069299f, 0.618379f))); TEST(RnnVadTest, ComputeSlidingFrameSquareEnergiesBitExactness) { - TestData test_data; + PitchTestData test_data; std::array computed_output; { // TODO(bugs.webrtc.org/8948): Add when the issue is fixed. @@ -120,51 +91,8 @@ TEST(RnnVadTest, ComputeSlidingFrameSquareEnergiesBitExactness) { computed_output, 3e-2f); } -TEST(RnnVadTest, ComputePitchAutoCorrelationBitExactness) { - TestData test_data; - std::array pitch_buf_decimated; - Decimate2x(test_data.GetPitchBufView(), pitch_buf_decimated); - std::array computed_output; - { - // TODO(bugs.webrtc.org/8948): Add when the issue is fixed. - // FloatingPointExceptionObserver fpe_observer; - std::unique_ptr fft = - RealFourier::Create(kAutoCorrelationFftOrder); - ComputePitchAutoCorrelation(pitch_buf_decimated, kMaxPitch12kHz, - computed_output, fft.get()); - } - auto auto_corr_view = test_data.GetPitchBufAutoCorrCoeffsView(); - ExpectNearAbsolute({auto_corr_view.data(), auto_corr_view.size()}, - computed_output, 3e-3f); -} - -// Check that the auto correlation function computes the right thing for a -// simple use case. -TEST(RnnVadTest, ComputePitchAutoCorrelationConstantBuffer) { - // Create constant signal with no pitch. - std::array pitch_buf_decimated; - std::fill(pitch_buf_decimated.begin(), pitch_buf_decimated.end(), 1.f); - - std::array computed_output; - { - // TODO(bugs.webrtc.org/8948): Add when the issue is fixed. - // FloatingPointExceptionObserver fpe_observer; - std::unique_ptr fft = - RealFourier::Create(kAutoCorrelationFftOrder); - ComputePitchAutoCorrelation(pitch_buf_decimated, kMaxPitch12kHz, - computed_output, fft.get()); - } - - // The expected output is constantly the length of the fixed 'x' - // array in ComputePitchAutoCorrelation. - std::array expected_output; - std::fill(expected_output.begin(), expected_output.end(), - kBufSize12kHz - kMaxPitch12kHz); - ExpectNearAbsolute(expected_output, computed_output, 4e-5f); -} - TEST(RnnVadTest, FindBestPitchPeriodsBitExactness) { - TestData test_data; + PitchTestData test_data; std::array pitch_buf_decimated; Decimate2x(test_data.GetPitchBufView(), pitch_buf_decimated); std::array pitch_candidates_inv_lags; @@ -181,7 +109,7 @@ TEST(RnnVadTest, FindBestPitchPeriodsBitExactness) { } TEST(RnnVadTest, RefinePitchPeriod48kHzBitExactness) { - TestData test_data; + PitchTestData test_data; std::array pitch_buf_decimated; Decimate2x(test_data.GetPitchBufView(), pitch_buf_decimated); size_t pitch_inv_lag; @@ -207,7 +135,7 @@ TEST_P(CheckLowerPitchPeriodsAndComputePitchGainTest, BitExactness) { const float prev_pitch_gain = std::get<2>(params); const int expected_pitch_period = std::get<3>(params); const float expected_pitch_gain = std::get<4>(params); - TestData test_data; + PitchTestData test_data; { // TODO(bugs.webrtc.org/8948): Add when the issue is fixed. // FloatingPointExceptionObserver fpe_observer; diff --git a/modules/audio_processing/agc2/rnn_vad/test_utils.cc b/modules/audio_processing/agc2/rnn_vad/test_utils.cc index 8decbd0ebb..4dae8cdb3a 100644 --- a/modules/audio_processing/agc2/rnn_vad/test_utils.cc +++ b/modules/audio_processing/agc2/rnn_vad/test_utils.cc @@ -111,6 +111,28 @@ ReaderPairType CreateVadProbsReader() { return {std::move(ptr), ptr->data_length()}; } +PitchTestData::PitchTestData() { + auto test_data_reader = CreatePitchSearchTestDataReader(); + test_data_reader->ReadChunk(test_data_); +} + +PitchTestData::~PitchTestData() = default; + +rtc::ArrayView PitchTestData::GetPitchBufView() { + return {test_data_.data(), kBufSize24kHz}; +} + +rtc::ArrayView +PitchTestData::GetPitchBufSquareEnergiesView() { + return {test_data_.data() + kBufSize24kHz, kNumPitchBufSquareEnergies}; +} + +rtc::ArrayView +PitchTestData::GetPitchBufAutoCorrCoeffsView() { + return {test_data_.data() + kBufSize24kHz + kNumPitchBufSquareEnergies, + kNumPitchBufAutoCorrCoeffs}; +} + } // namespace test } // namespace rnn_vad } // namespace webrtc diff --git a/modules/audio_processing/agc2/rnn_vad/test_utils.h b/modules/audio_processing/agc2/rnn_vad/test_utils.h index 15be85ac00..f9d7376d43 100644 --- a/modules/audio_processing/agc2/rnn_vad/test_utils.h +++ b/modules/audio_processing/agc2/rnn_vad/test_utils.h @@ -12,6 +12,7 @@ #define MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_TEST_UTILS_H_ #include +#include #include #include #include @@ -20,6 +21,7 @@ #include #include "api/array_view.h" +#include "modules/audio_processing/agc2/rnn_vad/common.h" #include "rtc_base/checks.h" namespace webrtc { @@ -118,6 +120,27 @@ CreateSilenceFlagsFeatureMatrixReader(); std::pair>, const size_t> CreateVadProbsReader(); +constexpr size_t kNumPitchBufAutoCorrCoeffs = 147; +constexpr size_t kNumPitchBufSquareEnergies = 385; +constexpr size_t kPitchTestDataSize = + kBufSize24kHz + kNumPitchBufSquareEnergies + kNumPitchBufAutoCorrCoeffs; + +// Class to retrieve a test pitch buffer content and the expected output for the +// analysis steps. +class PitchTestData { + public: + PitchTestData(); + ~PitchTestData(); + rtc::ArrayView GetPitchBufView(); + rtc::ArrayView + GetPitchBufSquareEnergiesView(); + rtc::ArrayView + GetPitchBufAutoCorrCoeffsView(); + + private: + std::array test_data_; +}; + } // namespace test } // namespace rnn_vad } // namespace webrtc