RNN VAD: unit test code clean-up
- test_utils.h/.cc simplified - webrtc::rnnvad::test -> webrtc::rnnvad - all unit test code inside the anonymous namespace - names improved Bug: webrtc:10480 Change-Id: I0a0f056f9728bb8a1b93006b95d7ed5bf5bd4adb Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/196509 Commit-Queue: Alessio Bazzica <alessiob@webrtc.org> Reviewed-by: Sam Zackrisson <saza@webrtc.org> Cr-Commit-Position: refs/heads/master@{#32789}
This commit is contained in:

committed by
Commit Bot

parent
bc7e5ac1c9
commit
bb1a28de3c
@ -17,15 +17,15 @@
|
||||
|
||||
namespace webrtc {
|
||||
namespace rnn_vad {
|
||||
namespace test {
|
||||
namespace {
|
||||
|
||||
// Checks that the auto correlation function produces output within tolerance
|
||||
// given test input data.
|
||||
TEST(RnnVadTest, PitchBufferAutoCorrelationWithinTolerance) {
|
||||
PitchTestData test_data;
|
||||
std::array<float, kBufSize12kHz> pitch_buf_decimated;
|
||||
Decimate2x(test_data.GetPitchBufView(), pitch_buf_decimated);
|
||||
std::array<float, kNumPitchBufAutoCorrCoeffs> computed_output;
|
||||
Decimate2x(test_data.PitchBuffer24kHzView(), pitch_buf_decimated);
|
||||
std::array<float, kNumLags12kHz> computed_output;
|
||||
{
|
||||
// TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
|
||||
// FloatingPointExceptionObserver fpe_observer;
|
||||
@ -33,7 +33,7 @@ TEST(RnnVadTest, PitchBufferAutoCorrelationWithinTolerance) {
|
||||
auto_corr_calculator.ComputeOnPitchBuffer(pitch_buf_decimated,
|
||||
computed_output);
|
||||
}
|
||||
auto auto_corr_view = test_data.GetPitchBufAutoCorrCoeffsView();
|
||||
auto auto_corr_view = test_data.AutoCorrelation12kHzView();
|
||||
ExpectNearAbsolute({auto_corr_view.data(), auto_corr_view.size()},
|
||||
computed_output, 3e-3f);
|
||||
}
|
||||
@ -44,7 +44,7 @@ TEST(RnnVadTest, CheckAutoCorrelationOnConstantPitchBuffer) {
|
||||
// Create constant signal with no pitch.
|
||||
std::array<float, kBufSize12kHz> pitch_buf_decimated;
|
||||
std::fill(pitch_buf_decimated.begin(), pitch_buf_decimated.end(), 1.f);
|
||||
std::array<float, kNumPitchBufAutoCorrCoeffs> computed_output;
|
||||
std::array<float, kNumLags12kHz> computed_output;
|
||||
{
|
||||
// TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
|
||||
// FloatingPointExceptionObserver fpe_observer;
|
||||
@ -55,12 +55,12 @@ TEST(RnnVadTest, CheckAutoCorrelationOnConstantPitchBuffer) {
|
||||
// The expected output is a vector filled with the same expected
|
||||
// auto-correlation value. The latter equals the length of a 20 ms frame.
|
||||
constexpr int kFrameSize20ms12kHz = kFrameSize20ms24kHz / 2;
|
||||
std::array<float, kNumPitchBufAutoCorrCoeffs> expected_output;
|
||||
std::array<float, kNumLags12kHz> expected_output;
|
||||
std::fill(expected_output.begin(), expected_output.end(),
|
||||
static_cast<float>(kFrameSize20ms12kHz));
|
||||
ExpectNearAbsolute(expected_output, computed_output, 4e-5f);
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
} // namespace
|
||||
} // namespace rnn_vad
|
||||
} // namespace webrtc
|
||||
|
@ -14,7 +14,6 @@
|
||||
#include <vector>
|
||||
|
||||
#include "modules/audio_processing/agc2/cpu_features.h"
|
||||
#include "modules/audio_processing/agc2/rnn_vad/test_utils.h"
|
||||
#include "rtc_base/numerics/safe_compare.h"
|
||||
#include "rtc_base/numerics/safe_conversions.h"
|
||||
// TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
|
||||
@ -23,7 +22,6 @@
|
||||
|
||||
namespace webrtc {
|
||||
namespace rnn_vad {
|
||||
namespace test {
|
||||
namespace {
|
||||
|
||||
constexpr int ceil(int n, int m) {
|
||||
@ -52,7 +50,7 @@ void CreatePureTone(float amplitude, float freq_hz, rtc::ArrayView<float> dst) {
|
||||
// Feeds |features_extractor| with |samples| splitting it in 10 ms frames.
|
||||
// For every frame, the output is written into |feature_vector|. Returns true
|
||||
// if silence is detected in the last frame.
|
||||
bool FeedTestData(FeaturesExtractor* features_extractor,
|
||||
bool FeedTestData(FeaturesExtractor& features_extractor,
|
||||
rtc::ArrayView<const float> samples,
|
||||
rtc::ArrayView<float, kFeatureVectorSize> feature_vector) {
|
||||
// TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
|
||||
@ -60,15 +58,13 @@ bool FeedTestData(FeaturesExtractor* features_extractor,
|
||||
bool is_silence = true;
|
||||
const int num_frames = samples.size() / kFrameSize10ms24kHz;
|
||||
for (int i = 0; i < num_frames; ++i) {
|
||||
is_silence = features_extractor->CheckSilenceComputeFeatures(
|
||||
is_silence = features_extractor.CheckSilenceComputeFeatures(
|
||||
{samples.data() + i * kFrameSize10ms24kHz, kFrameSize10ms24kHz},
|
||||
feature_vector);
|
||||
}
|
||||
return is_silence;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// Extracts the features for two pure tones and verifies that the pitch field
|
||||
// values reflect the known tone frequencies.
|
||||
TEST(RnnVadTest, FeatureExtractionLowHighPitch) {
|
||||
@ -91,17 +87,17 @@ TEST(RnnVadTest, FeatureExtractionLowHighPitch) {
|
||||
constexpr int pitch_feature_index = kFeatureVectorSize - 2;
|
||||
// Low frequency tone - i.e., high period.
|
||||
CreatePureTone(amplitude, low_pitch_hz, samples);
|
||||
ASSERT_FALSE(FeedTestData(&features_extractor, samples, feature_vector_view));
|
||||
ASSERT_FALSE(FeedTestData(features_extractor, samples, feature_vector_view));
|
||||
float high_pitch_period = feature_vector_view[pitch_feature_index];
|
||||
// High frequency tone - i.e., low period.
|
||||
features_extractor.Reset();
|
||||
CreatePureTone(amplitude, high_pitch_hz, samples);
|
||||
ASSERT_FALSE(FeedTestData(&features_extractor, samples, feature_vector_view));
|
||||
ASSERT_FALSE(FeedTestData(features_extractor, samples, feature_vector_view));
|
||||
float low_pitch_period = feature_vector_view[pitch_feature_index];
|
||||
// Check.
|
||||
EXPECT_LT(low_pitch_period, high_pitch_period);
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
} // namespace
|
||||
} // namespace rnn_vad
|
||||
} // namespace webrtc
|
||||
|
@ -18,7 +18,7 @@
|
||||
namespace webrtc {
|
||||
namespace rnn_vad {
|
||||
|
||||
// LPC inverse filter length.
|
||||
// Linear predictive coding (LPC) inverse filter length.
|
||||
constexpr int kNumLpcCoefficients = 5;
|
||||
|
||||
// Given a frame |x|, computes a post-processed version of LPC coefficients
|
||||
|
@ -22,7 +22,7 @@
|
||||
|
||||
namespace webrtc {
|
||||
namespace rnn_vad {
|
||||
namespace test {
|
||||
namespace {
|
||||
|
||||
// Checks that the LP residual can be computed on an empty frame.
|
||||
TEST(RnnVadTest, LpResidualOfEmptyFrame) {
|
||||
@ -33,55 +33,48 @@ TEST(RnnVadTest, LpResidualOfEmptyFrame) {
|
||||
std::array<float, kFrameSize10ms24kHz> empty_frame;
|
||||
empty_frame.fill(0.f);
|
||||
// Compute inverse filter coefficients.
|
||||
std::array<float, kNumLpcCoefficients> lpc_coeffs;
|
||||
ComputeAndPostProcessLpcCoefficients(empty_frame, lpc_coeffs);
|
||||
std::array<float, kNumLpcCoefficients> lpc;
|
||||
ComputeAndPostProcessLpcCoefficients(empty_frame, lpc);
|
||||
// Compute LP residual.
|
||||
std::array<float, kFrameSize10ms24kHz> lp_residual;
|
||||
ComputeLpResidual(lpc_coeffs, empty_frame, lp_residual);
|
||||
ComputeLpResidual(lpc, empty_frame, lp_residual);
|
||||
}
|
||||
|
||||
// Checks that the computed LP residual is bit-exact given test input data.
|
||||
TEST(RnnVadTest, LpResidualPipelineBitExactness) {
|
||||
// Input and expected output readers.
|
||||
auto pitch_buf_24kHz_reader = CreatePitchBuffer24kHzReader();
|
||||
auto lp_residual_reader = CreateLpResidualAndPitchPeriodGainReader();
|
||||
ChunksFileReader pitch_buffer_reader = CreatePitchBuffer24kHzReader();
|
||||
ChunksFileReader lp_pitch_reader = CreateLpResidualAndPitchInfoReader();
|
||||
|
||||
// Buffers.
|
||||
std::vector<float> pitch_buf_data(kBufSize24kHz);
|
||||
std::array<float, kNumLpcCoefficients> lpc_coeffs;
|
||||
std::vector<float> pitch_buffer_24kHz(kBufSize24kHz);
|
||||
std::array<float, kNumLpcCoefficients> lpc;
|
||||
std::vector<float> computed_lp_residual(kBufSize24kHz);
|
||||
std::vector<float> expected_lp_residual(kBufSize24kHz);
|
||||
|
||||
// Test length.
|
||||
const int num_frames =
|
||||
std::min(pitch_buf_24kHz_reader.second, 300); // Max 3 s.
|
||||
ASSERT_GE(lp_residual_reader.second, num_frames);
|
||||
std::min(pitch_buffer_reader.num_chunks, 300); // Max 3 s.
|
||||
ASSERT_GE(lp_pitch_reader.num_chunks, num_frames);
|
||||
|
||||
{
|
||||
// TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
|
||||
// FloatingPointExceptionObserver fpe_observer;
|
||||
for (int i = 0; i < num_frames; ++i) {
|
||||
// Read input.
|
||||
ASSERT_TRUE(pitch_buf_24kHz_reader.first->ReadChunk(pitch_buf_data));
|
||||
// Read expected output (ignore pitch gain and period).
|
||||
ASSERT_TRUE(lp_residual_reader.first->ReadChunk(expected_lp_residual));
|
||||
float unused;
|
||||
ASSERT_TRUE(lp_residual_reader.first->ReadValue(&unused));
|
||||
ASSERT_TRUE(lp_residual_reader.first->ReadValue(&unused));
|
||||
|
||||
// Check every 200 ms.
|
||||
if (i % 20 != 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
SCOPED_TRACE(i);
|
||||
ComputeAndPostProcessLpcCoefficients(pitch_buf_data, lpc_coeffs);
|
||||
ComputeLpResidual(lpc_coeffs, pitch_buf_data, computed_lp_residual);
|
||||
// TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
|
||||
// FloatingPointExceptionObserver fpe_observer;
|
||||
for (int i = 0; i < num_frames; ++i) {
|
||||
SCOPED_TRACE(i);
|
||||
// Read input.
|
||||
ASSERT_TRUE(pitch_buffer_reader.reader->ReadChunk(pitch_buffer_24kHz));
|
||||
// Read expected output (ignore pitch gain and period).
|
||||
ASSERT_TRUE(lp_pitch_reader.reader->ReadChunk(expected_lp_residual));
|
||||
lp_pitch_reader.reader->SeekForward(2); // Pitch period and strength.
|
||||
// Check every 200 ms.
|
||||
if (i % 20 == 0) {
|
||||
ComputeAndPostProcessLpcCoefficients(pitch_buffer_24kHz, lpc);
|
||||
ComputeLpResidual(lpc, pitch_buffer_24kHz, computed_lp_residual);
|
||||
ExpectNearAbsolute(expected_lp_residual, computed_lp_residual, kFloatMin);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
} // namespace
|
||||
} // namespace rnn_vad
|
||||
} // namespace webrtc
|
||||
|
@ -22,7 +22,6 @@
|
||||
|
||||
namespace webrtc {
|
||||
namespace rnn_vad {
|
||||
namespace test {
|
||||
namespace {
|
||||
|
||||
constexpr int kTestPitchPeriodsLow = 3 * kMinPitch48kHz / 2;
|
||||
@ -63,12 +62,12 @@ TEST(RnnVadTest, ComputeSlidingFrameSquareEnergies24kHzWithinTolerance) {
|
||||
const AvailableCpuFeatures cpu_features = GetAvailableCpuFeatures();
|
||||
|
||||
PitchTestData test_data;
|
||||
std::array<float, kNumPitchBufSquareEnergies> computed_output;
|
||||
std::array<float, kRefineNumLags24kHz> computed_output;
|
||||
// TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
|
||||
// FloatingPointExceptionObserver fpe_observer;
|
||||
ComputeSlidingFrameSquareEnergies24kHz(test_data.GetPitchBufView(),
|
||||
ComputeSlidingFrameSquareEnergies24kHz(test_data.PitchBuffer24kHzView(),
|
||||
computed_output, cpu_features);
|
||||
auto square_energies_view = test_data.GetPitchBufSquareEnergiesView();
|
||||
auto square_energies_view = test_data.SquareEnergies24kHzView();
|
||||
ExpectNearAbsolute({square_energies_view.data(), square_energies_view.size()},
|
||||
computed_output, 1e-3f);
|
||||
}
|
||||
@ -79,13 +78,12 @@ TEST(RnnVadTest, ComputePitchPeriod12kHzBitExactness) {
|
||||
|
||||
PitchTestData test_data;
|
||||
std::array<float, kBufSize12kHz> pitch_buf_decimated;
|
||||
Decimate2x(test_data.GetPitchBufView(), pitch_buf_decimated);
|
||||
Decimate2x(test_data.PitchBuffer24kHzView(), pitch_buf_decimated);
|
||||
CandidatePitchPeriods pitch_candidates;
|
||||
// TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
|
||||
// FloatingPointExceptionObserver fpe_observer;
|
||||
auto auto_corr_view = test_data.GetPitchBufAutoCorrCoeffsView();
|
||||
pitch_candidates = ComputePitchPeriod12kHz(pitch_buf_decimated,
|
||||
auto_corr_view, cpu_features);
|
||||
pitch_candidates = ComputePitchPeriod12kHz(
|
||||
pitch_buf_decimated, test_data.AutoCorrelation12kHzView(), cpu_features);
|
||||
EXPECT_EQ(pitch_candidates.best, 140);
|
||||
EXPECT_EQ(pitch_candidates.second_best, 142);
|
||||
}
|
||||
@ -98,16 +96,16 @@ TEST(RnnVadTest, ComputePitchPeriod48kHzBitExactness) {
|
||||
std::vector<float> y_energy(kRefineNumLags24kHz);
|
||||
rtc::ArrayView<float, kRefineNumLags24kHz> y_energy_view(y_energy.data(),
|
||||
kRefineNumLags24kHz);
|
||||
ComputeSlidingFrameSquareEnergies24kHz(test_data.GetPitchBufView(),
|
||||
ComputeSlidingFrameSquareEnergies24kHz(test_data.PitchBuffer24kHzView(),
|
||||
y_energy_view, cpu_features);
|
||||
// TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
|
||||
// FloatingPointExceptionObserver fpe_observer;
|
||||
EXPECT_EQ(
|
||||
ComputePitchPeriod48kHz(test_data.GetPitchBufView(), y_energy_view,
|
||||
ComputePitchPeriod48kHz(test_data.PitchBuffer24kHzView(), y_energy_view,
|
||||
/*pitch_candidates=*/{280, 284}, cpu_features),
|
||||
560);
|
||||
EXPECT_EQ(
|
||||
ComputePitchPeriod48kHz(test_data.GetPitchBufView(), y_energy_view,
|
||||
ComputePitchPeriod48kHz(test_data.PitchBuffer24kHzView(), y_energy_view,
|
||||
/*pitch_candidates=*/{260, 284}, cpu_features),
|
||||
568);
|
||||
}
|
||||
@ -132,12 +130,12 @@ TEST_P(PitchCandidatesParametrization,
|
||||
std::vector<float> y_energy(kRefineNumLags24kHz);
|
||||
rtc::ArrayView<float, kRefineNumLags24kHz> y_energy_view(y_energy.data(),
|
||||
kRefineNumLags24kHz);
|
||||
ComputeSlidingFrameSquareEnergies24kHz(test_data.GetPitchBufView(),
|
||||
ComputeSlidingFrameSquareEnergies24kHz(test_data.PitchBuffer24kHzView(),
|
||||
y_energy_view, params.cpu_features);
|
||||
EXPECT_EQ(
|
||||
ComputePitchPeriod48kHz(test_data.GetPitchBufView(), y_energy_view,
|
||||
ComputePitchPeriod48kHz(test_data.PitchBuffer24kHzView(), y_energy_view,
|
||||
params.pitch_candidates, params.cpu_features),
|
||||
ComputePitchPeriod48kHz(test_data.GetPitchBufView(), y_energy_view,
|
||||
ComputePitchPeriod48kHz(test_data.PitchBuffer24kHzView(), y_energy_view,
|
||||
swapped_pitch_candidates, params.cpu_features));
|
||||
}
|
||||
|
||||
@ -179,13 +177,13 @@ TEST_P(ExtendedPitchPeriodSearchParametrizaion,
|
||||
std::vector<float> y_energy(kRefineNumLags24kHz);
|
||||
rtc::ArrayView<float, kRefineNumLags24kHz> y_energy_view(y_energy.data(),
|
||||
kRefineNumLags24kHz);
|
||||
ComputeSlidingFrameSquareEnergies24kHz(test_data.GetPitchBufView(),
|
||||
ComputeSlidingFrameSquareEnergies24kHz(test_data.PitchBuffer24kHzView(),
|
||||
y_energy_view, params.cpu_features);
|
||||
// TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
|
||||
// FloatingPointExceptionObserver fpe_observer;
|
||||
const auto computed_output = ComputeExtendedPitchPeriod48kHz(
|
||||
test_data.GetPitchBufView(), y_energy_view, params.initial_pitch_period,
|
||||
params.last_pitch, params.cpu_features);
|
||||
test_data.PitchBuffer24kHzView(), y_energy_view,
|
||||
params.initial_pitch_period, params.last_pitch, params.cpu_features);
|
||||
EXPECT_EQ(params.expected_pitch.period, computed_output.period);
|
||||
EXPECT_NEAR(params.expected_pitch.strength, computed_output.strength, 1e-6f);
|
||||
}
|
||||
@ -219,6 +217,5 @@ INSTANTIATE_TEST_SUITE_P(
|
||||
PrintTestIndexAndCpuFeatures<ExtendedPitchPeriodSearchParameters>);
|
||||
|
||||
} // namespace
|
||||
} // namespace test
|
||||
} // namespace rnn_vad
|
||||
} // namespace webrtc
|
||||
|
@ -26,8 +26,8 @@ namespace rnn_vad {
|
||||
// Checks that the computed pitch period is bit-exact and that the computed
|
||||
// pitch gain is within tolerance given test input data.
|
||||
TEST(RnnVadTest, PitchSearchWithinTolerance) {
|
||||
auto lp_residual_reader = test::CreateLpResidualAndPitchPeriodGainReader();
|
||||
const int num_frames = std::min(lp_residual_reader.second, 300); // Max 3 s.
|
||||
ChunksFileReader reader = CreateLpResidualAndPitchInfoReader();
|
||||
const int num_frames = std::min(reader.num_chunks, 300); // Max 3 s.
|
||||
std::vector<float> lp_residual(kBufSize24kHz);
|
||||
float expected_pitch_period, expected_pitch_strength;
|
||||
const AvailableCpuFeatures cpu_features = GetAvailableCpuFeatures();
|
||||
@ -37,9 +37,9 @@ TEST(RnnVadTest, PitchSearchWithinTolerance) {
|
||||
// FloatingPointExceptionObserver fpe_observer;
|
||||
for (int i = 0; i < num_frames; ++i) {
|
||||
SCOPED_TRACE(i);
|
||||
lp_residual_reader.first->ReadChunk(lp_residual);
|
||||
lp_residual_reader.first->ReadValue(&expected_pitch_period);
|
||||
lp_residual_reader.first->ReadValue(&expected_pitch_strength);
|
||||
ASSERT_TRUE(reader.reader->ReadChunk(lp_residual));
|
||||
ASSERT_TRUE(reader.reader->ReadValue(expected_pitch_period));
|
||||
ASSERT_TRUE(reader.reader->ReadValue(expected_pitch_strength));
|
||||
int pitch_period =
|
||||
pitch_estimator.Estimate({lp_residual.data(), kBufSize24kHz});
|
||||
EXPECT_EQ(expected_pitch_period, pitch_period);
|
||||
|
@ -14,7 +14,6 @@
|
||||
|
||||
namespace webrtc {
|
||||
namespace rnn_vad {
|
||||
namespace test {
|
||||
namespace {
|
||||
|
||||
// Compare the elements of two given array views.
|
||||
@ -64,8 +63,6 @@ void TestRingBuffer() {
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// Check that for different delays, different views are returned.
|
||||
TEST(RnnVadTest, RingBufferArrayViews) {
|
||||
constexpr int s = 3;
|
||||
@ -110,6 +107,6 @@ TEST(RnnVadTest, RingBufferFloating) {
|
||||
TestRingBuffer<float, 5, 5>();
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
} // namespace
|
||||
} // namespace rnn_vad
|
||||
} // namespace webrtc
|
||||
|
@ -24,7 +24,6 @@
|
||||
|
||||
namespace webrtc {
|
||||
namespace rnn_vad {
|
||||
namespace test {
|
||||
namespace {
|
||||
|
||||
using ::rnnoise::kInputDenseBias;
|
||||
@ -104,6 +103,5 @@ INSTANTIATE_TEST_SUITE_P(
|
||||
});
|
||||
|
||||
} // namespace
|
||||
} // namespace test
|
||||
} // namespace rnn_vad
|
||||
} // namespace webrtc
|
||||
|
@ -21,7 +21,6 @@
|
||||
|
||||
namespace webrtc {
|
||||
namespace rnn_vad {
|
||||
namespace test {
|
||||
namespace {
|
||||
|
||||
void TestGatedRecurrentLayer(
|
||||
@ -135,6 +134,5 @@ TEST(RnnVadTest, DISABLED_BenchmarkGatedRecurrentLayer) {
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace test
|
||||
} // namespace rnn_vad
|
||||
} // namespace webrtc
|
||||
|
@ -17,7 +17,6 @@
|
||||
|
||||
namespace webrtc {
|
||||
namespace rnn_vad {
|
||||
namespace test {
|
||||
namespace {
|
||||
|
||||
constexpr std::array<float, kFeatureVectorSize> kFeatures = {
|
||||
@ -67,6 +66,5 @@ TEST(RnnVadTest, CheckRnnVadSilence) {
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace test
|
||||
} // namespace rnn_vad
|
||||
} // namespace webrtc
|
||||
|
@ -9,6 +9,7 @@
|
||||
*/
|
||||
|
||||
#include <array>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
@ -26,7 +27,6 @@
|
||||
|
||||
namespace webrtc {
|
||||
namespace rnn_vad {
|
||||
namespace test {
|
||||
namespace {
|
||||
|
||||
constexpr int kFrameSize10ms48kHz = 480;
|
||||
@ -49,8 +49,6 @@ void DumpPerfStats(int num_samples,
|
||||
// constant below to true in order to write new expected output binary files.
|
||||
constexpr bool kWriteComputedOutputToFile = false;
|
||||
|
||||
} // namespace
|
||||
|
||||
// Avoids that one forgets to set |kWriteComputedOutputToFile| back to false
|
||||
// when the expected output files are re-exported.
|
||||
TEST(RnnVadTest, CheckWriteComputedOutputIsFalse) {
|
||||
@ -71,12 +69,11 @@ TEST_P(RnnVadProbabilityParametrization, RnnVadProbabilityWithinTolerance) {
|
||||
RnnVad rnn_vad(cpu_features);
|
||||
|
||||
// Init input samples and expected output readers.
|
||||
auto samples_reader = CreatePcmSamplesReader(kFrameSize10ms48kHz);
|
||||
auto expected_vad_prob_reader = CreateVadProbsReader();
|
||||
std::unique_ptr<FileReader> samples_reader = CreatePcmSamplesReader();
|
||||
std::unique_ptr<FileReader> expected_vad_prob_reader = CreateVadProbsReader();
|
||||
|
||||
// Input length.
|
||||
const int num_frames = samples_reader.second;
|
||||
ASSERT_GE(expected_vad_prob_reader.second, num_frames);
|
||||
// Input length. The last incomplete frame is ignored.
|
||||
const int num_frames = samples_reader->size() / kFrameSize10ms48kHz;
|
||||
|
||||
// Init buffers.
|
||||
std::vector<float> samples_48k(kFrameSize10ms48kHz);
|
||||
@ -86,12 +83,12 @@ TEST_P(RnnVadProbabilityParametrization, RnnVadProbabilityWithinTolerance) {
|
||||
std::vector<float> expected_vad_prob(num_frames);
|
||||
|
||||
// Read expected output.
|
||||
ASSERT_TRUE(expected_vad_prob_reader.first->ReadChunk(expected_vad_prob));
|
||||
ASSERT_TRUE(expected_vad_prob_reader->ReadChunk(expected_vad_prob));
|
||||
|
||||
// Compute VAD probabilities on the downsampled input.
|
||||
float cumulative_error = 0.f;
|
||||
for (int i = 0; i < num_frames; ++i) {
|
||||
samples_reader.first->ReadChunk(samples_48k);
|
||||
ASSERT_TRUE(samples_reader->ReadChunk(samples_48k));
|
||||
decimator.Resample(samples_48k.data(), samples_48k.size(),
|
||||
samples_24k.data(), samples_24k.size());
|
||||
bool is_silence = features_extractor.CheckSilenceComputeFeatures(
|
||||
@ -106,7 +103,7 @@ TEST_P(RnnVadProbabilityParametrization, RnnVadProbabilityWithinTolerance) {
|
||||
EXPECT_LT(cumulative_error / num_frames, 1e-4f);
|
||||
|
||||
if (kWriteComputedOutputToFile) {
|
||||
BinaryFileWriter<float> vad_prob_writer("new_vad_prob.dat");
|
||||
FileWriter vad_prob_writer("new_vad_prob.dat");
|
||||
vad_prob_writer.WriteChunk(computed_vad_prob);
|
||||
}
|
||||
}
|
||||
@ -118,15 +115,16 @@ TEST_P(RnnVadProbabilityParametrization, RnnVadProbabilityWithinTolerance) {
|
||||
// - on android: run the this unit test adding "--logcat-output-file".
|
||||
TEST_P(RnnVadProbabilityParametrization, DISABLED_RnnVadPerformance) {
|
||||
// PCM samples reader and buffers.
|
||||
auto samples_reader = CreatePcmSamplesReader(kFrameSize10ms48kHz);
|
||||
const int num_frames = samples_reader.second;
|
||||
std::unique_ptr<FileReader> samples_reader = CreatePcmSamplesReader();
|
||||
// The last incomplete frame is ignored.
|
||||
const int num_frames = samples_reader->size() / kFrameSize10ms48kHz;
|
||||
std::array<float, kFrameSize10ms48kHz> samples;
|
||||
// Pre-fetch and decimate samples.
|
||||
PushSincResampler decimator(kFrameSize10ms48kHz, kFrameSize10ms24kHz);
|
||||
std::vector<float> prefetched_decimated_samples;
|
||||
prefetched_decimated_samples.resize(num_frames * kFrameSize10ms24kHz);
|
||||
for (int i = 0; i < num_frames; ++i) {
|
||||
samples_reader.first->ReadChunk(samples);
|
||||
ASSERT_TRUE(samples_reader->ReadChunk(samples));
|
||||
decimator.Resample(samples.data(), samples.size(),
|
||||
&prefetched_decimated_samples[i * kFrameSize10ms24kHz],
|
||||
kFrameSize10ms24kHz);
|
||||
@ -151,7 +149,6 @@ TEST_P(RnnVadProbabilityParametrization, DISABLED_RnnVadPerformance) {
|
||||
rnn_vad.ComputeVadProbability(feature_vector, is_silence);
|
||||
}
|
||||
perf_timer.StopTimer();
|
||||
samples_reader.first->SeekBeginning();
|
||||
}
|
||||
DumpPerfStats(num_frames * kFrameSize10ms24kHz, kSampleRate24kHz,
|
||||
perf_timer.GetDurationAverage(),
|
||||
@ -180,6 +177,6 @@ INSTANTIATE_TEST_SUITE_P(
|
||||
return info.param.ToString();
|
||||
});
|
||||
|
||||
} // namespace test
|
||||
} // namespace
|
||||
} // namespace rnn_vad
|
||||
} // namespace webrtc
|
||||
|
@ -17,7 +17,6 @@
|
||||
|
||||
namespace webrtc {
|
||||
namespace rnn_vad {
|
||||
namespace test {
|
||||
namespace {
|
||||
|
||||
template <typename T, int S, int N>
|
||||
@ -60,8 +59,6 @@ void TestSequenceBufferPushOp() {
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
TEST(RnnVadTest, SequenceBufferGetters) {
|
||||
constexpr int buffer_size = 8;
|
||||
constexpr int chunk_size = 8;
|
||||
@ -100,6 +97,6 @@ TEST(RnnVadTest, SequenceBufferPushOpsFloating) {
|
||||
TestSequenceBufferPushOp<float, 23, 7>(); // Non-integer ratio.
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
} // namespace
|
||||
} // namespace rnn_vad
|
||||
} // namespace webrtc
|
||||
|
@ -26,7 +26,6 @@
|
||||
|
||||
namespace webrtc {
|
||||
namespace rnn_vad {
|
||||
namespace test {
|
||||
namespace {
|
||||
|
||||
// Generates the values for the array named |kOpusBandWeights24kHz20ms| in the
|
||||
@ -49,8 +48,6 @@ std::vector<float> ComputeTriangularFiltersWeights() {
|
||||
return weights;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// Checks that the values returned by GetOpusScaleNumBins24kHz20ms() match the
|
||||
// Opus scale frequency boundaries.
|
||||
TEST(RnnVadTest, TestOpusScaleBoundaries) {
|
||||
@ -158,6 +155,6 @@ TEST(RnnVadTest, ComputeDctWithinTolerance) {
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
} // namespace
|
||||
} // namespace rnn_vad
|
||||
} // namespace webrtc
|
||||
|
@ -21,7 +21,6 @@
|
||||
|
||||
namespace webrtc {
|
||||
namespace rnn_vad {
|
||||
namespace test {
|
||||
namespace {
|
||||
|
||||
constexpr int kTestFeatureVectorSize = kNumBands + 3 * kNumLowerBands + 1;
|
||||
@ -66,8 +65,6 @@ float* GetCepstralVariability(
|
||||
|
||||
constexpr float kInitialFeatureVal = -9999.f;
|
||||
|
||||
} // namespace
|
||||
|
||||
// Checks that silence is detected when the input signal is 0 and that the
|
||||
// feature vector is written only if the input signal is not tagged as silence.
|
||||
TEST(RnnVadTest, SpectralFeaturesWithAndWithoutSilence) {
|
||||
@ -159,6 +156,6 @@ TEST(RnnVadTest, CepstralFeaturesConstantAverageZeroDerivative) {
|
||||
feature_vector_last[kNumBands + 3 * kNumLowerBands]);
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
} // namespace
|
||||
} // namespace rnn_vad
|
||||
} // namespace webrtc
|
||||
|
@ -15,7 +15,6 @@
|
||||
|
||||
namespace webrtc {
|
||||
namespace rnn_vad {
|
||||
namespace test {
|
||||
namespace {
|
||||
|
||||
template <typename T, int S>
|
||||
@ -44,8 +43,6 @@ bool CheckPairsWithValueExist(
|
||||
return false;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// Test that shows how to combine RingBuffer and SymmetricMatrixBuffer to
|
||||
// efficiently compute pair-wise scores. This test verifies that the evolution
|
||||
// of a SymmetricMatrixBuffer instance follows that of RingBuffer.
|
||||
@ -105,6 +102,6 @@ TEST(RnnVadTest, SymmetricMatrixBufferUseCase) {
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
} // namespace
|
||||
} // namespace rnn_vad
|
||||
} // namespace webrtc
|
||||
|
@ -11,7 +11,10 @@
|
||||
#include "modules/audio_processing/agc2/rnn_vad/test_utils.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <fstream>
|
||||
#include <memory>
|
||||
#include <type_traits>
|
||||
#include <vector>
|
||||
|
||||
#include "rtc_base/checks.h"
|
||||
#include "rtc_base/numerics/safe_compare.h"
|
||||
@ -20,11 +23,46 @@
|
||||
|
||||
namespace webrtc {
|
||||
namespace rnn_vad {
|
||||
namespace test {
|
||||
namespace {
|
||||
|
||||
using ReaderPairType =
|
||||
std::pair<std::unique_ptr<BinaryFileReader<float>>, const int>;
|
||||
// File reader for binary files that contain a sequence of values with
|
||||
// arithmetic type `T`. The values of type `T` that are read are cast to float.
|
||||
template <typename T>
|
||||
class FloatFileReader : public FileReader {
|
||||
public:
|
||||
static_assert(std::is_arithmetic<T>::value, "");
|
||||
FloatFileReader(const std::string& filename)
|
||||
: is_(filename, std::ios::binary | std::ios::ate),
|
||||
size_(is_.tellg() / sizeof(T)) {
|
||||
RTC_CHECK(is_);
|
||||
SeekBeginning();
|
||||
}
|
||||
FloatFileReader(const FloatFileReader&) = delete;
|
||||
FloatFileReader& operator=(const FloatFileReader&) = delete;
|
||||
~FloatFileReader() = default;
|
||||
|
||||
int size() const override { return size_; }
|
||||
bool ReadChunk(rtc::ArrayView<float> dst) override {
|
||||
const std::streamsize bytes_to_read = dst.size() * sizeof(T);
|
||||
if (std::is_same<T, float>::value) {
|
||||
is_.read(reinterpret_cast<char*>(dst.data()), bytes_to_read);
|
||||
} else {
|
||||
buffer_.resize(dst.size());
|
||||
is_.read(reinterpret_cast<char*>(buffer_.data()), bytes_to_read);
|
||||
std::transform(buffer_.begin(), buffer_.end(), dst.begin(),
|
||||
[](const T& v) -> float { return static_cast<float>(v); });
|
||||
}
|
||||
return is_.gcount() == bytes_to_read;
|
||||
}
|
||||
bool ReadValue(float& dst) override { return ReadChunk({&dst, 1}); }
|
||||
void SeekForward(int hop) override { is_.seekg(hop * sizeof(T), is_.cur); }
|
||||
void SeekBeginning() override { is_.seekg(0, is_.beg); }
|
||||
|
||||
private:
|
||||
std::ifstream is_;
|
||||
const int size_;
|
||||
std::vector<T> buffer_;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
@ -49,66 +87,49 @@ void ExpectNearAbsolute(rtc::ArrayView<const float> expected,
|
||||
}
|
||||
}
|
||||
|
||||
std::pair<std::unique_ptr<BinaryFileReader<int16_t, float>>, const int>
|
||||
CreatePcmSamplesReader(const int frame_length) {
|
||||
auto ptr = std::make_unique<BinaryFileReader<int16_t, float>>(
|
||||
test::ResourcePath("audio_processing/agc2/rnn_vad/samples", "pcm"),
|
||||
frame_length);
|
||||
// The last incomplete frame is ignored.
|
||||
return {std::move(ptr), ptr->data_length() / frame_length};
|
||||
std::unique_ptr<FileReader> CreatePcmSamplesReader() {
|
||||
return std::make_unique<FloatFileReader<int16_t>>(
|
||||
/*filename=*/test::ResourcePath("audio_processing/agc2/rnn_vad/samples",
|
||||
"pcm"));
|
||||
}
|
||||
|
||||
ReaderPairType CreatePitchBuffer24kHzReader() {
|
||||
constexpr int cols = 864;
|
||||
auto ptr = std::make_unique<BinaryFileReader<float>>(
|
||||
ResourcePath("audio_processing/agc2/rnn_vad/pitch_buf_24k", "dat"), cols);
|
||||
return {std::move(ptr), rtc::CheckedDivExact(ptr->data_length(), cols)};
|
||||
ChunksFileReader CreatePitchBuffer24kHzReader() {
|
||||
auto reader = std::make_unique<FloatFileReader<float>>(
|
||||
/*filename=*/test::ResourcePath(
|
||||
"audio_processing/agc2/rnn_vad/pitch_buf_24k", "dat"));
|
||||
const int num_chunks = rtc::CheckedDivExact(reader->size(), kBufSize24kHz);
|
||||
return {/*chunk_size=*/kBufSize24kHz, num_chunks, std::move(reader)};
|
||||
}
|
||||
|
||||
ReaderPairType CreateLpResidualAndPitchPeriodGainReader() {
|
||||
constexpr int num_lp_residual_coeffs = 864;
|
||||
auto ptr = std::make_unique<BinaryFileReader<float>>(
|
||||
ResourcePath("audio_processing/agc2/rnn_vad/pitch_lp_res", "dat"),
|
||||
num_lp_residual_coeffs);
|
||||
return {std::move(ptr),
|
||||
rtc::CheckedDivExact(ptr->data_length(), 2 + num_lp_residual_coeffs)};
|
||||
ChunksFileReader CreateLpResidualAndPitchInfoReader() {
|
||||
constexpr int kPitchInfoSize = 2; // Pitch period and strength.
|
||||
constexpr int kChunkSize = kBufSize24kHz + kPitchInfoSize;
|
||||
auto reader = std::make_unique<FloatFileReader<float>>(
|
||||
/*filename=*/test::ResourcePath(
|
||||
"audio_processing/agc2/rnn_vad/pitch_lp_res", "dat"));
|
||||
const int num_chunks = rtc::CheckedDivExact(reader->size(), kChunkSize);
|
||||
return {kChunkSize, num_chunks, std::move(reader)};
|
||||
}
|
||||
|
||||
ReaderPairType CreateVadProbsReader() {
|
||||
auto ptr = std::make_unique<BinaryFileReader<float>>(
|
||||
test::ResourcePath("audio_processing/agc2/rnn_vad/vad_prob", "dat"));
|
||||
return {std::move(ptr), ptr->data_length()};
|
||||
std::unique_ptr<FileReader> CreateVadProbsReader() {
|
||||
return std::make_unique<FloatFileReader<float>>(
|
||||
/*filename=*/test::ResourcePath("audio_processing/agc2/rnn_vad/vad_prob",
|
||||
"dat"));
|
||||
}
|
||||
|
||||
PitchTestData::PitchTestData() {
|
||||
BinaryFileReader<float> test_data_reader(
|
||||
ResourcePath("audio_processing/agc2/rnn_vad/pitch_search_int", "dat"),
|
||||
1396);
|
||||
test_data_reader.ReadChunk(test_data_);
|
||||
FloatFileReader<float> reader(
|
||||
/*filename=*/ResourcePath(
|
||||
"audio_processing/agc2/rnn_vad/pitch_search_int", "dat"));
|
||||
reader.ReadChunk(pitch_buffer_24k_);
|
||||
reader.ReadChunk(square_energies_24k_);
|
||||
reader.ReadChunk(auto_correlation_12k_);
|
||||
// Reverse the order of the squared energy values.
|
||||
// Required after the WebRTC CL 191703 which switched to forward computation.
|
||||
std::reverse(test_data_.begin() + kBufSize24kHz,
|
||||
test_data_.begin() + kBufSize24kHz + kNumPitchBufSquareEnergies);
|
||||
std::reverse(square_energies_24k_.begin(), square_energies_24k_.end());
|
||||
}
|
||||
|
||||
PitchTestData::~PitchTestData() = default;
|
||||
|
||||
rtc::ArrayView<const float, kBufSize24kHz> PitchTestData::GetPitchBufView()
|
||||
const {
|
||||
return {test_data_.data(), kBufSize24kHz};
|
||||
}
|
||||
|
||||
rtc::ArrayView<const float, kNumPitchBufSquareEnergies>
|
||||
PitchTestData::GetPitchBufSquareEnergiesView() const {
|
||||
return {test_data_.data() + kBufSize24kHz, kNumPitchBufSquareEnergies};
|
||||
}
|
||||
|
||||
rtc::ArrayView<const float, kNumPitchBufAutoCorrCoeffs>
|
||||
PitchTestData::GetPitchBufAutoCorrCoeffsView() const {
|
||||
return {test_data_.data() + kBufSize24kHz + kNumPitchBufSquareEnergies,
|
||||
kNumPitchBufAutoCorrCoeffs};
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
} // namespace rnn_vad
|
||||
} // namespace webrtc
|
||||
|
@ -11,15 +11,10 @@
|
||||
#ifndef MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_TEST_UTILS_H_
|
||||
#define MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_TEST_UTILS_H_
|
||||
|
||||
#include <algorithm>
|
||||
#include <array>
|
||||
#include <fstream>
|
||||
#include <limits>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "api/array_view.h"
|
||||
#include "modules/audio_processing/agc2/rnn_vad/common.h"
|
||||
@ -28,7 +23,6 @@
|
||||
|
||||
namespace webrtc {
|
||||
namespace rnn_vad {
|
||||
namespace test {
|
||||
|
||||
constexpr float kFloatMin = std::numeric_limits<float>::min();
|
||||
|
||||
@ -43,98 +37,48 @@ void ExpectNearAbsolute(rtc::ArrayView<const float> expected,
|
||||
rtc::ArrayView<const float> computed,
|
||||
float tolerance);
|
||||
|
||||
// Reader for binary files consisting of an arbitrary long sequence of elements
|
||||
// having type T. It is possible to read and cast to another type D at once.
|
||||
template <typename T, typename D = T>
|
||||
class BinaryFileReader {
|
||||
// File reader interface.
|
||||
class FileReader {
|
||||
public:
|
||||
BinaryFileReader(const std::string& file_path, int chunk_size = 0)
|
||||
: is_(file_path, std::ios::binary | std::ios::ate),
|
||||
data_length_(is_.tellg() / sizeof(T)),
|
||||
chunk_size_(chunk_size) {
|
||||
RTC_CHECK(is_);
|
||||
SeekBeginning();
|
||||
buf_.resize(chunk_size_);
|
||||
}
|
||||
BinaryFileReader(const BinaryFileReader&) = delete;
|
||||
BinaryFileReader& operator=(const BinaryFileReader&) = delete;
|
||||
~BinaryFileReader() = default;
|
||||
int data_length() const { return data_length_; }
|
||||
bool ReadValue(D* dst) {
|
||||
if (std::is_same<T, D>::value) {
|
||||
is_.read(reinterpret_cast<char*>(dst), sizeof(T));
|
||||
} else {
|
||||
T v;
|
||||
is_.read(reinterpret_cast<char*>(&v), sizeof(T));
|
||||
*dst = static_cast<D>(v);
|
||||
}
|
||||
return is_.gcount() == sizeof(T);
|
||||
}
|
||||
// If |chunk_size| was specified in the ctor, it will check that the size of
|
||||
// |dst| equals |chunk_size|.
|
||||
bool ReadChunk(rtc::ArrayView<D> dst) {
|
||||
RTC_DCHECK((chunk_size_ == 0) || rtc::SafeEq(chunk_size_, dst.size()));
|
||||
const std::streamsize bytes_to_read = dst.size() * sizeof(T);
|
||||
if (std::is_same<T, D>::value) {
|
||||
is_.read(reinterpret_cast<char*>(dst.data()), bytes_to_read);
|
||||
} else {
|
||||
is_.read(reinterpret_cast<char*>(buf_.data()), bytes_to_read);
|
||||
std::transform(buf_.begin(), buf_.end(), dst.begin(),
|
||||
[](const T& v) -> D { return static_cast<D>(v); });
|
||||
}
|
||||
return is_.gcount() == bytes_to_read;
|
||||
}
|
||||
void SeekForward(int items) { is_.seekg(items * sizeof(T), is_.cur); }
|
||||
void SeekBeginning() { is_.seekg(0, is_.beg); }
|
||||
|
||||
private:
|
||||
std::ifstream is_;
|
||||
const int data_length_;
|
||||
const int chunk_size_;
|
||||
std::vector<T> buf_;
|
||||
virtual ~FileReader() = default;
|
||||
// Number of values in the file.
|
||||
virtual int size() const = 0;
|
||||
// Reads `dst.size()` float values into `dst`, advances the internal file
|
||||
// position according to the number of read bytes and returns true if the
|
||||
// values are correctly read. If the number of remaining bytes in the file is
|
||||
// not sufficient to read `dst.size()` float values, `dst` is partially
|
||||
// modified and false is returned.
|
||||
virtual bool ReadChunk(rtc::ArrayView<float> dst) = 0;
|
||||
// Reads a single float value, advances the internal file position according
|
||||
// to the number of read bytes and returns true if the value is correctly
|
||||
// read. If the number of remaining bytes in the file is not sufficient to
|
||||
// read one float, `dst` is not modified and false is returned.
|
||||
virtual bool ReadValue(float& dst) = 0;
|
||||
// Advances the internal file position by `hop` float values.
|
||||
virtual void SeekForward(int hop) = 0;
|
||||
// Resets the internal file position to BOF.
|
||||
virtual void SeekBeginning() = 0;
|
||||
};
|
||||
|
||||
// Writer for binary files.
|
||||
template <typename T>
|
||||
class BinaryFileWriter {
|
||||
public:
|
||||
explicit BinaryFileWriter(const std::string& file_path)
|
||||
: os_(file_path, std::ios::binary) {}
|
||||
BinaryFileWriter(const BinaryFileWriter&) = delete;
|
||||
BinaryFileWriter& operator=(const BinaryFileWriter&) = delete;
|
||||
~BinaryFileWriter() = default;
|
||||
static_assert(std::is_arithmetic<T>::value, "");
|
||||
void WriteChunk(rtc::ArrayView<const T> value) {
|
||||
const std::streamsize bytes_to_write = value.size() * sizeof(T);
|
||||
os_.write(reinterpret_cast<const char*>(value.data()), bytes_to_write);
|
||||
}
|
||||
|
||||
private:
|
||||
std::ofstream os_;
|
||||
// File reader for files that contain `num_chunks` chunks with size equal to
|
||||
// `chunk_size`.
|
||||
struct ChunksFileReader {
|
||||
const int chunk_size;
|
||||
const int num_chunks;
|
||||
std::unique_ptr<FileReader> reader;
|
||||
};
|
||||
|
||||
// Factories for resource file readers.
|
||||
// The functions below return a pair where the first item is a reader unique
|
||||
// pointer and the second the number of chunks that can be read from the file.
|
||||
// Creates a reader for the PCM samples that casts from S16 to float and reads
|
||||
// chunks with length |frame_length|.
|
||||
std::pair<std::unique_ptr<BinaryFileReader<int16_t, float>>, const int>
|
||||
CreatePcmSamplesReader(const int frame_length);
|
||||
// Creates a reader for the pitch buffer content at 24 kHz.
|
||||
std::pair<std::unique_ptr<BinaryFileReader<float>>, const int>
|
||||
CreatePitchBuffer24kHzReader();
|
||||
// Creates a reader for the the LP residual coefficients and the pitch period
|
||||
// and gain values.
|
||||
std::pair<std::unique_ptr<BinaryFileReader<float>>, const int>
|
||||
CreateLpResidualAndPitchPeriodGainReader();
|
||||
// Creates a reader for the VAD probabilities.
|
||||
std::pair<std::unique_ptr<BinaryFileReader<float>>, const int>
|
||||
CreateVadProbsReader();
|
||||
// Creates a reader for the PCM S16 samples file.
|
||||
std::unique_ptr<FileReader> CreatePcmSamplesReader();
|
||||
|
||||
constexpr int kNumPitchBufAutoCorrCoeffs = 147;
|
||||
constexpr int kNumPitchBufSquareEnergies = 385;
|
||||
constexpr int kPitchTestDataSize =
|
||||
kBufSize24kHz + kNumPitchBufSquareEnergies + kNumPitchBufAutoCorrCoeffs;
|
||||
// Creates a reader for the 24 kHz pitch buffer test data.
|
||||
ChunksFileReader CreatePitchBuffer24kHzReader();
|
||||
|
||||
// Creates a reader for the LP residual and pitch information test data.
|
||||
ChunksFileReader CreateLpResidualAndPitchInfoReader();
|
||||
|
||||
// Creates a reader for the VAD probabilities test data.
|
||||
std::unique_ptr<FileReader> CreateVadProbsReader();
|
||||
|
||||
// Class to retrieve a test pitch buffer content and the expected output for the
|
||||
// analysis steps.
|
||||
@ -142,17 +86,40 @@ class PitchTestData {
|
||||
public:
|
||||
PitchTestData();
|
||||
~PitchTestData();
|
||||
rtc::ArrayView<const float, kBufSize24kHz> GetPitchBufView() const;
|
||||
rtc::ArrayView<const float, kNumPitchBufSquareEnergies>
|
||||
GetPitchBufSquareEnergiesView() const;
|
||||
rtc::ArrayView<const float, kNumPitchBufAutoCorrCoeffs>
|
||||
GetPitchBufAutoCorrCoeffsView() const;
|
||||
rtc::ArrayView<const float, kBufSize24kHz> PitchBuffer24kHzView() const {
|
||||
return pitch_buffer_24k_;
|
||||
}
|
||||
rtc::ArrayView<const float, kRefineNumLags24kHz> SquareEnergies24kHzView()
|
||||
const {
|
||||
return square_energies_24k_;
|
||||
}
|
||||
rtc::ArrayView<const float, kNumLags12kHz> AutoCorrelation12kHzView() const {
|
||||
return auto_correlation_12k_;
|
||||
}
|
||||
|
||||
private:
|
||||
std::array<float, kPitchTestDataSize> test_data_;
|
||||
std::array<float, kBufSize24kHz> pitch_buffer_24k_;
|
||||
std::array<float, kRefineNumLags24kHz> square_energies_24k_;
|
||||
std::array<float, kNumLags12kHz> auto_correlation_12k_;
|
||||
};
|
||||
|
||||
// Writer for binary files.
|
||||
class FileWriter {
|
||||
public:
|
||||
explicit FileWriter(const std::string& file_path)
|
||||
: os_(file_path, std::ios::binary) {}
|
||||
FileWriter(const FileWriter&) = delete;
|
||||
FileWriter& operator=(const FileWriter&) = delete;
|
||||
~FileWriter() = default;
|
||||
void WriteChunk(rtc::ArrayView<const float> value) {
|
||||
const std::streamsize bytes_to_write = value.size() * sizeof(float);
|
||||
os_.write(reinterpret_cast<const char*>(value.data()), bytes_to_write);
|
||||
}
|
||||
|
||||
private:
|
||||
std::ofstream os_;
|
||||
};
|
||||
|
||||
} // namespace test
|
||||
} // namespace rnn_vad
|
||||
} // namespace webrtc
|
||||
|
||||
|
Reference in New Issue
Block a user