AGC2 RNN VAD size_t -> int

Motivation: read "On Unsigned Integers" section in
https://google.github.io/styleguide/cppguide.html#Integer_Types

Plus, improved readability by getting rid of a bunch of
`static_cast<int>`s.

Bug: webrtc:10480
Change-Id: I911aa8cd08f5ccde4ee6f23534240d1faa84cdea
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/190880
Commit-Queue: Alessio Bazzica <alessiob@webrtc.org>
Reviewed-by: Karl Wiberg <kwiberg@webrtc.org>
Cr-Commit-Position: refs/heads/master@{#32524}
This commit is contained in:
Alessio Bazzica
2020-10-29 20:50:13 +01:00
committed by Commit Bot
parent bee5983b8f
commit f622ba725e
30 changed files with 292 additions and 276 deletions

View File

@ -33,6 +33,8 @@ rtc_library("rnn_vad") {
"../../../../api:function_view",
"../../../../rtc_base:checks",
"../../../../rtc_base:logging",
"../../../../rtc_base:safe_compare",
"../../../../rtc_base:safe_conversions",
"../../../../rtc_base/system:arch",
"//third_party/rnnoise:rnn_vad",
]
@ -93,6 +95,7 @@ rtc_library("rnn_vad_pitch") {
"../../../../api:array_view",
"../../../../rtc_base:checks",
"../../../../rtc_base:safe_compare",
"../../../../rtc_base:safe_conversions",
]
}
@ -125,6 +128,7 @@ rtc_library("rnn_vad_spectral_features") {
":rnn_vad_symmetric_matrix_buffer",
"../../../../api:array_view",
"../../../../rtc_base:checks",
"../../../../rtc_base:safe_compare",
"../../utility:pffft_wrapper",
]
}
@ -134,6 +138,7 @@ rtc_source_set("rnn_vad_symmetric_matrix_buffer") {
deps = [
"../../../../api:array_view",
"../../../../rtc_base:checks",
"../../../../rtc_base:safe_compare",
]
}
@ -150,6 +155,7 @@ if (rtc_include_tests) {
"../../../../api:array_view",
"../../../../api:scoped_refptr",
"../../../../rtc_base:checks",
"../../../../rtc_base:safe_compare",
"../../../../rtc_base/system:arch",
"../../../../system_wrappers",
"../../../../test:fileutils",
@ -206,6 +212,8 @@ if (rtc_include_tests) {
"../../../../common_audio/",
"../../../../rtc_base:checks",
"../../../../rtc_base:logging",
"../../../../rtc_base:safe_compare",
"../../../../rtc_base:safe_conversions",
"../../../../rtc_base/system:arch",
"../../../../test:test_support",
"../../utility:pffft_wrapper",
@ -227,6 +235,7 @@ if (rtc_include_tests) {
"../../../../api:array_view",
"../../../../common_audio",
"../../../../rtc_base:rtc_base_approved",
"../../../../rtc_base:safe_compare",
"../../../../test:test_support",
"//third_party/abseil-cpp/absl/flags:flag",
"//third_party/abseil-cpp/absl/flags:parse",

View File

@ -48,8 +48,8 @@ void AutoCorrelationCalculator::ComputeOnPitchBuffer(
rtc::ArrayView<float, kNumInvertedLags12kHz> auto_corr) {
RTC_DCHECK_LT(auto_corr.size(), kMaxPitch12kHz);
RTC_DCHECK_GT(pitch_buf.size(), kMaxPitch12kHz);
constexpr size_t kFftFrameSize = 1 << kAutoCorrelationFftOrder;
constexpr size_t kConvolutionLength = kBufSize12kHz - kMaxPitch12kHz;
constexpr int kFftFrameSize = 1 << kAutoCorrelationFftOrder;
constexpr int kConvolutionLength = kBufSize12kHz - kMaxPitch12kHz;
static_assert(kConvolutionLength == kFrameSize20ms12kHz,
"Mismatch between pitch buffer size, frame size and maximum "
"pitch period.");

View File

@ -54,7 +54,7 @@ TEST(RnnVadTest, CheckAutoCorrelationOnConstantPitchBuffer) {
}
// The expected output is a vector filled with the same expected
// auto-correlation value. The latter equals the length of a 20 ms frame.
constexpr size_t kFrameSize20ms12kHz = kFrameSize20ms24kHz / 2;
constexpr int kFrameSize20ms12kHz = kFrameSize20ms24kHz / 2;
std::array<float, kNumPitchBufAutoCorrCoeffs> expected_output;
std::fill(expected_output.begin(), expected_output.end(),
static_cast<float>(kFrameSize20ms12kHz));

View File

@ -18,52 +18,52 @@ namespace rnn_vad {
constexpr double kPi = 3.14159265358979323846;
constexpr size_t kSampleRate24kHz = 24000;
constexpr size_t kFrameSize10ms24kHz = kSampleRate24kHz / 100;
constexpr size_t kFrameSize20ms24kHz = kFrameSize10ms24kHz * 2;
constexpr int kSampleRate24kHz = 24000;
constexpr int kFrameSize10ms24kHz = kSampleRate24kHz / 100;
constexpr int kFrameSize20ms24kHz = kFrameSize10ms24kHz * 2;
// Pitch buffer.
constexpr size_t kMinPitch24kHz = kSampleRate24kHz / 800; // 0.00125 s.
constexpr size_t kMaxPitch24kHz = kSampleRate24kHz / 62.5; // 0.016 s.
constexpr size_t kBufSize24kHz = kMaxPitch24kHz + kFrameSize20ms24kHz;
constexpr int kMinPitch24kHz = kSampleRate24kHz / 800; // 0.00125 s.
constexpr int kMaxPitch24kHz = kSampleRate24kHz / 62.5; // 0.016 s.
constexpr int kBufSize24kHz = kMaxPitch24kHz + kFrameSize20ms24kHz;
static_assert((kBufSize24kHz & 1) == 0, "The buffer size must be even.");
// 24 kHz analysis.
// Define a higher minimum pitch period for the initial search. This is used to
// avoid searching for very short periods, for which a refinement step is
// responsible.
constexpr size_t kInitialMinPitch24kHz = 3 * kMinPitch24kHz;
constexpr int kInitialMinPitch24kHz = 3 * kMinPitch24kHz;
static_assert(kMinPitch24kHz < kInitialMinPitch24kHz, "");
static_assert(kInitialMinPitch24kHz < kMaxPitch24kHz, "");
static_assert(kMaxPitch24kHz > kInitialMinPitch24kHz, "");
constexpr size_t kNumInvertedLags24kHz = kMaxPitch24kHz - kInitialMinPitch24kHz;
constexpr int kNumInvertedLags24kHz = kMaxPitch24kHz - kInitialMinPitch24kHz;
// 12 kHz analysis.
constexpr size_t kSampleRate12kHz = 12000;
constexpr size_t kFrameSize10ms12kHz = kSampleRate12kHz / 100;
constexpr size_t kFrameSize20ms12kHz = kFrameSize10ms12kHz * 2;
constexpr size_t kBufSize12kHz = kBufSize24kHz / 2;
constexpr size_t kInitialMinPitch12kHz = kInitialMinPitch24kHz / 2;
constexpr size_t kMaxPitch12kHz = kMaxPitch24kHz / 2;
constexpr int kSampleRate12kHz = 12000;
constexpr int kFrameSize10ms12kHz = kSampleRate12kHz / 100;
constexpr int kFrameSize20ms12kHz = kFrameSize10ms12kHz * 2;
constexpr int kBufSize12kHz = kBufSize24kHz / 2;
constexpr int kInitialMinPitch12kHz = kInitialMinPitch24kHz / 2;
constexpr int kMaxPitch12kHz = kMaxPitch24kHz / 2;
static_assert(kMaxPitch12kHz > kInitialMinPitch12kHz, "");
// The inverted lags for the pitch interval [|kInitialMinPitch12kHz|,
// |kMaxPitch12kHz|] are in the range [0, |kNumInvertedLags12kHz|].
constexpr size_t kNumInvertedLags12kHz = kMaxPitch12kHz - kInitialMinPitch12kHz;
constexpr int kNumInvertedLags12kHz = kMaxPitch12kHz - kInitialMinPitch12kHz;
// 48 kHz constants.
constexpr size_t kMinPitch48kHz = kMinPitch24kHz * 2;
constexpr size_t kMaxPitch48kHz = kMaxPitch24kHz * 2;
constexpr int kMinPitch48kHz = kMinPitch24kHz * 2;
constexpr int kMaxPitch48kHz = kMaxPitch24kHz * 2;
// Spectral features.
constexpr size_t kNumBands = 22;
constexpr size_t kNumLowerBands = 6;
constexpr int kNumBands = 22;
constexpr int kNumLowerBands = 6;
static_assert((0 < kNumLowerBands) && (kNumLowerBands < kNumBands), "");
constexpr size_t kCepstralCoeffsHistorySize = 8;
constexpr int kCepstralCoeffsHistorySize = 8;
static_assert(kCepstralCoeffsHistorySize > 2,
"The history size must at least be 3 to compute first and second "
"derivatives.");
constexpr size_t kFeatureVectorSize = 42;
constexpr int kFeatureVectorSize = 42;
enum class Optimization { kNone, kSse2, kNeon };

View File

@ -69,7 +69,7 @@ bool FeaturesExtractor::CheckSilenceComputeFeatures(
// into the output vector (normalization based on training data stats).
pitch_info_48kHz_ = pitch_estimator_.Estimate(lp_residual_view_);
feature_vector[kFeatureVectorSize - 2] =
0.01f * (static_cast<int>(pitch_info_48kHz_.period) - 300);
0.01f * (pitch_info_48kHz_.period - 300);
// Extract lagged frames (according to the estimated pitch period).
RTC_DCHECK_LE(pitch_info_48kHz_.period / 2, kMaxPitch24kHz);
auto lagged_frame = pitch_buf_24kHz_view_.subview(

View File

@ -14,6 +14,8 @@
#include <vector>
#include "modules/audio_processing/agc2/rnn_vad/test_utils.h"
#include "rtc_base/numerics/safe_compare.h"
#include "rtc_base/numerics/safe_conversions.h"
// TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
// #include "test/fpe_observer.h"
#include "test/gtest.h"
@ -23,26 +25,25 @@ namespace rnn_vad {
namespace test {
namespace {
constexpr size_t ceil(size_t n, size_t m) {
constexpr int ceil(int n, int m) {
return (n + m - 1) / m;
}
// Number of 10 ms frames required to fill a pitch buffer having size
// |kBufSize24kHz|.
constexpr size_t kNumTestDataFrames = ceil(kBufSize24kHz, kFrameSize10ms24kHz);
constexpr int kNumTestDataFrames = ceil(kBufSize24kHz, kFrameSize10ms24kHz);
// Number of samples for the test data.
constexpr size_t kNumTestDataSize = kNumTestDataFrames * kFrameSize10ms24kHz;
constexpr int kNumTestDataSize = kNumTestDataFrames * kFrameSize10ms24kHz;
// Verifies that the pitch in Hz is in the detectable range.
bool PitchIsValid(float pitch_hz) {
const size_t pitch_period =
static_cast<size_t>(static_cast<float>(kSampleRate24kHz) / pitch_hz);
const int pitch_period = static_cast<float>(kSampleRate24kHz) / pitch_hz;
return kInitialMinPitch24kHz <= pitch_period &&
pitch_period <= kMaxPitch24kHz;
}
void CreatePureTone(float amplitude, float freq_hz, rtc::ArrayView<float> dst) {
for (size_t i = 0; i < dst.size(); ++i) {
for (int i = 0; rtc::SafeLt(i, dst.size()); ++i) {
dst[i] = amplitude * std::sin(2.f * kPi * freq_hz * i / kSampleRate24kHz);
}
}
@ -56,8 +57,8 @@ bool FeedTestData(FeaturesExtractor* features_extractor,
// TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
// FloatingPointExceptionObserver fpe_observer;
bool is_silence = true;
const size_t num_frames = samples.size() / kFrameSize10ms24kHz;
for (size_t i = 0; i < num_frames; ++i) {
const int num_frames = samples.size() / kFrameSize10ms24kHz;
for (int i = 0; i < num_frames; ++i) {
is_silence = features_extractor->CheckSilenceComputeFeatures(
{samples.data() + i * kFrameSize10ms24kHz, kFrameSize10ms24kHz},
feature_vector);
@ -79,13 +80,13 @@ TEST(RnnVadTest, FeatureExtractionLowHighPitch) {
FeaturesExtractor features_extractor;
std::vector<float> samples(kNumTestDataSize);
std::vector<float> feature_vector(kFeatureVectorSize);
ASSERT_EQ(kFeatureVectorSize, feature_vector.size());
ASSERT_EQ(kFeatureVectorSize, rtc::dchecked_cast<int>(feature_vector.size()));
rtc::ArrayView<float, kFeatureVectorSize> feature_vector_view(
feature_vector.data(), kFeatureVectorSize);
// Extract the normalized scalar feature that is proportional to the estimated
// pitch period.
constexpr size_t pitch_feature_index = kFeatureVectorSize - 2;
constexpr int pitch_feature_index = kFeatureVectorSize - 2;
// Low frequency tone - i.e., high period.
CreatePureTone(amplitude, low_pitch_hz, samples);
ASSERT_FALSE(FeedTestData(&features_extractor, samples, feature_vector_view));

View File

@ -28,9 +28,9 @@ namespace {
void ComputeAutoCorrelation(
rtc::ArrayView<const float> x,
rtc::ArrayView<float, kNumLpcCoefficients> auto_corr) {
constexpr size_t max_lag = auto_corr.size();
constexpr int max_lag = auto_corr.size();
RTC_DCHECK_LT(max_lag, x.size());
for (size_t lag = 0; lag < max_lag; ++lag) {
for (int lag = 0; lag < max_lag; ++lag) {
auto_corr[lag] =
std::inner_product(x.begin(), x.end() - lag, x.begin() + lag, 0.f);
}
@ -56,9 +56,9 @@ void ComputeInitialInverseFilterCoefficients(
rtc::ArrayView<const float, kNumLpcCoefficients> auto_corr,
rtc::ArrayView<float, kNumLpcCoefficients - 1> lpc_coeffs) {
float error = auto_corr[0];
for (size_t i = 0; i < kNumLpcCoefficients - 1; ++i) {
for (int i = 0; i < kNumLpcCoefficients - 1; ++i) {
float reflection_coeff = 0.f;
for (size_t j = 0; j < i; ++j) {
for (int j = 0; j < i; ++j) {
reflection_coeff += lpc_coeffs[j] * auto_corr[i - j];
}
reflection_coeff += auto_corr[i + 1];
@ -72,7 +72,7 @@ void ComputeInitialInverseFilterCoefficients(
reflection_coeff /= -error;
// Update LPC coefficients and total error.
lpc_coeffs[i] = reflection_coeff;
for (size_t j = 0; j<(i + 1)>> 1; ++j) {
for (int j = 0; j < ((i + 1) >> 1); ++j) {
const float tmp1 = lpc_coeffs[j];
const float tmp2 = lpc_coeffs[i - 1 - j];
lpc_coeffs[j] = tmp1 + reflection_coeff * tmp2;

View File

@ -53,14 +53,14 @@ TEST(RnnVadTest, LpResidualPipelineBitExactness) {
std::vector<float> expected_lp_residual(kBufSize24kHz);
// Test length.
const size_t num_frames = std::min(pitch_buf_24kHz_reader.second,
static_cast<size_t>(300)); // Max 3 s.
const int num_frames =
std::min(pitch_buf_24kHz_reader.second, 300); // Max 3 s.
ASSERT_GE(lp_residual_reader.second, num_frames);
{
// TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
// FloatingPointExceptionObserver fpe_observer;
for (size_t i = 0; i < num_frames; ++i) {
for (int i = 0; i < num_frames; ++i) {
// Read input.
ASSERT_TRUE(pitch_buf_24kHz_reader.first->ReadChunk(pitch_buf_data));
// Read expected output (ignore pitch gain and period).

View File

@ -35,9 +35,8 @@ PitchInfo PitchEstimator::Estimate(
Decimate2x(pitch_buf, pitch_buf_decimated_view_);
auto_corr_calculator_.ComputeOnPitchBuffer(pitch_buf_decimated_view_,
auto_corr_view_);
CandidatePitchPeriods pitch_candidates_inverted_lags =
FindBestPitchPeriods(auto_corr_view_, pitch_buf_decimated_view_,
static_cast<int>(kMaxPitch12kHz));
CandidatePitchPeriods pitch_candidates_inverted_lags = FindBestPitchPeriods(
auto_corr_view_, pitch_buf_decimated_view_, kMaxPitch12kHz);
// Refine the pitch period estimation.
// The refinement is done using the pitch buffer that contains 24 kHz samples.
// Therefore, adapt the inverted lags in |pitch_candidates_inv_lags| from 12
@ -47,10 +46,9 @@ PitchInfo PitchEstimator::Estimate(
const int pitch_inv_lag_48kHz =
RefinePitchPeriod48kHz(pitch_buf, pitch_candidates_inverted_lags);
// Look for stronger harmonics to find the final pitch period and its gain.
RTC_DCHECK_LT(pitch_inv_lag_48kHz, static_cast<int>(kMaxPitch48kHz));
RTC_DCHECK_LT(pitch_inv_lag_48kHz, kMaxPitch48kHz);
last_pitch_48kHz_ = CheckLowerPitchPeriodsAndComputePitchGain(
pitch_buf, static_cast<int>(kMaxPitch48kHz) - pitch_inv_lag_48kHz,
last_pitch_48kHz_);
pitch_buf, kMaxPitch48kHz - pitch_inv_lag_48kHz, last_pitch_48kHz_);
return last_pitch_48kHz_;
}

View File

@ -20,29 +20,27 @@
#include "modules/audio_processing/agc2/rnn_vad/common.h"
#include "rtc_base/checks.h"
#include "rtc_base/numerics/safe_compare.h"
#include "rtc_base/numerics/safe_conversions.h"
namespace webrtc {
namespace rnn_vad {
namespace {
constexpr int kMaxPitch24kHzInt = static_cast<int>(kMaxPitch24kHz);
// Converts a lag to an inverted lag (only for 24kHz).
int GetInvertedLag(int lag) {
RTC_DCHECK_LE(lag, kMaxPitch24kHzInt);
return kMaxPitch24kHzInt - lag;
RTC_DCHECK_LE(lag, kMaxPitch24kHz);
return kMaxPitch24kHz - lag;
}
float ComputeAutoCorrelationCoeff(rtc::ArrayView<const float> pitch_buf,
int inv_lag,
int max_pitch_period) {
RTC_DCHECK_LT(inv_lag, static_cast<int>(pitch_buf.size()));
RTC_DCHECK_LT(max_pitch_period, static_cast<int>(pitch_buf.size()));
RTC_DCHECK_LE(inv_lag, static_cast<int>(max_pitch_period));
RTC_DCHECK_LT(inv_lag, pitch_buf.size());
RTC_DCHECK_LT(max_pitch_period, pitch_buf.size());
RTC_DCHECK_LE(inv_lag, max_pitch_period);
// TODO(bugs.webrtc.org/9076): Maybe optimize using vectorization.
return std::inner_product(
pitch_buf.begin() + static_cast<size_t>(max_pitch_period),
pitch_buf.end(), pitch_buf.begin() + static_cast<size_t>(inv_lag), 0.f);
return std::inner_product(pitch_buf.begin() + max_pitch_period,
pitch_buf.end(), pitch_buf.begin() + inv_lag, 0.f);
}
// Given the auto-correlation coefficients for a lag and its neighbors, computes
@ -76,14 +74,14 @@ int PitchPseudoInterpolationLagPitchBuf(
rtc::ArrayView<const float, kBufSize24kHz> pitch_buf) {
int offset = 0;
// Cannot apply pseudo-interpolation at the boundaries.
if (lag > 0 && lag < kMaxPitch24kHzInt) {
if (lag > 0 && lag < kMaxPitch24kHz) {
offset = GetPitchPseudoInterpolationOffset(
ComputeAutoCorrelationCoeff(pitch_buf, GetInvertedLag(lag - 1),
kMaxPitch24kHzInt),
kMaxPitch24kHz),
ComputeAutoCorrelationCoeff(pitch_buf, GetInvertedLag(lag),
kMaxPitch24kHzInt),
kMaxPitch24kHz),
ComputeAutoCorrelationCoeff(pitch_buf, GetInvertedLag(lag + 1),
kMaxPitch24kHzInt));
kMaxPitch24kHz));
}
return 2 * lag + offset;
}
@ -96,7 +94,7 @@ int PitchPseudoInterpolationInvLagAutoCorr(
rtc::ArrayView<const float> auto_corr) {
int offset = 0;
// Cannot apply pseudo-interpolation at the boundaries.
if (inv_lag > 0 && inv_lag < static_cast<int>(auto_corr.size()) - 1) {
if (inv_lag > 0 && inv_lag < rtc::dchecked_cast<int>(auto_corr.size()) - 1) {
offset = GetPitchPseudoInterpolationOffset(
auto_corr[inv_lag + 1], auto_corr[inv_lag], auto_corr[inv_lag - 1]);
}
@ -143,7 +141,7 @@ void Decimate2x(rtc::ArrayView<const float, kBufSize24kHz> src,
rtc::ArrayView<float, kBufSize12kHz> dst) {
// TODO(bugs.webrtc.org/9076): Consider adding anti-aliasing filter.
static_assert(2 * dst.size() == src.size(), "");
for (size_t i = 0; i < dst.size(); ++i) {
for (int i = 0; rtc::SafeLt(i, dst.size()); ++i) {
dst[i] = src[2 * i];
}
}
@ -186,10 +184,10 @@ float ComputePitchGainThreshold(int candidate_pitch_period,
// reduce the chance of false positives caused by a bias towards high
// frequencies (originating from short-term correlations).
float threshold = std::max(0.3f, 0.7f * g0 - lower_threshold_term);
if (static_cast<size_t>(t1) < 3 * kMinPitch24kHz) {
if (t1 < 3 * kMinPitch24kHz) {
// High frequency.
threshold = std::max(0.4f, 0.85f * g0 - lower_threshold_term);
} else if (static_cast<size_t>(t1) < 2 * kMinPitch24kHz) {
} else if (t1 < 2 * kMinPitch24kHz) {
// Even higher frequency.
threshold = std::max(0.5f, 0.9f * g0 - lower_threshold_term);
}
@ -199,10 +197,10 @@ float ComputePitchGainThreshold(int candidate_pitch_period,
void ComputeSlidingFrameSquareEnergies(
rtc::ArrayView<const float, kBufSize24kHz> pitch_buf,
rtc::ArrayView<float, kMaxPitch24kHz + 1> yy_values) {
float yy = ComputeAutoCorrelationCoeff(pitch_buf, kMaxPitch24kHzInt,
kMaxPitch24kHzInt);
float yy =
ComputeAutoCorrelationCoeff(pitch_buf, kMaxPitch24kHz, kMaxPitch24kHz);
yy_values[0] = yy;
for (size_t i = 1; i < yy_values.size(); ++i) {
for (int i = 1; rtc::SafeLt(i, yy_values.size()); ++i) {
RTC_DCHECK_LE(i, kMaxPitch24kHz + kFrameSize20ms24kHz);
RTC_DCHECK_LE(i, kMaxPitch24kHz);
const float old_coeff = pitch_buf[kMaxPitch24kHz + kFrameSize20ms24kHz - i];
@ -233,9 +231,10 @@ CandidatePitchPeriods FindBestPitchPeriods(
}
};
RTC_DCHECK_GT(max_pitch_period, static_cast<int>(auto_corr.size()));
RTC_DCHECK_LT(max_pitch_period, static_cast<int>(pitch_buf.size()));
const int frame_size = static_cast<int>(pitch_buf.size()) - max_pitch_period;
RTC_DCHECK_GT(max_pitch_period, auto_corr.size());
RTC_DCHECK_LT(max_pitch_period, pitch_buf.size());
const int frame_size =
rtc::dchecked_cast<int>(pitch_buf.size()) - max_pitch_period;
RTC_DCHECK_GT(frame_size, 0);
// TODO(bugs.webrtc.org/9076): Maybe optimize using vectorization.
float yy =
@ -247,7 +246,7 @@ CandidatePitchPeriods FindBestPitchPeriods(
PitchCandidate best;
PitchCandidate second_best;
second_best.period_inverted_lag = 1;
for (int inv_lag = 0; inv_lag < static_cast<int>(auto_corr.size());
for (int inv_lag = 0; inv_lag < rtc::dchecked_cast<int>(auto_corr.size());
++inv_lag) {
// A pitch candidate must have positive correlation.
if (auto_corr[inv_lag] > 0) {
@ -290,12 +289,12 @@ int RefinePitchPeriod48kHz(
++inverted_lag) {
if (is_neighbor(inverted_lag, pitch_candidates_inverted_lags.best) ||
is_neighbor(inverted_lag, pitch_candidates_inverted_lags.second_best))
auto_correlation[inverted_lag] = ComputeAutoCorrelationCoeff(
pitch_buf, inverted_lag, kMaxPitch24kHzInt);
auto_correlation[inverted_lag] =
ComputeAutoCorrelationCoeff(pitch_buf, inverted_lag, kMaxPitch24kHz);
}
// Find best pitch at 24 kHz.
const CandidatePitchPeriods pitch_candidates_24kHz =
FindBestPitchPeriods(auto_correlation, pitch_buf, kMaxPitch24kHzInt);
FindBestPitchPeriods(auto_correlation, pitch_buf, kMaxPitch24kHz);
// Pseudo-interpolation.
return PitchPseudoInterpolationInvLagAutoCorr(pitch_candidates_24kHz.best,
auto_correlation);
@ -334,9 +333,9 @@ PitchInfo CheckLowerPitchPeriodsAndComputePitchGain(
// Initial pitch candidate gain.
RefinedPitchCandidate best_pitch;
best_pitch.period_24kHz =
std::min(initial_pitch_period_48kHz / 2, kMaxPitch24kHzInt - 1);
std::min(initial_pitch_period_48kHz / 2, kMaxPitch24kHz - 1);
best_pitch.xy = ComputeAutoCorrelationCoeff(
pitch_buf, GetInvertedLag(best_pitch.period_24kHz), kMaxPitch24kHzInt);
pitch_buf, GetInvertedLag(best_pitch.period_24kHz), kMaxPitch24kHz);
best_pitch.yy = yy_values[best_pitch.period_24kHz];
best_pitch.gain = pitch_gain(best_pitch.xy, best_pitch.yy, xx);
@ -351,11 +350,10 @@ PitchInfo CheckLowerPitchPeriodsAndComputePitchGain(
};
// |max_k| such that alternative_period(initial_pitch_period, max_k, 1) equals
// kMinPitch24kHz.
const int max_k =
(2 * initial_pitch_period) / (2 * static_cast<int>(kMinPitch24kHz) - 1);
const int max_k = (2 * initial_pitch_period) / (2 * kMinPitch24kHz - 1);
for (int k = 2; k <= max_k; ++k) {
int candidate_pitch_period = alternative_period(initial_pitch_period, k, 1);
RTC_DCHECK_GE(candidate_pitch_period, static_cast<int>(kMinPitch24kHz));
RTC_DCHECK_GE(candidate_pitch_period, kMinPitch24kHz);
// When looking at |candidate_pitch_period|, we also look at one of its
// sub-harmonics. |kSubHarmonicMultipliers| is used to know where to look.
// |k| == 2 is a special case since |candidate_pitch_secondary_period| might
@ -363,7 +361,7 @@ PitchInfo CheckLowerPitchPeriodsAndComputePitchGain(
int candidate_pitch_secondary_period = alternative_period(
initial_pitch_period, k, kSubHarmonicMultipliers[k - 2]);
RTC_DCHECK_GT(candidate_pitch_secondary_period, 0);
if (k == 2 && candidate_pitch_secondary_period > kMaxPitch24kHzInt) {
if (k == 2 && candidate_pitch_secondary_period > kMaxPitch24kHz) {
candidate_pitch_secondary_period = initial_pitch_period;
}
RTC_DCHECK_NE(candidate_pitch_period, candidate_pitch_secondary_period)
@ -373,10 +371,10 @@ PitchInfo CheckLowerPitchPeriodsAndComputePitchGain(
// |candidate_pitch_period| by also looking at its possible sub-harmonic
// |candidate_pitch_secondary_period|.
float xy_primary_period = ComputeAutoCorrelationCoeff(
pitch_buf, GetInvertedLag(candidate_pitch_period), kMaxPitch24kHzInt);
pitch_buf, GetInvertedLag(candidate_pitch_period), kMaxPitch24kHz);
float xy_secondary_period = ComputeAutoCorrelationCoeff(
pitch_buf, GetInvertedLag(candidate_pitch_secondary_period),
kMaxPitch24kHzInt);
kMaxPitch24kHz);
float xy = 0.5f * (xy_primary_period + xy_secondary_period);
float yy = 0.5f * (yy_values[candidate_pitch_period] +
yy_values[candidate_pitch_secondary_period]);
@ -399,7 +397,7 @@ PitchInfo CheckLowerPitchPeriodsAndComputePitchGain(
: best_pitch.xy / (best_pitch.yy + 1.f);
final_pitch_gain = std::min(best_pitch.gain, final_pitch_gain);
int final_pitch_period_48kHz = std::max(
static_cast<int>(kMinPitch48kHz),
kMinPitch48kHz,
PitchPseudoInterpolationLagPitchBuf(best_pitch.period_24kHz, pitch_buf));
return {final_pitch_period_48kHz, final_pitch_gain};

View File

@ -34,11 +34,11 @@ constexpr float kTestPitchGainsHigh = 0.75f;
class ComputePitchGainThresholdTest
: public ::testing::Test,
public ::testing::WithParamInterface<std::tuple<
/*candidate_pitch_period=*/size_t,
/*pitch_period_ratio=*/size_t,
/*initial_pitch_period=*/size_t,
/*candidate_pitch_period=*/int,
/*pitch_period_ratio=*/int,
/*initial_pitch_period=*/int,
/*initial_pitch_gain=*/float,
/*prev_pitch_period=*/size_t,
/*prev_pitch_period=*/int,
/*prev_pitch_gain=*/float,
/*threshold=*/float>> {};
@ -46,11 +46,11 @@ class ComputePitchGainThresholdTest
// data.
TEST_P(ComputePitchGainThresholdTest, WithinTolerance) {
const auto params = GetParam();
const size_t candidate_pitch_period = std::get<0>(params);
const size_t pitch_period_ratio = std::get<1>(params);
const size_t initial_pitch_period = std::get<2>(params);
const int candidate_pitch_period = std::get<0>(params);
const int pitch_period_ratio = std::get<1>(params);
const int initial_pitch_period = std::get<2>(params);
const float initial_pitch_gain = std::get<3>(params);
const size_t prev_pitch_period = std::get<4>(params);
const int prev_pitch_period = std::get<4>(params);
const float prev_pitch_gain = std::get<5>(params);
const float threshold = std::get<6>(params);
{

View File

@ -28,22 +28,21 @@ namespace test {
// pitch gain is within tolerance given test input data.
TEST(RnnVadTest, PitchSearchWithinTolerance) {
auto lp_residual_reader = CreateLpResidualAndPitchPeriodGainReader();
const size_t num_frames = std::min(lp_residual_reader.second,
static_cast<size_t>(300)); // Max 3 s.
const int num_frames = std::min(lp_residual_reader.second, 300); // Max 3 s.
std::vector<float> lp_residual(kBufSize24kHz);
float expected_pitch_period, expected_pitch_gain;
PitchEstimator pitch_estimator;
{
// TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
// FloatingPointExceptionObserver fpe_observer;
for (size_t i = 0; i < num_frames; ++i) {
for (int i = 0; i < num_frames; ++i) {
SCOPED_TRACE(i);
lp_residual_reader.first->ReadChunk(lp_residual);
lp_residual_reader.first->ReadValue(&expected_pitch_period);
lp_residual_reader.first->ReadValue(&expected_pitch_gain);
PitchInfo pitch_info =
pitch_estimator.Estimate({lp_residual.data(), kBufSize24kHz});
EXPECT_EQ(static_cast<int>(expected_pitch_period), pitch_info.period);
EXPECT_EQ(expected_pitch_period, pitch_info.period);
EXPECT_NEAR(expected_pitch_gain, pitch_info.gain, 1e-5f);
}
}

View File

@ -21,7 +21,7 @@ namespace webrtc {
namespace rnn_vad {
// Ring buffer for N arrays of type T each one with size S.
template <typename T, size_t S, size_t N>
template <typename T, int S, int N>
class RingBuffer {
static_assert(S > 0, "");
static_assert(N > 0, "");
@ -45,11 +45,10 @@ class RingBuffer {
// Return an array view onto the array with a given delay. A view on the last
// and least recently push array is returned when |delay| is 0 and N - 1
// respectively.
rtc::ArrayView<const T, S> GetArrayView(size_t delay) const {
const int delay_int = static_cast<int>(delay);
RTC_DCHECK_LE(0, delay_int);
RTC_DCHECK_LT(delay_int, N);
int offset = tail_ - 1 - delay_int;
rtc::ArrayView<const T, S> GetArrayView(int delay) const {
RTC_DCHECK_LE(0, delay);
RTC_DCHECK_LT(delay, N);
int offset = tail_ - 1 - delay;
if (offset < 0)
offset += N;
return {buffer_.data() + S * offset, S};

View File

@ -20,14 +20,14 @@ namespace {
// Compare the elements of two given array views.
template <typename T, std::ptrdiff_t S>
void ExpectEq(rtc::ArrayView<const T, S> a, rtc::ArrayView<const T, S> b) {
for (size_t i = 0; i < S; ++i) {
for (int i = 0; i < S; ++i) {
SCOPED_TRACE(i);
EXPECT_EQ(a[i], b[i]);
}
}
// Test push/read sequences.
template <typename T, size_t S, size_t N>
template <typename T, int S, int N>
void TestRingBuffer() {
SCOPED_TRACE(N);
SCOPED_TRACE(S);
@ -56,7 +56,7 @@ void TestRingBuffer() {
}
// Check buffer.
for (size_t delay = 2; delay < N; ++delay) {
for (int delay = 2; delay < N; ++delay) {
SCOPED_TRACE(delay);
T expected_value = N - static_cast<T>(delay);
pushed_array.fill(expected_value);
@ -68,18 +68,18 @@ void TestRingBuffer() {
// Check that for different delays, different views are returned.
TEST(RnnVadTest, RingBufferArrayViews) {
constexpr size_t s = 3;
constexpr size_t n = 4;
constexpr int s = 3;
constexpr int n = 4;
RingBuffer<int, s, n> ring_buf;
std::array<int, s> pushed_array;
pushed_array.fill(1);
for (size_t k = 0; k <= n; ++k) { // Push data n + 1 times.
for (int k = 0; k <= n; ++k) { // Push data n + 1 times.
SCOPED_TRACE(k);
// Check array views.
for (size_t i = 0; i < n; ++i) {
for (int i = 0; i < n; ++i) {
SCOPED_TRACE(i);
auto view_i = ring_buf.GetArrayView(i);
for (size_t j = i + 1; j < n; ++j) {
for (int j = i + 1; j < n; ++j) {
SCOPED_TRACE(j);
auto view_j = ring_buf.GetArrayView(j);
EXPECT_NE(view_i, view_j);

View File

@ -26,6 +26,7 @@
#include "rtc_base/checks.h"
#include "rtc_base/logging.h"
#include "rtc_base/numerics/safe_conversions.h"
#include "third_party/rnnoise/src/rnn_activations.h"
#include "third_party/rnnoise/src/rnn_vad_weights.h"
@ -77,15 +78,16 @@ std::vector<float> GetScaledParams(rtc::ArrayView<const int8_t> params) {
// Casts and scales |weights| and re-arranges the layout.
std::vector<float> GetPreprocessedFcWeights(
rtc::ArrayView<const int8_t> weights,
size_t output_size) {
int output_size) {
if (output_size == 1) {
return GetScaledParams(weights);
}
// Transpose, scale and cast.
const size_t input_size = rtc::CheckedDivExact(weights.size(), output_size);
const int input_size = rtc::CheckedDivExact(
rtc::dchecked_cast<int>(weights.size()), output_size);
std::vector<float> w(weights.size());
for (size_t o = 0; o < output_size; ++o) {
for (size_t i = 0; i < input_size; ++i) {
for (int o = 0; o < output_size; ++o) {
for (int i = 0; i < input_size; ++i) {
w[o * input_size + i] = rnnoise::kWeightsScale *
static_cast<float>(weights[i * output_size + o]);
}
@ -93,7 +95,7 @@ std::vector<float> GetPreprocessedFcWeights(
return w;
}
constexpr size_t kNumGruGates = 3; // Update, reset, output.
constexpr int kNumGruGates = 3; // Update, reset, output.
// TODO(bugs.chromium.org/10480): Hard-coded optimized layout and remove this
// function to improve setup time.
@ -101,17 +103,17 @@ constexpr size_t kNumGruGates = 3; // Update, reset, output.
// It works both for weights, recurrent weights and bias.
std::vector<float> GetPreprocessedGruTensor(
rtc::ArrayView<const int8_t> tensor_src,
size_t output_size) {
int output_size) {
// Transpose, cast and scale.
// |n| is the size of the first dimension of the 3-dim tensor |weights|.
const size_t n =
rtc::CheckedDivExact(tensor_src.size(), output_size * kNumGruGates);
const size_t stride_src = kNumGruGates * output_size;
const size_t stride_dst = n * output_size;
const int n = rtc::CheckedDivExact(rtc::dchecked_cast<int>(tensor_src.size()),
output_size * kNumGruGates);
const int stride_src = kNumGruGates * output_size;
const int stride_dst = n * output_size;
std::vector<float> tensor_dst(tensor_src.size());
for (size_t g = 0; g < kNumGruGates; ++g) {
for (size_t o = 0; o < output_size; ++o) {
for (size_t i = 0; i < n; ++i) {
for (int g = 0; g < kNumGruGates; ++g) {
for (int o = 0; o < output_size; ++o) {
for (int i = 0; i < n; ++i) {
tensor_dst[g * stride_dst + o * n + i] =
rnnoise::kWeightsScale *
static_cast<float>(
@ -122,28 +124,28 @@ std::vector<float> GetPreprocessedGruTensor(
return tensor_dst;
}
void ComputeGruUpdateResetGates(size_t input_size,
size_t output_size,
void ComputeGruUpdateResetGates(int input_size,
int output_size,
rtc::ArrayView<const float> weights,
rtc::ArrayView<const float> recurrent_weights,
rtc::ArrayView<const float> bias,
rtc::ArrayView<const float> input,
rtc::ArrayView<const float> state,
rtc::ArrayView<float> gate) {
for (size_t o = 0; o < output_size; ++o) {
for (int o = 0; o < output_size; ++o) {
gate[o] = bias[o];
for (size_t i = 0; i < input_size; ++i) {
for (int i = 0; i < input_size; ++i) {
gate[o] += input[i] * weights[o * input_size + i];
}
for (size_t s = 0; s < output_size; ++s) {
for (int s = 0; s < output_size; ++s) {
gate[o] += state[s] * recurrent_weights[o * output_size + s];
}
gate[o] = SigmoidApproximated(gate[o]);
}
}
void ComputeGruOutputGate(size_t input_size,
size_t output_size,
void ComputeGruOutputGate(int input_size,
int output_size,
rtc::ArrayView<const float> weights,
rtc::ArrayView<const float> recurrent_weights,
rtc::ArrayView<const float> bias,
@ -151,12 +153,12 @@ void ComputeGruOutputGate(size_t input_size,
rtc::ArrayView<const float> state,
rtc::ArrayView<const float> reset,
rtc::ArrayView<float> gate) {
for (size_t o = 0; o < output_size; ++o) {
for (int o = 0; o < output_size; ++o) {
gate[o] = bias[o];
for (size_t i = 0; i < input_size; ++i) {
for (int i = 0; i < input_size; ++i) {
gate[o] += input[i] * weights[o * input_size + i];
}
for (size_t s = 0; s < output_size; ++s) {
for (int s = 0; s < output_size; ++s) {
gate[o] += state[s] * recurrent_weights[o * output_size + s] * reset[s];
}
gate[o] = RectifiedLinearUnit(gate[o]);
@ -164,8 +166,8 @@ void ComputeGruOutputGate(size_t input_size,
}
// Gated recurrent unit (GRU) layer un-optimized implementation.
void ComputeGruLayerOutput(size_t input_size,
size_t output_size,
void ComputeGruLayerOutput(int input_size,
int output_size,
rtc::ArrayView<const float> input,
rtc::ArrayView<const float> weights,
rtc::ArrayView<const float> recurrent_weights,
@ -173,8 +175,8 @@ void ComputeGruLayerOutput(size_t input_size,
rtc::ArrayView<float> state) {
RTC_DCHECK_EQ(input_size, input.size());
// Stride and offset used to read parameter arrays.
const size_t stride_in = input_size * output_size;
const size_t stride_out = output_size * output_size;
const int stride_in = input_size * output_size;
const int stride_out = output_size * output_size;
// Update gate.
std::array<float, kRecurrentLayersMaxUnits> update;
@ -198,7 +200,7 @@ void ComputeGruLayerOutput(size_t input_size,
bias.subview(2 * output_size, output_size), input, state, reset, output);
// Update output through the update gates and update the state.
for (size_t o = 0; o < output_size; ++o) {
for (int o = 0; o < output_size; ++o) {
output[o] = update[o] * state[o] + (1.f - update[o]) * output[o];
state[o] = output[o];
}
@ -206,8 +208,8 @@ void ComputeGruLayerOutput(size_t input_size,
// Fully connected layer un-optimized implementation.
void ComputeFullyConnectedLayerOutput(
size_t input_size,
size_t output_size,
int input_size,
int output_size,
rtc::ArrayView<const float> input,
rtc::ArrayView<const float> bias,
rtc::ArrayView<const float> weights,
@ -216,11 +218,11 @@ void ComputeFullyConnectedLayerOutput(
RTC_DCHECK_EQ(input.size(), input_size);
RTC_DCHECK_EQ(bias.size(), output_size);
RTC_DCHECK_EQ(weights.size(), input_size * output_size);
for (size_t o = 0; o < output_size; ++o) {
for (int o = 0; o < output_size; ++o) {
output[o] = bias[o];
// TODO(bugs.chromium.org/9076): Benchmark how different layouts for
// |weights_| change the performance across different platforms.
for (size_t i = 0; i < input_size; ++i) {
for (int i = 0; i < input_size; ++i) {
output[o] += input[i] * weights[o * input_size + i];
}
output[o] = activation_function(output[o]);
@ -230,8 +232,8 @@ void ComputeFullyConnectedLayerOutput(
#if defined(WEBRTC_ARCH_X86_FAMILY)
// Fully connected layer SSE2 implementation.
void ComputeFullyConnectedLayerOutputSse2(
size_t input_size,
size_t output_size,
int input_size,
int output_size,
rtc::ArrayView<const float> input,
rtc::ArrayView<const float> bias,
rtc::ArrayView<const float> weights,
@ -240,16 +242,16 @@ void ComputeFullyConnectedLayerOutputSse2(
RTC_DCHECK_EQ(input.size(), input_size);
RTC_DCHECK_EQ(bias.size(), output_size);
RTC_DCHECK_EQ(weights.size(), input_size * output_size);
const size_t input_size_by_4 = input_size >> 2;
const size_t offset = input_size & ~3;
const int input_size_by_4 = input_size >> 2;
const int offset = input_size & ~3;
__m128 sum_wx_128;
const float* v = reinterpret_cast<const float*>(&sum_wx_128);
for (size_t o = 0; o < output_size; ++o) {
for (int o = 0; o < output_size; ++o) {
// Perform 128 bit vector operations.
sum_wx_128 = _mm_set1_ps(0);
const float* x_p = input.data();
const float* w_p = weights.data() + o * input_size;
for (size_t i = 0; i < input_size_by_4; ++i, x_p += 4, w_p += 4) {
for (int i = 0; i < input_size_by_4; ++i, x_p += 4, w_p += 4) {
sum_wx_128 = _mm_add_ps(sum_wx_128,
_mm_mul_ps(_mm_loadu_ps(x_p), _mm_loadu_ps(w_p)));
}
@ -266,8 +268,8 @@ void ComputeFullyConnectedLayerOutputSse2(
} // namespace
FullyConnectedLayer::FullyConnectedLayer(
const size_t input_size,
const size_t output_size,
const int input_size,
const int output_size,
const rtc::ArrayView<const int8_t> bias,
const rtc::ArrayView<const int8_t> weights,
rtc::FunctionView<float(float)> activation_function,
@ -316,8 +318,8 @@ void FullyConnectedLayer::ComputeOutput(rtc::ArrayView<const float> input) {
}
GatedRecurrentLayer::GatedRecurrentLayer(
const size_t input_size,
const size_t output_size,
const int input_size,
const int output_size,
const rtc::ArrayView<const int8_t> bias,
const rtc::ArrayView<const int8_t> weights,
const rtc::ArrayView<const int8_t> recurrent_weights,

View File

@ -29,19 +29,19 @@ namespace rnn_vad {
// over-allocate space for fully-connected layers output vectors (implemented as
// std::array). The value should equal the number of units of the largest
// fully-connected layer.
constexpr size_t kFullyConnectedLayersMaxUnits = 24;
constexpr int kFullyConnectedLayersMaxUnits = 24;
// Maximum number of units for a recurrent layer. This value is used to
// over-allocate space for recurrent layers state vectors (implemented as
// std::array). The value should equal the number of units of the largest
// recurrent layer.
constexpr size_t kRecurrentLayersMaxUnits = 24;
constexpr int kRecurrentLayersMaxUnits = 24;
// Fully-connected layer.
class FullyConnectedLayer {
public:
FullyConnectedLayer(size_t input_size,
size_t output_size,
FullyConnectedLayer(int input_size,
int output_size,
rtc::ArrayView<const int8_t> bias,
rtc::ArrayView<const int8_t> weights,
rtc::FunctionView<float(float)> activation_function,
@ -49,16 +49,16 @@ class FullyConnectedLayer {
FullyConnectedLayer(const FullyConnectedLayer&) = delete;
FullyConnectedLayer& operator=(const FullyConnectedLayer&) = delete;
~FullyConnectedLayer();
size_t input_size() const { return input_size_; }
size_t output_size() const { return output_size_; }
int input_size() const { return input_size_; }
int output_size() const { return output_size_; }
Optimization optimization() const { return optimization_; }
rtc::ArrayView<const float> GetOutput() const;
// Computes the fully-connected layer output.
void ComputeOutput(rtc::ArrayView<const float> input);
private:
const size_t input_size_;
const size_t output_size_;
const int input_size_;
const int output_size_;
const std::vector<float> bias_;
const std::vector<float> weights_;
rtc::FunctionView<float(float)> activation_function_;
@ -72,8 +72,8 @@ class FullyConnectedLayer {
// activation functions for the update/reset and output gates respectively.
class GatedRecurrentLayer {
public:
GatedRecurrentLayer(size_t input_size,
size_t output_size,
GatedRecurrentLayer(int input_size,
int output_size,
rtc::ArrayView<const int8_t> bias,
rtc::ArrayView<const int8_t> weights,
rtc::ArrayView<const int8_t> recurrent_weights,
@ -81,8 +81,8 @@ class GatedRecurrentLayer {
GatedRecurrentLayer(const GatedRecurrentLayer&) = delete;
GatedRecurrentLayer& operator=(const GatedRecurrentLayer&) = delete;
~GatedRecurrentLayer();
size_t input_size() const { return input_size_; }
size_t output_size() const { return output_size_; }
int input_size() const { return input_size_; }
int output_size() const { return output_size_; }
Optimization optimization() const { return optimization_; }
rtc::ArrayView<const float> GetOutput() const;
void Reset();
@ -90,8 +90,8 @@ class GatedRecurrentLayer {
void ComputeOutput(rtc::ArrayView<const float> input);
private:
const size_t input_size_;
const size_t output_size_;
const int input_size_;
const int output_size_;
const std::vector<float> bias_;
const std::vector<float> weights_;
const std::vector<float> recurrent_weights_;

View File

@ -18,6 +18,7 @@
#include "modules/audio_processing/test/performance_timer.h"
#include "rtc_base/checks.h"
#include "rtc_base/logging.h"
#include "rtc_base/numerics/safe_conversions.h"
#include "rtc_base/system/arch.h"
#include "test/gtest.h"
#include "third_party/rnnoise/src/rnn_activations.h"
@ -43,15 +44,16 @@ void TestGatedRecurrentLayer(
rtc::ArrayView<const float> expected_output_sequence) {
RTC_CHECK(gru);
auto gru_output_view = gru->GetOutput();
const size_t input_sequence_length =
rtc::CheckedDivExact(input_sequence.size(), gru->input_size());
const size_t output_sequence_length =
rtc::CheckedDivExact(expected_output_sequence.size(), gru->output_size());
const int input_sequence_length = rtc::CheckedDivExact(
rtc::dchecked_cast<int>(input_sequence.size()), gru->input_size());
const int output_sequence_length = rtc::CheckedDivExact(
rtc::dchecked_cast<int>(expected_output_sequence.size()),
gru->output_size());
ASSERT_EQ(input_sequence_length, output_sequence_length)
<< "The test data length is invalid.";
// Feed the GRU layer and check the output at every step.
gru->Reset();
for (size_t i = 0; i < input_sequence_length; ++i) {
for (int i = 0; i < input_sequence_length; ++i) {
SCOPED_TRACE(i);
gru->ComputeOutput(
input_sequence.subview(i * gru->input_size(), gru->input_size()));
@ -77,8 +79,8 @@ constexpr std::array<float, 24> kFullyConnectedExpectedOutput = {
0.875092f, 0.999846f, 0.997707f, -0.999382f, 0.973153f, -0.966605f};
// Gated recurrent units layer test data.
constexpr size_t kGruInputSize = 5;
constexpr size_t kGruOutputSize = 4;
constexpr int kGruInputSize = 5;
constexpr int kGruOutputSize = 4;
constexpr std::array<int8_t, 12> kGruBias = {96, -99, -81, -114, 49, 119,
-118, 68, -76, 91, 121, 125};
constexpr std::array<int8_t, 60> kGruWeights = {
@ -213,10 +215,10 @@ TEST(RnnVadTest, DISABLED_BenchmarkFullyConnectedLayer) {
}
std::vector<Result> results;
constexpr size_t number_of_tests = 10000;
constexpr int number_of_tests = 10000;
for (auto& fc : implementations) {
::webrtc::test::PerformanceTimer perf_timer(number_of_tests);
for (size_t k = 0; k < number_of_tests; ++k) {
for (int k = 0; k < number_of_tests; ++k) {
perf_timer.StartTimer();
fc->ComputeOutput(kFullyConnectedInputVector);
perf_timer.StopTimer();
@ -240,17 +242,17 @@ TEST(RnnVadTest, DISABLED_BenchmarkGatedRecurrentLayer) {
rtc::ArrayView<const float> input_sequence(kGruInputSequence);
static_assert(kGruInputSequence.size() % kGruInputSize == 0, "");
constexpr size_t input_sequence_length =
constexpr int input_sequence_length =
kGruInputSequence.size() / kGruInputSize;
std::vector<Result> results;
constexpr size_t number_of_tests = 10000;
constexpr int number_of_tests = 10000;
for (auto& gru : implementations) {
::webrtc::test::PerformanceTimer perf_timer(number_of_tests);
gru->Reset();
for (size_t k = 0; k < number_of_tests; ++k) {
for (int k = 0; k < number_of_tests; ++k) {
perf_timer.StartTimer();
for (size_t i = 0; i < input_sequence_length; ++i) {
for (int i = 0; i < input_sequence_length; ++i) {
gru->ComputeOutput(
input_sequence.subview(i * gru->input_size(), gru->input_size()));
}

View File

@ -20,6 +20,7 @@
#include "modules/audio_processing/agc2/rnn_vad/features_extraction.h"
#include "modules/audio_processing/agc2/rnn_vad/rnn.h"
#include "rtc_base/logging.h"
#include "rtc_base/numerics/safe_compare.h"
ABSL_FLAG(std::string, i, "", "Path to the input wav file");
ABSL_FLAG(std::string, f, "", "Path to the output features file");
@ -56,7 +57,7 @@ int main(int argc, char* argv[]) {
}
// Initialize.
const size_t frame_size_10ms =
const int frame_size_10ms =
rtc::CheckedDivExact(wav_reader.sample_rate(), 100);
std::vector<float> samples_10ms;
samples_10ms.resize(frame_size_10ms);
@ -69,9 +70,9 @@ int main(int argc, char* argv[]) {
// Compute VAD probabilities.
while (true) {
// Read frame at the input sample rate.
const auto read_samples =
const size_t read_samples =
wav_reader.ReadSamples(frame_size_10ms, samples_10ms.data());
if (read_samples < frame_size_10ms) {
if (rtc::SafeLt(read_samples, frame_size_10ms)) {
break; // EOF.
}
// Resample input.

View File

@ -28,10 +28,10 @@ namespace rnn_vad {
namespace test {
namespace {
constexpr size_t kFrameSize10ms48kHz = 480;
constexpr int kFrameSize10ms48kHz = 480;
void DumpPerfStats(size_t num_samples,
size_t sample_rate,
void DumpPerfStats(int num_samples,
int sample_rate,
double average_us,
double standard_deviation) {
float audio_track_length_ms =
@ -70,7 +70,7 @@ TEST(RnnVadTest, RnnVadProbabilityWithinTolerance) {
auto expected_vad_prob_reader = CreateVadProbsReader();
// Input length.
const size_t num_frames = samples_reader.second;
const int num_frames = samples_reader.second;
ASSERT_GE(expected_vad_prob_reader.second, num_frames);
// Init buffers.
@ -85,7 +85,7 @@ TEST(RnnVadTest, RnnVadProbabilityWithinTolerance) {
// Compute VAD probabilities on the downsampled input.
float cumulative_error = 0.f;
for (size_t i = 0; i < num_frames; ++i) {
for (int i = 0; i < num_frames; ++i) {
samples_reader.first->ReadChunk(samples_48k);
decimator.Resample(samples_48k.data(), samples_48k.size(),
samples_24k.data(), samples_24k.size());
@ -114,13 +114,13 @@ TEST(RnnVadTest, RnnVadProbabilityWithinTolerance) {
TEST(RnnVadTest, DISABLED_RnnVadPerformance) {
// PCM samples reader and buffers.
auto samples_reader = CreatePcmSamplesReader(kFrameSize10ms48kHz);
const size_t num_frames = samples_reader.second;
const int num_frames = samples_reader.second;
std::array<float, kFrameSize10ms48kHz> samples;
// Pre-fetch and decimate samples.
PushSincResampler decimator(kFrameSize10ms48kHz, kFrameSize10ms24kHz);
std::vector<float> prefetched_decimated_samples;
prefetched_decimated_samples.resize(num_frames * kFrameSize10ms24kHz);
for (size_t i = 0; i < num_frames; ++i) {
for (int i = 0; i < num_frames; ++i) {
samples_reader.first->ReadChunk(samples);
decimator.Resample(samples.data(), samples.size(),
&prefetched_decimated_samples[i * kFrameSize10ms24kHz],
@ -130,14 +130,14 @@ TEST(RnnVadTest, DISABLED_RnnVadPerformance) {
FeaturesExtractor features_extractor;
std::array<float, kFeatureVectorSize> feature_vector;
RnnBasedVad rnn_vad;
constexpr size_t number_of_tests = 100;
constexpr int number_of_tests = 100;
::webrtc::test::PerformanceTimer perf_timer(number_of_tests);
for (size_t k = 0; k < number_of_tests; ++k) {
for (int k = 0; k < number_of_tests; ++k) {
features_extractor.Reset();
rnn_vad.Reset();
// Process frames.
perf_timer.StartTimer();
for (size_t i = 0; i < num_frames; ++i) {
for (int i = 0; i < num_frames; ++i) {
bool is_silence = features_extractor.CheckSilenceComputeFeatures(
{&prefetched_decimated_samples[i * kFrameSize10ms24kHz],
kFrameSize10ms24kHz},

View File

@ -29,7 +29,7 @@ namespace rnn_vad {
// values are written at the end of the buffer.
// The class also provides a view on the most recent M values, where 0 < M <= S
// and by default M = N.
template <typename T, size_t S, size_t N, size_t M = N>
template <typename T, int S, int N, int M = N>
class SequenceBuffer {
static_assert(N <= S,
"The new chunk size cannot be larger than the sequence buffer "
@ -45,8 +45,8 @@ class SequenceBuffer {
SequenceBuffer(const SequenceBuffer&) = delete;
SequenceBuffer& operator=(const SequenceBuffer&) = delete;
~SequenceBuffer() = default;
size_t size() const { return S; }
size_t chunks_size() const { return N; }
int size() const { return S; }
int chunks_size() const { return N; }
// Sets the sequence buffer values to zero.
void Reset() { std::fill(buffer_.begin(), buffer_.end(), 0); }
// Returns a view on the whole buffer.

View File

@ -20,7 +20,7 @@ namespace rnn_vad {
namespace test {
namespace {
template <typename T, size_t S, size_t N>
template <typename T, int S, int N>
void TestSequenceBufferPushOp() {
SCOPED_TRACE(S);
SCOPED_TRACE(N);
@ -32,8 +32,8 @@ void TestSequenceBufferPushOp() {
chunk.fill(1);
seq_buf.Push(chunk);
chunk.fill(0);
constexpr size_t required_push_ops = (S % N) ? S / N + 1 : S / N;
for (size_t i = 0; i < required_push_ops - 1; ++i) {
constexpr int required_push_ops = (S % N) ? S / N + 1 : S / N;
for (int i = 0; i < required_push_ops - 1; ++i) {
SCOPED_TRACE(i);
seq_buf.Push(chunk);
// Still in the buffer.
@ -48,12 +48,12 @@ void TestSequenceBufferPushOp() {
// Check that the last item moves left by N positions after a push op.
if (S > N) {
// Fill in with non-zero values.
for (size_t i = 0; i < N; ++i)
for (int i = 0; i < N; ++i)
chunk[i] = static_cast<T>(i + 1);
seq_buf.Push(chunk);
// With the next Push(), |last| will be moved left by N positions.
const T last = chunk[N - 1];
for (size_t i = 0; i < N; ++i)
for (int i = 0; i < N; ++i)
chunk[i] = static_cast<T>(last + i + 1);
seq_buf.Push(chunk);
EXPECT_EQ(last, seq_buf_view[S - N - 1]);
@ -63,8 +63,8 @@ void TestSequenceBufferPushOp() {
} // namespace
TEST(RnnVadTest, SequenceBufferGetters) {
constexpr size_t buffer_size = 8;
constexpr size_t chunk_size = 8;
constexpr int buffer_size = 8;
constexpr int chunk_size = 8;
SequenceBuffer<int, buffer_size, chunk_size> seq_buf;
EXPECT_EQ(buffer_size, seq_buf.size());
EXPECT_EQ(chunk_size, seq_buf.chunks_size());

View File

@ -16,6 +16,7 @@
#include <numeric>
#include "rtc_base/checks.h"
#include "rtc_base/numerics/safe_compare.h"
namespace webrtc {
namespace rnn_vad {
@ -32,11 +33,11 @@ void UpdateCepstralDifferenceStats(
RTC_DCHECK(sym_matrix_buf);
// Compute the new cepstral distance stats.
std::array<float, kCepstralCoeffsHistorySize - 1> distances;
for (size_t i = 0; i < kCepstralCoeffsHistorySize - 1; ++i) {
const size_t delay = i + 1;
for (int i = 0; i < kCepstralCoeffsHistorySize - 1; ++i) {
const int delay = i + 1;
auto old_cepstral_coeffs = ring_buf.GetArrayView(delay);
distances[i] = 0.f;
for (size_t k = 0; k < kNumBands; ++k) {
for (int k = 0; k < kNumBands; ++k) {
const float c = new_cepstral_coeffs[k] - old_cepstral_coeffs[k];
distances[i] += c * c;
}
@ -48,9 +49,9 @@ void UpdateCepstralDifferenceStats(
// Computes the first half of the Vorbis window.
std::array<float, kFrameSize20ms24kHz / 2> ComputeScaledHalfVorbisWindow(
float scaling = 1.f) {
constexpr size_t kHalfSize = kFrameSize20ms24kHz / 2;
constexpr int kHalfSize = kFrameSize20ms24kHz / 2;
std::array<float, kHalfSize> half_window{};
for (size_t i = 0; i < kHalfSize; ++i) {
for (int i = 0; i < kHalfSize; ++i) {
half_window[i] =
scaling *
std::sin(0.5 * kPi * std::sin(0.5 * kPi * (i + 0.5) / kHalfSize) *
@ -71,8 +72,8 @@ void ComputeWindowedForwardFft(
RTC_DCHECK_EQ(frame.size(), 2 * half_window.size());
// Apply windowing.
auto in = fft_input_buffer->GetView();
for (size_t i = 0, j = kFrameSize20ms24kHz - 1; i < half_window.size();
++i, --j) {
for (int i = 0, j = kFrameSize20ms24kHz - 1;
rtc::SafeLt(i, half_window.size()); ++i, --j) {
in[i] = frame[i] * half_window[i];
in[j] = frame[j] * half_window[i];
}
@ -162,7 +163,7 @@ void SpectralFeaturesExtractor::ComputeAvgAndDerivatives(
RTC_DCHECK_EQ(average.size(), first_derivative.size());
RTC_DCHECK_EQ(first_derivative.size(), second_derivative.size());
RTC_DCHECK_LE(average.size(), curr.size());
for (size_t i = 0; i < average.size(); ++i) {
for (int i = 0; rtc::SafeLt(i, average.size()); ++i) {
// Average, kernel: [1, 1, 1].
average[i] = curr[i] + prev1[i] + prev2[i];
// First derivative, kernel: [1, 0, - 1].
@ -178,7 +179,7 @@ void SpectralFeaturesExtractor::ComputeNormalizedCepstralCorrelation(
reference_frame_fft_->GetConstView(), lagged_frame_fft_->GetConstView(),
bands_cross_corr_);
// Normalize.
for (size_t i = 0; i < bands_cross_corr_.size(); ++i) {
for (int i = 0; rtc::SafeLt(i, bands_cross_corr_.size()); ++i) {
bands_cross_corr_[i] =
bands_cross_corr_[i] /
std::sqrt(0.001f + reference_frame_bands_energy_[i] *
@ -194,9 +195,9 @@ void SpectralFeaturesExtractor::ComputeNormalizedCepstralCorrelation(
float SpectralFeaturesExtractor::ComputeVariability() const {
// Compute cepstral variability score.
float variability = 0.f;
for (size_t delay1 = 0; delay1 < kCepstralCoeffsHistorySize; ++delay1) {
for (int delay1 = 0; delay1 < kCepstralCoeffsHistorySize; ++delay1) {
float min_dist = std::numeric_limits<float>::max();
for (size_t delay2 = 0; delay2 < kCepstralCoeffsHistorySize; ++delay2) {
for (int delay2 = 0; delay2 < kCepstralCoeffsHistorySize; ++delay2) {
if (delay1 == delay2) // The distance would be 0.
continue;
min_dist =

View File

@ -15,6 +15,7 @@
#include <cstddef>
#include "rtc_base/checks.h"
#include "rtc_base/numerics/safe_compare.h"
namespace webrtc {
namespace rnn_vad {
@ -105,9 +106,9 @@ void SpectralCorrelator::ComputeCrossCorrelation(
RTC_DCHECK_EQ(x[1], 0.f) << "The Nyquist coefficient must be zeroed.";
RTC_DCHECK_EQ(y[1], 0.f) << "The Nyquist coefficient must be zeroed.";
constexpr auto kOpusScaleNumBins24kHz20ms = GetOpusScaleNumBins24kHz20ms();
size_t k = 0; // Next Fourier coefficient index.
int k = 0; // Next Fourier coefficient index.
cross_corr[0] = 0.f;
for (size_t i = 0; i < kOpusBands24kHz - 1; ++i) {
for (int i = 0; i < kOpusBands24kHz - 1; ++i) {
cross_corr[i + 1] = 0.f;
for (int j = 0; j < kOpusScaleNumBins24kHz20ms[i]; ++j) { // Band size.
const float v = x[2 * k] * y[2 * k] + x[2 * k + 1] * y[2 * k + 1];
@ -137,11 +138,11 @@ void ComputeSmoothedLogMagnitudeSpectrum(
return x;
};
// Smoothing over the bands for which the band energy is defined.
for (size_t i = 0; i < bands_energy.size(); ++i) {
for (int i = 0; rtc::SafeLt(i, bands_energy.size()); ++i) {
log_bands_energy[i] = smooth(std::log10(kOneByHundred + bands_energy[i]));
}
// Smoothing over the remaining bands (zero energy).
for (size_t i = bands_energy.size(); i < kNumBands; ++i) {
for (int i = bands_energy.size(); i < kNumBands; ++i) {
log_bands_energy[i] = smooth(kLogOneByHundred);
}
}
@ -149,8 +150,8 @@ void ComputeSmoothedLogMagnitudeSpectrum(
std::array<float, kNumBands * kNumBands> ComputeDctTable() {
std::array<float, kNumBands * kNumBands> dct_table;
const double k = std::sqrt(0.5);
for (size_t i = 0; i < kNumBands; ++i) {
for (size_t j = 0; j < kNumBands; ++j)
for (int i = 0; i < kNumBands; ++i) {
for (int j = 0; j < kNumBands; ++j)
dct_table[i * kNumBands + j] = std::cos((i + 0.5) * j * kPi / kNumBands);
dct_table[i * kNumBands] *= k;
}
@ -173,9 +174,9 @@ void ComputeDct(rtc::ArrayView<const float> in,
RTC_DCHECK_LE(in.size(), kNumBands);
RTC_DCHECK_LE(1, out.size());
RTC_DCHECK_LE(out.size(), in.size());
for (size_t i = 0; i < out.size(); ++i) {
for (int i = 0; rtc::SafeLt(i, out.size()); ++i) {
out[i] = 0.f;
for (size_t j = 0; j < in.size(); ++j) {
for (int j = 0; rtc::SafeLt(j, in.size()); ++j) {
out[i] += in[j] * dct_table[j * kNumBands + i];
}
// TODO(bugs.webrtc.org/10480): Scaling factor in the DCT table.

View File

@ -25,7 +25,7 @@ namespace rnn_vad {
// At a sample rate of 24 kHz, the last 3 Opus bands are beyond the Nyquist
// frequency. However, band #19 gets the contributions from band #18 because
// of the symmetric triangular filter with peak response at 12 kHz.
constexpr size_t kOpusBands24kHz = 20;
constexpr int kOpusBands24kHz = 20;
static_assert(kOpusBands24kHz < kNumBands,
"The number of bands at 24 kHz must be less than those defined "
"in the Opus scale at 48 kHz.");

View File

@ -19,6 +19,7 @@
#include "api/array_view.h"
#include "modules/audio_processing/agc2/rnn_vad/test_utils.h"
#include "modules/audio_processing/utility/pffft_wrapper.h"
#include "rtc_base/numerics/safe_compare.h"
// TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
// #include "test/fpe_observer.h"
#include "test/gtest.h"
@ -34,13 +35,13 @@ namespace {
std::vector<float> ComputeTriangularFiltersWeights() {
constexpr auto kOpusScaleNumBins24kHz20ms = GetOpusScaleNumBins24kHz20ms();
const auto& v = kOpusScaleNumBins24kHz20ms; // Alias.
const size_t num_weights = std::accumulate(
kOpusScaleNumBins24kHz20ms.begin(), kOpusScaleNumBins24kHz20ms.end(), 0);
const int num_weights = std::accumulate(kOpusScaleNumBins24kHz20ms.begin(),
kOpusScaleNumBins24kHz20ms.end(), 0);
std::vector<float> weights(num_weights);
size_t next_fft_coeff_index = 0;
for (size_t band = 0; band < v.size(); ++band) {
const size_t band_size = v[band];
for (size_t j = 0; j < band_size; ++j) {
int next_fft_coeff_index = 0;
for (int band = 0; rtc::SafeLt(band, v.size()); ++band) {
const int band_size = v[band];
for (int j = 0; rtc::SafeLt(j, band_size); ++j) {
weights[next_fft_coeff_index + j] = static_cast<float>(j) / band_size;
}
next_fft_coeff_index += band_size;
@ -58,7 +59,7 @@ TEST(RnnVadTest, TestOpusScaleBoundaries) {
3200, 4000, 4800, 5600, 6800, 8000, 9600, 12000, 15600, 20000};
constexpr auto kOpusScaleNumBins24kHz20ms = GetOpusScaleNumBins24kHz20ms();
int prev = 0;
for (size_t i = 0; i < kOpusScaleNumBins24kHz20ms.size(); ++i) {
for (int i = 0; rtc::SafeLt(i, kOpusScaleNumBins24kHz20ms.size()); ++i) {
int boundary =
kBandFrequencyBoundariesHz[i] * kFrameSize20ms24kHz / kSampleRate24kHz;
EXPECT_EQ(kOpusScaleNumBins24kHz20ms[i], boundary - prev);
@ -72,8 +73,8 @@ TEST(RnnVadTest, TestOpusScaleBoundaries) {
// is updated accordingly.
TEST(RnnVadTest, DISABLED_TestOpusScaleWeights) {
auto weights = ComputeTriangularFiltersWeights();
size_t i = 0;
for (size_t band_size : GetOpusScaleNumBins24kHz20ms()) {
int i = 0;
for (int band_size : GetOpusScaleNumBins24kHz20ms()) {
SCOPED_TRACE(band_size);
rtc::ArrayView<float> band_weights(weights.data() + i, band_size);
float prev = -1.f;
@ -98,7 +99,7 @@ TEST(RnnVadTest, SpectralCorrelatorValidOutput) {
// Compute and check output.
SpectralCorrelator e;
e.ComputeAutoCorrelation(in_view, out);
for (size_t i = 0; i < kOpusBands24kHz; ++i) {
for (int i = 0; i < kOpusBands24kHz; ++i) {
SCOPED_TRACE(i);
EXPECT_GT(out[i], 0.f);
}

View File

@ -14,6 +14,7 @@
#include "modules/audio_processing/agc2/rnn_vad/test_utils.h"
#include "rtc_base/checks.h"
#include "rtc_base/numerics/safe_compare.h"
// TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
// #include "test/fpe_observer.h"
#include "test/gtest.h"
@ -23,11 +24,11 @@ namespace rnn_vad {
namespace test {
namespace {
constexpr size_t kTestFeatureVectorSize = kNumBands + 3 * kNumLowerBands + 1;
constexpr int kTestFeatureVectorSize = kNumBands + 3 * kNumLowerBands + 1;
// Writes non-zero sample values.
void WriteTestData(rtc::ArrayView<float> samples) {
for (size_t i = 0; i < samples.size(); ++i) {
for (int i = 0; rtc::SafeLt(i, samples.size()); ++i) {
samples[i] = i % 100;
}
}
@ -124,7 +125,7 @@ TEST(RnnVadTest, CepstralFeaturesConstantAverageZeroDerivative) {
// Fill the spectral features with test data.
std::array<float, kTestFeatureVectorSize> feature_vector;
for (size_t i = 0; i < kCepstralCoeffsHistorySize; ++i) {
for (int i = 0; i < kCepstralCoeffsHistorySize; ++i) {
is_silence = sfe.CheckSilenceComputeFeatures(
samples_view, samples_view, GetHigherBandsSpectrum(&feature_vector),
GetAverage(&feature_vector), GetFirstDerivative(&feature_vector),

View File

@ -18,6 +18,7 @@
#include "api/array_view.h"
#include "rtc_base/checks.h"
#include "rtc_base/numerics/safe_compare.h"
namespace webrtc {
namespace rnn_vad {
@ -29,7 +30,7 @@ namespace rnn_vad {
// removed when one of the two corresponding items that have been compared is
// removed from the ring buffer. It is assumed that the comparison is symmetric
// and that comparing an item with itself is not needed.
template <typename T, size_t S>
template <typename T, int S>
class SymmetricMatrixBuffer {
static_assert(S > 2, "");
@ -55,9 +56,9 @@ class SymmetricMatrixBuffer {
// column left.
std::memmove(buf_.data(), buf_.data() + S, (buf_.size() - S) * sizeof(T));
// Copy new values in the last column in the right order.
for (size_t i = 0; i < values.size(); ++i) {
const size_t index = (S - 1 - i) * (S - 1) - 1;
RTC_DCHECK_LE(static_cast<size_t>(0), index);
for (int i = 0; rtc::SafeLt(i, values.size()); ++i) {
const int index = (S - 1 - i) * (S - 1) - 1;
RTC_DCHECK_GE(index, 0);
RTC_DCHECK_LT(index, buf_.size());
buf_[index] = values[i];
}
@ -65,9 +66,9 @@ class SymmetricMatrixBuffer {
// Reads the value that corresponds to comparison of two items in the ring
// buffer having delay |delay1| and |delay2|. The two arguments must not be
// equal and both must be in {0, ..., S - 1}.
T GetValue(size_t delay1, size_t delay2) const {
int row = S - 1 - static_cast<int>(delay1);
int col = S - 1 - static_cast<int>(delay2);
T GetValue(int delay1, int delay2) const {
int row = S - 1 - delay1;
int col = S - 1 - delay2;
RTC_DCHECK_NE(row, col) << "The diagonal cannot be accessed.";
if (row > col)
std::swap(row, col); // Swap to access the upper-right triangular part.

View File

@ -18,10 +18,10 @@ namespace rnn_vad {
namespace test {
namespace {
template <typename T, size_t S>
template <typename T, int S>
void CheckSymmetry(const SymmetricMatrixBuffer<T, S>* sym_matrix_buf) {
for (size_t row = 0; row < S - 1; ++row)
for (size_t col = row + 1; col < S; ++col)
for (int row = 0; row < S - 1; ++row)
for (int col = row + 1; col < S; ++col)
EXPECT_EQ(sym_matrix_buf->GetValue(row, col),
sym_matrix_buf->GetValue(col, row));
}
@ -30,12 +30,12 @@ using PairType = std::pair<int, int>;
// Checks that the symmetric matrix buffer contains any pair with a value equal
// to the given one.
template <size_t S>
template <int S>
bool CheckPairsWithValueExist(
const SymmetricMatrixBuffer<PairType, S>* sym_matrix_buf,
const int value) {
for (size_t row = 0; row < S - 1; ++row) {
for (size_t col = row + 1; col < S; ++col) {
for (int row = 0; row < S - 1; ++row) {
for (int col = row + 1; col < S; ++col) {
auto p = sym_matrix_buf->GetValue(row, col);
if (p.first == value || p.second == value)
return true;
@ -52,7 +52,7 @@ bool CheckPairsWithValueExist(
TEST(RnnVadTest, SymmetricMatrixBufferUseCase) {
// Instance a ring buffer which will be fed with a series of integer values.
constexpr int kRingBufSize = 10;
RingBuffer<int, 1, static_cast<size_t>(kRingBufSize)> ring_buf;
RingBuffer<int, 1, kRingBufSize> ring_buf;
// Instance a symmetric matrix buffer for the ring buffer above. It stores
// pairs of integers with which this test can easily check that the evolution
// of RingBuffer and SymmetricMatrixBuffer match.
@ -81,8 +81,8 @@ TEST(RnnVadTest, SymmetricMatrixBufferUseCase) {
CheckSymmetry(&sym_matrix_buf);
// Check that the pairs resulting from the content in the ring buffer are
// in the right position.
for (size_t delay1 = 0; delay1 < kRingBufSize - 1; ++delay1) {
for (size_t delay2 = delay1 + 1; delay2 < kRingBufSize; ++delay2) {
for (int delay1 = 0; delay1 < kRingBufSize - 1; ++delay1) {
for (int delay2 = delay1 + 1; delay2 < kRingBufSize; ++delay2) {
const auto t1 = ring_buf.GetArrayView(delay1)[0];
const auto t2 = ring_buf.GetArrayView(delay2)[0];
ASSERT_LE(t2, t1);
@ -93,7 +93,7 @@ TEST(RnnVadTest, SymmetricMatrixBufferUseCase) {
}
// Check that every older element in the ring buffer still has a
// corresponding pair in the symmetric matrix buffer.
for (size_t delay = 1; delay < kRingBufSize; ++delay) {
for (int delay = 1; delay < kRingBufSize; ++delay) {
const auto t_prev = ring_buf.GetArrayView(delay)[0];
EXPECT_TRUE(CheckPairsWithValueExist(&sym_matrix_buf, t_prev));
}

View File

@ -13,6 +13,7 @@
#include <memory>
#include "rtc_base/checks.h"
#include "rtc_base/numerics/safe_compare.h"
#include "rtc_base/system/arch.h"
#include "system_wrappers/include/cpu_features_wrapper.h"
#include "test/gtest.h"
@ -24,7 +25,7 @@ namespace test {
namespace {
using ReaderPairType =
std::pair<std::unique_ptr<BinaryFileReader<float>>, const size_t>;
std::pair<std::unique_ptr<BinaryFileReader<float>>, const int>;
} // namespace
@ -33,7 +34,7 @@ using webrtc::test::ResourcePath;
void ExpectEqualFloatArray(rtc::ArrayView<const float> expected,
rtc::ArrayView<const float> computed) {
ASSERT_EQ(expected.size(), computed.size());
for (size_t i = 0; i < expected.size(); ++i) {
for (int i = 0; rtc::SafeLt(i, expected.size()); ++i) {
SCOPED_TRACE(i);
EXPECT_FLOAT_EQ(expected[i], computed[i]);
}
@ -43,14 +44,14 @@ void ExpectNearAbsolute(rtc::ArrayView<const float> expected,
rtc::ArrayView<const float> computed,
float tolerance) {
ASSERT_EQ(expected.size(), computed.size());
for (size_t i = 0; i < expected.size(); ++i) {
for (int i = 0; rtc::SafeLt(i, expected.size()); ++i) {
SCOPED_TRACE(i);
EXPECT_NEAR(expected[i], computed[i], tolerance);
}
}
std::pair<std::unique_ptr<BinaryFileReader<int16_t, float>>, const size_t>
CreatePcmSamplesReader(const size_t frame_length) {
std::pair<std::unique_ptr<BinaryFileReader<int16_t, float>>, const int>
CreatePcmSamplesReader(const int frame_length) {
auto ptr = std::make_unique<BinaryFileReader<int16_t, float>>(
test::ResourcePath("audio_processing/agc2/rnn_vad/samples", "pcm"),
frame_length);
@ -59,14 +60,14 @@ CreatePcmSamplesReader(const size_t frame_length) {
}
ReaderPairType CreatePitchBuffer24kHzReader() {
constexpr size_t cols = 864;
constexpr int cols = 864;
auto ptr = std::make_unique<BinaryFileReader<float>>(
ResourcePath("audio_processing/agc2/rnn_vad/pitch_buf_24k", "dat"), cols);
return {std::move(ptr), rtc::CheckedDivExact(ptr->data_length(), cols)};
}
ReaderPairType CreateLpResidualAndPitchPeriodGainReader() {
constexpr size_t num_lp_residual_coeffs = 864;
constexpr int num_lp_residual_coeffs = 864;
auto ptr = std::make_unique<BinaryFileReader<float>>(
ResourcePath("audio_processing/agc2/rnn_vad/pitch_lp_res", "dat"),
num_lp_residual_coeffs);
@ -83,7 +84,7 @@ ReaderPairType CreateVadProbsReader() {
PitchTestData::PitchTestData() {
BinaryFileReader<float> test_data_reader(
ResourcePath("audio_processing/agc2/rnn_vad/pitch_search_int", "dat"),
static_cast<size_t>(1396));
1396);
test_data_reader.ReadChunk(test_data_);
}

View File

@ -24,6 +24,7 @@
#include "api/array_view.h"
#include "modules/audio_processing/agc2/rnn_vad/common.h"
#include "rtc_base/checks.h"
#include "rtc_base/numerics/safe_compare.h"
namespace webrtc {
namespace rnn_vad {
@ -47,7 +48,7 @@ void ExpectNearAbsolute(rtc::ArrayView<const float> expected,
template <typename T, typename D = T>
class BinaryFileReader {
public:
explicit BinaryFileReader(const std::string& file_path, size_t chunk_size = 0)
BinaryFileReader(const std::string& file_path, int chunk_size = 0)
: is_(file_path, std::ios::binary | std::ios::ate),
data_length_(is_.tellg() / sizeof(T)),
chunk_size_(chunk_size) {
@ -58,7 +59,7 @@ class BinaryFileReader {
BinaryFileReader(const BinaryFileReader&) = delete;
BinaryFileReader& operator=(const BinaryFileReader&) = delete;
~BinaryFileReader() = default;
size_t data_length() const { return data_length_; }
int data_length() const { return data_length_; }
bool ReadValue(D* dst) {
if (std::is_same<T, D>::value) {
is_.read(reinterpret_cast<char*>(dst), sizeof(T));
@ -72,7 +73,7 @@ class BinaryFileReader {
// If |chunk_size| was specified in the ctor, it will check that the size of
// |dst| equals |chunk_size|.
bool ReadChunk(rtc::ArrayView<D> dst) {
RTC_DCHECK((chunk_size_ == 0) || (chunk_size_ == dst.size()));
RTC_DCHECK((chunk_size_ == 0) || rtc::SafeEq(chunk_size_, dst.size()));
const std::streamsize bytes_to_read = dst.size() * sizeof(T);
if (std::is_same<T, D>::value) {
is_.read(reinterpret_cast<char*>(dst.data()), bytes_to_read);
@ -83,13 +84,13 @@ class BinaryFileReader {
}
return is_.gcount() == bytes_to_read;
}
void SeekForward(size_t items) { is_.seekg(items * sizeof(T), is_.cur); }
void SeekForward(int items) { is_.seekg(items * sizeof(T), is_.cur); }
void SeekBeginning() { is_.seekg(0, is_.beg); }
private:
std::ifstream is_;
const size_t data_length_;
const size_t chunk_size_;
const int data_length_;
const int chunk_size_;
std::vector<T> buf_;
};
@ -117,22 +118,22 @@ class BinaryFileWriter {
// pointer and the second the number of chunks that can be read from the file.
// Creates a reader for the PCM samples that casts from S16 to float and reads
// chunks with length |frame_length|.
std::pair<std::unique_ptr<BinaryFileReader<int16_t, float>>, const size_t>
CreatePcmSamplesReader(const size_t frame_length);
std::pair<std::unique_ptr<BinaryFileReader<int16_t, float>>, const int>
CreatePcmSamplesReader(const int frame_length);
// Creates a reader for the pitch buffer content at 24 kHz.
std::pair<std::unique_ptr<BinaryFileReader<float>>, const size_t>
std::pair<std::unique_ptr<BinaryFileReader<float>>, const int>
CreatePitchBuffer24kHzReader();
// Creates a reader for the the LP residual coefficients and the pitch period
// and gain values.
std::pair<std::unique_ptr<BinaryFileReader<float>>, const size_t>
std::pair<std::unique_ptr<BinaryFileReader<float>>, const int>
CreateLpResidualAndPitchPeriodGainReader();
// Creates a reader for the VAD probabilities.
std::pair<std::unique_ptr<BinaryFileReader<float>>, const size_t>
std::pair<std::unique_ptr<BinaryFileReader<float>>, const int>
CreateVadProbsReader();
constexpr size_t kNumPitchBufAutoCorrCoeffs = 147;
constexpr size_t kNumPitchBufSquareEnergies = 385;
constexpr size_t kPitchTestDataSize =
constexpr int kNumPitchBufAutoCorrCoeffs = 147;
constexpr int kNumPitchBufSquareEnergies = 385;
constexpr int kPitchTestDataSize =
kBufSize24kHz + kNumPitchBufSquareEnergies + kNumPitchBufAutoCorrCoeffs;
// Class to retrieve a test pitch buffer content and the expected output for the