diff --git a/modules/audio_processing/agc2/rnn_vad/auto_correlation_unittest.cc b/modules/audio_processing/agc2/rnn_vad/auto_correlation_unittest.cc index ef3748d7cf..76001ed7b7 100644 --- a/modules/audio_processing/agc2/rnn_vad/auto_correlation_unittest.cc +++ b/modules/audio_processing/agc2/rnn_vad/auto_correlation_unittest.cc @@ -17,15 +17,15 @@ namespace webrtc { namespace rnn_vad { -namespace test { +namespace { // Checks that the auto correlation function produces output within tolerance // given test input data. TEST(RnnVadTest, PitchBufferAutoCorrelationWithinTolerance) { PitchTestData test_data; std::array pitch_buf_decimated; - Decimate2x(test_data.GetPitchBufView(), pitch_buf_decimated); - std::array computed_output; + Decimate2x(test_data.PitchBuffer24kHzView(), pitch_buf_decimated); + std::array computed_output; { // TODO(bugs.webrtc.org/8948): Add when the issue is fixed. // FloatingPointExceptionObserver fpe_observer; @@ -33,7 +33,7 @@ TEST(RnnVadTest, PitchBufferAutoCorrelationWithinTolerance) { auto_corr_calculator.ComputeOnPitchBuffer(pitch_buf_decimated, computed_output); } - auto auto_corr_view = test_data.GetPitchBufAutoCorrCoeffsView(); + auto auto_corr_view = test_data.AutoCorrelation12kHzView(); ExpectNearAbsolute({auto_corr_view.data(), auto_corr_view.size()}, computed_output, 3e-3f); } @@ -44,7 +44,7 @@ TEST(RnnVadTest, CheckAutoCorrelationOnConstantPitchBuffer) { // Create constant signal with no pitch. std::array pitch_buf_decimated; std::fill(pitch_buf_decimated.begin(), pitch_buf_decimated.end(), 1.f); - std::array computed_output; + std::array computed_output; { // TODO(bugs.webrtc.org/8948): Add when the issue is fixed. // FloatingPointExceptionObserver fpe_observer; @@ -55,12 +55,12 @@ TEST(RnnVadTest, CheckAutoCorrelationOnConstantPitchBuffer) { // The expected output is a vector filled with the same expected // auto-correlation value. The latter equals the length of a 20 ms frame. constexpr int kFrameSize20ms12kHz = kFrameSize20ms24kHz / 2; - std::array expected_output; + std::array expected_output; std::fill(expected_output.begin(), expected_output.end(), static_cast(kFrameSize20ms12kHz)); ExpectNearAbsolute(expected_output, computed_output, 4e-5f); } -} // namespace test +} // namespace } // namespace rnn_vad } // namespace webrtc diff --git a/modules/audio_processing/agc2/rnn_vad/features_extraction_unittest.cc b/modules/audio_processing/agc2/rnn_vad/features_extraction_unittest.cc index 0da971e3da..98da39e38a 100644 --- a/modules/audio_processing/agc2/rnn_vad/features_extraction_unittest.cc +++ b/modules/audio_processing/agc2/rnn_vad/features_extraction_unittest.cc @@ -14,7 +14,6 @@ #include #include "modules/audio_processing/agc2/cpu_features.h" -#include "modules/audio_processing/agc2/rnn_vad/test_utils.h" #include "rtc_base/numerics/safe_compare.h" #include "rtc_base/numerics/safe_conversions.h" // TODO(bugs.webrtc.org/8948): Add when the issue is fixed. @@ -23,7 +22,6 @@ namespace webrtc { namespace rnn_vad { -namespace test { namespace { constexpr int ceil(int n, int m) { @@ -52,7 +50,7 @@ void CreatePureTone(float amplitude, float freq_hz, rtc::ArrayView dst) { // Feeds |features_extractor| with |samples| splitting it in 10 ms frames. // For every frame, the output is written into |feature_vector|. Returns true // if silence is detected in the last frame. -bool FeedTestData(FeaturesExtractor* features_extractor, +bool FeedTestData(FeaturesExtractor& features_extractor, rtc::ArrayView samples, rtc::ArrayView feature_vector) { // TODO(bugs.webrtc.org/8948): Add when the issue is fixed. @@ -60,15 +58,13 @@ bool FeedTestData(FeaturesExtractor* features_extractor, bool is_silence = true; const int num_frames = samples.size() / kFrameSize10ms24kHz; for (int i = 0; i < num_frames; ++i) { - is_silence = features_extractor->CheckSilenceComputeFeatures( + is_silence = features_extractor.CheckSilenceComputeFeatures( {samples.data() + i * kFrameSize10ms24kHz, kFrameSize10ms24kHz}, feature_vector); } return is_silence; } -} // namespace - // Extracts the features for two pure tones and verifies that the pitch field // values reflect the known tone frequencies. TEST(RnnVadTest, FeatureExtractionLowHighPitch) { @@ -91,17 +87,17 @@ TEST(RnnVadTest, FeatureExtractionLowHighPitch) { constexpr int pitch_feature_index = kFeatureVectorSize - 2; // Low frequency tone - i.e., high period. CreatePureTone(amplitude, low_pitch_hz, samples); - ASSERT_FALSE(FeedTestData(&features_extractor, samples, feature_vector_view)); + ASSERT_FALSE(FeedTestData(features_extractor, samples, feature_vector_view)); float high_pitch_period = feature_vector_view[pitch_feature_index]; // High frequency tone - i.e., low period. features_extractor.Reset(); CreatePureTone(amplitude, high_pitch_hz, samples); - ASSERT_FALSE(FeedTestData(&features_extractor, samples, feature_vector_view)); + ASSERT_FALSE(FeedTestData(features_extractor, samples, feature_vector_view)); float low_pitch_period = feature_vector_view[pitch_feature_index]; // Check. EXPECT_LT(low_pitch_period, high_pitch_period); } -} // namespace test +} // namespace } // namespace rnn_vad } // namespace webrtc diff --git a/modules/audio_processing/agc2/rnn_vad/lp_residual.h b/modules/audio_processing/agc2/rnn_vad/lp_residual.h index 2e54dd93d8..380d9f608b 100644 --- a/modules/audio_processing/agc2/rnn_vad/lp_residual.h +++ b/modules/audio_processing/agc2/rnn_vad/lp_residual.h @@ -18,7 +18,7 @@ namespace webrtc { namespace rnn_vad { -// LPC inverse filter length. +// Linear predictive coding (LPC) inverse filter length. constexpr int kNumLpcCoefficients = 5; // Given a frame |x|, computes a post-processed version of LPC coefficients diff --git a/modules/audio_processing/agc2/rnn_vad/lp_residual_unittest.cc b/modules/audio_processing/agc2/rnn_vad/lp_residual_unittest.cc index 177977688e..7b3a4a3f65 100644 --- a/modules/audio_processing/agc2/rnn_vad/lp_residual_unittest.cc +++ b/modules/audio_processing/agc2/rnn_vad/lp_residual_unittest.cc @@ -22,7 +22,7 @@ namespace webrtc { namespace rnn_vad { -namespace test { +namespace { // Checks that the LP residual can be computed on an empty frame. TEST(RnnVadTest, LpResidualOfEmptyFrame) { @@ -33,55 +33,48 @@ TEST(RnnVadTest, LpResidualOfEmptyFrame) { std::array empty_frame; empty_frame.fill(0.f); // Compute inverse filter coefficients. - std::array lpc_coeffs; - ComputeAndPostProcessLpcCoefficients(empty_frame, lpc_coeffs); + std::array lpc; + ComputeAndPostProcessLpcCoefficients(empty_frame, lpc); // Compute LP residual. std::array lp_residual; - ComputeLpResidual(lpc_coeffs, empty_frame, lp_residual); + ComputeLpResidual(lpc, empty_frame, lp_residual); } // Checks that the computed LP residual is bit-exact given test input data. TEST(RnnVadTest, LpResidualPipelineBitExactness) { // Input and expected output readers. - auto pitch_buf_24kHz_reader = CreatePitchBuffer24kHzReader(); - auto lp_residual_reader = CreateLpResidualAndPitchPeriodGainReader(); + ChunksFileReader pitch_buffer_reader = CreatePitchBuffer24kHzReader(); + ChunksFileReader lp_pitch_reader = CreateLpResidualAndPitchInfoReader(); // Buffers. - std::vector pitch_buf_data(kBufSize24kHz); - std::array lpc_coeffs; + std::vector pitch_buffer_24kHz(kBufSize24kHz); + std::array lpc; std::vector computed_lp_residual(kBufSize24kHz); std::vector expected_lp_residual(kBufSize24kHz); // Test length. const int num_frames = - std::min(pitch_buf_24kHz_reader.second, 300); // Max 3 s. - ASSERT_GE(lp_residual_reader.second, num_frames); + std::min(pitch_buffer_reader.num_chunks, 300); // Max 3 s. + ASSERT_GE(lp_pitch_reader.num_chunks, num_frames); - { - // TODO(bugs.webrtc.org/8948): Add when the issue is fixed. - // FloatingPointExceptionObserver fpe_observer; - for (int i = 0; i < num_frames; ++i) { - // Read input. - ASSERT_TRUE(pitch_buf_24kHz_reader.first->ReadChunk(pitch_buf_data)); - // Read expected output (ignore pitch gain and period). - ASSERT_TRUE(lp_residual_reader.first->ReadChunk(expected_lp_residual)); - float unused; - ASSERT_TRUE(lp_residual_reader.first->ReadValue(&unused)); - ASSERT_TRUE(lp_residual_reader.first->ReadValue(&unused)); - - // Check every 200 ms. - if (i % 20 != 0) { - continue; - } - - SCOPED_TRACE(i); - ComputeAndPostProcessLpcCoefficients(pitch_buf_data, lpc_coeffs); - ComputeLpResidual(lpc_coeffs, pitch_buf_data, computed_lp_residual); + // TODO(bugs.webrtc.org/8948): Add when the issue is fixed. + // FloatingPointExceptionObserver fpe_observer; + for (int i = 0; i < num_frames; ++i) { + SCOPED_TRACE(i); + // Read input. + ASSERT_TRUE(pitch_buffer_reader.reader->ReadChunk(pitch_buffer_24kHz)); + // Read expected output (ignore pitch gain and period). + ASSERT_TRUE(lp_pitch_reader.reader->ReadChunk(expected_lp_residual)); + lp_pitch_reader.reader->SeekForward(2); // Pitch period and strength. + // Check every 200 ms. + if (i % 20 == 0) { + ComputeAndPostProcessLpcCoefficients(pitch_buffer_24kHz, lpc); + ComputeLpResidual(lpc, pitch_buffer_24kHz, computed_lp_residual); ExpectNearAbsolute(expected_lp_residual, computed_lp_residual, kFloatMin); } } } -} // namespace test +} // namespace } // namespace rnn_vad } // namespace webrtc diff --git a/modules/audio_processing/agc2/rnn_vad/pitch_search_internal_unittest.cc b/modules/audio_processing/agc2/rnn_vad/pitch_search_internal_unittest.cc index a4a4df12dc..8c336af90f 100644 --- a/modules/audio_processing/agc2/rnn_vad/pitch_search_internal_unittest.cc +++ b/modules/audio_processing/agc2/rnn_vad/pitch_search_internal_unittest.cc @@ -22,7 +22,6 @@ namespace webrtc { namespace rnn_vad { -namespace test { namespace { constexpr int kTestPitchPeriodsLow = 3 * kMinPitch48kHz / 2; @@ -63,12 +62,12 @@ TEST(RnnVadTest, ComputeSlidingFrameSquareEnergies24kHzWithinTolerance) { const AvailableCpuFeatures cpu_features = GetAvailableCpuFeatures(); PitchTestData test_data; - std::array computed_output; + std::array computed_output; // TODO(bugs.webrtc.org/8948): Add when the issue is fixed. // FloatingPointExceptionObserver fpe_observer; - ComputeSlidingFrameSquareEnergies24kHz(test_data.GetPitchBufView(), + ComputeSlidingFrameSquareEnergies24kHz(test_data.PitchBuffer24kHzView(), computed_output, cpu_features); - auto square_energies_view = test_data.GetPitchBufSquareEnergiesView(); + auto square_energies_view = test_data.SquareEnergies24kHzView(); ExpectNearAbsolute({square_energies_view.data(), square_energies_view.size()}, computed_output, 1e-3f); } @@ -79,13 +78,12 @@ TEST(RnnVadTest, ComputePitchPeriod12kHzBitExactness) { PitchTestData test_data; std::array pitch_buf_decimated; - Decimate2x(test_data.GetPitchBufView(), pitch_buf_decimated); + Decimate2x(test_data.PitchBuffer24kHzView(), pitch_buf_decimated); CandidatePitchPeriods pitch_candidates; // TODO(bugs.webrtc.org/8948): Add when the issue is fixed. // FloatingPointExceptionObserver fpe_observer; - auto auto_corr_view = test_data.GetPitchBufAutoCorrCoeffsView(); - pitch_candidates = ComputePitchPeriod12kHz(pitch_buf_decimated, - auto_corr_view, cpu_features); + pitch_candidates = ComputePitchPeriod12kHz( + pitch_buf_decimated, test_data.AutoCorrelation12kHzView(), cpu_features); EXPECT_EQ(pitch_candidates.best, 140); EXPECT_EQ(pitch_candidates.second_best, 142); } @@ -98,16 +96,16 @@ TEST(RnnVadTest, ComputePitchPeriod48kHzBitExactness) { std::vector y_energy(kRefineNumLags24kHz); rtc::ArrayView y_energy_view(y_energy.data(), kRefineNumLags24kHz); - ComputeSlidingFrameSquareEnergies24kHz(test_data.GetPitchBufView(), + ComputeSlidingFrameSquareEnergies24kHz(test_data.PitchBuffer24kHzView(), y_energy_view, cpu_features); // TODO(bugs.webrtc.org/8948): Add when the issue is fixed. // FloatingPointExceptionObserver fpe_observer; EXPECT_EQ( - ComputePitchPeriod48kHz(test_data.GetPitchBufView(), y_energy_view, + ComputePitchPeriod48kHz(test_data.PitchBuffer24kHzView(), y_energy_view, /*pitch_candidates=*/{280, 284}, cpu_features), 560); EXPECT_EQ( - ComputePitchPeriod48kHz(test_data.GetPitchBufView(), y_energy_view, + ComputePitchPeriod48kHz(test_data.PitchBuffer24kHzView(), y_energy_view, /*pitch_candidates=*/{260, 284}, cpu_features), 568); } @@ -132,12 +130,12 @@ TEST_P(PitchCandidatesParametrization, std::vector y_energy(kRefineNumLags24kHz); rtc::ArrayView y_energy_view(y_energy.data(), kRefineNumLags24kHz); - ComputeSlidingFrameSquareEnergies24kHz(test_data.GetPitchBufView(), + ComputeSlidingFrameSquareEnergies24kHz(test_data.PitchBuffer24kHzView(), y_energy_view, params.cpu_features); EXPECT_EQ( - ComputePitchPeriod48kHz(test_data.GetPitchBufView(), y_energy_view, + ComputePitchPeriod48kHz(test_data.PitchBuffer24kHzView(), y_energy_view, params.pitch_candidates, params.cpu_features), - ComputePitchPeriod48kHz(test_data.GetPitchBufView(), y_energy_view, + ComputePitchPeriod48kHz(test_data.PitchBuffer24kHzView(), y_energy_view, swapped_pitch_candidates, params.cpu_features)); } @@ -179,13 +177,13 @@ TEST_P(ExtendedPitchPeriodSearchParametrizaion, std::vector y_energy(kRefineNumLags24kHz); rtc::ArrayView y_energy_view(y_energy.data(), kRefineNumLags24kHz); - ComputeSlidingFrameSquareEnergies24kHz(test_data.GetPitchBufView(), + ComputeSlidingFrameSquareEnergies24kHz(test_data.PitchBuffer24kHzView(), y_energy_view, params.cpu_features); // TODO(bugs.webrtc.org/8948): Add when the issue is fixed. // FloatingPointExceptionObserver fpe_observer; const auto computed_output = ComputeExtendedPitchPeriod48kHz( - test_data.GetPitchBufView(), y_energy_view, params.initial_pitch_period, - params.last_pitch, params.cpu_features); + test_data.PitchBuffer24kHzView(), y_energy_view, + params.initial_pitch_period, params.last_pitch, params.cpu_features); EXPECT_EQ(params.expected_pitch.period, computed_output.period); EXPECT_NEAR(params.expected_pitch.strength, computed_output.strength, 1e-6f); } @@ -219,6 +217,5 @@ INSTANTIATE_TEST_SUITE_P( PrintTestIndexAndCpuFeatures); } // namespace -} // namespace test } // namespace rnn_vad } // namespace webrtc diff --git a/modules/audio_processing/agc2/rnn_vad/pitch_search_unittest.cc b/modules/audio_processing/agc2/rnn_vad/pitch_search_unittest.cc index fe9be5dbba..79b44b995c 100644 --- a/modules/audio_processing/agc2/rnn_vad/pitch_search_unittest.cc +++ b/modules/audio_processing/agc2/rnn_vad/pitch_search_unittest.cc @@ -26,8 +26,8 @@ namespace rnn_vad { // Checks that the computed pitch period is bit-exact and that the computed // pitch gain is within tolerance given test input data. TEST(RnnVadTest, PitchSearchWithinTolerance) { - auto lp_residual_reader = test::CreateLpResidualAndPitchPeriodGainReader(); - const int num_frames = std::min(lp_residual_reader.second, 300); // Max 3 s. + ChunksFileReader reader = CreateLpResidualAndPitchInfoReader(); + const int num_frames = std::min(reader.num_chunks, 300); // Max 3 s. std::vector lp_residual(kBufSize24kHz); float expected_pitch_period, expected_pitch_strength; const AvailableCpuFeatures cpu_features = GetAvailableCpuFeatures(); @@ -37,9 +37,9 @@ TEST(RnnVadTest, PitchSearchWithinTolerance) { // FloatingPointExceptionObserver fpe_observer; for (int i = 0; i < num_frames; ++i) { SCOPED_TRACE(i); - lp_residual_reader.first->ReadChunk(lp_residual); - lp_residual_reader.first->ReadValue(&expected_pitch_period); - lp_residual_reader.first->ReadValue(&expected_pitch_strength); + ASSERT_TRUE(reader.reader->ReadChunk(lp_residual)); + ASSERT_TRUE(reader.reader->ReadValue(expected_pitch_period)); + ASSERT_TRUE(reader.reader->ReadValue(expected_pitch_strength)); int pitch_period = pitch_estimator.Estimate({lp_residual.data(), kBufSize24kHz}); EXPECT_EQ(expected_pitch_period, pitch_period); diff --git a/modules/audio_processing/agc2/rnn_vad/ring_buffer_unittest.cc b/modules/audio_processing/agc2/rnn_vad/ring_buffer_unittest.cc index 8b061a968f..d11d4eac3e 100644 --- a/modules/audio_processing/agc2/rnn_vad/ring_buffer_unittest.cc +++ b/modules/audio_processing/agc2/rnn_vad/ring_buffer_unittest.cc @@ -14,7 +14,6 @@ namespace webrtc { namespace rnn_vad { -namespace test { namespace { // Compare the elements of two given array views. @@ -64,8 +63,6 @@ void TestRingBuffer() { } } -} // namespace - // Check that for different delays, different views are returned. TEST(RnnVadTest, RingBufferArrayViews) { constexpr int s = 3; @@ -110,6 +107,6 @@ TEST(RnnVadTest, RingBufferFloating) { TestRingBuffer(); } -} // namespace test +} // namespace } // namespace rnn_vad } // namespace webrtc diff --git a/modules/audio_processing/agc2/rnn_vad/rnn_fc_unittest.cc b/modules/audio_processing/agc2/rnn_vad/rnn_fc_unittest.cc index 1094832df8..c586ed291f 100644 --- a/modules/audio_processing/agc2/rnn_vad/rnn_fc_unittest.cc +++ b/modules/audio_processing/agc2/rnn_vad/rnn_fc_unittest.cc @@ -24,7 +24,6 @@ namespace webrtc { namespace rnn_vad { -namespace test { namespace { using ::rnnoise::kInputDenseBias; @@ -104,6 +103,5 @@ INSTANTIATE_TEST_SUITE_P( }); } // namespace -} // namespace test } // namespace rnn_vad } // namespace webrtc diff --git a/modules/audio_processing/agc2/rnn_vad/rnn_gru_unittest.cc b/modules/audio_processing/agc2/rnn_vad/rnn_gru_unittest.cc index 54e1cf538a..4e8b524d6f 100644 --- a/modules/audio_processing/agc2/rnn_vad/rnn_gru_unittest.cc +++ b/modules/audio_processing/agc2/rnn_vad/rnn_gru_unittest.cc @@ -21,7 +21,6 @@ namespace webrtc { namespace rnn_vad { -namespace test { namespace { void TestGatedRecurrentLayer( @@ -135,6 +134,5 @@ TEST(RnnVadTest, DISABLED_BenchmarkGatedRecurrentLayer) { } } // namespace -} // namespace test } // namespace rnn_vad } // namespace webrtc diff --git a/modules/audio_processing/agc2/rnn_vad/rnn_unittest.cc b/modules/audio_processing/agc2/rnn_vad/rnn_unittest.cc index 1c314d17ce..4c5409a14e 100644 --- a/modules/audio_processing/agc2/rnn_vad/rnn_unittest.cc +++ b/modules/audio_processing/agc2/rnn_vad/rnn_unittest.cc @@ -17,7 +17,6 @@ namespace webrtc { namespace rnn_vad { -namespace test { namespace { constexpr std::array kFeatures = { @@ -67,6 +66,5 @@ TEST(RnnVadTest, CheckRnnVadSilence) { } } // namespace -} // namespace test } // namespace rnn_vad } // namespace webrtc diff --git a/modules/audio_processing/agc2/rnn_vad/rnn_vad_unittest.cc b/modules/audio_processing/agc2/rnn_vad/rnn_vad_unittest.cc index 81553b4789..7eb699c39f 100644 --- a/modules/audio_processing/agc2/rnn_vad/rnn_vad_unittest.cc +++ b/modules/audio_processing/agc2/rnn_vad/rnn_vad_unittest.cc @@ -9,6 +9,7 @@ */ #include +#include #include #include @@ -26,7 +27,6 @@ namespace webrtc { namespace rnn_vad { -namespace test { namespace { constexpr int kFrameSize10ms48kHz = 480; @@ -49,8 +49,6 @@ void DumpPerfStats(int num_samples, // constant below to true in order to write new expected output binary files. constexpr bool kWriteComputedOutputToFile = false; -} // namespace - // Avoids that one forgets to set |kWriteComputedOutputToFile| back to false // when the expected output files are re-exported. TEST(RnnVadTest, CheckWriteComputedOutputIsFalse) { @@ -71,12 +69,11 @@ TEST_P(RnnVadProbabilityParametrization, RnnVadProbabilityWithinTolerance) { RnnVad rnn_vad(cpu_features); // Init input samples and expected output readers. - auto samples_reader = CreatePcmSamplesReader(kFrameSize10ms48kHz); - auto expected_vad_prob_reader = CreateVadProbsReader(); + std::unique_ptr samples_reader = CreatePcmSamplesReader(); + std::unique_ptr expected_vad_prob_reader = CreateVadProbsReader(); - // Input length. - const int num_frames = samples_reader.second; - ASSERT_GE(expected_vad_prob_reader.second, num_frames); + // Input length. The last incomplete frame is ignored. + const int num_frames = samples_reader->size() / kFrameSize10ms48kHz; // Init buffers. std::vector samples_48k(kFrameSize10ms48kHz); @@ -86,12 +83,12 @@ TEST_P(RnnVadProbabilityParametrization, RnnVadProbabilityWithinTolerance) { std::vector expected_vad_prob(num_frames); // Read expected output. - ASSERT_TRUE(expected_vad_prob_reader.first->ReadChunk(expected_vad_prob)); + ASSERT_TRUE(expected_vad_prob_reader->ReadChunk(expected_vad_prob)); // Compute VAD probabilities on the downsampled input. float cumulative_error = 0.f; for (int i = 0; i < num_frames; ++i) { - samples_reader.first->ReadChunk(samples_48k); + ASSERT_TRUE(samples_reader->ReadChunk(samples_48k)); decimator.Resample(samples_48k.data(), samples_48k.size(), samples_24k.data(), samples_24k.size()); bool is_silence = features_extractor.CheckSilenceComputeFeatures( @@ -106,7 +103,7 @@ TEST_P(RnnVadProbabilityParametrization, RnnVadProbabilityWithinTolerance) { EXPECT_LT(cumulative_error / num_frames, 1e-4f); if (kWriteComputedOutputToFile) { - BinaryFileWriter vad_prob_writer("new_vad_prob.dat"); + FileWriter vad_prob_writer("new_vad_prob.dat"); vad_prob_writer.WriteChunk(computed_vad_prob); } } @@ -118,15 +115,16 @@ TEST_P(RnnVadProbabilityParametrization, RnnVadProbabilityWithinTolerance) { // - on android: run the this unit test adding "--logcat-output-file". TEST_P(RnnVadProbabilityParametrization, DISABLED_RnnVadPerformance) { // PCM samples reader and buffers. - auto samples_reader = CreatePcmSamplesReader(kFrameSize10ms48kHz); - const int num_frames = samples_reader.second; + std::unique_ptr samples_reader = CreatePcmSamplesReader(); + // The last incomplete frame is ignored. + const int num_frames = samples_reader->size() / kFrameSize10ms48kHz; std::array samples; // Pre-fetch and decimate samples. PushSincResampler decimator(kFrameSize10ms48kHz, kFrameSize10ms24kHz); std::vector prefetched_decimated_samples; prefetched_decimated_samples.resize(num_frames * kFrameSize10ms24kHz); for (int i = 0; i < num_frames; ++i) { - samples_reader.first->ReadChunk(samples); + ASSERT_TRUE(samples_reader->ReadChunk(samples)); decimator.Resample(samples.data(), samples.size(), &prefetched_decimated_samples[i * kFrameSize10ms24kHz], kFrameSize10ms24kHz); @@ -151,7 +149,6 @@ TEST_P(RnnVadProbabilityParametrization, DISABLED_RnnVadPerformance) { rnn_vad.ComputeVadProbability(feature_vector, is_silence); } perf_timer.StopTimer(); - samples_reader.first->SeekBeginning(); } DumpPerfStats(num_frames * kFrameSize10ms24kHz, kSampleRate24kHz, perf_timer.GetDurationAverage(), @@ -180,6 +177,6 @@ INSTANTIATE_TEST_SUITE_P( return info.param.ToString(); }); -} // namespace test +} // namespace } // namespace rnn_vad } // namespace webrtc diff --git a/modules/audio_processing/agc2/rnn_vad/sequence_buffer_unittest.cc b/modules/audio_processing/agc2/rnn_vad/sequence_buffer_unittest.cc index 125f1b821c..f577571b09 100644 --- a/modules/audio_processing/agc2/rnn_vad/sequence_buffer_unittest.cc +++ b/modules/audio_processing/agc2/rnn_vad/sequence_buffer_unittest.cc @@ -17,7 +17,6 @@ namespace webrtc { namespace rnn_vad { -namespace test { namespace { template @@ -60,8 +59,6 @@ void TestSequenceBufferPushOp() { } } -} // namespace - TEST(RnnVadTest, SequenceBufferGetters) { constexpr int buffer_size = 8; constexpr int chunk_size = 8; @@ -100,6 +97,6 @@ TEST(RnnVadTest, SequenceBufferPushOpsFloating) { TestSequenceBufferPushOp(); // Non-integer ratio. } -} // namespace test +} // namespace } // namespace rnn_vad } // namespace webrtc diff --git a/modules/audio_processing/agc2/rnn_vad/spectral_features_internal_unittest.cc b/modules/audio_processing/agc2/rnn_vad/spectral_features_internal_unittest.cc index 461047d004..11a44a57da 100644 --- a/modules/audio_processing/agc2/rnn_vad/spectral_features_internal_unittest.cc +++ b/modules/audio_processing/agc2/rnn_vad/spectral_features_internal_unittest.cc @@ -26,7 +26,6 @@ namespace webrtc { namespace rnn_vad { -namespace test { namespace { // Generates the values for the array named |kOpusBandWeights24kHz20ms| in the @@ -49,8 +48,6 @@ std::vector ComputeTriangularFiltersWeights() { return weights; } -} // namespace - // Checks that the values returned by GetOpusScaleNumBins24kHz20ms() match the // Opus scale frequency boundaries. TEST(RnnVadTest, TestOpusScaleBoundaries) { @@ -158,6 +155,6 @@ TEST(RnnVadTest, ComputeDctWithinTolerance) { } } -} // namespace test +} // namespace } // namespace rnn_vad } // namespace webrtc diff --git a/modules/audio_processing/agc2/rnn_vad/spectral_features_unittest.cc b/modules/audio_processing/agc2/rnn_vad/spectral_features_unittest.cc index fa376f2a0a..9f41e96e5e 100644 --- a/modules/audio_processing/agc2/rnn_vad/spectral_features_unittest.cc +++ b/modules/audio_processing/agc2/rnn_vad/spectral_features_unittest.cc @@ -21,7 +21,6 @@ namespace webrtc { namespace rnn_vad { -namespace test { namespace { constexpr int kTestFeatureVectorSize = kNumBands + 3 * kNumLowerBands + 1; @@ -66,8 +65,6 @@ float* GetCepstralVariability( constexpr float kInitialFeatureVal = -9999.f; -} // namespace - // Checks that silence is detected when the input signal is 0 and that the // feature vector is written only if the input signal is not tagged as silence. TEST(RnnVadTest, SpectralFeaturesWithAndWithoutSilence) { @@ -159,6 +156,6 @@ TEST(RnnVadTest, CepstralFeaturesConstantAverageZeroDerivative) { feature_vector_last[kNumBands + 3 * kNumLowerBands]); } -} // namespace test +} // namespace } // namespace rnn_vad } // namespace webrtc diff --git a/modules/audio_processing/agc2/rnn_vad/symmetric_matrix_buffer_unittest.cc b/modules/audio_processing/agc2/rnn_vad/symmetric_matrix_buffer_unittest.cc index c1da8d181b..6f61c87104 100644 --- a/modules/audio_processing/agc2/rnn_vad/symmetric_matrix_buffer_unittest.cc +++ b/modules/audio_processing/agc2/rnn_vad/symmetric_matrix_buffer_unittest.cc @@ -15,7 +15,6 @@ namespace webrtc { namespace rnn_vad { -namespace test { namespace { template @@ -44,8 +43,6 @@ bool CheckPairsWithValueExist( return false; } -} // namespace - // Test that shows how to combine RingBuffer and SymmetricMatrixBuffer to // efficiently compute pair-wise scores. This test verifies that the evolution // of a SymmetricMatrixBuffer instance follows that of RingBuffer. @@ -105,6 +102,6 @@ TEST(RnnVadTest, SymmetricMatrixBufferUseCase) { } } -} // namespace test +} // namespace } // namespace rnn_vad } // namespace webrtc diff --git a/modules/audio_processing/agc2/rnn_vad/test_utils.cc b/modules/audio_processing/agc2/rnn_vad/test_utils.cc index 75de1099f2..3db6774450 100644 --- a/modules/audio_processing/agc2/rnn_vad/test_utils.cc +++ b/modules/audio_processing/agc2/rnn_vad/test_utils.cc @@ -11,7 +11,10 @@ #include "modules/audio_processing/agc2/rnn_vad/test_utils.h" #include +#include #include +#include +#include #include "rtc_base/checks.h" #include "rtc_base/numerics/safe_compare.h" @@ -20,11 +23,46 @@ namespace webrtc { namespace rnn_vad { -namespace test { namespace { -using ReaderPairType = - std::pair>, const int>; +// File reader for binary files that contain a sequence of values with +// arithmetic type `T`. The values of type `T` that are read are cast to float. +template +class FloatFileReader : public FileReader { + public: + static_assert(std::is_arithmetic::value, ""); + FloatFileReader(const std::string& filename) + : is_(filename, std::ios::binary | std::ios::ate), + size_(is_.tellg() / sizeof(T)) { + RTC_CHECK(is_); + SeekBeginning(); + } + FloatFileReader(const FloatFileReader&) = delete; + FloatFileReader& operator=(const FloatFileReader&) = delete; + ~FloatFileReader() = default; + + int size() const override { return size_; } + bool ReadChunk(rtc::ArrayView dst) override { + const std::streamsize bytes_to_read = dst.size() * sizeof(T); + if (std::is_same::value) { + is_.read(reinterpret_cast(dst.data()), bytes_to_read); + } else { + buffer_.resize(dst.size()); + is_.read(reinterpret_cast(buffer_.data()), bytes_to_read); + std::transform(buffer_.begin(), buffer_.end(), dst.begin(), + [](const T& v) -> float { return static_cast(v); }); + } + return is_.gcount() == bytes_to_read; + } + bool ReadValue(float& dst) override { return ReadChunk({&dst, 1}); } + void SeekForward(int hop) override { is_.seekg(hop * sizeof(T), is_.cur); } + void SeekBeginning() override { is_.seekg(0, is_.beg); } + + private: + std::ifstream is_; + const int size_; + std::vector buffer_; +}; } // namespace @@ -49,66 +87,49 @@ void ExpectNearAbsolute(rtc::ArrayView expected, } } -std::pair>, const int> -CreatePcmSamplesReader(const int frame_length) { - auto ptr = std::make_unique>( - test::ResourcePath("audio_processing/agc2/rnn_vad/samples", "pcm"), - frame_length); - // The last incomplete frame is ignored. - return {std::move(ptr), ptr->data_length() / frame_length}; +std::unique_ptr CreatePcmSamplesReader() { + return std::make_unique>( + /*filename=*/test::ResourcePath("audio_processing/agc2/rnn_vad/samples", + "pcm")); } -ReaderPairType CreatePitchBuffer24kHzReader() { - constexpr int cols = 864; - auto ptr = std::make_unique>( - ResourcePath("audio_processing/agc2/rnn_vad/pitch_buf_24k", "dat"), cols); - return {std::move(ptr), rtc::CheckedDivExact(ptr->data_length(), cols)}; +ChunksFileReader CreatePitchBuffer24kHzReader() { + auto reader = std::make_unique>( + /*filename=*/test::ResourcePath( + "audio_processing/agc2/rnn_vad/pitch_buf_24k", "dat")); + const int num_chunks = rtc::CheckedDivExact(reader->size(), kBufSize24kHz); + return {/*chunk_size=*/kBufSize24kHz, num_chunks, std::move(reader)}; } -ReaderPairType CreateLpResidualAndPitchPeriodGainReader() { - constexpr int num_lp_residual_coeffs = 864; - auto ptr = std::make_unique>( - ResourcePath("audio_processing/agc2/rnn_vad/pitch_lp_res", "dat"), - num_lp_residual_coeffs); - return {std::move(ptr), - rtc::CheckedDivExact(ptr->data_length(), 2 + num_lp_residual_coeffs)}; +ChunksFileReader CreateLpResidualAndPitchInfoReader() { + constexpr int kPitchInfoSize = 2; // Pitch period and strength. + constexpr int kChunkSize = kBufSize24kHz + kPitchInfoSize; + auto reader = std::make_unique>( + /*filename=*/test::ResourcePath( + "audio_processing/agc2/rnn_vad/pitch_lp_res", "dat")); + const int num_chunks = rtc::CheckedDivExact(reader->size(), kChunkSize); + return {kChunkSize, num_chunks, std::move(reader)}; } -ReaderPairType CreateVadProbsReader() { - auto ptr = std::make_unique>( - test::ResourcePath("audio_processing/agc2/rnn_vad/vad_prob", "dat")); - return {std::move(ptr), ptr->data_length()}; +std::unique_ptr CreateVadProbsReader() { + return std::make_unique>( + /*filename=*/test::ResourcePath("audio_processing/agc2/rnn_vad/vad_prob", + "dat")); } PitchTestData::PitchTestData() { - BinaryFileReader test_data_reader( - ResourcePath("audio_processing/agc2/rnn_vad/pitch_search_int", "dat"), - 1396); - test_data_reader.ReadChunk(test_data_); + FloatFileReader reader( + /*filename=*/ResourcePath( + "audio_processing/agc2/rnn_vad/pitch_search_int", "dat")); + reader.ReadChunk(pitch_buffer_24k_); + reader.ReadChunk(square_energies_24k_); + reader.ReadChunk(auto_correlation_12k_); // Reverse the order of the squared energy values. // Required after the WebRTC CL 191703 which switched to forward computation. - std::reverse(test_data_.begin() + kBufSize24kHz, - test_data_.begin() + kBufSize24kHz + kNumPitchBufSquareEnergies); + std::reverse(square_energies_24k_.begin(), square_energies_24k_.end()); } PitchTestData::~PitchTestData() = default; -rtc::ArrayView PitchTestData::GetPitchBufView() - const { - return {test_data_.data(), kBufSize24kHz}; -} - -rtc::ArrayView -PitchTestData::GetPitchBufSquareEnergiesView() const { - return {test_data_.data() + kBufSize24kHz, kNumPitchBufSquareEnergies}; -} - -rtc::ArrayView -PitchTestData::GetPitchBufAutoCorrCoeffsView() const { - return {test_data_.data() + kBufSize24kHz + kNumPitchBufSquareEnergies, - kNumPitchBufAutoCorrCoeffs}; -} - -} // namespace test } // namespace rnn_vad } // namespace webrtc diff --git a/modules/audio_processing/agc2/rnn_vad/test_utils.h b/modules/audio_processing/agc2/rnn_vad/test_utils.h index 3d1ad259db..86af5e0076 100644 --- a/modules/audio_processing/agc2/rnn_vad/test_utils.h +++ b/modules/audio_processing/agc2/rnn_vad/test_utils.h @@ -11,15 +11,10 @@ #ifndef MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_TEST_UTILS_H_ #define MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_TEST_UTILS_H_ -#include #include #include -#include #include #include -#include -#include -#include #include "api/array_view.h" #include "modules/audio_processing/agc2/rnn_vad/common.h" @@ -28,7 +23,6 @@ namespace webrtc { namespace rnn_vad { -namespace test { constexpr float kFloatMin = std::numeric_limits::min(); @@ -43,98 +37,48 @@ void ExpectNearAbsolute(rtc::ArrayView expected, rtc::ArrayView computed, float tolerance); -// Reader for binary files consisting of an arbitrary long sequence of elements -// having type T. It is possible to read and cast to another type D at once. -template -class BinaryFileReader { +// File reader interface. +class FileReader { public: - BinaryFileReader(const std::string& file_path, int chunk_size = 0) - : is_(file_path, std::ios::binary | std::ios::ate), - data_length_(is_.tellg() / sizeof(T)), - chunk_size_(chunk_size) { - RTC_CHECK(is_); - SeekBeginning(); - buf_.resize(chunk_size_); - } - BinaryFileReader(const BinaryFileReader&) = delete; - BinaryFileReader& operator=(const BinaryFileReader&) = delete; - ~BinaryFileReader() = default; - int data_length() const { return data_length_; } - bool ReadValue(D* dst) { - if (std::is_same::value) { - is_.read(reinterpret_cast(dst), sizeof(T)); - } else { - T v; - is_.read(reinterpret_cast(&v), sizeof(T)); - *dst = static_cast(v); - } - return is_.gcount() == sizeof(T); - } - // If |chunk_size| was specified in the ctor, it will check that the size of - // |dst| equals |chunk_size|. - bool ReadChunk(rtc::ArrayView dst) { - RTC_DCHECK((chunk_size_ == 0) || rtc::SafeEq(chunk_size_, dst.size())); - const std::streamsize bytes_to_read = dst.size() * sizeof(T); - if (std::is_same::value) { - is_.read(reinterpret_cast(dst.data()), bytes_to_read); - } else { - is_.read(reinterpret_cast(buf_.data()), bytes_to_read); - std::transform(buf_.begin(), buf_.end(), dst.begin(), - [](const T& v) -> D { return static_cast(v); }); - } - return is_.gcount() == bytes_to_read; - } - void SeekForward(int items) { is_.seekg(items * sizeof(T), is_.cur); } - void SeekBeginning() { is_.seekg(0, is_.beg); } - - private: - std::ifstream is_; - const int data_length_; - const int chunk_size_; - std::vector buf_; + virtual ~FileReader() = default; + // Number of values in the file. + virtual int size() const = 0; + // Reads `dst.size()` float values into `dst`, advances the internal file + // position according to the number of read bytes and returns true if the + // values are correctly read. If the number of remaining bytes in the file is + // not sufficient to read `dst.size()` float values, `dst` is partially + // modified and false is returned. + virtual bool ReadChunk(rtc::ArrayView dst) = 0; + // Reads a single float value, advances the internal file position according + // to the number of read bytes and returns true if the value is correctly + // read. If the number of remaining bytes in the file is not sufficient to + // read one float, `dst` is not modified and false is returned. + virtual bool ReadValue(float& dst) = 0; + // Advances the internal file position by `hop` float values. + virtual void SeekForward(int hop) = 0; + // Resets the internal file position to BOF. + virtual void SeekBeginning() = 0; }; -// Writer for binary files. -template -class BinaryFileWriter { - public: - explicit BinaryFileWriter(const std::string& file_path) - : os_(file_path, std::ios::binary) {} - BinaryFileWriter(const BinaryFileWriter&) = delete; - BinaryFileWriter& operator=(const BinaryFileWriter&) = delete; - ~BinaryFileWriter() = default; - static_assert(std::is_arithmetic::value, ""); - void WriteChunk(rtc::ArrayView value) { - const std::streamsize bytes_to_write = value.size() * sizeof(T); - os_.write(reinterpret_cast(value.data()), bytes_to_write); - } - - private: - std::ofstream os_; +// File reader for files that contain `num_chunks` chunks with size equal to +// `chunk_size`. +struct ChunksFileReader { + const int chunk_size; + const int num_chunks; + std::unique_ptr reader; }; -// Factories for resource file readers. -// The functions below return a pair where the first item is a reader unique -// pointer and the second the number of chunks that can be read from the file. -// Creates a reader for the PCM samples that casts from S16 to float and reads -// chunks with length |frame_length|. -std::pair>, const int> -CreatePcmSamplesReader(const int frame_length); -// Creates a reader for the pitch buffer content at 24 kHz. -std::pair>, const int> -CreatePitchBuffer24kHzReader(); -// Creates a reader for the the LP residual coefficients and the pitch period -// and gain values. -std::pair>, const int> -CreateLpResidualAndPitchPeriodGainReader(); -// Creates a reader for the VAD probabilities. -std::pair>, const int> -CreateVadProbsReader(); +// Creates a reader for the PCM S16 samples file. +std::unique_ptr CreatePcmSamplesReader(); -constexpr int kNumPitchBufAutoCorrCoeffs = 147; -constexpr int kNumPitchBufSquareEnergies = 385; -constexpr int kPitchTestDataSize = - kBufSize24kHz + kNumPitchBufSquareEnergies + kNumPitchBufAutoCorrCoeffs; +// Creates a reader for the 24 kHz pitch buffer test data. +ChunksFileReader CreatePitchBuffer24kHzReader(); + +// Creates a reader for the LP residual and pitch information test data. +ChunksFileReader CreateLpResidualAndPitchInfoReader(); + +// Creates a reader for the VAD probabilities test data. +std::unique_ptr CreateVadProbsReader(); // Class to retrieve a test pitch buffer content and the expected output for the // analysis steps. @@ -142,17 +86,40 @@ class PitchTestData { public: PitchTestData(); ~PitchTestData(); - rtc::ArrayView GetPitchBufView() const; - rtc::ArrayView - GetPitchBufSquareEnergiesView() const; - rtc::ArrayView - GetPitchBufAutoCorrCoeffsView() const; + rtc::ArrayView PitchBuffer24kHzView() const { + return pitch_buffer_24k_; + } + rtc::ArrayView SquareEnergies24kHzView() + const { + return square_energies_24k_; + } + rtc::ArrayView AutoCorrelation12kHzView() const { + return auto_correlation_12k_; + } private: - std::array test_data_; + std::array pitch_buffer_24k_; + std::array square_energies_24k_; + std::array auto_correlation_12k_; +}; + +// Writer for binary files. +class FileWriter { + public: + explicit FileWriter(const std::string& file_path) + : os_(file_path, std::ios::binary) {} + FileWriter(const FileWriter&) = delete; + FileWriter& operator=(const FileWriter&) = delete; + ~FileWriter() = default; + void WriteChunk(rtc::ArrayView value) { + const std::streamsize bytes_to_write = value.size() * sizeof(float); + os_.write(reinterpret_cast(value.data()), bytes_to_write); + } + + private: + std::ofstream os_; }; -} // namespace test } // namespace rnn_vad } // namespace webrtc