RNN VAD: unit test code clean-up

- test_utils.h/.cc simplified
- webrtc::rnnvad::test -> webrtc::rnnvad
- all unit test code inside the anonymous namespace
- names improved

Bug: webrtc:10480
Change-Id: I0a0f056f9728bb8a1b93006b95d7ed5bf5bd4adb
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/196509
Commit-Queue: Alessio Bazzica <alessiob@webrtc.org>
Reviewed-by: Sam Zackrisson <saza@webrtc.org>
Cr-Commit-Position: refs/heads/master@{#32789}
This commit is contained in:
Alessio Bazzica
2020-12-07 17:02:22 +01:00
committed by Commit Bot
parent bc7e5ac1c9
commit bb1a28de3c
17 changed files with 211 additions and 261 deletions

View File

@ -17,15 +17,15 @@
namespace webrtc {
namespace rnn_vad {
namespace test {
namespace {
// Checks that the auto correlation function produces output within tolerance
// given test input data.
TEST(RnnVadTest, PitchBufferAutoCorrelationWithinTolerance) {
PitchTestData test_data;
std::array<float, kBufSize12kHz> pitch_buf_decimated;
Decimate2x(test_data.GetPitchBufView(), pitch_buf_decimated);
std::array<float, kNumPitchBufAutoCorrCoeffs> computed_output;
Decimate2x(test_data.PitchBuffer24kHzView(), pitch_buf_decimated);
std::array<float, kNumLags12kHz> computed_output;
{
// TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
// FloatingPointExceptionObserver fpe_observer;
@ -33,7 +33,7 @@ TEST(RnnVadTest, PitchBufferAutoCorrelationWithinTolerance) {
auto_corr_calculator.ComputeOnPitchBuffer(pitch_buf_decimated,
computed_output);
}
auto auto_corr_view = test_data.GetPitchBufAutoCorrCoeffsView();
auto auto_corr_view = test_data.AutoCorrelation12kHzView();
ExpectNearAbsolute({auto_corr_view.data(), auto_corr_view.size()},
computed_output, 3e-3f);
}
@ -44,7 +44,7 @@ TEST(RnnVadTest, CheckAutoCorrelationOnConstantPitchBuffer) {
// Create constant signal with no pitch.
std::array<float, kBufSize12kHz> pitch_buf_decimated;
std::fill(pitch_buf_decimated.begin(), pitch_buf_decimated.end(), 1.f);
std::array<float, kNumPitchBufAutoCorrCoeffs> computed_output;
std::array<float, kNumLags12kHz> computed_output;
{
// TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
// FloatingPointExceptionObserver fpe_observer;
@ -55,12 +55,12 @@ TEST(RnnVadTest, CheckAutoCorrelationOnConstantPitchBuffer) {
// The expected output is a vector filled with the same expected
// auto-correlation value. The latter equals the length of a 20 ms frame.
constexpr int kFrameSize20ms12kHz = kFrameSize20ms24kHz / 2;
std::array<float, kNumPitchBufAutoCorrCoeffs> expected_output;
std::array<float, kNumLags12kHz> expected_output;
std::fill(expected_output.begin(), expected_output.end(),
static_cast<float>(kFrameSize20ms12kHz));
ExpectNearAbsolute(expected_output, computed_output, 4e-5f);
}
} // namespace test
} // namespace
} // namespace rnn_vad
} // namespace webrtc

View File

@ -14,7 +14,6 @@
#include <vector>
#include "modules/audio_processing/agc2/cpu_features.h"
#include "modules/audio_processing/agc2/rnn_vad/test_utils.h"
#include "rtc_base/numerics/safe_compare.h"
#include "rtc_base/numerics/safe_conversions.h"
// TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
@ -23,7 +22,6 @@
namespace webrtc {
namespace rnn_vad {
namespace test {
namespace {
constexpr int ceil(int n, int m) {
@ -52,7 +50,7 @@ void CreatePureTone(float amplitude, float freq_hz, rtc::ArrayView<float> dst) {
// Feeds |features_extractor| with |samples| splitting it in 10 ms frames.
// For every frame, the output is written into |feature_vector|. Returns true
// if silence is detected in the last frame.
bool FeedTestData(FeaturesExtractor* features_extractor,
bool FeedTestData(FeaturesExtractor& features_extractor,
rtc::ArrayView<const float> samples,
rtc::ArrayView<float, kFeatureVectorSize> feature_vector) {
// TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
@ -60,15 +58,13 @@ bool FeedTestData(FeaturesExtractor* features_extractor,
bool is_silence = true;
const int num_frames = samples.size() / kFrameSize10ms24kHz;
for (int i = 0; i < num_frames; ++i) {
is_silence = features_extractor->CheckSilenceComputeFeatures(
is_silence = features_extractor.CheckSilenceComputeFeatures(
{samples.data() + i * kFrameSize10ms24kHz, kFrameSize10ms24kHz},
feature_vector);
}
return is_silence;
}
} // namespace
// Extracts the features for two pure tones and verifies that the pitch field
// values reflect the known tone frequencies.
TEST(RnnVadTest, FeatureExtractionLowHighPitch) {
@ -91,17 +87,17 @@ TEST(RnnVadTest, FeatureExtractionLowHighPitch) {
constexpr int pitch_feature_index = kFeatureVectorSize - 2;
// Low frequency tone - i.e., high period.
CreatePureTone(amplitude, low_pitch_hz, samples);
ASSERT_FALSE(FeedTestData(&features_extractor, samples, feature_vector_view));
ASSERT_FALSE(FeedTestData(features_extractor, samples, feature_vector_view));
float high_pitch_period = feature_vector_view[pitch_feature_index];
// High frequency tone - i.e., low period.
features_extractor.Reset();
CreatePureTone(amplitude, high_pitch_hz, samples);
ASSERT_FALSE(FeedTestData(&features_extractor, samples, feature_vector_view));
ASSERT_FALSE(FeedTestData(features_extractor, samples, feature_vector_view));
float low_pitch_period = feature_vector_view[pitch_feature_index];
// Check.
EXPECT_LT(low_pitch_period, high_pitch_period);
}
} // namespace test
} // namespace
} // namespace rnn_vad
} // namespace webrtc

View File

@ -18,7 +18,7 @@
namespace webrtc {
namespace rnn_vad {
// LPC inverse filter length.
// Linear predictive coding (LPC) inverse filter length.
constexpr int kNumLpcCoefficients = 5;
// Given a frame |x|, computes a post-processed version of LPC coefficients

View File

@ -22,7 +22,7 @@
namespace webrtc {
namespace rnn_vad {
namespace test {
namespace {
// Checks that the LP residual can be computed on an empty frame.
TEST(RnnVadTest, LpResidualOfEmptyFrame) {
@ -33,55 +33,48 @@ TEST(RnnVadTest, LpResidualOfEmptyFrame) {
std::array<float, kFrameSize10ms24kHz> empty_frame;
empty_frame.fill(0.f);
// Compute inverse filter coefficients.
std::array<float, kNumLpcCoefficients> lpc_coeffs;
ComputeAndPostProcessLpcCoefficients(empty_frame, lpc_coeffs);
std::array<float, kNumLpcCoefficients> lpc;
ComputeAndPostProcessLpcCoefficients(empty_frame, lpc);
// Compute LP residual.
std::array<float, kFrameSize10ms24kHz> lp_residual;
ComputeLpResidual(lpc_coeffs, empty_frame, lp_residual);
ComputeLpResidual(lpc, empty_frame, lp_residual);
}
// Checks that the computed LP residual is bit-exact given test input data.
TEST(RnnVadTest, LpResidualPipelineBitExactness) {
// Input and expected output readers.
auto pitch_buf_24kHz_reader = CreatePitchBuffer24kHzReader();
auto lp_residual_reader = CreateLpResidualAndPitchPeriodGainReader();
ChunksFileReader pitch_buffer_reader = CreatePitchBuffer24kHzReader();
ChunksFileReader lp_pitch_reader = CreateLpResidualAndPitchInfoReader();
// Buffers.
std::vector<float> pitch_buf_data(kBufSize24kHz);
std::array<float, kNumLpcCoefficients> lpc_coeffs;
std::vector<float> pitch_buffer_24kHz(kBufSize24kHz);
std::array<float, kNumLpcCoefficients> lpc;
std::vector<float> computed_lp_residual(kBufSize24kHz);
std::vector<float> expected_lp_residual(kBufSize24kHz);
// Test length.
const int num_frames =
std::min(pitch_buf_24kHz_reader.second, 300); // Max 3 s.
ASSERT_GE(lp_residual_reader.second, num_frames);
std::min(pitch_buffer_reader.num_chunks, 300); // Max 3 s.
ASSERT_GE(lp_pitch_reader.num_chunks, num_frames);
{
// TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
// FloatingPointExceptionObserver fpe_observer;
for (int i = 0; i < num_frames; ++i) {
// Read input.
ASSERT_TRUE(pitch_buf_24kHz_reader.first->ReadChunk(pitch_buf_data));
// Read expected output (ignore pitch gain and period).
ASSERT_TRUE(lp_residual_reader.first->ReadChunk(expected_lp_residual));
float unused;
ASSERT_TRUE(lp_residual_reader.first->ReadValue(&unused));
ASSERT_TRUE(lp_residual_reader.first->ReadValue(&unused));
// Check every 200 ms.
if (i % 20 != 0) {
continue;
}
SCOPED_TRACE(i);
ComputeAndPostProcessLpcCoefficients(pitch_buf_data, lpc_coeffs);
ComputeLpResidual(lpc_coeffs, pitch_buf_data, computed_lp_residual);
// TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
// FloatingPointExceptionObserver fpe_observer;
for (int i = 0; i < num_frames; ++i) {
SCOPED_TRACE(i);
// Read input.
ASSERT_TRUE(pitch_buffer_reader.reader->ReadChunk(pitch_buffer_24kHz));
// Read expected output (ignore pitch gain and period).
ASSERT_TRUE(lp_pitch_reader.reader->ReadChunk(expected_lp_residual));
lp_pitch_reader.reader->SeekForward(2); // Pitch period and strength.
// Check every 200 ms.
if (i % 20 == 0) {
ComputeAndPostProcessLpcCoefficients(pitch_buffer_24kHz, lpc);
ComputeLpResidual(lpc, pitch_buffer_24kHz, computed_lp_residual);
ExpectNearAbsolute(expected_lp_residual, computed_lp_residual, kFloatMin);
}
}
}
} // namespace test
} // namespace
} // namespace rnn_vad
} // namespace webrtc

View File

@ -22,7 +22,6 @@
namespace webrtc {
namespace rnn_vad {
namespace test {
namespace {
constexpr int kTestPitchPeriodsLow = 3 * kMinPitch48kHz / 2;
@ -63,12 +62,12 @@ TEST(RnnVadTest, ComputeSlidingFrameSquareEnergies24kHzWithinTolerance) {
const AvailableCpuFeatures cpu_features = GetAvailableCpuFeatures();
PitchTestData test_data;
std::array<float, kNumPitchBufSquareEnergies> computed_output;
std::array<float, kRefineNumLags24kHz> computed_output;
// TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
// FloatingPointExceptionObserver fpe_observer;
ComputeSlidingFrameSquareEnergies24kHz(test_data.GetPitchBufView(),
ComputeSlidingFrameSquareEnergies24kHz(test_data.PitchBuffer24kHzView(),
computed_output, cpu_features);
auto square_energies_view = test_data.GetPitchBufSquareEnergiesView();
auto square_energies_view = test_data.SquareEnergies24kHzView();
ExpectNearAbsolute({square_energies_view.data(), square_energies_view.size()},
computed_output, 1e-3f);
}
@ -79,13 +78,12 @@ TEST(RnnVadTest, ComputePitchPeriod12kHzBitExactness) {
PitchTestData test_data;
std::array<float, kBufSize12kHz> pitch_buf_decimated;
Decimate2x(test_data.GetPitchBufView(), pitch_buf_decimated);
Decimate2x(test_data.PitchBuffer24kHzView(), pitch_buf_decimated);
CandidatePitchPeriods pitch_candidates;
// TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
// FloatingPointExceptionObserver fpe_observer;
auto auto_corr_view = test_data.GetPitchBufAutoCorrCoeffsView();
pitch_candidates = ComputePitchPeriod12kHz(pitch_buf_decimated,
auto_corr_view, cpu_features);
pitch_candidates = ComputePitchPeriod12kHz(
pitch_buf_decimated, test_data.AutoCorrelation12kHzView(), cpu_features);
EXPECT_EQ(pitch_candidates.best, 140);
EXPECT_EQ(pitch_candidates.second_best, 142);
}
@ -98,16 +96,16 @@ TEST(RnnVadTest, ComputePitchPeriod48kHzBitExactness) {
std::vector<float> y_energy(kRefineNumLags24kHz);
rtc::ArrayView<float, kRefineNumLags24kHz> y_energy_view(y_energy.data(),
kRefineNumLags24kHz);
ComputeSlidingFrameSquareEnergies24kHz(test_data.GetPitchBufView(),
ComputeSlidingFrameSquareEnergies24kHz(test_data.PitchBuffer24kHzView(),
y_energy_view, cpu_features);
// TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
// FloatingPointExceptionObserver fpe_observer;
EXPECT_EQ(
ComputePitchPeriod48kHz(test_data.GetPitchBufView(), y_energy_view,
ComputePitchPeriod48kHz(test_data.PitchBuffer24kHzView(), y_energy_view,
/*pitch_candidates=*/{280, 284}, cpu_features),
560);
EXPECT_EQ(
ComputePitchPeriod48kHz(test_data.GetPitchBufView(), y_energy_view,
ComputePitchPeriod48kHz(test_data.PitchBuffer24kHzView(), y_energy_view,
/*pitch_candidates=*/{260, 284}, cpu_features),
568);
}
@ -132,12 +130,12 @@ TEST_P(PitchCandidatesParametrization,
std::vector<float> y_energy(kRefineNumLags24kHz);
rtc::ArrayView<float, kRefineNumLags24kHz> y_energy_view(y_energy.data(),
kRefineNumLags24kHz);
ComputeSlidingFrameSquareEnergies24kHz(test_data.GetPitchBufView(),
ComputeSlidingFrameSquareEnergies24kHz(test_data.PitchBuffer24kHzView(),
y_energy_view, params.cpu_features);
EXPECT_EQ(
ComputePitchPeriod48kHz(test_data.GetPitchBufView(), y_energy_view,
ComputePitchPeriod48kHz(test_data.PitchBuffer24kHzView(), y_energy_view,
params.pitch_candidates, params.cpu_features),
ComputePitchPeriod48kHz(test_data.GetPitchBufView(), y_energy_view,
ComputePitchPeriod48kHz(test_data.PitchBuffer24kHzView(), y_energy_view,
swapped_pitch_candidates, params.cpu_features));
}
@ -179,13 +177,13 @@ TEST_P(ExtendedPitchPeriodSearchParametrizaion,
std::vector<float> y_energy(kRefineNumLags24kHz);
rtc::ArrayView<float, kRefineNumLags24kHz> y_energy_view(y_energy.data(),
kRefineNumLags24kHz);
ComputeSlidingFrameSquareEnergies24kHz(test_data.GetPitchBufView(),
ComputeSlidingFrameSquareEnergies24kHz(test_data.PitchBuffer24kHzView(),
y_energy_view, params.cpu_features);
// TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
// FloatingPointExceptionObserver fpe_observer;
const auto computed_output = ComputeExtendedPitchPeriod48kHz(
test_data.GetPitchBufView(), y_energy_view, params.initial_pitch_period,
params.last_pitch, params.cpu_features);
test_data.PitchBuffer24kHzView(), y_energy_view,
params.initial_pitch_period, params.last_pitch, params.cpu_features);
EXPECT_EQ(params.expected_pitch.period, computed_output.period);
EXPECT_NEAR(params.expected_pitch.strength, computed_output.strength, 1e-6f);
}
@ -219,6 +217,5 @@ INSTANTIATE_TEST_SUITE_P(
PrintTestIndexAndCpuFeatures<ExtendedPitchPeriodSearchParameters>);
} // namespace
} // namespace test
} // namespace rnn_vad
} // namespace webrtc

View File

@ -26,8 +26,8 @@ namespace rnn_vad {
// Checks that the computed pitch period is bit-exact and that the computed
// pitch gain is within tolerance given test input data.
TEST(RnnVadTest, PitchSearchWithinTolerance) {
auto lp_residual_reader = test::CreateLpResidualAndPitchPeriodGainReader();
const int num_frames = std::min(lp_residual_reader.second, 300); // Max 3 s.
ChunksFileReader reader = CreateLpResidualAndPitchInfoReader();
const int num_frames = std::min(reader.num_chunks, 300); // Max 3 s.
std::vector<float> lp_residual(kBufSize24kHz);
float expected_pitch_period, expected_pitch_strength;
const AvailableCpuFeatures cpu_features = GetAvailableCpuFeatures();
@ -37,9 +37,9 @@ TEST(RnnVadTest, PitchSearchWithinTolerance) {
// FloatingPointExceptionObserver fpe_observer;
for (int i = 0; i < num_frames; ++i) {
SCOPED_TRACE(i);
lp_residual_reader.first->ReadChunk(lp_residual);
lp_residual_reader.first->ReadValue(&expected_pitch_period);
lp_residual_reader.first->ReadValue(&expected_pitch_strength);
ASSERT_TRUE(reader.reader->ReadChunk(lp_residual));
ASSERT_TRUE(reader.reader->ReadValue(expected_pitch_period));
ASSERT_TRUE(reader.reader->ReadValue(expected_pitch_strength));
int pitch_period =
pitch_estimator.Estimate({lp_residual.data(), kBufSize24kHz});
EXPECT_EQ(expected_pitch_period, pitch_period);

View File

@ -14,7 +14,6 @@
namespace webrtc {
namespace rnn_vad {
namespace test {
namespace {
// Compare the elements of two given array views.
@ -64,8 +63,6 @@ void TestRingBuffer() {
}
}
} // namespace
// Check that for different delays, different views are returned.
TEST(RnnVadTest, RingBufferArrayViews) {
constexpr int s = 3;
@ -110,6 +107,6 @@ TEST(RnnVadTest, RingBufferFloating) {
TestRingBuffer<float, 5, 5>();
}
} // namespace test
} // namespace
} // namespace rnn_vad
} // namespace webrtc

View File

@ -24,7 +24,6 @@
namespace webrtc {
namespace rnn_vad {
namespace test {
namespace {
using ::rnnoise::kInputDenseBias;
@ -104,6 +103,5 @@ INSTANTIATE_TEST_SUITE_P(
});
} // namespace
} // namespace test
} // namespace rnn_vad
} // namespace webrtc

View File

@ -21,7 +21,6 @@
namespace webrtc {
namespace rnn_vad {
namespace test {
namespace {
void TestGatedRecurrentLayer(
@ -135,6 +134,5 @@ TEST(RnnVadTest, DISABLED_BenchmarkGatedRecurrentLayer) {
}
} // namespace
} // namespace test
} // namespace rnn_vad
} // namespace webrtc

View File

@ -17,7 +17,6 @@
namespace webrtc {
namespace rnn_vad {
namespace test {
namespace {
constexpr std::array<float, kFeatureVectorSize> kFeatures = {
@ -67,6 +66,5 @@ TEST(RnnVadTest, CheckRnnVadSilence) {
}
} // namespace
} // namespace test
} // namespace rnn_vad
} // namespace webrtc

View File

@ -9,6 +9,7 @@
*/
#include <array>
#include <memory>
#include <string>
#include <vector>
@ -26,7 +27,6 @@
namespace webrtc {
namespace rnn_vad {
namespace test {
namespace {
constexpr int kFrameSize10ms48kHz = 480;
@ -49,8 +49,6 @@ void DumpPerfStats(int num_samples,
// constant below to true in order to write new expected output binary files.
constexpr bool kWriteComputedOutputToFile = false;
} // namespace
// Avoids that one forgets to set |kWriteComputedOutputToFile| back to false
// when the expected output files are re-exported.
TEST(RnnVadTest, CheckWriteComputedOutputIsFalse) {
@ -71,12 +69,11 @@ TEST_P(RnnVadProbabilityParametrization, RnnVadProbabilityWithinTolerance) {
RnnVad rnn_vad(cpu_features);
// Init input samples and expected output readers.
auto samples_reader = CreatePcmSamplesReader(kFrameSize10ms48kHz);
auto expected_vad_prob_reader = CreateVadProbsReader();
std::unique_ptr<FileReader> samples_reader = CreatePcmSamplesReader();
std::unique_ptr<FileReader> expected_vad_prob_reader = CreateVadProbsReader();
// Input length.
const int num_frames = samples_reader.second;
ASSERT_GE(expected_vad_prob_reader.second, num_frames);
// Input length. The last incomplete frame is ignored.
const int num_frames = samples_reader->size() / kFrameSize10ms48kHz;
// Init buffers.
std::vector<float> samples_48k(kFrameSize10ms48kHz);
@ -86,12 +83,12 @@ TEST_P(RnnVadProbabilityParametrization, RnnVadProbabilityWithinTolerance) {
std::vector<float> expected_vad_prob(num_frames);
// Read expected output.
ASSERT_TRUE(expected_vad_prob_reader.first->ReadChunk(expected_vad_prob));
ASSERT_TRUE(expected_vad_prob_reader->ReadChunk(expected_vad_prob));
// Compute VAD probabilities on the downsampled input.
float cumulative_error = 0.f;
for (int i = 0; i < num_frames; ++i) {
samples_reader.first->ReadChunk(samples_48k);
ASSERT_TRUE(samples_reader->ReadChunk(samples_48k));
decimator.Resample(samples_48k.data(), samples_48k.size(),
samples_24k.data(), samples_24k.size());
bool is_silence = features_extractor.CheckSilenceComputeFeatures(
@ -106,7 +103,7 @@ TEST_P(RnnVadProbabilityParametrization, RnnVadProbabilityWithinTolerance) {
EXPECT_LT(cumulative_error / num_frames, 1e-4f);
if (kWriteComputedOutputToFile) {
BinaryFileWriter<float> vad_prob_writer("new_vad_prob.dat");
FileWriter vad_prob_writer("new_vad_prob.dat");
vad_prob_writer.WriteChunk(computed_vad_prob);
}
}
@ -118,15 +115,16 @@ TEST_P(RnnVadProbabilityParametrization, RnnVadProbabilityWithinTolerance) {
// - on android: run the this unit test adding "--logcat-output-file".
TEST_P(RnnVadProbabilityParametrization, DISABLED_RnnVadPerformance) {
// PCM samples reader and buffers.
auto samples_reader = CreatePcmSamplesReader(kFrameSize10ms48kHz);
const int num_frames = samples_reader.second;
std::unique_ptr<FileReader> samples_reader = CreatePcmSamplesReader();
// The last incomplete frame is ignored.
const int num_frames = samples_reader->size() / kFrameSize10ms48kHz;
std::array<float, kFrameSize10ms48kHz> samples;
// Pre-fetch and decimate samples.
PushSincResampler decimator(kFrameSize10ms48kHz, kFrameSize10ms24kHz);
std::vector<float> prefetched_decimated_samples;
prefetched_decimated_samples.resize(num_frames * kFrameSize10ms24kHz);
for (int i = 0; i < num_frames; ++i) {
samples_reader.first->ReadChunk(samples);
ASSERT_TRUE(samples_reader->ReadChunk(samples));
decimator.Resample(samples.data(), samples.size(),
&prefetched_decimated_samples[i * kFrameSize10ms24kHz],
kFrameSize10ms24kHz);
@ -151,7 +149,6 @@ TEST_P(RnnVadProbabilityParametrization, DISABLED_RnnVadPerformance) {
rnn_vad.ComputeVadProbability(feature_vector, is_silence);
}
perf_timer.StopTimer();
samples_reader.first->SeekBeginning();
}
DumpPerfStats(num_frames * kFrameSize10ms24kHz, kSampleRate24kHz,
perf_timer.GetDurationAverage(),
@ -180,6 +177,6 @@ INSTANTIATE_TEST_SUITE_P(
return info.param.ToString();
});
} // namespace test
} // namespace
} // namespace rnn_vad
} // namespace webrtc

View File

@ -17,7 +17,6 @@
namespace webrtc {
namespace rnn_vad {
namespace test {
namespace {
template <typename T, int S, int N>
@ -60,8 +59,6 @@ void TestSequenceBufferPushOp() {
}
}
} // namespace
TEST(RnnVadTest, SequenceBufferGetters) {
constexpr int buffer_size = 8;
constexpr int chunk_size = 8;
@ -100,6 +97,6 @@ TEST(RnnVadTest, SequenceBufferPushOpsFloating) {
TestSequenceBufferPushOp<float, 23, 7>(); // Non-integer ratio.
}
} // namespace test
} // namespace
} // namespace rnn_vad
} // namespace webrtc

View File

@ -26,7 +26,6 @@
namespace webrtc {
namespace rnn_vad {
namespace test {
namespace {
// Generates the values for the array named |kOpusBandWeights24kHz20ms| in the
@ -49,8 +48,6 @@ std::vector<float> ComputeTriangularFiltersWeights() {
return weights;
}
} // namespace
// Checks that the values returned by GetOpusScaleNumBins24kHz20ms() match the
// Opus scale frequency boundaries.
TEST(RnnVadTest, TestOpusScaleBoundaries) {
@ -158,6 +155,6 @@ TEST(RnnVadTest, ComputeDctWithinTolerance) {
}
}
} // namespace test
} // namespace
} // namespace rnn_vad
} // namespace webrtc

View File

@ -21,7 +21,6 @@
namespace webrtc {
namespace rnn_vad {
namespace test {
namespace {
constexpr int kTestFeatureVectorSize = kNumBands + 3 * kNumLowerBands + 1;
@ -66,8 +65,6 @@ float* GetCepstralVariability(
constexpr float kInitialFeatureVal = -9999.f;
} // namespace
// Checks that silence is detected when the input signal is 0 and that the
// feature vector is written only if the input signal is not tagged as silence.
TEST(RnnVadTest, SpectralFeaturesWithAndWithoutSilence) {
@ -159,6 +156,6 @@ TEST(RnnVadTest, CepstralFeaturesConstantAverageZeroDerivative) {
feature_vector_last[kNumBands + 3 * kNumLowerBands]);
}
} // namespace test
} // namespace
} // namespace rnn_vad
} // namespace webrtc

View File

@ -15,7 +15,6 @@
namespace webrtc {
namespace rnn_vad {
namespace test {
namespace {
template <typename T, int S>
@ -44,8 +43,6 @@ bool CheckPairsWithValueExist(
return false;
}
} // namespace
// Test that shows how to combine RingBuffer and SymmetricMatrixBuffer to
// efficiently compute pair-wise scores. This test verifies that the evolution
// of a SymmetricMatrixBuffer instance follows that of RingBuffer.
@ -105,6 +102,6 @@ TEST(RnnVadTest, SymmetricMatrixBufferUseCase) {
}
}
} // namespace test
} // namespace
} // namespace rnn_vad
} // namespace webrtc

View File

@ -11,7 +11,10 @@
#include "modules/audio_processing/agc2/rnn_vad/test_utils.h"
#include <algorithm>
#include <fstream>
#include <memory>
#include <type_traits>
#include <vector>
#include "rtc_base/checks.h"
#include "rtc_base/numerics/safe_compare.h"
@ -20,11 +23,46 @@
namespace webrtc {
namespace rnn_vad {
namespace test {
namespace {
using ReaderPairType =
std::pair<std::unique_ptr<BinaryFileReader<float>>, const int>;
// File reader for binary files that contain a sequence of values with
// arithmetic type `T`. The values of type `T` that are read are cast to float.
template <typename T>
class FloatFileReader : public FileReader {
public:
static_assert(std::is_arithmetic<T>::value, "");
FloatFileReader(const std::string& filename)
: is_(filename, std::ios::binary | std::ios::ate),
size_(is_.tellg() / sizeof(T)) {
RTC_CHECK(is_);
SeekBeginning();
}
FloatFileReader(const FloatFileReader&) = delete;
FloatFileReader& operator=(const FloatFileReader&) = delete;
~FloatFileReader() = default;
int size() const override { return size_; }
bool ReadChunk(rtc::ArrayView<float> dst) override {
const std::streamsize bytes_to_read = dst.size() * sizeof(T);
if (std::is_same<T, float>::value) {
is_.read(reinterpret_cast<char*>(dst.data()), bytes_to_read);
} else {
buffer_.resize(dst.size());
is_.read(reinterpret_cast<char*>(buffer_.data()), bytes_to_read);
std::transform(buffer_.begin(), buffer_.end(), dst.begin(),
[](const T& v) -> float { return static_cast<float>(v); });
}
return is_.gcount() == bytes_to_read;
}
bool ReadValue(float& dst) override { return ReadChunk({&dst, 1}); }
void SeekForward(int hop) override { is_.seekg(hop * sizeof(T), is_.cur); }
void SeekBeginning() override { is_.seekg(0, is_.beg); }
private:
std::ifstream is_;
const int size_;
std::vector<T> buffer_;
};
} // namespace
@ -49,66 +87,49 @@ void ExpectNearAbsolute(rtc::ArrayView<const float> expected,
}
}
std::pair<std::unique_ptr<BinaryFileReader<int16_t, float>>, const int>
CreatePcmSamplesReader(const int frame_length) {
auto ptr = std::make_unique<BinaryFileReader<int16_t, float>>(
test::ResourcePath("audio_processing/agc2/rnn_vad/samples", "pcm"),
frame_length);
// The last incomplete frame is ignored.
return {std::move(ptr), ptr->data_length() / frame_length};
std::unique_ptr<FileReader> CreatePcmSamplesReader() {
return std::make_unique<FloatFileReader<int16_t>>(
/*filename=*/test::ResourcePath("audio_processing/agc2/rnn_vad/samples",
"pcm"));
}
ReaderPairType CreatePitchBuffer24kHzReader() {
constexpr int cols = 864;
auto ptr = std::make_unique<BinaryFileReader<float>>(
ResourcePath("audio_processing/agc2/rnn_vad/pitch_buf_24k", "dat"), cols);
return {std::move(ptr), rtc::CheckedDivExact(ptr->data_length(), cols)};
ChunksFileReader CreatePitchBuffer24kHzReader() {
auto reader = std::make_unique<FloatFileReader<float>>(
/*filename=*/test::ResourcePath(
"audio_processing/agc2/rnn_vad/pitch_buf_24k", "dat"));
const int num_chunks = rtc::CheckedDivExact(reader->size(), kBufSize24kHz);
return {/*chunk_size=*/kBufSize24kHz, num_chunks, std::move(reader)};
}
ReaderPairType CreateLpResidualAndPitchPeriodGainReader() {
constexpr int num_lp_residual_coeffs = 864;
auto ptr = std::make_unique<BinaryFileReader<float>>(
ResourcePath("audio_processing/agc2/rnn_vad/pitch_lp_res", "dat"),
num_lp_residual_coeffs);
return {std::move(ptr),
rtc::CheckedDivExact(ptr->data_length(), 2 + num_lp_residual_coeffs)};
ChunksFileReader CreateLpResidualAndPitchInfoReader() {
constexpr int kPitchInfoSize = 2; // Pitch period and strength.
constexpr int kChunkSize = kBufSize24kHz + kPitchInfoSize;
auto reader = std::make_unique<FloatFileReader<float>>(
/*filename=*/test::ResourcePath(
"audio_processing/agc2/rnn_vad/pitch_lp_res", "dat"));
const int num_chunks = rtc::CheckedDivExact(reader->size(), kChunkSize);
return {kChunkSize, num_chunks, std::move(reader)};
}
ReaderPairType CreateVadProbsReader() {
auto ptr = std::make_unique<BinaryFileReader<float>>(
test::ResourcePath("audio_processing/agc2/rnn_vad/vad_prob", "dat"));
return {std::move(ptr), ptr->data_length()};
std::unique_ptr<FileReader> CreateVadProbsReader() {
return std::make_unique<FloatFileReader<float>>(
/*filename=*/test::ResourcePath("audio_processing/agc2/rnn_vad/vad_prob",
"dat"));
}
PitchTestData::PitchTestData() {
BinaryFileReader<float> test_data_reader(
ResourcePath("audio_processing/agc2/rnn_vad/pitch_search_int", "dat"),
1396);
test_data_reader.ReadChunk(test_data_);
FloatFileReader<float> reader(
/*filename=*/ResourcePath(
"audio_processing/agc2/rnn_vad/pitch_search_int", "dat"));
reader.ReadChunk(pitch_buffer_24k_);
reader.ReadChunk(square_energies_24k_);
reader.ReadChunk(auto_correlation_12k_);
// Reverse the order of the squared energy values.
// Required after the WebRTC CL 191703 which switched to forward computation.
std::reverse(test_data_.begin() + kBufSize24kHz,
test_data_.begin() + kBufSize24kHz + kNumPitchBufSquareEnergies);
std::reverse(square_energies_24k_.begin(), square_energies_24k_.end());
}
PitchTestData::~PitchTestData() = default;
rtc::ArrayView<const float, kBufSize24kHz> PitchTestData::GetPitchBufView()
const {
return {test_data_.data(), kBufSize24kHz};
}
rtc::ArrayView<const float, kNumPitchBufSquareEnergies>
PitchTestData::GetPitchBufSquareEnergiesView() const {
return {test_data_.data() + kBufSize24kHz, kNumPitchBufSquareEnergies};
}
rtc::ArrayView<const float, kNumPitchBufAutoCorrCoeffs>
PitchTestData::GetPitchBufAutoCorrCoeffsView() const {
return {test_data_.data() + kBufSize24kHz + kNumPitchBufSquareEnergies,
kNumPitchBufAutoCorrCoeffs};
}
} // namespace test
} // namespace rnn_vad
} // namespace webrtc

View File

@ -11,15 +11,10 @@
#ifndef MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_TEST_UTILS_H_
#define MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_TEST_UTILS_H_
#include <algorithm>
#include <array>
#include <fstream>
#include <limits>
#include <memory>
#include <string>
#include <type_traits>
#include <utility>
#include <vector>
#include "api/array_view.h"
#include "modules/audio_processing/agc2/rnn_vad/common.h"
@ -28,7 +23,6 @@
namespace webrtc {
namespace rnn_vad {
namespace test {
constexpr float kFloatMin = std::numeric_limits<float>::min();
@ -43,98 +37,48 @@ void ExpectNearAbsolute(rtc::ArrayView<const float> expected,
rtc::ArrayView<const float> computed,
float tolerance);
// Reader for binary files consisting of an arbitrary long sequence of elements
// having type T. It is possible to read and cast to another type D at once.
template <typename T, typename D = T>
class BinaryFileReader {
// File reader interface.
class FileReader {
public:
BinaryFileReader(const std::string& file_path, int chunk_size = 0)
: is_(file_path, std::ios::binary | std::ios::ate),
data_length_(is_.tellg() / sizeof(T)),
chunk_size_(chunk_size) {
RTC_CHECK(is_);
SeekBeginning();
buf_.resize(chunk_size_);
}
BinaryFileReader(const BinaryFileReader&) = delete;
BinaryFileReader& operator=(const BinaryFileReader&) = delete;
~BinaryFileReader() = default;
int data_length() const { return data_length_; }
bool ReadValue(D* dst) {
if (std::is_same<T, D>::value) {
is_.read(reinterpret_cast<char*>(dst), sizeof(T));
} else {
T v;
is_.read(reinterpret_cast<char*>(&v), sizeof(T));
*dst = static_cast<D>(v);
}
return is_.gcount() == sizeof(T);
}
// If |chunk_size| was specified in the ctor, it will check that the size of
// |dst| equals |chunk_size|.
bool ReadChunk(rtc::ArrayView<D> dst) {
RTC_DCHECK((chunk_size_ == 0) || rtc::SafeEq(chunk_size_, dst.size()));
const std::streamsize bytes_to_read = dst.size() * sizeof(T);
if (std::is_same<T, D>::value) {
is_.read(reinterpret_cast<char*>(dst.data()), bytes_to_read);
} else {
is_.read(reinterpret_cast<char*>(buf_.data()), bytes_to_read);
std::transform(buf_.begin(), buf_.end(), dst.begin(),
[](const T& v) -> D { return static_cast<D>(v); });
}
return is_.gcount() == bytes_to_read;
}
void SeekForward(int items) { is_.seekg(items * sizeof(T), is_.cur); }
void SeekBeginning() { is_.seekg(0, is_.beg); }
private:
std::ifstream is_;
const int data_length_;
const int chunk_size_;
std::vector<T> buf_;
virtual ~FileReader() = default;
// Number of values in the file.
virtual int size() const = 0;
// Reads `dst.size()` float values into `dst`, advances the internal file
// position according to the number of read bytes and returns true if the
// values are correctly read. If the number of remaining bytes in the file is
// not sufficient to read `dst.size()` float values, `dst` is partially
// modified and false is returned.
virtual bool ReadChunk(rtc::ArrayView<float> dst) = 0;
// Reads a single float value, advances the internal file position according
// to the number of read bytes and returns true if the value is correctly
// read. If the number of remaining bytes in the file is not sufficient to
// read one float, `dst` is not modified and false is returned.
virtual bool ReadValue(float& dst) = 0;
// Advances the internal file position by `hop` float values.
virtual void SeekForward(int hop) = 0;
// Resets the internal file position to BOF.
virtual void SeekBeginning() = 0;
};
// Writer for binary files.
template <typename T>
class BinaryFileWriter {
public:
explicit BinaryFileWriter(const std::string& file_path)
: os_(file_path, std::ios::binary) {}
BinaryFileWriter(const BinaryFileWriter&) = delete;
BinaryFileWriter& operator=(const BinaryFileWriter&) = delete;
~BinaryFileWriter() = default;
static_assert(std::is_arithmetic<T>::value, "");
void WriteChunk(rtc::ArrayView<const T> value) {
const std::streamsize bytes_to_write = value.size() * sizeof(T);
os_.write(reinterpret_cast<const char*>(value.data()), bytes_to_write);
}
private:
std::ofstream os_;
// File reader for files that contain `num_chunks` chunks with size equal to
// `chunk_size`.
struct ChunksFileReader {
const int chunk_size;
const int num_chunks;
std::unique_ptr<FileReader> reader;
};
// Factories for resource file readers.
// The functions below return a pair where the first item is a reader unique
// pointer and the second the number of chunks that can be read from the file.
// Creates a reader for the PCM samples that casts from S16 to float and reads
// chunks with length |frame_length|.
std::pair<std::unique_ptr<BinaryFileReader<int16_t, float>>, const int>
CreatePcmSamplesReader(const int frame_length);
// Creates a reader for the pitch buffer content at 24 kHz.
std::pair<std::unique_ptr<BinaryFileReader<float>>, const int>
CreatePitchBuffer24kHzReader();
// Creates a reader for the the LP residual coefficients and the pitch period
// and gain values.
std::pair<std::unique_ptr<BinaryFileReader<float>>, const int>
CreateLpResidualAndPitchPeriodGainReader();
// Creates a reader for the VAD probabilities.
std::pair<std::unique_ptr<BinaryFileReader<float>>, const int>
CreateVadProbsReader();
// Creates a reader for the PCM S16 samples file.
std::unique_ptr<FileReader> CreatePcmSamplesReader();
constexpr int kNumPitchBufAutoCorrCoeffs = 147;
constexpr int kNumPitchBufSquareEnergies = 385;
constexpr int kPitchTestDataSize =
kBufSize24kHz + kNumPitchBufSquareEnergies + kNumPitchBufAutoCorrCoeffs;
// Creates a reader for the 24 kHz pitch buffer test data.
ChunksFileReader CreatePitchBuffer24kHzReader();
// Creates a reader for the LP residual and pitch information test data.
ChunksFileReader CreateLpResidualAndPitchInfoReader();
// Creates a reader for the VAD probabilities test data.
std::unique_ptr<FileReader> CreateVadProbsReader();
// Class to retrieve a test pitch buffer content and the expected output for the
// analysis steps.
@ -142,17 +86,40 @@ class PitchTestData {
public:
PitchTestData();
~PitchTestData();
rtc::ArrayView<const float, kBufSize24kHz> GetPitchBufView() const;
rtc::ArrayView<const float, kNumPitchBufSquareEnergies>
GetPitchBufSquareEnergiesView() const;
rtc::ArrayView<const float, kNumPitchBufAutoCorrCoeffs>
GetPitchBufAutoCorrCoeffsView() const;
rtc::ArrayView<const float, kBufSize24kHz> PitchBuffer24kHzView() const {
return pitch_buffer_24k_;
}
rtc::ArrayView<const float, kRefineNumLags24kHz> SquareEnergies24kHzView()
const {
return square_energies_24k_;
}
rtc::ArrayView<const float, kNumLags12kHz> AutoCorrelation12kHzView() const {
return auto_correlation_12k_;
}
private:
std::array<float, kPitchTestDataSize> test_data_;
std::array<float, kBufSize24kHz> pitch_buffer_24k_;
std::array<float, kRefineNumLags24kHz> square_energies_24k_;
std::array<float, kNumLags12kHz> auto_correlation_12k_;
};
// Writer for binary files.
class FileWriter {
public:
explicit FileWriter(const std::string& file_path)
: os_(file_path, std::ios::binary) {}
FileWriter(const FileWriter&) = delete;
FileWriter& operator=(const FileWriter&) = delete;
~FileWriter() = default;
void WriteChunk(rtc::ArrayView<const float> value) {
const std::streamsize bytes_to_write = value.size() * sizeof(float);
os_.write(reinterpret_cast<const char*>(value.data()), bytes_to_write);
}
private:
std::ofstream os_;
};
} // namespace test
} // namespace rnn_vad
} // namespace webrtc