AGC2 RNN VAD size_t -> int
Motivation: read "On Unsigned Integers" section in https://google.github.io/styleguide/cppguide.html#Integer_Types Plus, improved readability by getting rid of a bunch of `static_cast<int>`s. Bug: webrtc:10480 Change-Id: I911aa8cd08f5ccde4ee6f23534240d1faa84cdea Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/190880 Commit-Queue: Alessio Bazzica <alessiob@webrtc.org> Reviewed-by: Karl Wiberg <kwiberg@webrtc.org> Cr-Commit-Position: refs/heads/master@{#32524}
This commit is contained in:
committed by
Commit Bot
parent
bee5983b8f
commit
f622ba725e
@ -33,6 +33,8 @@ rtc_library("rnn_vad") {
|
||||
"../../../../api:function_view",
|
||||
"../../../../rtc_base:checks",
|
||||
"../../../../rtc_base:logging",
|
||||
"../../../../rtc_base:safe_compare",
|
||||
"../../../../rtc_base:safe_conversions",
|
||||
"../../../../rtc_base/system:arch",
|
||||
"//third_party/rnnoise:rnn_vad",
|
||||
]
|
||||
@ -93,6 +95,7 @@ rtc_library("rnn_vad_pitch") {
|
||||
"../../../../api:array_view",
|
||||
"../../../../rtc_base:checks",
|
||||
"../../../../rtc_base:safe_compare",
|
||||
"../../../../rtc_base:safe_conversions",
|
||||
]
|
||||
}
|
||||
|
||||
@ -125,6 +128,7 @@ rtc_library("rnn_vad_spectral_features") {
|
||||
":rnn_vad_symmetric_matrix_buffer",
|
||||
"../../../../api:array_view",
|
||||
"../../../../rtc_base:checks",
|
||||
"../../../../rtc_base:safe_compare",
|
||||
"../../utility:pffft_wrapper",
|
||||
]
|
||||
}
|
||||
@ -134,6 +138,7 @@ rtc_source_set("rnn_vad_symmetric_matrix_buffer") {
|
||||
deps = [
|
||||
"../../../../api:array_view",
|
||||
"../../../../rtc_base:checks",
|
||||
"../../../../rtc_base:safe_compare",
|
||||
]
|
||||
}
|
||||
|
||||
@ -150,6 +155,7 @@ if (rtc_include_tests) {
|
||||
"../../../../api:array_view",
|
||||
"../../../../api:scoped_refptr",
|
||||
"../../../../rtc_base:checks",
|
||||
"../../../../rtc_base:safe_compare",
|
||||
"../../../../rtc_base/system:arch",
|
||||
"../../../../system_wrappers",
|
||||
"../../../../test:fileutils",
|
||||
@ -206,6 +212,8 @@ if (rtc_include_tests) {
|
||||
"../../../../common_audio/",
|
||||
"../../../../rtc_base:checks",
|
||||
"../../../../rtc_base:logging",
|
||||
"../../../../rtc_base:safe_compare",
|
||||
"../../../../rtc_base:safe_conversions",
|
||||
"../../../../rtc_base/system:arch",
|
||||
"../../../../test:test_support",
|
||||
"../../utility:pffft_wrapper",
|
||||
@ -227,6 +235,7 @@ if (rtc_include_tests) {
|
||||
"../../../../api:array_view",
|
||||
"../../../../common_audio",
|
||||
"../../../../rtc_base:rtc_base_approved",
|
||||
"../../../../rtc_base:safe_compare",
|
||||
"../../../../test:test_support",
|
||||
"//third_party/abseil-cpp/absl/flags:flag",
|
||||
"//third_party/abseil-cpp/absl/flags:parse",
|
||||
|
||||
@ -48,8 +48,8 @@ void AutoCorrelationCalculator::ComputeOnPitchBuffer(
|
||||
rtc::ArrayView<float, kNumInvertedLags12kHz> auto_corr) {
|
||||
RTC_DCHECK_LT(auto_corr.size(), kMaxPitch12kHz);
|
||||
RTC_DCHECK_GT(pitch_buf.size(), kMaxPitch12kHz);
|
||||
constexpr size_t kFftFrameSize = 1 << kAutoCorrelationFftOrder;
|
||||
constexpr size_t kConvolutionLength = kBufSize12kHz - kMaxPitch12kHz;
|
||||
constexpr int kFftFrameSize = 1 << kAutoCorrelationFftOrder;
|
||||
constexpr int kConvolutionLength = kBufSize12kHz - kMaxPitch12kHz;
|
||||
static_assert(kConvolutionLength == kFrameSize20ms12kHz,
|
||||
"Mismatch between pitch buffer size, frame size and maximum "
|
||||
"pitch period.");
|
||||
|
||||
@ -54,7 +54,7 @@ TEST(RnnVadTest, CheckAutoCorrelationOnConstantPitchBuffer) {
|
||||
}
|
||||
// The expected output is a vector filled with the same expected
|
||||
// auto-correlation value. The latter equals the length of a 20 ms frame.
|
||||
constexpr size_t kFrameSize20ms12kHz = kFrameSize20ms24kHz / 2;
|
||||
constexpr int kFrameSize20ms12kHz = kFrameSize20ms24kHz / 2;
|
||||
std::array<float, kNumPitchBufAutoCorrCoeffs> expected_output;
|
||||
std::fill(expected_output.begin(), expected_output.end(),
|
||||
static_cast<float>(kFrameSize20ms12kHz));
|
||||
|
||||
@ -18,52 +18,52 @@ namespace rnn_vad {
|
||||
|
||||
constexpr double kPi = 3.14159265358979323846;
|
||||
|
||||
constexpr size_t kSampleRate24kHz = 24000;
|
||||
constexpr size_t kFrameSize10ms24kHz = kSampleRate24kHz / 100;
|
||||
constexpr size_t kFrameSize20ms24kHz = kFrameSize10ms24kHz * 2;
|
||||
constexpr int kSampleRate24kHz = 24000;
|
||||
constexpr int kFrameSize10ms24kHz = kSampleRate24kHz / 100;
|
||||
constexpr int kFrameSize20ms24kHz = kFrameSize10ms24kHz * 2;
|
||||
|
||||
// Pitch buffer.
|
||||
constexpr size_t kMinPitch24kHz = kSampleRate24kHz / 800; // 0.00125 s.
|
||||
constexpr size_t kMaxPitch24kHz = kSampleRate24kHz / 62.5; // 0.016 s.
|
||||
constexpr size_t kBufSize24kHz = kMaxPitch24kHz + kFrameSize20ms24kHz;
|
||||
constexpr int kMinPitch24kHz = kSampleRate24kHz / 800; // 0.00125 s.
|
||||
constexpr int kMaxPitch24kHz = kSampleRate24kHz / 62.5; // 0.016 s.
|
||||
constexpr int kBufSize24kHz = kMaxPitch24kHz + kFrameSize20ms24kHz;
|
||||
static_assert((kBufSize24kHz & 1) == 0, "The buffer size must be even.");
|
||||
|
||||
// 24 kHz analysis.
|
||||
// Define a higher minimum pitch period for the initial search. This is used to
|
||||
// avoid searching for very short periods, for which a refinement step is
|
||||
// responsible.
|
||||
constexpr size_t kInitialMinPitch24kHz = 3 * kMinPitch24kHz;
|
||||
constexpr int kInitialMinPitch24kHz = 3 * kMinPitch24kHz;
|
||||
static_assert(kMinPitch24kHz < kInitialMinPitch24kHz, "");
|
||||
static_assert(kInitialMinPitch24kHz < kMaxPitch24kHz, "");
|
||||
static_assert(kMaxPitch24kHz > kInitialMinPitch24kHz, "");
|
||||
constexpr size_t kNumInvertedLags24kHz = kMaxPitch24kHz - kInitialMinPitch24kHz;
|
||||
constexpr int kNumInvertedLags24kHz = kMaxPitch24kHz - kInitialMinPitch24kHz;
|
||||
|
||||
// 12 kHz analysis.
|
||||
constexpr size_t kSampleRate12kHz = 12000;
|
||||
constexpr size_t kFrameSize10ms12kHz = kSampleRate12kHz / 100;
|
||||
constexpr size_t kFrameSize20ms12kHz = kFrameSize10ms12kHz * 2;
|
||||
constexpr size_t kBufSize12kHz = kBufSize24kHz / 2;
|
||||
constexpr size_t kInitialMinPitch12kHz = kInitialMinPitch24kHz / 2;
|
||||
constexpr size_t kMaxPitch12kHz = kMaxPitch24kHz / 2;
|
||||
constexpr int kSampleRate12kHz = 12000;
|
||||
constexpr int kFrameSize10ms12kHz = kSampleRate12kHz / 100;
|
||||
constexpr int kFrameSize20ms12kHz = kFrameSize10ms12kHz * 2;
|
||||
constexpr int kBufSize12kHz = kBufSize24kHz / 2;
|
||||
constexpr int kInitialMinPitch12kHz = kInitialMinPitch24kHz / 2;
|
||||
constexpr int kMaxPitch12kHz = kMaxPitch24kHz / 2;
|
||||
static_assert(kMaxPitch12kHz > kInitialMinPitch12kHz, "");
|
||||
// The inverted lags for the pitch interval [|kInitialMinPitch12kHz|,
|
||||
// |kMaxPitch12kHz|] are in the range [0, |kNumInvertedLags12kHz|].
|
||||
constexpr size_t kNumInvertedLags12kHz = kMaxPitch12kHz - kInitialMinPitch12kHz;
|
||||
constexpr int kNumInvertedLags12kHz = kMaxPitch12kHz - kInitialMinPitch12kHz;
|
||||
|
||||
// 48 kHz constants.
|
||||
constexpr size_t kMinPitch48kHz = kMinPitch24kHz * 2;
|
||||
constexpr size_t kMaxPitch48kHz = kMaxPitch24kHz * 2;
|
||||
constexpr int kMinPitch48kHz = kMinPitch24kHz * 2;
|
||||
constexpr int kMaxPitch48kHz = kMaxPitch24kHz * 2;
|
||||
|
||||
// Spectral features.
|
||||
constexpr size_t kNumBands = 22;
|
||||
constexpr size_t kNumLowerBands = 6;
|
||||
constexpr int kNumBands = 22;
|
||||
constexpr int kNumLowerBands = 6;
|
||||
static_assert((0 < kNumLowerBands) && (kNumLowerBands < kNumBands), "");
|
||||
constexpr size_t kCepstralCoeffsHistorySize = 8;
|
||||
constexpr int kCepstralCoeffsHistorySize = 8;
|
||||
static_assert(kCepstralCoeffsHistorySize > 2,
|
||||
"The history size must at least be 3 to compute first and second "
|
||||
"derivatives.");
|
||||
|
||||
constexpr size_t kFeatureVectorSize = 42;
|
||||
constexpr int kFeatureVectorSize = 42;
|
||||
|
||||
enum class Optimization { kNone, kSse2, kNeon };
|
||||
|
||||
|
||||
@ -69,7 +69,7 @@ bool FeaturesExtractor::CheckSilenceComputeFeatures(
|
||||
// into the output vector (normalization based on training data stats).
|
||||
pitch_info_48kHz_ = pitch_estimator_.Estimate(lp_residual_view_);
|
||||
feature_vector[kFeatureVectorSize - 2] =
|
||||
0.01f * (static_cast<int>(pitch_info_48kHz_.period) - 300);
|
||||
0.01f * (pitch_info_48kHz_.period - 300);
|
||||
// Extract lagged frames (according to the estimated pitch period).
|
||||
RTC_DCHECK_LE(pitch_info_48kHz_.period / 2, kMaxPitch24kHz);
|
||||
auto lagged_frame = pitch_buf_24kHz_view_.subview(
|
||||
|
||||
@ -14,6 +14,8 @@
|
||||
#include <vector>
|
||||
|
||||
#include "modules/audio_processing/agc2/rnn_vad/test_utils.h"
|
||||
#include "rtc_base/numerics/safe_compare.h"
|
||||
#include "rtc_base/numerics/safe_conversions.h"
|
||||
// TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
|
||||
// #include "test/fpe_observer.h"
|
||||
#include "test/gtest.h"
|
||||
@ -23,26 +25,25 @@ namespace rnn_vad {
|
||||
namespace test {
|
||||
namespace {
|
||||
|
||||
constexpr size_t ceil(size_t n, size_t m) {
|
||||
constexpr int ceil(int n, int m) {
|
||||
return (n + m - 1) / m;
|
||||
}
|
||||
|
||||
// Number of 10 ms frames required to fill a pitch buffer having size
|
||||
// |kBufSize24kHz|.
|
||||
constexpr size_t kNumTestDataFrames = ceil(kBufSize24kHz, kFrameSize10ms24kHz);
|
||||
constexpr int kNumTestDataFrames = ceil(kBufSize24kHz, kFrameSize10ms24kHz);
|
||||
// Number of samples for the test data.
|
||||
constexpr size_t kNumTestDataSize = kNumTestDataFrames * kFrameSize10ms24kHz;
|
||||
constexpr int kNumTestDataSize = kNumTestDataFrames * kFrameSize10ms24kHz;
|
||||
|
||||
// Verifies that the pitch in Hz is in the detectable range.
|
||||
bool PitchIsValid(float pitch_hz) {
|
||||
const size_t pitch_period =
|
||||
static_cast<size_t>(static_cast<float>(kSampleRate24kHz) / pitch_hz);
|
||||
const int pitch_period = static_cast<float>(kSampleRate24kHz) / pitch_hz;
|
||||
return kInitialMinPitch24kHz <= pitch_period &&
|
||||
pitch_period <= kMaxPitch24kHz;
|
||||
}
|
||||
|
||||
void CreatePureTone(float amplitude, float freq_hz, rtc::ArrayView<float> dst) {
|
||||
for (size_t i = 0; i < dst.size(); ++i) {
|
||||
for (int i = 0; rtc::SafeLt(i, dst.size()); ++i) {
|
||||
dst[i] = amplitude * std::sin(2.f * kPi * freq_hz * i / kSampleRate24kHz);
|
||||
}
|
||||
}
|
||||
@ -56,8 +57,8 @@ bool FeedTestData(FeaturesExtractor* features_extractor,
|
||||
// TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
|
||||
// FloatingPointExceptionObserver fpe_observer;
|
||||
bool is_silence = true;
|
||||
const size_t num_frames = samples.size() / kFrameSize10ms24kHz;
|
||||
for (size_t i = 0; i < num_frames; ++i) {
|
||||
const int num_frames = samples.size() / kFrameSize10ms24kHz;
|
||||
for (int i = 0; i < num_frames; ++i) {
|
||||
is_silence = features_extractor->CheckSilenceComputeFeatures(
|
||||
{samples.data() + i * kFrameSize10ms24kHz, kFrameSize10ms24kHz},
|
||||
feature_vector);
|
||||
@ -79,13 +80,13 @@ TEST(RnnVadTest, FeatureExtractionLowHighPitch) {
|
||||
FeaturesExtractor features_extractor;
|
||||
std::vector<float> samples(kNumTestDataSize);
|
||||
std::vector<float> feature_vector(kFeatureVectorSize);
|
||||
ASSERT_EQ(kFeatureVectorSize, feature_vector.size());
|
||||
ASSERT_EQ(kFeatureVectorSize, rtc::dchecked_cast<int>(feature_vector.size()));
|
||||
rtc::ArrayView<float, kFeatureVectorSize> feature_vector_view(
|
||||
feature_vector.data(), kFeatureVectorSize);
|
||||
|
||||
// Extract the normalized scalar feature that is proportional to the estimated
|
||||
// pitch period.
|
||||
constexpr size_t pitch_feature_index = kFeatureVectorSize - 2;
|
||||
constexpr int pitch_feature_index = kFeatureVectorSize - 2;
|
||||
// Low frequency tone - i.e., high period.
|
||||
CreatePureTone(amplitude, low_pitch_hz, samples);
|
||||
ASSERT_FALSE(FeedTestData(&features_extractor, samples, feature_vector_view));
|
||||
|
||||
@ -28,9 +28,9 @@ namespace {
|
||||
void ComputeAutoCorrelation(
|
||||
rtc::ArrayView<const float> x,
|
||||
rtc::ArrayView<float, kNumLpcCoefficients> auto_corr) {
|
||||
constexpr size_t max_lag = auto_corr.size();
|
||||
constexpr int max_lag = auto_corr.size();
|
||||
RTC_DCHECK_LT(max_lag, x.size());
|
||||
for (size_t lag = 0; lag < max_lag; ++lag) {
|
||||
for (int lag = 0; lag < max_lag; ++lag) {
|
||||
auto_corr[lag] =
|
||||
std::inner_product(x.begin(), x.end() - lag, x.begin() + lag, 0.f);
|
||||
}
|
||||
@ -56,9 +56,9 @@ void ComputeInitialInverseFilterCoefficients(
|
||||
rtc::ArrayView<const float, kNumLpcCoefficients> auto_corr,
|
||||
rtc::ArrayView<float, kNumLpcCoefficients - 1> lpc_coeffs) {
|
||||
float error = auto_corr[0];
|
||||
for (size_t i = 0; i < kNumLpcCoefficients - 1; ++i) {
|
||||
for (int i = 0; i < kNumLpcCoefficients - 1; ++i) {
|
||||
float reflection_coeff = 0.f;
|
||||
for (size_t j = 0; j < i; ++j) {
|
||||
for (int j = 0; j < i; ++j) {
|
||||
reflection_coeff += lpc_coeffs[j] * auto_corr[i - j];
|
||||
}
|
||||
reflection_coeff += auto_corr[i + 1];
|
||||
@ -72,7 +72,7 @@ void ComputeInitialInverseFilterCoefficients(
|
||||
reflection_coeff /= -error;
|
||||
// Update LPC coefficients and total error.
|
||||
lpc_coeffs[i] = reflection_coeff;
|
||||
for (size_t j = 0; j<(i + 1)>> 1; ++j) {
|
||||
for (int j = 0; j < ((i + 1) >> 1); ++j) {
|
||||
const float tmp1 = lpc_coeffs[j];
|
||||
const float tmp2 = lpc_coeffs[i - 1 - j];
|
||||
lpc_coeffs[j] = tmp1 + reflection_coeff * tmp2;
|
||||
|
||||
@ -53,14 +53,14 @@ TEST(RnnVadTest, LpResidualPipelineBitExactness) {
|
||||
std::vector<float> expected_lp_residual(kBufSize24kHz);
|
||||
|
||||
// Test length.
|
||||
const size_t num_frames = std::min(pitch_buf_24kHz_reader.second,
|
||||
static_cast<size_t>(300)); // Max 3 s.
|
||||
const int num_frames =
|
||||
std::min(pitch_buf_24kHz_reader.second, 300); // Max 3 s.
|
||||
ASSERT_GE(lp_residual_reader.second, num_frames);
|
||||
|
||||
{
|
||||
// TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
|
||||
// FloatingPointExceptionObserver fpe_observer;
|
||||
for (size_t i = 0; i < num_frames; ++i) {
|
||||
for (int i = 0; i < num_frames; ++i) {
|
||||
// Read input.
|
||||
ASSERT_TRUE(pitch_buf_24kHz_reader.first->ReadChunk(pitch_buf_data));
|
||||
// Read expected output (ignore pitch gain and period).
|
||||
|
||||
@ -35,9 +35,8 @@ PitchInfo PitchEstimator::Estimate(
|
||||
Decimate2x(pitch_buf, pitch_buf_decimated_view_);
|
||||
auto_corr_calculator_.ComputeOnPitchBuffer(pitch_buf_decimated_view_,
|
||||
auto_corr_view_);
|
||||
CandidatePitchPeriods pitch_candidates_inverted_lags =
|
||||
FindBestPitchPeriods(auto_corr_view_, pitch_buf_decimated_view_,
|
||||
static_cast<int>(kMaxPitch12kHz));
|
||||
CandidatePitchPeriods pitch_candidates_inverted_lags = FindBestPitchPeriods(
|
||||
auto_corr_view_, pitch_buf_decimated_view_, kMaxPitch12kHz);
|
||||
// Refine the pitch period estimation.
|
||||
// The refinement is done using the pitch buffer that contains 24 kHz samples.
|
||||
// Therefore, adapt the inverted lags in |pitch_candidates_inv_lags| from 12
|
||||
@ -47,10 +46,9 @@ PitchInfo PitchEstimator::Estimate(
|
||||
const int pitch_inv_lag_48kHz =
|
||||
RefinePitchPeriod48kHz(pitch_buf, pitch_candidates_inverted_lags);
|
||||
// Look for stronger harmonics to find the final pitch period and its gain.
|
||||
RTC_DCHECK_LT(pitch_inv_lag_48kHz, static_cast<int>(kMaxPitch48kHz));
|
||||
RTC_DCHECK_LT(pitch_inv_lag_48kHz, kMaxPitch48kHz);
|
||||
last_pitch_48kHz_ = CheckLowerPitchPeriodsAndComputePitchGain(
|
||||
pitch_buf, static_cast<int>(kMaxPitch48kHz) - pitch_inv_lag_48kHz,
|
||||
last_pitch_48kHz_);
|
||||
pitch_buf, kMaxPitch48kHz - pitch_inv_lag_48kHz, last_pitch_48kHz_);
|
||||
return last_pitch_48kHz_;
|
||||
}
|
||||
|
||||
|
||||
@ -20,29 +20,27 @@
|
||||
#include "modules/audio_processing/agc2/rnn_vad/common.h"
|
||||
#include "rtc_base/checks.h"
|
||||
#include "rtc_base/numerics/safe_compare.h"
|
||||
#include "rtc_base/numerics/safe_conversions.h"
|
||||
|
||||
namespace webrtc {
|
||||
namespace rnn_vad {
|
||||
namespace {
|
||||
|
||||
constexpr int kMaxPitch24kHzInt = static_cast<int>(kMaxPitch24kHz);
|
||||
|
||||
// Converts a lag to an inverted lag (only for 24kHz).
|
||||
int GetInvertedLag(int lag) {
|
||||
RTC_DCHECK_LE(lag, kMaxPitch24kHzInt);
|
||||
return kMaxPitch24kHzInt - lag;
|
||||
RTC_DCHECK_LE(lag, kMaxPitch24kHz);
|
||||
return kMaxPitch24kHz - lag;
|
||||
}
|
||||
|
||||
float ComputeAutoCorrelationCoeff(rtc::ArrayView<const float> pitch_buf,
|
||||
int inv_lag,
|
||||
int max_pitch_period) {
|
||||
RTC_DCHECK_LT(inv_lag, static_cast<int>(pitch_buf.size()));
|
||||
RTC_DCHECK_LT(max_pitch_period, static_cast<int>(pitch_buf.size()));
|
||||
RTC_DCHECK_LE(inv_lag, static_cast<int>(max_pitch_period));
|
||||
RTC_DCHECK_LT(inv_lag, pitch_buf.size());
|
||||
RTC_DCHECK_LT(max_pitch_period, pitch_buf.size());
|
||||
RTC_DCHECK_LE(inv_lag, max_pitch_period);
|
||||
// TODO(bugs.webrtc.org/9076): Maybe optimize using vectorization.
|
||||
return std::inner_product(
|
||||
pitch_buf.begin() + static_cast<size_t>(max_pitch_period),
|
||||
pitch_buf.end(), pitch_buf.begin() + static_cast<size_t>(inv_lag), 0.f);
|
||||
return std::inner_product(pitch_buf.begin() + max_pitch_period,
|
||||
pitch_buf.end(), pitch_buf.begin() + inv_lag, 0.f);
|
||||
}
|
||||
|
||||
// Given the auto-correlation coefficients for a lag and its neighbors, computes
|
||||
@ -76,14 +74,14 @@ int PitchPseudoInterpolationLagPitchBuf(
|
||||
rtc::ArrayView<const float, kBufSize24kHz> pitch_buf) {
|
||||
int offset = 0;
|
||||
// Cannot apply pseudo-interpolation at the boundaries.
|
||||
if (lag > 0 && lag < kMaxPitch24kHzInt) {
|
||||
if (lag > 0 && lag < kMaxPitch24kHz) {
|
||||
offset = GetPitchPseudoInterpolationOffset(
|
||||
ComputeAutoCorrelationCoeff(pitch_buf, GetInvertedLag(lag - 1),
|
||||
kMaxPitch24kHzInt),
|
||||
kMaxPitch24kHz),
|
||||
ComputeAutoCorrelationCoeff(pitch_buf, GetInvertedLag(lag),
|
||||
kMaxPitch24kHzInt),
|
||||
kMaxPitch24kHz),
|
||||
ComputeAutoCorrelationCoeff(pitch_buf, GetInvertedLag(lag + 1),
|
||||
kMaxPitch24kHzInt));
|
||||
kMaxPitch24kHz));
|
||||
}
|
||||
return 2 * lag + offset;
|
||||
}
|
||||
@ -96,7 +94,7 @@ int PitchPseudoInterpolationInvLagAutoCorr(
|
||||
rtc::ArrayView<const float> auto_corr) {
|
||||
int offset = 0;
|
||||
// Cannot apply pseudo-interpolation at the boundaries.
|
||||
if (inv_lag > 0 && inv_lag < static_cast<int>(auto_corr.size()) - 1) {
|
||||
if (inv_lag > 0 && inv_lag < rtc::dchecked_cast<int>(auto_corr.size()) - 1) {
|
||||
offset = GetPitchPseudoInterpolationOffset(
|
||||
auto_corr[inv_lag + 1], auto_corr[inv_lag], auto_corr[inv_lag - 1]);
|
||||
}
|
||||
@ -143,7 +141,7 @@ void Decimate2x(rtc::ArrayView<const float, kBufSize24kHz> src,
|
||||
rtc::ArrayView<float, kBufSize12kHz> dst) {
|
||||
// TODO(bugs.webrtc.org/9076): Consider adding anti-aliasing filter.
|
||||
static_assert(2 * dst.size() == src.size(), "");
|
||||
for (size_t i = 0; i < dst.size(); ++i) {
|
||||
for (int i = 0; rtc::SafeLt(i, dst.size()); ++i) {
|
||||
dst[i] = src[2 * i];
|
||||
}
|
||||
}
|
||||
@ -186,10 +184,10 @@ float ComputePitchGainThreshold(int candidate_pitch_period,
|
||||
// reduce the chance of false positives caused by a bias towards high
|
||||
// frequencies (originating from short-term correlations).
|
||||
float threshold = std::max(0.3f, 0.7f * g0 - lower_threshold_term);
|
||||
if (static_cast<size_t>(t1) < 3 * kMinPitch24kHz) {
|
||||
if (t1 < 3 * kMinPitch24kHz) {
|
||||
// High frequency.
|
||||
threshold = std::max(0.4f, 0.85f * g0 - lower_threshold_term);
|
||||
} else if (static_cast<size_t>(t1) < 2 * kMinPitch24kHz) {
|
||||
} else if (t1 < 2 * kMinPitch24kHz) {
|
||||
// Even higher frequency.
|
||||
threshold = std::max(0.5f, 0.9f * g0 - lower_threshold_term);
|
||||
}
|
||||
@ -199,10 +197,10 @@ float ComputePitchGainThreshold(int candidate_pitch_period,
|
||||
void ComputeSlidingFrameSquareEnergies(
|
||||
rtc::ArrayView<const float, kBufSize24kHz> pitch_buf,
|
||||
rtc::ArrayView<float, kMaxPitch24kHz + 1> yy_values) {
|
||||
float yy = ComputeAutoCorrelationCoeff(pitch_buf, kMaxPitch24kHzInt,
|
||||
kMaxPitch24kHzInt);
|
||||
float yy =
|
||||
ComputeAutoCorrelationCoeff(pitch_buf, kMaxPitch24kHz, kMaxPitch24kHz);
|
||||
yy_values[0] = yy;
|
||||
for (size_t i = 1; i < yy_values.size(); ++i) {
|
||||
for (int i = 1; rtc::SafeLt(i, yy_values.size()); ++i) {
|
||||
RTC_DCHECK_LE(i, kMaxPitch24kHz + kFrameSize20ms24kHz);
|
||||
RTC_DCHECK_LE(i, kMaxPitch24kHz);
|
||||
const float old_coeff = pitch_buf[kMaxPitch24kHz + kFrameSize20ms24kHz - i];
|
||||
@ -233,9 +231,10 @@ CandidatePitchPeriods FindBestPitchPeriods(
|
||||
}
|
||||
};
|
||||
|
||||
RTC_DCHECK_GT(max_pitch_period, static_cast<int>(auto_corr.size()));
|
||||
RTC_DCHECK_LT(max_pitch_period, static_cast<int>(pitch_buf.size()));
|
||||
const int frame_size = static_cast<int>(pitch_buf.size()) - max_pitch_period;
|
||||
RTC_DCHECK_GT(max_pitch_period, auto_corr.size());
|
||||
RTC_DCHECK_LT(max_pitch_period, pitch_buf.size());
|
||||
const int frame_size =
|
||||
rtc::dchecked_cast<int>(pitch_buf.size()) - max_pitch_period;
|
||||
RTC_DCHECK_GT(frame_size, 0);
|
||||
// TODO(bugs.webrtc.org/9076): Maybe optimize using vectorization.
|
||||
float yy =
|
||||
@ -247,7 +246,7 @@ CandidatePitchPeriods FindBestPitchPeriods(
|
||||
PitchCandidate best;
|
||||
PitchCandidate second_best;
|
||||
second_best.period_inverted_lag = 1;
|
||||
for (int inv_lag = 0; inv_lag < static_cast<int>(auto_corr.size());
|
||||
for (int inv_lag = 0; inv_lag < rtc::dchecked_cast<int>(auto_corr.size());
|
||||
++inv_lag) {
|
||||
// A pitch candidate must have positive correlation.
|
||||
if (auto_corr[inv_lag] > 0) {
|
||||
@ -290,12 +289,12 @@ int RefinePitchPeriod48kHz(
|
||||
++inverted_lag) {
|
||||
if (is_neighbor(inverted_lag, pitch_candidates_inverted_lags.best) ||
|
||||
is_neighbor(inverted_lag, pitch_candidates_inverted_lags.second_best))
|
||||
auto_correlation[inverted_lag] = ComputeAutoCorrelationCoeff(
|
||||
pitch_buf, inverted_lag, kMaxPitch24kHzInt);
|
||||
auto_correlation[inverted_lag] =
|
||||
ComputeAutoCorrelationCoeff(pitch_buf, inverted_lag, kMaxPitch24kHz);
|
||||
}
|
||||
// Find best pitch at 24 kHz.
|
||||
const CandidatePitchPeriods pitch_candidates_24kHz =
|
||||
FindBestPitchPeriods(auto_correlation, pitch_buf, kMaxPitch24kHzInt);
|
||||
FindBestPitchPeriods(auto_correlation, pitch_buf, kMaxPitch24kHz);
|
||||
// Pseudo-interpolation.
|
||||
return PitchPseudoInterpolationInvLagAutoCorr(pitch_candidates_24kHz.best,
|
||||
auto_correlation);
|
||||
@ -334,9 +333,9 @@ PitchInfo CheckLowerPitchPeriodsAndComputePitchGain(
|
||||
// Initial pitch candidate gain.
|
||||
RefinedPitchCandidate best_pitch;
|
||||
best_pitch.period_24kHz =
|
||||
std::min(initial_pitch_period_48kHz / 2, kMaxPitch24kHzInt - 1);
|
||||
std::min(initial_pitch_period_48kHz / 2, kMaxPitch24kHz - 1);
|
||||
best_pitch.xy = ComputeAutoCorrelationCoeff(
|
||||
pitch_buf, GetInvertedLag(best_pitch.period_24kHz), kMaxPitch24kHzInt);
|
||||
pitch_buf, GetInvertedLag(best_pitch.period_24kHz), kMaxPitch24kHz);
|
||||
best_pitch.yy = yy_values[best_pitch.period_24kHz];
|
||||
best_pitch.gain = pitch_gain(best_pitch.xy, best_pitch.yy, xx);
|
||||
|
||||
@ -351,11 +350,10 @@ PitchInfo CheckLowerPitchPeriodsAndComputePitchGain(
|
||||
};
|
||||
// |max_k| such that alternative_period(initial_pitch_period, max_k, 1) equals
|
||||
// kMinPitch24kHz.
|
||||
const int max_k =
|
||||
(2 * initial_pitch_period) / (2 * static_cast<int>(kMinPitch24kHz) - 1);
|
||||
const int max_k = (2 * initial_pitch_period) / (2 * kMinPitch24kHz - 1);
|
||||
for (int k = 2; k <= max_k; ++k) {
|
||||
int candidate_pitch_period = alternative_period(initial_pitch_period, k, 1);
|
||||
RTC_DCHECK_GE(candidate_pitch_period, static_cast<int>(kMinPitch24kHz));
|
||||
RTC_DCHECK_GE(candidate_pitch_period, kMinPitch24kHz);
|
||||
// When looking at |candidate_pitch_period|, we also look at one of its
|
||||
// sub-harmonics. |kSubHarmonicMultipliers| is used to know where to look.
|
||||
// |k| == 2 is a special case since |candidate_pitch_secondary_period| might
|
||||
@ -363,7 +361,7 @@ PitchInfo CheckLowerPitchPeriodsAndComputePitchGain(
|
||||
int candidate_pitch_secondary_period = alternative_period(
|
||||
initial_pitch_period, k, kSubHarmonicMultipliers[k - 2]);
|
||||
RTC_DCHECK_GT(candidate_pitch_secondary_period, 0);
|
||||
if (k == 2 && candidate_pitch_secondary_period > kMaxPitch24kHzInt) {
|
||||
if (k == 2 && candidate_pitch_secondary_period > kMaxPitch24kHz) {
|
||||
candidate_pitch_secondary_period = initial_pitch_period;
|
||||
}
|
||||
RTC_DCHECK_NE(candidate_pitch_period, candidate_pitch_secondary_period)
|
||||
@ -373,10 +371,10 @@ PitchInfo CheckLowerPitchPeriodsAndComputePitchGain(
|
||||
// |candidate_pitch_period| by also looking at its possible sub-harmonic
|
||||
// |candidate_pitch_secondary_period|.
|
||||
float xy_primary_period = ComputeAutoCorrelationCoeff(
|
||||
pitch_buf, GetInvertedLag(candidate_pitch_period), kMaxPitch24kHzInt);
|
||||
pitch_buf, GetInvertedLag(candidate_pitch_period), kMaxPitch24kHz);
|
||||
float xy_secondary_period = ComputeAutoCorrelationCoeff(
|
||||
pitch_buf, GetInvertedLag(candidate_pitch_secondary_period),
|
||||
kMaxPitch24kHzInt);
|
||||
kMaxPitch24kHz);
|
||||
float xy = 0.5f * (xy_primary_period + xy_secondary_period);
|
||||
float yy = 0.5f * (yy_values[candidate_pitch_period] +
|
||||
yy_values[candidate_pitch_secondary_period]);
|
||||
@ -399,7 +397,7 @@ PitchInfo CheckLowerPitchPeriodsAndComputePitchGain(
|
||||
: best_pitch.xy / (best_pitch.yy + 1.f);
|
||||
final_pitch_gain = std::min(best_pitch.gain, final_pitch_gain);
|
||||
int final_pitch_period_48kHz = std::max(
|
||||
static_cast<int>(kMinPitch48kHz),
|
||||
kMinPitch48kHz,
|
||||
PitchPseudoInterpolationLagPitchBuf(best_pitch.period_24kHz, pitch_buf));
|
||||
|
||||
return {final_pitch_period_48kHz, final_pitch_gain};
|
||||
|
||||
@ -34,11 +34,11 @@ constexpr float kTestPitchGainsHigh = 0.75f;
|
||||
class ComputePitchGainThresholdTest
|
||||
: public ::testing::Test,
|
||||
public ::testing::WithParamInterface<std::tuple<
|
||||
/*candidate_pitch_period=*/size_t,
|
||||
/*pitch_period_ratio=*/size_t,
|
||||
/*initial_pitch_period=*/size_t,
|
||||
/*candidate_pitch_period=*/int,
|
||||
/*pitch_period_ratio=*/int,
|
||||
/*initial_pitch_period=*/int,
|
||||
/*initial_pitch_gain=*/float,
|
||||
/*prev_pitch_period=*/size_t,
|
||||
/*prev_pitch_period=*/int,
|
||||
/*prev_pitch_gain=*/float,
|
||||
/*threshold=*/float>> {};
|
||||
|
||||
@ -46,11 +46,11 @@ class ComputePitchGainThresholdTest
|
||||
// data.
|
||||
TEST_P(ComputePitchGainThresholdTest, WithinTolerance) {
|
||||
const auto params = GetParam();
|
||||
const size_t candidate_pitch_period = std::get<0>(params);
|
||||
const size_t pitch_period_ratio = std::get<1>(params);
|
||||
const size_t initial_pitch_period = std::get<2>(params);
|
||||
const int candidate_pitch_period = std::get<0>(params);
|
||||
const int pitch_period_ratio = std::get<1>(params);
|
||||
const int initial_pitch_period = std::get<2>(params);
|
||||
const float initial_pitch_gain = std::get<3>(params);
|
||||
const size_t prev_pitch_period = std::get<4>(params);
|
||||
const int prev_pitch_period = std::get<4>(params);
|
||||
const float prev_pitch_gain = std::get<5>(params);
|
||||
const float threshold = std::get<6>(params);
|
||||
{
|
||||
|
||||
@ -28,22 +28,21 @@ namespace test {
|
||||
// pitch gain is within tolerance given test input data.
|
||||
TEST(RnnVadTest, PitchSearchWithinTolerance) {
|
||||
auto lp_residual_reader = CreateLpResidualAndPitchPeriodGainReader();
|
||||
const size_t num_frames = std::min(lp_residual_reader.second,
|
||||
static_cast<size_t>(300)); // Max 3 s.
|
||||
const int num_frames = std::min(lp_residual_reader.second, 300); // Max 3 s.
|
||||
std::vector<float> lp_residual(kBufSize24kHz);
|
||||
float expected_pitch_period, expected_pitch_gain;
|
||||
PitchEstimator pitch_estimator;
|
||||
{
|
||||
// TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
|
||||
// FloatingPointExceptionObserver fpe_observer;
|
||||
for (size_t i = 0; i < num_frames; ++i) {
|
||||
for (int i = 0; i < num_frames; ++i) {
|
||||
SCOPED_TRACE(i);
|
||||
lp_residual_reader.first->ReadChunk(lp_residual);
|
||||
lp_residual_reader.first->ReadValue(&expected_pitch_period);
|
||||
lp_residual_reader.first->ReadValue(&expected_pitch_gain);
|
||||
PitchInfo pitch_info =
|
||||
pitch_estimator.Estimate({lp_residual.data(), kBufSize24kHz});
|
||||
EXPECT_EQ(static_cast<int>(expected_pitch_period), pitch_info.period);
|
||||
EXPECT_EQ(expected_pitch_period, pitch_info.period);
|
||||
EXPECT_NEAR(expected_pitch_gain, pitch_info.gain, 1e-5f);
|
||||
}
|
||||
}
|
||||
|
||||
@ -21,7 +21,7 @@ namespace webrtc {
|
||||
namespace rnn_vad {
|
||||
|
||||
// Ring buffer for N arrays of type T each one with size S.
|
||||
template <typename T, size_t S, size_t N>
|
||||
template <typename T, int S, int N>
|
||||
class RingBuffer {
|
||||
static_assert(S > 0, "");
|
||||
static_assert(N > 0, "");
|
||||
@ -45,11 +45,10 @@ class RingBuffer {
|
||||
// Return an array view onto the array with a given delay. A view on the last
|
||||
// and least recently push array is returned when |delay| is 0 and N - 1
|
||||
// respectively.
|
||||
rtc::ArrayView<const T, S> GetArrayView(size_t delay) const {
|
||||
const int delay_int = static_cast<int>(delay);
|
||||
RTC_DCHECK_LE(0, delay_int);
|
||||
RTC_DCHECK_LT(delay_int, N);
|
||||
int offset = tail_ - 1 - delay_int;
|
||||
rtc::ArrayView<const T, S> GetArrayView(int delay) const {
|
||||
RTC_DCHECK_LE(0, delay);
|
||||
RTC_DCHECK_LT(delay, N);
|
||||
int offset = tail_ - 1 - delay;
|
||||
if (offset < 0)
|
||||
offset += N;
|
||||
return {buffer_.data() + S * offset, S};
|
||||
|
||||
@ -20,14 +20,14 @@ namespace {
|
||||
// Compare the elements of two given array views.
|
||||
template <typename T, std::ptrdiff_t S>
|
||||
void ExpectEq(rtc::ArrayView<const T, S> a, rtc::ArrayView<const T, S> b) {
|
||||
for (size_t i = 0; i < S; ++i) {
|
||||
for (int i = 0; i < S; ++i) {
|
||||
SCOPED_TRACE(i);
|
||||
EXPECT_EQ(a[i], b[i]);
|
||||
}
|
||||
}
|
||||
|
||||
// Test push/read sequences.
|
||||
template <typename T, size_t S, size_t N>
|
||||
template <typename T, int S, int N>
|
||||
void TestRingBuffer() {
|
||||
SCOPED_TRACE(N);
|
||||
SCOPED_TRACE(S);
|
||||
@ -56,7 +56,7 @@ void TestRingBuffer() {
|
||||
}
|
||||
|
||||
// Check buffer.
|
||||
for (size_t delay = 2; delay < N; ++delay) {
|
||||
for (int delay = 2; delay < N; ++delay) {
|
||||
SCOPED_TRACE(delay);
|
||||
T expected_value = N - static_cast<T>(delay);
|
||||
pushed_array.fill(expected_value);
|
||||
@ -68,18 +68,18 @@ void TestRingBuffer() {
|
||||
|
||||
// Check that for different delays, different views are returned.
|
||||
TEST(RnnVadTest, RingBufferArrayViews) {
|
||||
constexpr size_t s = 3;
|
||||
constexpr size_t n = 4;
|
||||
constexpr int s = 3;
|
||||
constexpr int n = 4;
|
||||
RingBuffer<int, s, n> ring_buf;
|
||||
std::array<int, s> pushed_array;
|
||||
pushed_array.fill(1);
|
||||
for (size_t k = 0; k <= n; ++k) { // Push data n + 1 times.
|
||||
for (int k = 0; k <= n; ++k) { // Push data n + 1 times.
|
||||
SCOPED_TRACE(k);
|
||||
// Check array views.
|
||||
for (size_t i = 0; i < n; ++i) {
|
||||
for (int i = 0; i < n; ++i) {
|
||||
SCOPED_TRACE(i);
|
||||
auto view_i = ring_buf.GetArrayView(i);
|
||||
for (size_t j = i + 1; j < n; ++j) {
|
||||
for (int j = i + 1; j < n; ++j) {
|
||||
SCOPED_TRACE(j);
|
||||
auto view_j = ring_buf.GetArrayView(j);
|
||||
EXPECT_NE(view_i, view_j);
|
||||
|
||||
@ -26,6 +26,7 @@
|
||||
|
||||
#include "rtc_base/checks.h"
|
||||
#include "rtc_base/logging.h"
|
||||
#include "rtc_base/numerics/safe_conversions.h"
|
||||
#include "third_party/rnnoise/src/rnn_activations.h"
|
||||
#include "third_party/rnnoise/src/rnn_vad_weights.h"
|
||||
|
||||
@ -77,15 +78,16 @@ std::vector<float> GetScaledParams(rtc::ArrayView<const int8_t> params) {
|
||||
// Casts and scales |weights| and re-arranges the layout.
|
||||
std::vector<float> GetPreprocessedFcWeights(
|
||||
rtc::ArrayView<const int8_t> weights,
|
||||
size_t output_size) {
|
||||
int output_size) {
|
||||
if (output_size == 1) {
|
||||
return GetScaledParams(weights);
|
||||
}
|
||||
// Transpose, scale and cast.
|
||||
const size_t input_size = rtc::CheckedDivExact(weights.size(), output_size);
|
||||
const int input_size = rtc::CheckedDivExact(
|
||||
rtc::dchecked_cast<int>(weights.size()), output_size);
|
||||
std::vector<float> w(weights.size());
|
||||
for (size_t o = 0; o < output_size; ++o) {
|
||||
for (size_t i = 0; i < input_size; ++i) {
|
||||
for (int o = 0; o < output_size; ++o) {
|
||||
for (int i = 0; i < input_size; ++i) {
|
||||
w[o * input_size + i] = rnnoise::kWeightsScale *
|
||||
static_cast<float>(weights[i * output_size + o]);
|
||||
}
|
||||
@ -93,7 +95,7 @@ std::vector<float> GetPreprocessedFcWeights(
|
||||
return w;
|
||||
}
|
||||
|
||||
constexpr size_t kNumGruGates = 3; // Update, reset, output.
|
||||
constexpr int kNumGruGates = 3; // Update, reset, output.
|
||||
|
||||
// TODO(bugs.chromium.org/10480): Hard-coded optimized layout and remove this
|
||||
// function to improve setup time.
|
||||
@ -101,17 +103,17 @@ constexpr size_t kNumGruGates = 3; // Update, reset, output.
|
||||
// It works both for weights, recurrent weights and bias.
|
||||
std::vector<float> GetPreprocessedGruTensor(
|
||||
rtc::ArrayView<const int8_t> tensor_src,
|
||||
size_t output_size) {
|
||||
int output_size) {
|
||||
// Transpose, cast and scale.
|
||||
// |n| is the size of the first dimension of the 3-dim tensor |weights|.
|
||||
const size_t n =
|
||||
rtc::CheckedDivExact(tensor_src.size(), output_size * kNumGruGates);
|
||||
const size_t stride_src = kNumGruGates * output_size;
|
||||
const size_t stride_dst = n * output_size;
|
||||
const int n = rtc::CheckedDivExact(rtc::dchecked_cast<int>(tensor_src.size()),
|
||||
output_size * kNumGruGates);
|
||||
const int stride_src = kNumGruGates * output_size;
|
||||
const int stride_dst = n * output_size;
|
||||
std::vector<float> tensor_dst(tensor_src.size());
|
||||
for (size_t g = 0; g < kNumGruGates; ++g) {
|
||||
for (size_t o = 0; o < output_size; ++o) {
|
||||
for (size_t i = 0; i < n; ++i) {
|
||||
for (int g = 0; g < kNumGruGates; ++g) {
|
||||
for (int o = 0; o < output_size; ++o) {
|
||||
for (int i = 0; i < n; ++i) {
|
||||
tensor_dst[g * stride_dst + o * n + i] =
|
||||
rnnoise::kWeightsScale *
|
||||
static_cast<float>(
|
||||
@ -122,28 +124,28 @@ std::vector<float> GetPreprocessedGruTensor(
|
||||
return tensor_dst;
|
||||
}
|
||||
|
||||
void ComputeGruUpdateResetGates(size_t input_size,
|
||||
size_t output_size,
|
||||
void ComputeGruUpdateResetGates(int input_size,
|
||||
int output_size,
|
||||
rtc::ArrayView<const float> weights,
|
||||
rtc::ArrayView<const float> recurrent_weights,
|
||||
rtc::ArrayView<const float> bias,
|
||||
rtc::ArrayView<const float> input,
|
||||
rtc::ArrayView<const float> state,
|
||||
rtc::ArrayView<float> gate) {
|
||||
for (size_t o = 0; o < output_size; ++o) {
|
||||
for (int o = 0; o < output_size; ++o) {
|
||||
gate[o] = bias[o];
|
||||
for (size_t i = 0; i < input_size; ++i) {
|
||||
for (int i = 0; i < input_size; ++i) {
|
||||
gate[o] += input[i] * weights[o * input_size + i];
|
||||
}
|
||||
for (size_t s = 0; s < output_size; ++s) {
|
||||
for (int s = 0; s < output_size; ++s) {
|
||||
gate[o] += state[s] * recurrent_weights[o * output_size + s];
|
||||
}
|
||||
gate[o] = SigmoidApproximated(gate[o]);
|
||||
}
|
||||
}
|
||||
|
||||
void ComputeGruOutputGate(size_t input_size,
|
||||
size_t output_size,
|
||||
void ComputeGruOutputGate(int input_size,
|
||||
int output_size,
|
||||
rtc::ArrayView<const float> weights,
|
||||
rtc::ArrayView<const float> recurrent_weights,
|
||||
rtc::ArrayView<const float> bias,
|
||||
@ -151,12 +153,12 @@ void ComputeGruOutputGate(size_t input_size,
|
||||
rtc::ArrayView<const float> state,
|
||||
rtc::ArrayView<const float> reset,
|
||||
rtc::ArrayView<float> gate) {
|
||||
for (size_t o = 0; o < output_size; ++o) {
|
||||
for (int o = 0; o < output_size; ++o) {
|
||||
gate[o] = bias[o];
|
||||
for (size_t i = 0; i < input_size; ++i) {
|
||||
for (int i = 0; i < input_size; ++i) {
|
||||
gate[o] += input[i] * weights[o * input_size + i];
|
||||
}
|
||||
for (size_t s = 0; s < output_size; ++s) {
|
||||
for (int s = 0; s < output_size; ++s) {
|
||||
gate[o] += state[s] * recurrent_weights[o * output_size + s] * reset[s];
|
||||
}
|
||||
gate[o] = RectifiedLinearUnit(gate[o]);
|
||||
@ -164,8 +166,8 @@ void ComputeGruOutputGate(size_t input_size,
|
||||
}
|
||||
|
||||
// Gated recurrent unit (GRU) layer un-optimized implementation.
|
||||
void ComputeGruLayerOutput(size_t input_size,
|
||||
size_t output_size,
|
||||
void ComputeGruLayerOutput(int input_size,
|
||||
int output_size,
|
||||
rtc::ArrayView<const float> input,
|
||||
rtc::ArrayView<const float> weights,
|
||||
rtc::ArrayView<const float> recurrent_weights,
|
||||
@ -173,8 +175,8 @@ void ComputeGruLayerOutput(size_t input_size,
|
||||
rtc::ArrayView<float> state) {
|
||||
RTC_DCHECK_EQ(input_size, input.size());
|
||||
// Stride and offset used to read parameter arrays.
|
||||
const size_t stride_in = input_size * output_size;
|
||||
const size_t stride_out = output_size * output_size;
|
||||
const int stride_in = input_size * output_size;
|
||||
const int stride_out = output_size * output_size;
|
||||
|
||||
// Update gate.
|
||||
std::array<float, kRecurrentLayersMaxUnits> update;
|
||||
@ -198,7 +200,7 @@ void ComputeGruLayerOutput(size_t input_size,
|
||||
bias.subview(2 * output_size, output_size), input, state, reset, output);
|
||||
|
||||
// Update output through the update gates and update the state.
|
||||
for (size_t o = 0; o < output_size; ++o) {
|
||||
for (int o = 0; o < output_size; ++o) {
|
||||
output[o] = update[o] * state[o] + (1.f - update[o]) * output[o];
|
||||
state[o] = output[o];
|
||||
}
|
||||
@ -206,8 +208,8 @@ void ComputeGruLayerOutput(size_t input_size,
|
||||
|
||||
// Fully connected layer un-optimized implementation.
|
||||
void ComputeFullyConnectedLayerOutput(
|
||||
size_t input_size,
|
||||
size_t output_size,
|
||||
int input_size,
|
||||
int output_size,
|
||||
rtc::ArrayView<const float> input,
|
||||
rtc::ArrayView<const float> bias,
|
||||
rtc::ArrayView<const float> weights,
|
||||
@ -216,11 +218,11 @@ void ComputeFullyConnectedLayerOutput(
|
||||
RTC_DCHECK_EQ(input.size(), input_size);
|
||||
RTC_DCHECK_EQ(bias.size(), output_size);
|
||||
RTC_DCHECK_EQ(weights.size(), input_size * output_size);
|
||||
for (size_t o = 0; o < output_size; ++o) {
|
||||
for (int o = 0; o < output_size; ++o) {
|
||||
output[o] = bias[o];
|
||||
// TODO(bugs.chromium.org/9076): Benchmark how different layouts for
|
||||
// |weights_| change the performance across different platforms.
|
||||
for (size_t i = 0; i < input_size; ++i) {
|
||||
for (int i = 0; i < input_size; ++i) {
|
||||
output[o] += input[i] * weights[o * input_size + i];
|
||||
}
|
||||
output[o] = activation_function(output[o]);
|
||||
@ -230,8 +232,8 @@ void ComputeFullyConnectedLayerOutput(
|
||||
#if defined(WEBRTC_ARCH_X86_FAMILY)
|
||||
// Fully connected layer SSE2 implementation.
|
||||
void ComputeFullyConnectedLayerOutputSse2(
|
||||
size_t input_size,
|
||||
size_t output_size,
|
||||
int input_size,
|
||||
int output_size,
|
||||
rtc::ArrayView<const float> input,
|
||||
rtc::ArrayView<const float> bias,
|
||||
rtc::ArrayView<const float> weights,
|
||||
@ -240,16 +242,16 @@ void ComputeFullyConnectedLayerOutputSse2(
|
||||
RTC_DCHECK_EQ(input.size(), input_size);
|
||||
RTC_DCHECK_EQ(bias.size(), output_size);
|
||||
RTC_DCHECK_EQ(weights.size(), input_size * output_size);
|
||||
const size_t input_size_by_4 = input_size >> 2;
|
||||
const size_t offset = input_size & ~3;
|
||||
const int input_size_by_4 = input_size >> 2;
|
||||
const int offset = input_size & ~3;
|
||||
__m128 sum_wx_128;
|
||||
const float* v = reinterpret_cast<const float*>(&sum_wx_128);
|
||||
for (size_t o = 0; o < output_size; ++o) {
|
||||
for (int o = 0; o < output_size; ++o) {
|
||||
// Perform 128 bit vector operations.
|
||||
sum_wx_128 = _mm_set1_ps(0);
|
||||
const float* x_p = input.data();
|
||||
const float* w_p = weights.data() + o * input_size;
|
||||
for (size_t i = 0; i < input_size_by_4; ++i, x_p += 4, w_p += 4) {
|
||||
for (int i = 0; i < input_size_by_4; ++i, x_p += 4, w_p += 4) {
|
||||
sum_wx_128 = _mm_add_ps(sum_wx_128,
|
||||
_mm_mul_ps(_mm_loadu_ps(x_p), _mm_loadu_ps(w_p)));
|
||||
}
|
||||
@ -266,8 +268,8 @@ void ComputeFullyConnectedLayerOutputSse2(
|
||||
} // namespace
|
||||
|
||||
FullyConnectedLayer::FullyConnectedLayer(
|
||||
const size_t input_size,
|
||||
const size_t output_size,
|
||||
const int input_size,
|
||||
const int output_size,
|
||||
const rtc::ArrayView<const int8_t> bias,
|
||||
const rtc::ArrayView<const int8_t> weights,
|
||||
rtc::FunctionView<float(float)> activation_function,
|
||||
@ -316,8 +318,8 @@ void FullyConnectedLayer::ComputeOutput(rtc::ArrayView<const float> input) {
|
||||
}
|
||||
|
||||
GatedRecurrentLayer::GatedRecurrentLayer(
|
||||
const size_t input_size,
|
||||
const size_t output_size,
|
||||
const int input_size,
|
||||
const int output_size,
|
||||
const rtc::ArrayView<const int8_t> bias,
|
||||
const rtc::ArrayView<const int8_t> weights,
|
||||
const rtc::ArrayView<const int8_t> recurrent_weights,
|
||||
|
||||
@ -29,19 +29,19 @@ namespace rnn_vad {
|
||||
// over-allocate space for fully-connected layers output vectors (implemented as
|
||||
// std::array). The value should equal the number of units of the largest
|
||||
// fully-connected layer.
|
||||
constexpr size_t kFullyConnectedLayersMaxUnits = 24;
|
||||
constexpr int kFullyConnectedLayersMaxUnits = 24;
|
||||
|
||||
// Maximum number of units for a recurrent layer. This value is used to
|
||||
// over-allocate space for recurrent layers state vectors (implemented as
|
||||
// std::array). The value should equal the number of units of the largest
|
||||
// recurrent layer.
|
||||
constexpr size_t kRecurrentLayersMaxUnits = 24;
|
||||
constexpr int kRecurrentLayersMaxUnits = 24;
|
||||
|
||||
// Fully-connected layer.
|
||||
class FullyConnectedLayer {
|
||||
public:
|
||||
FullyConnectedLayer(size_t input_size,
|
||||
size_t output_size,
|
||||
FullyConnectedLayer(int input_size,
|
||||
int output_size,
|
||||
rtc::ArrayView<const int8_t> bias,
|
||||
rtc::ArrayView<const int8_t> weights,
|
||||
rtc::FunctionView<float(float)> activation_function,
|
||||
@ -49,16 +49,16 @@ class FullyConnectedLayer {
|
||||
FullyConnectedLayer(const FullyConnectedLayer&) = delete;
|
||||
FullyConnectedLayer& operator=(const FullyConnectedLayer&) = delete;
|
||||
~FullyConnectedLayer();
|
||||
size_t input_size() const { return input_size_; }
|
||||
size_t output_size() const { return output_size_; }
|
||||
int input_size() const { return input_size_; }
|
||||
int output_size() const { return output_size_; }
|
||||
Optimization optimization() const { return optimization_; }
|
||||
rtc::ArrayView<const float> GetOutput() const;
|
||||
// Computes the fully-connected layer output.
|
||||
void ComputeOutput(rtc::ArrayView<const float> input);
|
||||
|
||||
private:
|
||||
const size_t input_size_;
|
||||
const size_t output_size_;
|
||||
const int input_size_;
|
||||
const int output_size_;
|
||||
const std::vector<float> bias_;
|
||||
const std::vector<float> weights_;
|
||||
rtc::FunctionView<float(float)> activation_function_;
|
||||
@ -72,8 +72,8 @@ class FullyConnectedLayer {
|
||||
// activation functions for the update/reset and output gates respectively.
|
||||
class GatedRecurrentLayer {
|
||||
public:
|
||||
GatedRecurrentLayer(size_t input_size,
|
||||
size_t output_size,
|
||||
GatedRecurrentLayer(int input_size,
|
||||
int output_size,
|
||||
rtc::ArrayView<const int8_t> bias,
|
||||
rtc::ArrayView<const int8_t> weights,
|
||||
rtc::ArrayView<const int8_t> recurrent_weights,
|
||||
@ -81,8 +81,8 @@ class GatedRecurrentLayer {
|
||||
GatedRecurrentLayer(const GatedRecurrentLayer&) = delete;
|
||||
GatedRecurrentLayer& operator=(const GatedRecurrentLayer&) = delete;
|
||||
~GatedRecurrentLayer();
|
||||
size_t input_size() const { return input_size_; }
|
||||
size_t output_size() const { return output_size_; }
|
||||
int input_size() const { return input_size_; }
|
||||
int output_size() const { return output_size_; }
|
||||
Optimization optimization() const { return optimization_; }
|
||||
rtc::ArrayView<const float> GetOutput() const;
|
||||
void Reset();
|
||||
@ -90,8 +90,8 @@ class GatedRecurrentLayer {
|
||||
void ComputeOutput(rtc::ArrayView<const float> input);
|
||||
|
||||
private:
|
||||
const size_t input_size_;
|
||||
const size_t output_size_;
|
||||
const int input_size_;
|
||||
const int output_size_;
|
||||
const std::vector<float> bias_;
|
||||
const std::vector<float> weights_;
|
||||
const std::vector<float> recurrent_weights_;
|
||||
|
||||
@ -18,6 +18,7 @@
|
||||
#include "modules/audio_processing/test/performance_timer.h"
|
||||
#include "rtc_base/checks.h"
|
||||
#include "rtc_base/logging.h"
|
||||
#include "rtc_base/numerics/safe_conversions.h"
|
||||
#include "rtc_base/system/arch.h"
|
||||
#include "test/gtest.h"
|
||||
#include "third_party/rnnoise/src/rnn_activations.h"
|
||||
@ -43,15 +44,16 @@ void TestGatedRecurrentLayer(
|
||||
rtc::ArrayView<const float> expected_output_sequence) {
|
||||
RTC_CHECK(gru);
|
||||
auto gru_output_view = gru->GetOutput();
|
||||
const size_t input_sequence_length =
|
||||
rtc::CheckedDivExact(input_sequence.size(), gru->input_size());
|
||||
const size_t output_sequence_length =
|
||||
rtc::CheckedDivExact(expected_output_sequence.size(), gru->output_size());
|
||||
const int input_sequence_length = rtc::CheckedDivExact(
|
||||
rtc::dchecked_cast<int>(input_sequence.size()), gru->input_size());
|
||||
const int output_sequence_length = rtc::CheckedDivExact(
|
||||
rtc::dchecked_cast<int>(expected_output_sequence.size()),
|
||||
gru->output_size());
|
||||
ASSERT_EQ(input_sequence_length, output_sequence_length)
|
||||
<< "The test data length is invalid.";
|
||||
// Feed the GRU layer and check the output at every step.
|
||||
gru->Reset();
|
||||
for (size_t i = 0; i < input_sequence_length; ++i) {
|
||||
for (int i = 0; i < input_sequence_length; ++i) {
|
||||
SCOPED_TRACE(i);
|
||||
gru->ComputeOutput(
|
||||
input_sequence.subview(i * gru->input_size(), gru->input_size()));
|
||||
@ -77,8 +79,8 @@ constexpr std::array<float, 24> kFullyConnectedExpectedOutput = {
|
||||
0.875092f, 0.999846f, 0.997707f, -0.999382f, 0.973153f, -0.966605f};
|
||||
|
||||
// Gated recurrent units layer test data.
|
||||
constexpr size_t kGruInputSize = 5;
|
||||
constexpr size_t kGruOutputSize = 4;
|
||||
constexpr int kGruInputSize = 5;
|
||||
constexpr int kGruOutputSize = 4;
|
||||
constexpr std::array<int8_t, 12> kGruBias = {96, -99, -81, -114, 49, 119,
|
||||
-118, 68, -76, 91, 121, 125};
|
||||
constexpr std::array<int8_t, 60> kGruWeights = {
|
||||
@ -213,10 +215,10 @@ TEST(RnnVadTest, DISABLED_BenchmarkFullyConnectedLayer) {
|
||||
}
|
||||
|
||||
std::vector<Result> results;
|
||||
constexpr size_t number_of_tests = 10000;
|
||||
constexpr int number_of_tests = 10000;
|
||||
for (auto& fc : implementations) {
|
||||
::webrtc::test::PerformanceTimer perf_timer(number_of_tests);
|
||||
for (size_t k = 0; k < number_of_tests; ++k) {
|
||||
for (int k = 0; k < number_of_tests; ++k) {
|
||||
perf_timer.StartTimer();
|
||||
fc->ComputeOutput(kFullyConnectedInputVector);
|
||||
perf_timer.StopTimer();
|
||||
@ -240,17 +242,17 @@ TEST(RnnVadTest, DISABLED_BenchmarkGatedRecurrentLayer) {
|
||||
|
||||
rtc::ArrayView<const float> input_sequence(kGruInputSequence);
|
||||
static_assert(kGruInputSequence.size() % kGruInputSize == 0, "");
|
||||
constexpr size_t input_sequence_length =
|
||||
constexpr int input_sequence_length =
|
||||
kGruInputSequence.size() / kGruInputSize;
|
||||
|
||||
std::vector<Result> results;
|
||||
constexpr size_t number_of_tests = 10000;
|
||||
constexpr int number_of_tests = 10000;
|
||||
for (auto& gru : implementations) {
|
||||
::webrtc::test::PerformanceTimer perf_timer(number_of_tests);
|
||||
gru->Reset();
|
||||
for (size_t k = 0; k < number_of_tests; ++k) {
|
||||
for (int k = 0; k < number_of_tests; ++k) {
|
||||
perf_timer.StartTimer();
|
||||
for (size_t i = 0; i < input_sequence_length; ++i) {
|
||||
for (int i = 0; i < input_sequence_length; ++i) {
|
||||
gru->ComputeOutput(
|
||||
input_sequence.subview(i * gru->input_size(), gru->input_size()));
|
||||
}
|
||||
|
||||
@ -20,6 +20,7 @@
|
||||
#include "modules/audio_processing/agc2/rnn_vad/features_extraction.h"
|
||||
#include "modules/audio_processing/agc2/rnn_vad/rnn.h"
|
||||
#include "rtc_base/logging.h"
|
||||
#include "rtc_base/numerics/safe_compare.h"
|
||||
|
||||
ABSL_FLAG(std::string, i, "", "Path to the input wav file");
|
||||
ABSL_FLAG(std::string, f, "", "Path to the output features file");
|
||||
@ -56,7 +57,7 @@ int main(int argc, char* argv[]) {
|
||||
}
|
||||
|
||||
// Initialize.
|
||||
const size_t frame_size_10ms =
|
||||
const int frame_size_10ms =
|
||||
rtc::CheckedDivExact(wav_reader.sample_rate(), 100);
|
||||
std::vector<float> samples_10ms;
|
||||
samples_10ms.resize(frame_size_10ms);
|
||||
@ -69,9 +70,9 @@ int main(int argc, char* argv[]) {
|
||||
// Compute VAD probabilities.
|
||||
while (true) {
|
||||
// Read frame at the input sample rate.
|
||||
const auto read_samples =
|
||||
const size_t read_samples =
|
||||
wav_reader.ReadSamples(frame_size_10ms, samples_10ms.data());
|
||||
if (read_samples < frame_size_10ms) {
|
||||
if (rtc::SafeLt(read_samples, frame_size_10ms)) {
|
||||
break; // EOF.
|
||||
}
|
||||
// Resample input.
|
||||
|
||||
@ -28,10 +28,10 @@ namespace rnn_vad {
|
||||
namespace test {
|
||||
namespace {
|
||||
|
||||
constexpr size_t kFrameSize10ms48kHz = 480;
|
||||
constexpr int kFrameSize10ms48kHz = 480;
|
||||
|
||||
void DumpPerfStats(size_t num_samples,
|
||||
size_t sample_rate,
|
||||
void DumpPerfStats(int num_samples,
|
||||
int sample_rate,
|
||||
double average_us,
|
||||
double standard_deviation) {
|
||||
float audio_track_length_ms =
|
||||
@ -70,7 +70,7 @@ TEST(RnnVadTest, RnnVadProbabilityWithinTolerance) {
|
||||
auto expected_vad_prob_reader = CreateVadProbsReader();
|
||||
|
||||
// Input length.
|
||||
const size_t num_frames = samples_reader.second;
|
||||
const int num_frames = samples_reader.second;
|
||||
ASSERT_GE(expected_vad_prob_reader.second, num_frames);
|
||||
|
||||
// Init buffers.
|
||||
@ -85,7 +85,7 @@ TEST(RnnVadTest, RnnVadProbabilityWithinTolerance) {
|
||||
|
||||
// Compute VAD probabilities on the downsampled input.
|
||||
float cumulative_error = 0.f;
|
||||
for (size_t i = 0; i < num_frames; ++i) {
|
||||
for (int i = 0; i < num_frames; ++i) {
|
||||
samples_reader.first->ReadChunk(samples_48k);
|
||||
decimator.Resample(samples_48k.data(), samples_48k.size(),
|
||||
samples_24k.data(), samples_24k.size());
|
||||
@ -114,13 +114,13 @@ TEST(RnnVadTest, RnnVadProbabilityWithinTolerance) {
|
||||
TEST(RnnVadTest, DISABLED_RnnVadPerformance) {
|
||||
// PCM samples reader and buffers.
|
||||
auto samples_reader = CreatePcmSamplesReader(kFrameSize10ms48kHz);
|
||||
const size_t num_frames = samples_reader.second;
|
||||
const int num_frames = samples_reader.second;
|
||||
std::array<float, kFrameSize10ms48kHz> samples;
|
||||
// Pre-fetch and decimate samples.
|
||||
PushSincResampler decimator(kFrameSize10ms48kHz, kFrameSize10ms24kHz);
|
||||
std::vector<float> prefetched_decimated_samples;
|
||||
prefetched_decimated_samples.resize(num_frames * kFrameSize10ms24kHz);
|
||||
for (size_t i = 0; i < num_frames; ++i) {
|
||||
for (int i = 0; i < num_frames; ++i) {
|
||||
samples_reader.first->ReadChunk(samples);
|
||||
decimator.Resample(samples.data(), samples.size(),
|
||||
&prefetched_decimated_samples[i * kFrameSize10ms24kHz],
|
||||
@ -130,14 +130,14 @@ TEST(RnnVadTest, DISABLED_RnnVadPerformance) {
|
||||
FeaturesExtractor features_extractor;
|
||||
std::array<float, kFeatureVectorSize> feature_vector;
|
||||
RnnBasedVad rnn_vad;
|
||||
constexpr size_t number_of_tests = 100;
|
||||
constexpr int number_of_tests = 100;
|
||||
::webrtc::test::PerformanceTimer perf_timer(number_of_tests);
|
||||
for (size_t k = 0; k < number_of_tests; ++k) {
|
||||
for (int k = 0; k < number_of_tests; ++k) {
|
||||
features_extractor.Reset();
|
||||
rnn_vad.Reset();
|
||||
// Process frames.
|
||||
perf_timer.StartTimer();
|
||||
for (size_t i = 0; i < num_frames; ++i) {
|
||||
for (int i = 0; i < num_frames; ++i) {
|
||||
bool is_silence = features_extractor.CheckSilenceComputeFeatures(
|
||||
{&prefetched_decimated_samples[i * kFrameSize10ms24kHz],
|
||||
kFrameSize10ms24kHz},
|
||||
|
||||
@ -29,7 +29,7 @@ namespace rnn_vad {
|
||||
// values are written at the end of the buffer.
|
||||
// The class also provides a view on the most recent M values, where 0 < M <= S
|
||||
// and by default M = N.
|
||||
template <typename T, size_t S, size_t N, size_t M = N>
|
||||
template <typename T, int S, int N, int M = N>
|
||||
class SequenceBuffer {
|
||||
static_assert(N <= S,
|
||||
"The new chunk size cannot be larger than the sequence buffer "
|
||||
@ -45,8 +45,8 @@ class SequenceBuffer {
|
||||
SequenceBuffer(const SequenceBuffer&) = delete;
|
||||
SequenceBuffer& operator=(const SequenceBuffer&) = delete;
|
||||
~SequenceBuffer() = default;
|
||||
size_t size() const { return S; }
|
||||
size_t chunks_size() const { return N; }
|
||||
int size() const { return S; }
|
||||
int chunks_size() const { return N; }
|
||||
// Sets the sequence buffer values to zero.
|
||||
void Reset() { std::fill(buffer_.begin(), buffer_.end(), 0); }
|
||||
// Returns a view on the whole buffer.
|
||||
|
||||
@ -20,7 +20,7 @@ namespace rnn_vad {
|
||||
namespace test {
|
||||
namespace {
|
||||
|
||||
template <typename T, size_t S, size_t N>
|
||||
template <typename T, int S, int N>
|
||||
void TestSequenceBufferPushOp() {
|
||||
SCOPED_TRACE(S);
|
||||
SCOPED_TRACE(N);
|
||||
@ -32,8 +32,8 @@ void TestSequenceBufferPushOp() {
|
||||
chunk.fill(1);
|
||||
seq_buf.Push(chunk);
|
||||
chunk.fill(0);
|
||||
constexpr size_t required_push_ops = (S % N) ? S / N + 1 : S / N;
|
||||
for (size_t i = 0; i < required_push_ops - 1; ++i) {
|
||||
constexpr int required_push_ops = (S % N) ? S / N + 1 : S / N;
|
||||
for (int i = 0; i < required_push_ops - 1; ++i) {
|
||||
SCOPED_TRACE(i);
|
||||
seq_buf.Push(chunk);
|
||||
// Still in the buffer.
|
||||
@ -48,12 +48,12 @@ void TestSequenceBufferPushOp() {
|
||||
// Check that the last item moves left by N positions after a push op.
|
||||
if (S > N) {
|
||||
// Fill in with non-zero values.
|
||||
for (size_t i = 0; i < N; ++i)
|
||||
for (int i = 0; i < N; ++i)
|
||||
chunk[i] = static_cast<T>(i + 1);
|
||||
seq_buf.Push(chunk);
|
||||
// With the next Push(), |last| will be moved left by N positions.
|
||||
const T last = chunk[N - 1];
|
||||
for (size_t i = 0; i < N; ++i)
|
||||
for (int i = 0; i < N; ++i)
|
||||
chunk[i] = static_cast<T>(last + i + 1);
|
||||
seq_buf.Push(chunk);
|
||||
EXPECT_EQ(last, seq_buf_view[S - N - 1]);
|
||||
@ -63,8 +63,8 @@ void TestSequenceBufferPushOp() {
|
||||
} // namespace
|
||||
|
||||
TEST(RnnVadTest, SequenceBufferGetters) {
|
||||
constexpr size_t buffer_size = 8;
|
||||
constexpr size_t chunk_size = 8;
|
||||
constexpr int buffer_size = 8;
|
||||
constexpr int chunk_size = 8;
|
||||
SequenceBuffer<int, buffer_size, chunk_size> seq_buf;
|
||||
EXPECT_EQ(buffer_size, seq_buf.size());
|
||||
EXPECT_EQ(chunk_size, seq_buf.chunks_size());
|
||||
|
||||
@ -16,6 +16,7 @@
|
||||
#include <numeric>
|
||||
|
||||
#include "rtc_base/checks.h"
|
||||
#include "rtc_base/numerics/safe_compare.h"
|
||||
|
||||
namespace webrtc {
|
||||
namespace rnn_vad {
|
||||
@ -32,11 +33,11 @@ void UpdateCepstralDifferenceStats(
|
||||
RTC_DCHECK(sym_matrix_buf);
|
||||
// Compute the new cepstral distance stats.
|
||||
std::array<float, kCepstralCoeffsHistorySize - 1> distances;
|
||||
for (size_t i = 0; i < kCepstralCoeffsHistorySize - 1; ++i) {
|
||||
const size_t delay = i + 1;
|
||||
for (int i = 0; i < kCepstralCoeffsHistorySize - 1; ++i) {
|
||||
const int delay = i + 1;
|
||||
auto old_cepstral_coeffs = ring_buf.GetArrayView(delay);
|
||||
distances[i] = 0.f;
|
||||
for (size_t k = 0; k < kNumBands; ++k) {
|
||||
for (int k = 0; k < kNumBands; ++k) {
|
||||
const float c = new_cepstral_coeffs[k] - old_cepstral_coeffs[k];
|
||||
distances[i] += c * c;
|
||||
}
|
||||
@ -48,9 +49,9 @@ void UpdateCepstralDifferenceStats(
|
||||
// Computes the first half of the Vorbis window.
|
||||
std::array<float, kFrameSize20ms24kHz / 2> ComputeScaledHalfVorbisWindow(
|
||||
float scaling = 1.f) {
|
||||
constexpr size_t kHalfSize = kFrameSize20ms24kHz / 2;
|
||||
constexpr int kHalfSize = kFrameSize20ms24kHz / 2;
|
||||
std::array<float, kHalfSize> half_window{};
|
||||
for (size_t i = 0; i < kHalfSize; ++i) {
|
||||
for (int i = 0; i < kHalfSize; ++i) {
|
||||
half_window[i] =
|
||||
scaling *
|
||||
std::sin(0.5 * kPi * std::sin(0.5 * kPi * (i + 0.5) / kHalfSize) *
|
||||
@ -71,8 +72,8 @@ void ComputeWindowedForwardFft(
|
||||
RTC_DCHECK_EQ(frame.size(), 2 * half_window.size());
|
||||
// Apply windowing.
|
||||
auto in = fft_input_buffer->GetView();
|
||||
for (size_t i = 0, j = kFrameSize20ms24kHz - 1; i < half_window.size();
|
||||
++i, --j) {
|
||||
for (int i = 0, j = kFrameSize20ms24kHz - 1;
|
||||
rtc::SafeLt(i, half_window.size()); ++i, --j) {
|
||||
in[i] = frame[i] * half_window[i];
|
||||
in[j] = frame[j] * half_window[i];
|
||||
}
|
||||
@ -162,7 +163,7 @@ void SpectralFeaturesExtractor::ComputeAvgAndDerivatives(
|
||||
RTC_DCHECK_EQ(average.size(), first_derivative.size());
|
||||
RTC_DCHECK_EQ(first_derivative.size(), second_derivative.size());
|
||||
RTC_DCHECK_LE(average.size(), curr.size());
|
||||
for (size_t i = 0; i < average.size(); ++i) {
|
||||
for (int i = 0; rtc::SafeLt(i, average.size()); ++i) {
|
||||
// Average, kernel: [1, 1, 1].
|
||||
average[i] = curr[i] + prev1[i] + prev2[i];
|
||||
// First derivative, kernel: [1, 0, - 1].
|
||||
@ -178,7 +179,7 @@ void SpectralFeaturesExtractor::ComputeNormalizedCepstralCorrelation(
|
||||
reference_frame_fft_->GetConstView(), lagged_frame_fft_->GetConstView(),
|
||||
bands_cross_corr_);
|
||||
// Normalize.
|
||||
for (size_t i = 0; i < bands_cross_corr_.size(); ++i) {
|
||||
for (int i = 0; rtc::SafeLt(i, bands_cross_corr_.size()); ++i) {
|
||||
bands_cross_corr_[i] =
|
||||
bands_cross_corr_[i] /
|
||||
std::sqrt(0.001f + reference_frame_bands_energy_[i] *
|
||||
@ -194,9 +195,9 @@ void SpectralFeaturesExtractor::ComputeNormalizedCepstralCorrelation(
|
||||
float SpectralFeaturesExtractor::ComputeVariability() const {
|
||||
// Compute cepstral variability score.
|
||||
float variability = 0.f;
|
||||
for (size_t delay1 = 0; delay1 < kCepstralCoeffsHistorySize; ++delay1) {
|
||||
for (int delay1 = 0; delay1 < kCepstralCoeffsHistorySize; ++delay1) {
|
||||
float min_dist = std::numeric_limits<float>::max();
|
||||
for (size_t delay2 = 0; delay2 < kCepstralCoeffsHistorySize; ++delay2) {
|
||||
for (int delay2 = 0; delay2 < kCepstralCoeffsHistorySize; ++delay2) {
|
||||
if (delay1 == delay2) // The distance would be 0.
|
||||
continue;
|
||||
min_dist =
|
||||
|
||||
@ -15,6 +15,7 @@
|
||||
#include <cstddef>
|
||||
|
||||
#include "rtc_base/checks.h"
|
||||
#include "rtc_base/numerics/safe_compare.h"
|
||||
|
||||
namespace webrtc {
|
||||
namespace rnn_vad {
|
||||
@ -105,9 +106,9 @@ void SpectralCorrelator::ComputeCrossCorrelation(
|
||||
RTC_DCHECK_EQ(x[1], 0.f) << "The Nyquist coefficient must be zeroed.";
|
||||
RTC_DCHECK_EQ(y[1], 0.f) << "The Nyquist coefficient must be zeroed.";
|
||||
constexpr auto kOpusScaleNumBins24kHz20ms = GetOpusScaleNumBins24kHz20ms();
|
||||
size_t k = 0; // Next Fourier coefficient index.
|
||||
int k = 0; // Next Fourier coefficient index.
|
||||
cross_corr[0] = 0.f;
|
||||
for (size_t i = 0; i < kOpusBands24kHz - 1; ++i) {
|
||||
for (int i = 0; i < kOpusBands24kHz - 1; ++i) {
|
||||
cross_corr[i + 1] = 0.f;
|
||||
for (int j = 0; j < kOpusScaleNumBins24kHz20ms[i]; ++j) { // Band size.
|
||||
const float v = x[2 * k] * y[2 * k] + x[2 * k + 1] * y[2 * k + 1];
|
||||
@ -137,11 +138,11 @@ void ComputeSmoothedLogMagnitudeSpectrum(
|
||||
return x;
|
||||
};
|
||||
// Smoothing over the bands for which the band energy is defined.
|
||||
for (size_t i = 0; i < bands_energy.size(); ++i) {
|
||||
for (int i = 0; rtc::SafeLt(i, bands_energy.size()); ++i) {
|
||||
log_bands_energy[i] = smooth(std::log10(kOneByHundred + bands_energy[i]));
|
||||
}
|
||||
// Smoothing over the remaining bands (zero energy).
|
||||
for (size_t i = bands_energy.size(); i < kNumBands; ++i) {
|
||||
for (int i = bands_energy.size(); i < kNumBands; ++i) {
|
||||
log_bands_energy[i] = smooth(kLogOneByHundred);
|
||||
}
|
||||
}
|
||||
@ -149,8 +150,8 @@ void ComputeSmoothedLogMagnitudeSpectrum(
|
||||
std::array<float, kNumBands * kNumBands> ComputeDctTable() {
|
||||
std::array<float, kNumBands * kNumBands> dct_table;
|
||||
const double k = std::sqrt(0.5);
|
||||
for (size_t i = 0; i < kNumBands; ++i) {
|
||||
for (size_t j = 0; j < kNumBands; ++j)
|
||||
for (int i = 0; i < kNumBands; ++i) {
|
||||
for (int j = 0; j < kNumBands; ++j)
|
||||
dct_table[i * kNumBands + j] = std::cos((i + 0.5) * j * kPi / kNumBands);
|
||||
dct_table[i * kNumBands] *= k;
|
||||
}
|
||||
@ -173,9 +174,9 @@ void ComputeDct(rtc::ArrayView<const float> in,
|
||||
RTC_DCHECK_LE(in.size(), kNumBands);
|
||||
RTC_DCHECK_LE(1, out.size());
|
||||
RTC_DCHECK_LE(out.size(), in.size());
|
||||
for (size_t i = 0; i < out.size(); ++i) {
|
||||
for (int i = 0; rtc::SafeLt(i, out.size()); ++i) {
|
||||
out[i] = 0.f;
|
||||
for (size_t j = 0; j < in.size(); ++j) {
|
||||
for (int j = 0; rtc::SafeLt(j, in.size()); ++j) {
|
||||
out[i] += in[j] * dct_table[j * kNumBands + i];
|
||||
}
|
||||
// TODO(bugs.webrtc.org/10480): Scaling factor in the DCT table.
|
||||
|
||||
@ -25,7 +25,7 @@ namespace rnn_vad {
|
||||
// At a sample rate of 24 kHz, the last 3 Opus bands are beyond the Nyquist
|
||||
// frequency. However, band #19 gets the contributions from band #18 because
|
||||
// of the symmetric triangular filter with peak response at 12 kHz.
|
||||
constexpr size_t kOpusBands24kHz = 20;
|
||||
constexpr int kOpusBands24kHz = 20;
|
||||
static_assert(kOpusBands24kHz < kNumBands,
|
||||
"The number of bands at 24 kHz must be less than those defined "
|
||||
"in the Opus scale at 48 kHz.");
|
||||
|
||||
@ -19,6 +19,7 @@
|
||||
#include "api/array_view.h"
|
||||
#include "modules/audio_processing/agc2/rnn_vad/test_utils.h"
|
||||
#include "modules/audio_processing/utility/pffft_wrapper.h"
|
||||
#include "rtc_base/numerics/safe_compare.h"
|
||||
// TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
|
||||
// #include "test/fpe_observer.h"
|
||||
#include "test/gtest.h"
|
||||
@ -34,13 +35,13 @@ namespace {
|
||||
std::vector<float> ComputeTriangularFiltersWeights() {
|
||||
constexpr auto kOpusScaleNumBins24kHz20ms = GetOpusScaleNumBins24kHz20ms();
|
||||
const auto& v = kOpusScaleNumBins24kHz20ms; // Alias.
|
||||
const size_t num_weights = std::accumulate(
|
||||
kOpusScaleNumBins24kHz20ms.begin(), kOpusScaleNumBins24kHz20ms.end(), 0);
|
||||
const int num_weights = std::accumulate(kOpusScaleNumBins24kHz20ms.begin(),
|
||||
kOpusScaleNumBins24kHz20ms.end(), 0);
|
||||
std::vector<float> weights(num_weights);
|
||||
size_t next_fft_coeff_index = 0;
|
||||
for (size_t band = 0; band < v.size(); ++band) {
|
||||
const size_t band_size = v[band];
|
||||
for (size_t j = 0; j < band_size; ++j) {
|
||||
int next_fft_coeff_index = 0;
|
||||
for (int band = 0; rtc::SafeLt(band, v.size()); ++band) {
|
||||
const int band_size = v[band];
|
||||
for (int j = 0; rtc::SafeLt(j, band_size); ++j) {
|
||||
weights[next_fft_coeff_index + j] = static_cast<float>(j) / band_size;
|
||||
}
|
||||
next_fft_coeff_index += band_size;
|
||||
@ -58,7 +59,7 @@ TEST(RnnVadTest, TestOpusScaleBoundaries) {
|
||||
3200, 4000, 4800, 5600, 6800, 8000, 9600, 12000, 15600, 20000};
|
||||
constexpr auto kOpusScaleNumBins24kHz20ms = GetOpusScaleNumBins24kHz20ms();
|
||||
int prev = 0;
|
||||
for (size_t i = 0; i < kOpusScaleNumBins24kHz20ms.size(); ++i) {
|
||||
for (int i = 0; rtc::SafeLt(i, kOpusScaleNumBins24kHz20ms.size()); ++i) {
|
||||
int boundary =
|
||||
kBandFrequencyBoundariesHz[i] * kFrameSize20ms24kHz / kSampleRate24kHz;
|
||||
EXPECT_EQ(kOpusScaleNumBins24kHz20ms[i], boundary - prev);
|
||||
@ -72,8 +73,8 @@ TEST(RnnVadTest, TestOpusScaleBoundaries) {
|
||||
// is updated accordingly.
|
||||
TEST(RnnVadTest, DISABLED_TestOpusScaleWeights) {
|
||||
auto weights = ComputeTriangularFiltersWeights();
|
||||
size_t i = 0;
|
||||
for (size_t band_size : GetOpusScaleNumBins24kHz20ms()) {
|
||||
int i = 0;
|
||||
for (int band_size : GetOpusScaleNumBins24kHz20ms()) {
|
||||
SCOPED_TRACE(band_size);
|
||||
rtc::ArrayView<float> band_weights(weights.data() + i, band_size);
|
||||
float prev = -1.f;
|
||||
@ -98,7 +99,7 @@ TEST(RnnVadTest, SpectralCorrelatorValidOutput) {
|
||||
// Compute and check output.
|
||||
SpectralCorrelator e;
|
||||
e.ComputeAutoCorrelation(in_view, out);
|
||||
for (size_t i = 0; i < kOpusBands24kHz; ++i) {
|
||||
for (int i = 0; i < kOpusBands24kHz; ++i) {
|
||||
SCOPED_TRACE(i);
|
||||
EXPECT_GT(out[i], 0.f);
|
||||
}
|
||||
|
||||
@ -14,6 +14,7 @@
|
||||
|
||||
#include "modules/audio_processing/agc2/rnn_vad/test_utils.h"
|
||||
#include "rtc_base/checks.h"
|
||||
#include "rtc_base/numerics/safe_compare.h"
|
||||
// TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
|
||||
// #include "test/fpe_observer.h"
|
||||
#include "test/gtest.h"
|
||||
@ -23,11 +24,11 @@ namespace rnn_vad {
|
||||
namespace test {
|
||||
namespace {
|
||||
|
||||
constexpr size_t kTestFeatureVectorSize = kNumBands + 3 * kNumLowerBands + 1;
|
||||
constexpr int kTestFeatureVectorSize = kNumBands + 3 * kNumLowerBands + 1;
|
||||
|
||||
// Writes non-zero sample values.
|
||||
void WriteTestData(rtc::ArrayView<float> samples) {
|
||||
for (size_t i = 0; i < samples.size(); ++i) {
|
||||
for (int i = 0; rtc::SafeLt(i, samples.size()); ++i) {
|
||||
samples[i] = i % 100;
|
||||
}
|
||||
}
|
||||
@ -124,7 +125,7 @@ TEST(RnnVadTest, CepstralFeaturesConstantAverageZeroDerivative) {
|
||||
|
||||
// Fill the spectral features with test data.
|
||||
std::array<float, kTestFeatureVectorSize> feature_vector;
|
||||
for (size_t i = 0; i < kCepstralCoeffsHistorySize; ++i) {
|
||||
for (int i = 0; i < kCepstralCoeffsHistorySize; ++i) {
|
||||
is_silence = sfe.CheckSilenceComputeFeatures(
|
||||
samples_view, samples_view, GetHigherBandsSpectrum(&feature_vector),
|
||||
GetAverage(&feature_vector), GetFirstDerivative(&feature_vector),
|
||||
|
||||
@ -18,6 +18,7 @@
|
||||
|
||||
#include "api/array_view.h"
|
||||
#include "rtc_base/checks.h"
|
||||
#include "rtc_base/numerics/safe_compare.h"
|
||||
|
||||
namespace webrtc {
|
||||
namespace rnn_vad {
|
||||
@ -29,7 +30,7 @@ namespace rnn_vad {
|
||||
// removed when one of the two corresponding items that have been compared is
|
||||
// removed from the ring buffer. It is assumed that the comparison is symmetric
|
||||
// and that comparing an item with itself is not needed.
|
||||
template <typename T, size_t S>
|
||||
template <typename T, int S>
|
||||
class SymmetricMatrixBuffer {
|
||||
static_assert(S > 2, "");
|
||||
|
||||
@ -55,9 +56,9 @@ class SymmetricMatrixBuffer {
|
||||
// column left.
|
||||
std::memmove(buf_.data(), buf_.data() + S, (buf_.size() - S) * sizeof(T));
|
||||
// Copy new values in the last column in the right order.
|
||||
for (size_t i = 0; i < values.size(); ++i) {
|
||||
const size_t index = (S - 1 - i) * (S - 1) - 1;
|
||||
RTC_DCHECK_LE(static_cast<size_t>(0), index);
|
||||
for (int i = 0; rtc::SafeLt(i, values.size()); ++i) {
|
||||
const int index = (S - 1 - i) * (S - 1) - 1;
|
||||
RTC_DCHECK_GE(index, 0);
|
||||
RTC_DCHECK_LT(index, buf_.size());
|
||||
buf_[index] = values[i];
|
||||
}
|
||||
@ -65,9 +66,9 @@ class SymmetricMatrixBuffer {
|
||||
// Reads the value that corresponds to comparison of two items in the ring
|
||||
// buffer having delay |delay1| and |delay2|. The two arguments must not be
|
||||
// equal and both must be in {0, ..., S - 1}.
|
||||
T GetValue(size_t delay1, size_t delay2) const {
|
||||
int row = S - 1 - static_cast<int>(delay1);
|
||||
int col = S - 1 - static_cast<int>(delay2);
|
||||
T GetValue(int delay1, int delay2) const {
|
||||
int row = S - 1 - delay1;
|
||||
int col = S - 1 - delay2;
|
||||
RTC_DCHECK_NE(row, col) << "The diagonal cannot be accessed.";
|
||||
if (row > col)
|
||||
std::swap(row, col); // Swap to access the upper-right triangular part.
|
||||
|
||||
@ -18,10 +18,10 @@ namespace rnn_vad {
|
||||
namespace test {
|
||||
namespace {
|
||||
|
||||
template <typename T, size_t S>
|
||||
template <typename T, int S>
|
||||
void CheckSymmetry(const SymmetricMatrixBuffer<T, S>* sym_matrix_buf) {
|
||||
for (size_t row = 0; row < S - 1; ++row)
|
||||
for (size_t col = row + 1; col < S; ++col)
|
||||
for (int row = 0; row < S - 1; ++row)
|
||||
for (int col = row + 1; col < S; ++col)
|
||||
EXPECT_EQ(sym_matrix_buf->GetValue(row, col),
|
||||
sym_matrix_buf->GetValue(col, row));
|
||||
}
|
||||
@ -30,12 +30,12 @@ using PairType = std::pair<int, int>;
|
||||
|
||||
// Checks that the symmetric matrix buffer contains any pair with a value equal
|
||||
// to the given one.
|
||||
template <size_t S>
|
||||
template <int S>
|
||||
bool CheckPairsWithValueExist(
|
||||
const SymmetricMatrixBuffer<PairType, S>* sym_matrix_buf,
|
||||
const int value) {
|
||||
for (size_t row = 0; row < S - 1; ++row) {
|
||||
for (size_t col = row + 1; col < S; ++col) {
|
||||
for (int row = 0; row < S - 1; ++row) {
|
||||
for (int col = row + 1; col < S; ++col) {
|
||||
auto p = sym_matrix_buf->GetValue(row, col);
|
||||
if (p.first == value || p.second == value)
|
||||
return true;
|
||||
@ -52,7 +52,7 @@ bool CheckPairsWithValueExist(
|
||||
TEST(RnnVadTest, SymmetricMatrixBufferUseCase) {
|
||||
// Instance a ring buffer which will be fed with a series of integer values.
|
||||
constexpr int kRingBufSize = 10;
|
||||
RingBuffer<int, 1, static_cast<size_t>(kRingBufSize)> ring_buf;
|
||||
RingBuffer<int, 1, kRingBufSize> ring_buf;
|
||||
// Instance a symmetric matrix buffer for the ring buffer above. It stores
|
||||
// pairs of integers with which this test can easily check that the evolution
|
||||
// of RingBuffer and SymmetricMatrixBuffer match.
|
||||
@ -81,8 +81,8 @@ TEST(RnnVadTest, SymmetricMatrixBufferUseCase) {
|
||||
CheckSymmetry(&sym_matrix_buf);
|
||||
// Check that the pairs resulting from the content in the ring buffer are
|
||||
// in the right position.
|
||||
for (size_t delay1 = 0; delay1 < kRingBufSize - 1; ++delay1) {
|
||||
for (size_t delay2 = delay1 + 1; delay2 < kRingBufSize; ++delay2) {
|
||||
for (int delay1 = 0; delay1 < kRingBufSize - 1; ++delay1) {
|
||||
for (int delay2 = delay1 + 1; delay2 < kRingBufSize; ++delay2) {
|
||||
const auto t1 = ring_buf.GetArrayView(delay1)[0];
|
||||
const auto t2 = ring_buf.GetArrayView(delay2)[0];
|
||||
ASSERT_LE(t2, t1);
|
||||
@ -93,7 +93,7 @@ TEST(RnnVadTest, SymmetricMatrixBufferUseCase) {
|
||||
}
|
||||
// Check that every older element in the ring buffer still has a
|
||||
// corresponding pair in the symmetric matrix buffer.
|
||||
for (size_t delay = 1; delay < kRingBufSize; ++delay) {
|
||||
for (int delay = 1; delay < kRingBufSize; ++delay) {
|
||||
const auto t_prev = ring_buf.GetArrayView(delay)[0];
|
||||
EXPECT_TRUE(CheckPairsWithValueExist(&sym_matrix_buf, t_prev));
|
||||
}
|
||||
|
||||
@ -13,6 +13,7 @@
|
||||
#include <memory>
|
||||
|
||||
#include "rtc_base/checks.h"
|
||||
#include "rtc_base/numerics/safe_compare.h"
|
||||
#include "rtc_base/system/arch.h"
|
||||
#include "system_wrappers/include/cpu_features_wrapper.h"
|
||||
#include "test/gtest.h"
|
||||
@ -24,7 +25,7 @@ namespace test {
|
||||
namespace {
|
||||
|
||||
using ReaderPairType =
|
||||
std::pair<std::unique_ptr<BinaryFileReader<float>>, const size_t>;
|
||||
std::pair<std::unique_ptr<BinaryFileReader<float>>, const int>;
|
||||
|
||||
} // namespace
|
||||
|
||||
@ -33,7 +34,7 @@ using webrtc::test::ResourcePath;
|
||||
void ExpectEqualFloatArray(rtc::ArrayView<const float> expected,
|
||||
rtc::ArrayView<const float> computed) {
|
||||
ASSERT_EQ(expected.size(), computed.size());
|
||||
for (size_t i = 0; i < expected.size(); ++i) {
|
||||
for (int i = 0; rtc::SafeLt(i, expected.size()); ++i) {
|
||||
SCOPED_TRACE(i);
|
||||
EXPECT_FLOAT_EQ(expected[i], computed[i]);
|
||||
}
|
||||
@ -43,14 +44,14 @@ void ExpectNearAbsolute(rtc::ArrayView<const float> expected,
|
||||
rtc::ArrayView<const float> computed,
|
||||
float tolerance) {
|
||||
ASSERT_EQ(expected.size(), computed.size());
|
||||
for (size_t i = 0; i < expected.size(); ++i) {
|
||||
for (int i = 0; rtc::SafeLt(i, expected.size()); ++i) {
|
||||
SCOPED_TRACE(i);
|
||||
EXPECT_NEAR(expected[i], computed[i], tolerance);
|
||||
}
|
||||
}
|
||||
|
||||
std::pair<std::unique_ptr<BinaryFileReader<int16_t, float>>, const size_t>
|
||||
CreatePcmSamplesReader(const size_t frame_length) {
|
||||
std::pair<std::unique_ptr<BinaryFileReader<int16_t, float>>, const int>
|
||||
CreatePcmSamplesReader(const int frame_length) {
|
||||
auto ptr = std::make_unique<BinaryFileReader<int16_t, float>>(
|
||||
test::ResourcePath("audio_processing/agc2/rnn_vad/samples", "pcm"),
|
||||
frame_length);
|
||||
@ -59,14 +60,14 @@ CreatePcmSamplesReader(const size_t frame_length) {
|
||||
}
|
||||
|
||||
ReaderPairType CreatePitchBuffer24kHzReader() {
|
||||
constexpr size_t cols = 864;
|
||||
constexpr int cols = 864;
|
||||
auto ptr = std::make_unique<BinaryFileReader<float>>(
|
||||
ResourcePath("audio_processing/agc2/rnn_vad/pitch_buf_24k", "dat"), cols);
|
||||
return {std::move(ptr), rtc::CheckedDivExact(ptr->data_length(), cols)};
|
||||
}
|
||||
|
||||
ReaderPairType CreateLpResidualAndPitchPeriodGainReader() {
|
||||
constexpr size_t num_lp_residual_coeffs = 864;
|
||||
constexpr int num_lp_residual_coeffs = 864;
|
||||
auto ptr = std::make_unique<BinaryFileReader<float>>(
|
||||
ResourcePath("audio_processing/agc2/rnn_vad/pitch_lp_res", "dat"),
|
||||
num_lp_residual_coeffs);
|
||||
@ -83,7 +84,7 @@ ReaderPairType CreateVadProbsReader() {
|
||||
PitchTestData::PitchTestData() {
|
||||
BinaryFileReader<float> test_data_reader(
|
||||
ResourcePath("audio_processing/agc2/rnn_vad/pitch_search_int", "dat"),
|
||||
static_cast<size_t>(1396));
|
||||
1396);
|
||||
test_data_reader.ReadChunk(test_data_);
|
||||
}
|
||||
|
||||
|
||||
@ -24,6 +24,7 @@
|
||||
#include "api/array_view.h"
|
||||
#include "modules/audio_processing/agc2/rnn_vad/common.h"
|
||||
#include "rtc_base/checks.h"
|
||||
#include "rtc_base/numerics/safe_compare.h"
|
||||
|
||||
namespace webrtc {
|
||||
namespace rnn_vad {
|
||||
@ -47,7 +48,7 @@ void ExpectNearAbsolute(rtc::ArrayView<const float> expected,
|
||||
template <typename T, typename D = T>
|
||||
class BinaryFileReader {
|
||||
public:
|
||||
explicit BinaryFileReader(const std::string& file_path, size_t chunk_size = 0)
|
||||
BinaryFileReader(const std::string& file_path, int chunk_size = 0)
|
||||
: is_(file_path, std::ios::binary | std::ios::ate),
|
||||
data_length_(is_.tellg() / sizeof(T)),
|
||||
chunk_size_(chunk_size) {
|
||||
@ -58,7 +59,7 @@ class BinaryFileReader {
|
||||
BinaryFileReader(const BinaryFileReader&) = delete;
|
||||
BinaryFileReader& operator=(const BinaryFileReader&) = delete;
|
||||
~BinaryFileReader() = default;
|
||||
size_t data_length() const { return data_length_; }
|
||||
int data_length() const { return data_length_; }
|
||||
bool ReadValue(D* dst) {
|
||||
if (std::is_same<T, D>::value) {
|
||||
is_.read(reinterpret_cast<char*>(dst), sizeof(T));
|
||||
@ -72,7 +73,7 @@ class BinaryFileReader {
|
||||
// If |chunk_size| was specified in the ctor, it will check that the size of
|
||||
// |dst| equals |chunk_size|.
|
||||
bool ReadChunk(rtc::ArrayView<D> dst) {
|
||||
RTC_DCHECK((chunk_size_ == 0) || (chunk_size_ == dst.size()));
|
||||
RTC_DCHECK((chunk_size_ == 0) || rtc::SafeEq(chunk_size_, dst.size()));
|
||||
const std::streamsize bytes_to_read = dst.size() * sizeof(T);
|
||||
if (std::is_same<T, D>::value) {
|
||||
is_.read(reinterpret_cast<char*>(dst.data()), bytes_to_read);
|
||||
@ -83,13 +84,13 @@ class BinaryFileReader {
|
||||
}
|
||||
return is_.gcount() == bytes_to_read;
|
||||
}
|
||||
void SeekForward(size_t items) { is_.seekg(items * sizeof(T), is_.cur); }
|
||||
void SeekForward(int items) { is_.seekg(items * sizeof(T), is_.cur); }
|
||||
void SeekBeginning() { is_.seekg(0, is_.beg); }
|
||||
|
||||
private:
|
||||
std::ifstream is_;
|
||||
const size_t data_length_;
|
||||
const size_t chunk_size_;
|
||||
const int data_length_;
|
||||
const int chunk_size_;
|
||||
std::vector<T> buf_;
|
||||
};
|
||||
|
||||
@ -117,22 +118,22 @@ class BinaryFileWriter {
|
||||
// pointer and the second the number of chunks that can be read from the file.
|
||||
// Creates a reader for the PCM samples that casts from S16 to float and reads
|
||||
// chunks with length |frame_length|.
|
||||
std::pair<std::unique_ptr<BinaryFileReader<int16_t, float>>, const size_t>
|
||||
CreatePcmSamplesReader(const size_t frame_length);
|
||||
std::pair<std::unique_ptr<BinaryFileReader<int16_t, float>>, const int>
|
||||
CreatePcmSamplesReader(const int frame_length);
|
||||
// Creates a reader for the pitch buffer content at 24 kHz.
|
||||
std::pair<std::unique_ptr<BinaryFileReader<float>>, const size_t>
|
||||
std::pair<std::unique_ptr<BinaryFileReader<float>>, const int>
|
||||
CreatePitchBuffer24kHzReader();
|
||||
// Creates a reader for the the LP residual coefficients and the pitch period
|
||||
// and gain values.
|
||||
std::pair<std::unique_ptr<BinaryFileReader<float>>, const size_t>
|
||||
std::pair<std::unique_ptr<BinaryFileReader<float>>, const int>
|
||||
CreateLpResidualAndPitchPeriodGainReader();
|
||||
// Creates a reader for the VAD probabilities.
|
||||
std::pair<std::unique_ptr<BinaryFileReader<float>>, const size_t>
|
||||
std::pair<std::unique_ptr<BinaryFileReader<float>>, const int>
|
||||
CreateVadProbsReader();
|
||||
|
||||
constexpr size_t kNumPitchBufAutoCorrCoeffs = 147;
|
||||
constexpr size_t kNumPitchBufSquareEnergies = 385;
|
||||
constexpr size_t kPitchTestDataSize =
|
||||
constexpr int kNumPitchBufAutoCorrCoeffs = 147;
|
||||
constexpr int kNumPitchBufSquareEnergies = 385;
|
||||
constexpr int kPitchTestDataSize =
|
||||
kBufSize24kHz + kNumPitchBufSquareEnergies + kNumPitchBufAutoCorrCoeffs;
|
||||
|
||||
// Class to retrieve a test pitch buffer content and the expected output for the
|
||||
|
||||
Reference in New Issue
Block a user