diff --git a/modules/audio_processing/agc2/rnn_vad/BUILD.gn b/modules/audio_processing/agc2/rnn_vad/BUILD.gn index 237c80972d..7379d4195b 100644 --- a/modules/audio_processing/agc2/rnn_vad/BUILD.gn +++ b/modules/audio_processing/agc2/rnn_vad/BUILD.gn @@ -11,8 +11,6 @@ import("../../../../webrtc.gni") rtc_source_set("rnn_vad") { visibility = [ "../*" ] sources = [ - "auto_correlation.cc", - "auto_correlation.h", "common.h", "features_extraction.cc", "features_extraction.h", @@ -39,9 +37,9 @@ rtc_source_set("rnn_vad") { "..:biquad_filter", "../../../../api:array_view", "../../../../api:function_view", + "../../../../common_audio/", "../../../../rtc_base:checks", "../../../../rtc_base:rtc_base_approved", - "../../utility:pffft_wrapper", "//third_party/rnnoise:kiss_fft", "//third_party/rnnoise:rnn_vad", ] @@ -55,7 +53,6 @@ if (rtc_include_tests) { "test_utils.h", ] deps = [ - ":rnn_vad", "../../../../api:array_view", "../../../../api:scoped_refptr", "../../../../rtc_base:checks", @@ -89,7 +86,6 @@ if (rtc_include_tests) { rtc_source_set("unittests") { testonly = true sources = [ - "auto_correlation_unittest.cc", "features_extraction_unittest.cc", "fft_util_unittest.cc", "lp_residual_unittest.cc", diff --git a/modules/audio_processing/agc2/rnn_vad/auto_correlation.cc b/modules/audio_processing/agc2/rnn_vad/auto_correlation.cc deleted file mode 100644 index d932c78063..0000000000 --- a/modules/audio_processing/agc2/rnn_vad/auto_correlation.cc +++ /dev/null @@ -1,92 +0,0 @@ -/* - * Copyright (c) 2019 The WebRTC project authors. All Rights Reserved. - * - * Use of this source code is governed by a BSD-style license - * that can be found in the LICENSE file in the root of the source - * tree. An additional intellectual property rights grant can be found - * in the file PATENTS. All contributing project authors may - * be found in the AUTHORS file in the root of the source tree. - */ - -#include "modules/audio_processing/agc2/rnn_vad/auto_correlation.h" - -#include - -#include "rtc_base/checks.h" - -namespace webrtc { -namespace rnn_vad { -namespace { - -constexpr int kAutoCorrelationFftOrder = 9; // Length-512 FFT. -static_assert(1 << kAutoCorrelationFftOrder > - kNumInvertedLags12kHz + kBufSize12kHz - kMaxPitch12kHz, - ""); - -} // namespace - -AutoCorrelationCalculator::AutoCorrelationCalculator() - : fft_(1 << kAutoCorrelationFftOrder, Pffft::FftType::kReal), - tmp_(fft_.CreateBuffer()), - X_(fft_.CreateBuffer()), - H_(fft_.CreateBuffer()) {} - -AutoCorrelationCalculator::~AutoCorrelationCalculator() = default; - -// The auto-correlations coefficients are computed as follows: -// |.........|...........| <- pitch buffer -// [ x (fixed) ] -// [ y_0 ] -// [ y_{m-1} ] -// x and y are sub-array of equal length; x is never moved, whereas y slides. -// The cross-correlation between y_0 and x corresponds to the auto-correlation -// for the maximum pitch period. Hence, the first value in |auto_corr| has an -// inverted lag equal to 0 that corresponds to a lag equal to the maximum -// pitch period. -void AutoCorrelationCalculator::ComputeOnPitchBuffer( - rtc::ArrayView pitch_buf, - rtc::ArrayView auto_corr) { - RTC_DCHECK_LT(auto_corr.size(), kMaxPitch12kHz); - RTC_DCHECK_GT(pitch_buf.size(), kMaxPitch12kHz); - constexpr size_t kFftFrameSize = 1 << kAutoCorrelationFftOrder; - constexpr size_t kConvolutionLength = kBufSize12kHz - kMaxPitch12kHz; - static_assert(kConvolutionLength == kFrameSize20ms12kHz, - "Mismatch between pitch buffer size, frame size and maximum " - "pitch period."); - static_assert(kFftFrameSize > kNumInvertedLags12kHz + kConvolutionLength, - "The FFT length is not sufficiently big to avoid cyclic " - "convolution errors."); - auto tmp = tmp_->GetView(); - - // Compute the FFT for the reversed reference frame - i.e., - // pitch_buf[-kConvolutionLength:]. - std::reverse_copy(pitch_buf.end() - kConvolutionLength, pitch_buf.end(), - tmp.begin()); - std::fill(tmp.begin() + kConvolutionLength, tmp.end(), 0.f); - fft_.ForwardTransform(*tmp_, H_.get(), /*ordered=*/false); - - // Compute the FFT for the sliding frames chunk. The sliding frames are - // defined as pitch_buf[i:i+kConvolutionLength] where i in - // [0, kNumInvertedLags12kHz). The chunk includes all of them, hence it is - // defined as pitch_buf[:kNumInvertedLags12kHz+kConvolutionLength]. - std::copy(pitch_buf.begin(), - pitch_buf.begin() + kConvolutionLength + kNumInvertedLags12kHz, - tmp.begin()); - std::fill(tmp.begin() + kNumInvertedLags12kHz + kConvolutionLength, tmp.end(), - 0.f); - fft_.ForwardTransform(*tmp_, X_.get(), /*ordered=*/false); - - // Convolve in the frequency domain. - constexpr float kScalingFactor = 1.f / static_cast(kFftFrameSize); - std::fill(tmp.begin(), tmp.end(), 0.f); - fft_.FrequencyDomainConvolve(*X_, *H_, tmp_.get(), kScalingFactor); - fft_.BackwardTransform(*tmp_, tmp_.get(), /*ordered=*/false); - - // Extract the auto-correlation coefficients. - std::copy(tmp.begin() + kConvolutionLength - 1, - tmp.begin() + kConvolutionLength + kNumInvertedLags12kHz - 1, - auto_corr.begin()); -} - -} // namespace rnn_vad -} // namespace webrtc diff --git a/modules/audio_processing/agc2/rnn_vad/auto_correlation.h b/modules/audio_processing/agc2/rnn_vad/auto_correlation.h deleted file mode 100644 index de7f453bc7..0000000000 --- a/modules/audio_processing/agc2/rnn_vad/auto_correlation.h +++ /dev/null @@ -1,49 +0,0 @@ -/* - * Copyright (c) 2019 The WebRTC project authors. All Rights Reserved. - * - * Use of this source code is governed by a BSD-style license - * that can be found in the LICENSE file in the root of the source - * tree. An additional intellectual property rights grant can be found - * in the file PATENTS. All contributing project authors may - * be found in the AUTHORS file in the root of the source tree. - */ - -#ifndef MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_AUTO_CORRELATION_H_ -#define MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_AUTO_CORRELATION_H_ - -#include - -#include "api/array_view.h" -#include "modules/audio_processing/agc2/rnn_vad/common.h" -#include "modules/audio_processing/utility/pffft_wrapper.h" - -namespace webrtc { -namespace rnn_vad { - -// Class to compute the auto correlation on the pitch buffer for a target pitch -// interval. -class AutoCorrelationCalculator { - public: - AutoCorrelationCalculator(); - AutoCorrelationCalculator(const AutoCorrelationCalculator&) = delete; - AutoCorrelationCalculator& operator=(const AutoCorrelationCalculator&) = - delete; - ~AutoCorrelationCalculator(); - - // Computes the auto-correlation coefficients for a target pitch interval. - // |auto_corr| indexes are inverted lags. - void ComputeOnPitchBuffer( - rtc::ArrayView pitch_buf, - rtc::ArrayView auto_corr); - - private: - Pffft fft_; - std::unique_ptr tmp_; - std::unique_ptr X_; - std::unique_ptr H_; -}; - -} // namespace rnn_vad -} // namespace webrtc - -#endif // MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_AUTO_CORRELATION_H_ diff --git a/modules/audio_processing/agc2/rnn_vad/auto_correlation_unittest.cc b/modules/audio_processing/agc2/rnn_vad/auto_correlation_unittest.cc deleted file mode 100644 index a5e456a4de..0000000000 --- a/modules/audio_processing/agc2/rnn_vad/auto_correlation_unittest.cc +++ /dev/null @@ -1,62 +0,0 @@ -/* - * Copyright (c) 2019 The WebRTC project authors. All Rights Reserved. - * - * Use of this source code is governed by a BSD-style license - * that can be found in the LICENSE file in the root of the source - * tree. An additional intellectual property rights grant can be found - * in the file PATENTS. All contributing project authors may - * be found in the AUTHORS file in the root of the source tree. - */ - -#include "modules/audio_processing/agc2/rnn_vad/auto_correlation.h" - -#include "modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h" -#include "modules/audio_processing/agc2/rnn_vad/test_utils.h" -#include "test/gtest.h" - -namespace webrtc { -namespace rnn_vad { -namespace test { - -TEST(RnnVadTest, PitchBufferAutoCorrelationWithinTolerance) { - PitchTestData test_data; - std::array pitch_buf_decimated; - Decimate2x(test_data.GetPitchBufView(), pitch_buf_decimated); - std::array computed_output; - { - // TODO(bugs.webrtc.org/8948): Add when the issue is fixed. - // FloatingPointExceptionObserver fpe_observer; - AutoCorrelationCalculator auto_corr_calculator; - auto_corr_calculator.ComputeOnPitchBuffer(pitch_buf_decimated, - computed_output); - } - auto auto_corr_view = test_data.GetPitchBufAutoCorrCoeffsView(); - ExpectNearAbsolute({auto_corr_view.data(), auto_corr_view.size()}, - computed_output, 3e-3f); -} - -// Check that the auto correlation function computes the right thing for a -// simple use case. -TEST(RnnVadTest, CheckAutoCorrelationOnConstantPitchBuffer) { - // Create constant signal with no pitch. - std::array pitch_buf_decimated; - std::fill(pitch_buf_decimated.begin(), pitch_buf_decimated.end(), 1.f); - std::array computed_output; - { - // TODO(bugs.webrtc.org/8948): Add when the issue is fixed. - // FloatingPointExceptionObserver fpe_observer; - AutoCorrelationCalculator auto_corr_calculator; - auto_corr_calculator.ComputeOnPitchBuffer(pitch_buf_decimated, - computed_output); - } - // The expected output is constantly the length of the fixed 'x' - // array in ComputePitchAutoCorrelation. - std::array expected_output; - std::fill(expected_output.begin(), expected_output.end(), - kBufSize12kHz - kMaxPitch12kHz); - ExpectNearAbsolute(expected_output, computed_output, 4e-5f); -} - -} // namespace test -} // namespace rnn_vad -} // namespace webrtc diff --git a/modules/audio_processing/agc2/rnn_vad/common.h b/modules/audio_processing/agc2/rnn_vad/common.h index 2f16cd41e9..b98438db68 100644 --- a/modules/audio_processing/agc2/rnn_vad/common.h +++ b/modules/audio_processing/agc2/rnn_vad/common.h @@ -20,21 +20,18 @@ constexpr size_t kSampleRate24kHz = 24000; constexpr size_t kFrameSize10ms24kHz = kSampleRate24kHz / 100; constexpr size_t kFrameSize20ms24kHz = kFrameSize10ms24kHz * 2; -// Pitch buffer. +// Pitch analysis params. constexpr size_t kMinPitch24kHz = kSampleRate24kHz / 800; // 0.00125 s. constexpr size_t kMaxPitch24kHz = kSampleRate24kHz / 62.5; // 0.016 s. constexpr size_t kBufSize24kHz = kMaxPitch24kHz + kFrameSize20ms24kHz; static_assert((kBufSize24kHz & 1) == 0, "The buffer size must be even."); -// 24 kHz analysis. // Define a higher minimum pitch period for the initial search. This is used to // avoid searching for very short periods, for which a refinement step is // responsible. constexpr size_t kInitialMinPitch24kHz = 3 * kMinPitch24kHz; static_assert(kMinPitch24kHz < kInitialMinPitch24kHz, ""); static_assert(kInitialMinPitch24kHz < kMaxPitch24kHz, ""); -static_assert(kMaxPitch24kHz > kInitialMinPitch24kHz, ""); -constexpr size_t kNumInvertedLags24kHz = kMaxPitch24kHz - kInitialMinPitch24kHz; // 12 kHz analysis. constexpr size_t kSampleRate12kHz = 12000; @@ -43,10 +40,6 @@ constexpr size_t kFrameSize20ms12kHz = kFrameSize10ms12kHz * 2; constexpr size_t kBufSize12kHz = kBufSize24kHz / 2; constexpr size_t kInitialMinPitch12kHz = kInitialMinPitch24kHz / 2; constexpr size_t kMaxPitch12kHz = kMaxPitch24kHz / 2; -static_assert(kMaxPitch12kHz > kInitialMinPitch12kHz, ""); -// The inverted lags for the pitch interval [|kInitialMinPitch12kHz|, -// |kMaxPitch12kHz|] are in the range [0, |kNumInvertedLags12kHz|]. -constexpr size_t kNumInvertedLags12kHz = kMaxPitch12kHz - kInitialMinPitch12kHz; // 48 kHz constants. constexpr size_t kMinPitch48kHz = kMinPitch24kHz * 2; diff --git a/modules/audio_processing/agc2/rnn_vad/pitch_search.cc b/modules/audio_processing/agc2/rnn_vad/pitch_search.cc index 1b3b459c5f..aa0b751d28 100644 --- a/modules/audio_processing/agc2/rnn_vad/pitch_search.cc +++ b/modules/audio_processing/agc2/rnn_vad/pitch_search.cc @@ -19,7 +19,8 @@ namespace webrtc { namespace rnn_vad { PitchEstimator::PitchEstimator() - : pitch_buf_decimated_(kBufSize12kHz), + : fft_(RealFourier::Create(kAutoCorrelationFftOrder)), + pitch_buf_decimated_(kBufSize12kHz), pitch_buf_decimated_view_(pitch_buf_decimated_.data(), kBufSize12kHz), auto_corr_(kNumInvertedLags12kHz), auto_corr_view_(auto_corr_.data(), kNumInvertedLags12kHz) { @@ -33,16 +34,20 @@ PitchInfo PitchEstimator::Estimate( rtc::ArrayView pitch_buf) { // Perform the initial pitch search at 12 kHz. Decimate2x(pitch_buf, pitch_buf_decimated_view_); - auto_corr_calculator_.ComputeOnPitchBuffer(pitch_buf_decimated_view_, - auto_corr_view_); + // Compute auto-correlation terms. + ComputePitchAutoCorrelation(pitch_buf_decimated_view_, kMaxPitch12kHz, + auto_corr_view_, fft_.get()); + + // Search for pitch at 12 kHz. std::array pitch_candidates_inv_lags = FindBestPitchPeriods( auto_corr_view_, pitch_buf_decimated_view_, kMaxPitch12kHz); + // Refine the pitch period estimation. // The refinement is done using the pitch buffer that contains 24 kHz samples. // Therefore, adapt the inverted lags in |pitch_candidates_inv_lags| from 12 // to 24 kHz. - pitch_candidates_inv_lags[0] *= 2; - pitch_candidates_inv_lags[1] *= 2; + for (size_t i = 0; i < pitch_candidates_inv_lags.size(); ++i) + pitch_candidates_inv_lags[i] *= 2; size_t pitch_inv_lag_48kHz = RefinePitchPeriod48kHz(pitch_buf, pitch_candidates_inv_lags); // Look for stronger harmonics to find the final pitch period and its gain. diff --git a/modules/audio_processing/agc2/rnn_vad/pitch_search.h b/modules/audio_processing/agc2/rnn_vad/pitch_search.h index 74133d0738..59145353c1 100644 --- a/modules/audio_processing/agc2/rnn_vad/pitch_search.h +++ b/modules/audio_processing/agc2/rnn_vad/pitch_search.h @@ -15,7 +15,7 @@ #include #include "api/array_view.h" -#include "modules/audio_processing/agc2/rnn_vad/auto_correlation.h" +#include "common_audio/real_fourier.h" #include "modules/audio_processing/agc2/rnn_vad/common.h" #include "modules/audio_processing/agc2/rnn_vad/pitch_info.h" #include "modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h" @@ -36,7 +36,7 @@ class PitchEstimator { private: PitchInfo last_pitch_48kHz_; - AutoCorrelationCalculator auto_corr_calculator_; + std::unique_ptr fft_; std::vector pitch_buf_decimated_; rtc::ArrayView pitch_buf_decimated_view_; std::vector auto_corr_; diff --git a/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.cc b/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.cc index 0561c3715f..7c17dfb0bc 100644 --- a/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.cc +++ b/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.cc @@ -13,6 +13,7 @@ #include #include #include +#include #include #include @@ -212,6 +213,64 @@ void ComputeSlidingFrameSquareEnergies( } } +void ComputePitchAutoCorrelation( + rtc::ArrayView pitch_buf, + size_t max_pitch_period, + rtc::ArrayView auto_corr, + webrtc::RealFourier* fft) { + RTC_DCHECK_GT(max_pitch_period, auto_corr.size()); + RTC_DCHECK_LT(max_pitch_period, pitch_buf.size()); + RTC_DCHECK(fft); + + constexpr size_t time_domain_fft_length = 1 << kAutoCorrelationFftOrder; + constexpr size_t freq_domain_fft_length = time_domain_fft_length / 2 + 1; + + RTC_DCHECK_EQ(RealFourier::FftLength(fft->order()), time_domain_fft_length); + RTC_DCHECK_EQ(RealFourier::ComplexLength(fft->order()), + freq_domain_fft_length); + + // Cross-correlation of y_i=pitch_buf[i:i+convolution_length] and + // x=pitch_buf[-convolution_length:] is equivalent to convolution of + // y_i and reversed(x). New notation: h=reversed(x), x=y. + std::array h{}; + std::array x{}; + + const size_t convolution_length = kBufSize12kHz - max_pitch_period; + // Check that the FFT-length is big enough to avoid cyclic + // convolution errors. + RTC_DCHECK_GT(time_domain_fft_length, + kNumInvertedLags12kHz + convolution_length); + + // h[0:convolution_length] is reversed pitch_buf[-convolution_length:]. + std::reverse_copy(pitch_buf.end() - convolution_length, pitch_buf.end(), + h.begin()); + + // x is pitch_buf[:kNumInvertedLags12kHz + convolution_length]. + std::copy(pitch_buf.begin(), + pitch_buf.begin() + kNumInvertedLags12kHz + convolution_length, + x.begin()); + + // Shift to frequency domain. + std::array, freq_domain_fft_length> X{}; + std::array, freq_domain_fft_length> H{}; + fft->Forward(&x[0], &X[0]); + fft->Forward(&h[0], &H[0]); + + // Convolve in frequency domain. + for (size_t i = 0; i < X.size(); ++i) { + X[i] *= H[i]; + } + + // Shift back to time domain. + std::array x_conv_h; + fft->Inverse(&X[0], &x_conv_h[0]); + + // Collect the result. + std::copy(x_conv_h.begin() + convolution_length - 1, + x_conv_h.begin() + convolution_length + kNumInvertedLags12kHz - 1, + auto_corr.begin()); +} + std::array FindBestPitchPeriods( rtc::ArrayView auto_corr, rtc::ArrayView pitch_buf, diff --git a/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h b/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h index 6ccd165010..aabf713fce 100644 --- a/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h +++ b/modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h @@ -15,12 +15,25 @@ #include #include "api/array_view.h" +#include "common_audio/real_fourier.h" #include "modules/audio_processing/agc2/rnn_vad/common.h" #include "modules/audio_processing/agc2/rnn_vad/pitch_info.h" namespace webrtc { namespace rnn_vad { +// The inverted lags for the pitch interval [|kInitialMinPitch12kHz|, +// |kMaxPitch12kHz|] are in the range [0, |kNumInvertedLags|]. +static_assert(kMaxPitch12kHz > kInitialMinPitch12kHz, ""); +static_assert(kMaxPitch24kHz > kInitialMinPitch24kHz, ""); +constexpr size_t kNumInvertedLags12kHz = kMaxPitch12kHz - kInitialMinPitch12kHz; +constexpr size_t kNumInvertedLags24kHz = kMaxPitch24kHz - kInitialMinPitch24kHz; +constexpr int kAutoCorrelationFftOrder = 9; // Length-512 FFT. + +static_assert(1 << kAutoCorrelationFftOrder > + kNumInvertedLags12kHz + kBufSize12kHz - kMaxPitch12kHz, + ""); + // Performs 2x decimation without any anti-aliasing filter. void Decimate2x(rtc::ArrayView src, rtc::ArrayView dst); @@ -48,6 +61,25 @@ void ComputeSlidingFrameSquareEnergies( rtc::ArrayView pitch_buf, rtc::ArrayView yy_values); +// Computes the auto-correlation coefficients for a given pitch interval. +// |auto_corr| indexes are inverted lags. +// +// The auto-correlations coefficients are computed as follows: +// |.........|...........| <- pitch buffer +// [ x (fixed) ] +// [ y_0 ] +// [ y_{m-1} ] +// x and y are sub-array of equal length; x is never moved, whereas y slides. +// The cross-correlation between y_0 and x corresponds to the auto-correlation +// for the maximum pitch period. Hence, the first value in |auto_corr| has an +// inverted lag equal to 0 that corresponds to a lag equal to the maximum pitch +// period. +void ComputePitchAutoCorrelation( + rtc::ArrayView pitch_buf, + size_t max_pitch_period, + rtc::ArrayView auto_corr, + webrtc::RealFourier* fft); + // Given the auto-correlation coefficients stored according to // ComputePitchAutoCorrelation() (i.e., using inverted lags), returns the best // and the second best pitch periods. diff --git a/modules/audio_processing/agc2/rnn_vad/pitch_search_internal_unittest.cc b/modules/audio_processing/agc2/rnn_vad/pitch_search_internal_unittest.cc index bd2ea24961..8ff6ac1c12 100644 --- a/modules/audio_processing/agc2/rnn_vad/pitch_search_internal_unittest.cc +++ b/modules/audio_processing/agc2/rnn_vad/pitch_search_internal_unittest.cc @@ -9,6 +9,7 @@ */ #include "modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h" +#include "common_audio/real_fourier.h" #include #include @@ -29,6 +30,34 @@ constexpr std::array kTestPitchPeriods = { }; constexpr std::array kTestPitchGains = {0.35f, 0.75f}; +constexpr size_t kNumPitchBufSquareEnergies = 385; +constexpr size_t kNumPitchBufAutoCorrCoeffs = 147; +constexpr size_t kTestDataSize = + kBufSize24kHz + kNumPitchBufSquareEnergies + kNumPitchBufAutoCorrCoeffs; + +class TestData { + public: + TestData() { + auto test_data_reader = CreatePitchSearchTestDataReader(); + test_data_reader->ReadChunk(test_data_); + } + rtc::ArrayView GetPitchBufView() { + return {test_data_.data(), kBufSize24kHz}; + } + rtc::ArrayView + GetPitchBufSquareEnergiesView() { + return {test_data_.data() + kBufSize24kHz, kNumPitchBufSquareEnergies}; + } + rtc::ArrayView + GetPitchBufAutoCorrCoeffsView() { + return {test_data_.data() + kBufSize24kHz + kNumPitchBufSquareEnergies, + kNumPitchBufAutoCorrCoeffs}; + } + + private: + std::array test_data_; +}; + } // namespace class ComputePitchGainThresholdTest @@ -78,7 +107,7 @@ INSTANTIATE_TEST_SUITE_P( std::make_tuple(78, 2, 156, 0.72750503f, 153, 0.85069299f, 0.618379f))); TEST(RnnVadTest, ComputeSlidingFrameSquareEnergiesBitExactness) { - PitchTestData test_data; + TestData test_data; std::array computed_output; { // TODO(bugs.webrtc.org/8948): Add when the issue is fixed. @@ -91,8 +120,51 @@ TEST(RnnVadTest, ComputeSlidingFrameSquareEnergiesBitExactness) { computed_output, 3e-2f); } +TEST(RnnVadTest, ComputePitchAutoCorrelationBitExactness) { + TestData test_data; + std::array pitch_buf_decimated; + Decimate2x(test_data.GetPitchBufView(), pitch_buf_decimated); + std::array computed_output; + { + // TODO(bugs.webrtc.org/8948): Add when the issue is fixed. + // FloatingPointExceptionObserver fpe_observer; + std::unique_ptr fft = + RealFourier::Create(kAutoCorrelationFftOrder); + ComputePitchAutoCorrelation(pitch_buf_decimated, kMaxPitch12kHz, + computed_output, fft.get()); + } + auto auto_corr_view = test_data.GetPitchBufAutoCorrCoeffsView(); + ExpectNearAbsolute({auto_corr_view.data(), auto_corr_view.size()}, + computed_output, 3e-3f); +} + +// Check that the auto correlation function computes the right thing for a +// simple use case. +TEST(RnnVadTest, ComputePitchAutoCorrelationConstantBuffer) { + // Create constant signal with no pitch. + std::array pitch_buf_decimated; + std::fill(pitch_buf_decimated.begin(), pitch_buf_decimated.end(), 1.f); + + std::array computed_output; + { + // TODO(bugs.webrtc.org/8948): Add when the issue is fixed. + // FloatingPointExceptionObserver fpe_observer; + std::unique_ptr fft = + RealFourier::Create(kAutoCorrelationFftOrder); + ComputePitchAutoCorrelation(pitch_buf_decimated, kMaxPitch12kHz, + computed_output, fft.get()); + } + + // The expected output is constantly the length of the fixed 'x' + // array in ComputePitchAutoCorrelation. + std::array expected_output; + std::fill(expected_output.begin(), expected_output.end(), + kBufSize12kHz - kMaxPitch12kHz); + ExpectNearAbsolute(expected_output, computed_output, 4e-5f); +} + TEST(RnnVadTest, FindBestPitchPeriodsBitExactness) { - PitchTestData test_data; + TestData test_data; std::array pitch_buf_decimated; Decimate2x(test_data.GetPitchBufView(), pitch_buf_decimated); std::array pitch_candidates_inv_lags; @@ -109,7 +181,7 @@ TEST(RnnVadTest, FindBestPitchPeriodsBitExactness) { } TEST(RnnVadTest, RefinePitchPeriod48kHzBitExactness) { - PitchTestData test_data; + TestData test_data; std::array pitch_buf_decimated; Decimate2x(test_data.GetPitchBufView(), pitch_buf_decimated); size_t pitch_inv_lag; @@ -135,7 +207,7 @@ TEST_P(CheckLowerPitchPeriodsAndComputePitchGainTest, BitExactness) { const float prev_pitch_gain = std::get<2>(params); const int expected_pitch_period = std::get<3>(params); const float expected_pitch_gain = std::get<4>(params); - PitchTestData test_data; + TestData test_data; { // TODO(bugs.webrtc.org/8948): Add when the issue is fixed. // FloatingPointExceptionObserver fpe_observer; diff --git a/modules/audio_processing/agc2/rnn_vad/test_utils.cc b/modules/audio_processing/agc2/rnn_vad/test_utils.cc index 4dae8cdb3a..8decbd0ebb 100644 --- a/modules/audio_processing/agc2/rnn_vad/test_utils.cc +++ b/modules/audio_processing/agc2/rnn_vad/test_utils.cc @@ -111,28 +111,6 @@ ReaderPairType CreateVadProbsReader() { return {std::move(ptr), ptr->data_length()}; } -PitchTestData::PitchTestData() { - auto test_data_reader = CreatePitchSearchTestDataReader(); - test_data_reader->ReadChunk(test_data_); -} - -PitchTestData::~PitchTestData() = default; - -rtc::ArrayView PitchTestData::GetPitchBufView() { - return {test_data_.data(), kBufSize24kHz}; -} - -rtc::ArrayView -PitchTestData::GetPitchBufSquareEnergiesView() { - return {test_data_.data() + kBufSize24kHz, kNumPitchBufSquareEnergies}; -} - -rtc::ArrayView -PitchTestData::GetPitchBufAutoCorrCoeffsView() { - return {test_data_.data() + kBufSize24kHz + kNumPitchBufSquareEnergies, - kNumPitchBufAutoCorrCoeffs}; -} - } // namespace test } // namespace rnn_vad } // namespace webrtc diff --git a/modules/audio_processing/agc2/rnn_vad/test_utils.h b/modules/audio_processing/agc2/rnn_vad/test_utils.h index f9d7376d43..15be85ac00 100644 --- a/modules/audio_processing/agc2/rnn_vad/test_utils.h +++ b/modules/audio_processing/agc2/rnn_vad/test_utils.h @@ -12,7 +12,6 @@ #define MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_TEST_UTILS_H_ #include -#include #include #include #include @@ -21,7 +20,6 @@ #include #include "api/array_view.h" -#include "modules/audio_processing/agc2/rnn_vad/common.h" #include "rtc_base/checks.h" namespace webrtc { @@ -120,27 +118,6 @@ CreateSilenceFlagsFeatureMatrixReader(); std::pair>, const size_t> CreateVadProbsReader(); -constexpr size_t kNumPitchBufAutoCorrCoeffs = 147; -constexpr size_t kNumPitchBufSquareEnergies = 385; -constexpr size_t kPitchTestDataSize = - kBufSize24kHz + kNumPitchBufSquareEnergies + kNumPitchBufAutoCorrCoeffs; - -// Class to retrieve a test pitch buffer content and the expected output for the -// analysis steps. -class PitchTestData { - public: - PitchTestData(); - ~PitchTestData(); - rtc::ArrayView GetPitchBufView(); - rtc::ArrayView - GetPitchBufSquareEnergiesView(); - rtc::ArrayView - GetPitchBufAutoCorrCoeffsView(); - - private: - std::array test_data_; -}; - } // namespace test } // namespace rnn_vad } // namespace webrtc