Revert "RNN VAD: Replace Ooura with PFFFT for the pitch auto correlation."
This reverts commit 8fcd6537f242ffd74154a62dad410e573e2efc4b. Reason for revert: broke internal projects. Original change's description: > RNN VAD: Replace Ooura with PFFFT for the pitch auto correlation. > > Bug: webrtc:9577, webrtc:10480 > Change-Id: I6d58866d48b8eaaa4102551b88d4f55133d1915c > Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/130482 > Commit-Queue: Alessio Bazzica <alessiob@webrtc.org> > Reviewed-by: Gustaf Ullberg <gustaf@webrtc.org> > Cr-Commit-Position: refs/heads/master@{#27387} TBR=gustaf@webrtc.org,alessiob@webrtc.org Change-Id: Ia05057326ebc277f334b13db0bfec9d4442903c2 No-Presubmit: true No-Tree-Checks: true No-Try: true Bug: webrtc:9577, webrtc:10480 Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/130369 Reviewed-by: Qingsi Wang <qingsi@webrtc.org> Commit-Queue: Qingsi Wang <qingsi@webrtc.org> Cr-Commit-Position: refs/heads/master@{#27405}
This commit is contained in:
@ -11,8 +11,6 @@ import("../../../../webrtc.gni")
|
|||||||
rtc_source_set("rnn_vad") {
|
rtc_source_set("rnn_vad") {
|
||||||
visibility = [ "../*" ]
|
visibility = [ "../*" ]
|
||||||
sources = [
|
sources = [
|
||||||
"auto_correlation.cc",
|
|
||||||
"auto_correlation.h",
|
|
||||||
"common.h",
|
"common.h",
|
||||||
"features_extraction.cc",
|
"features_extraction.cc",
|
||||||
"features_extraction.h",
|
"features_extraction.h",
|
||||||
@ -39,9 +37,9 @@ rtc_source_set("rnn_vad") {
|
|||||||
"..:biquad_filter",
|
"..:biquad_filter",
|
||||||
"../../../../api:array_view",
|
"../../../../api:array_view",
|
||||||
"../../../../api:function_view",
|
"../../../../api:function_view",
|
||||||
|
"../../../../common_audio/",
|
||||||
"../../../../rtc_base:checks",
|
"../../../../rtc_base:checks",
|
||||||
"../../../../rtc_base:rtc_base_approved",
|
"../../../../rtc_base:rtc_base_approved",
|
||||||
"../../utility:pffft_wrapper",
|
|
||||||
"//third_party/rnnoise:kiss_fft",
|
"//third_party/rnnoise:kiss_fft",
|
||||||
"//third_party/rnnoise:rnn_vad",
|
"//third_party/rnnoise:rnn_vad",
|
||||||
]
|
]
|
||||||
@ -55,7 +53,6 @@ if (rtc_include_tests) {
|
|||||||
"test_utils.h",
|
"test_utils.h",
|
||||||
]
|
]
|
||||||
deps = [
|
deps = [
|
||||||
":rnn_vad",
|
|
||||||
"../../../../api:array_view",
|
"../../../../api:array_view",
|
||||||
"../../../../api:scoped_refptr",
|
"../../../../api:scoped_refptr",
|
||||||
"../../../../rtc_base:checks",
|
"../../../../rtc_base:checks",
|
||||||
@ -89,7 +86,6 @@ if (rtc_include_tests) {
|
|||||||
rtc_source_set("unittests") {
|
rtc_source_set("unittests") {
|
||||||
testonly = true
|
testonly = true
|
||||||
sources = [
|
sources = [
|
||||||
"auto_correlation_unittest.cc",
|
|
||||||
"features_extraction_unittest.cc",
|
"features_extraction_unittest.cc",
|
||||||
"fft_util_unittest.cc",
|
"fft_util_unittest.cc",
|
||||||
"lp_residual_unittest.cc",
|
"lp_residual_unittest.cc",
|
||||||
|
@ -1,92 +0,0 @@
|
|||||||
/*
|
|
||||||
* Copyright (c) 2019 The WebRTC project authors. All Rights Reserved.
|
|
||||||
*
|
|
||||||
* Use of this source code is governed by a BSD-style license
|
|
||||||
* that can be found in the LICENSE file in the root of the source
|
|
||||||
* tree. An additional intellectual property rights grant can be found
|
|
||||||
* in the file PATENTS. All contributing project authors may
|
|
||||||
* be found in the AUTHORS file in the root of the source tree.
|
|
||||||
*/
|
|
||||||
|
|
||||||
#include "modules/audio_processing/agc2/rnn_vad/auto_correlation.h"
|
|
||||||
|
|
||||||
#include <algorithm>
|
|
||||||
|
|
||||||
#include "rtc_base/checks.h"
|
|
||||||
|
|
||||||
namespace webrtc {
|
|
||||||
namespace rnn_vad {
|
|
||||||
namespace {
|
|
||||||
|
|
||||||
constexpr int kAutoCorrelationFftOrder = 9; // Length-512 FFT.
|
|
||||||
static_assert(1 << kAutoCorrelationFftOrder >
|
|
||||||
kNumInvertedLags12kHz + kBufSize12kHz - kMaxPitch12kHz,
|
|
||||||
"");
|
|
||||||
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
AutoCorrelationCalculator::AutoCorrelationCalculator()
|
|
||||||
: fft_(1 << kAutoCorrelationFftOrder, Pffft::FftType::kReal),
|
|
||||||
tmp_(fft_.CreateBuffer()),
|
|
||||||
X_(fft_.CreateBuffer()),
|
|
||||||
H_(fft_.CreateBuffer()) {}
|
|
||||||
|
|
||||||
AutoCorrelationCalculator::~AutoCorrelationCalculator() = default;
|
|
||||||
|
|
||||||
// The auto-correlations coefficients are computed as follows:
|
|
||||||
// |.........|...........| <- pitch buffer
|
|
||||||
// [ x (fixed) ]
|
|
||||||
// [ y_0 ]
|
|
||||||
// [ y_{m-1} ]
|
|
||||||
// x and y are sub-array of equal length; x is never moved, whereas y slides.
|
|
||||||
// The cross-correlation between y_0 and x corresponds to the auto-correlation
|
|
||||||
// for the maximum pitch period. Hence, the first value in |auto_corr| has an
|
|
||||||
// inverted lag equal to 0 that corresponds to a lag equal to the maximum
|
|
||||||
// pitch period.
|
|
||||||
void AutoCorrelationCalculator::ComputeOnPitchBuffer(
|
|
||||||
rtc::ArrayView<const float, kBufSize12kHz> pitch_buf,
|
|
||||||
rtc::ArrayView<float, kNumInvertedLags12kHz> auto_corr) {
|
|
||||||
RTC_DCHECK_LT(auto_corr.size(), kMaxPitch12kHz);
|
|
||||||
RTC_DCHECK_GT(pitch_buf.size(), kMaxPitch12kHz);
|
|
||||||
constexpr size_t kFftFrameSize = 1 << kAutoCorrelationFftOrder;
|
|
||||||
constexpr size_t kConvolutionLength = kBufSize12kHz - kMaxPitch12kHz;
|
|
||||||
static_assert(kConvolutionLength == kFrameSize20ms12kHz,
|
|
||||||
"Mismatch between pitch buffer size, frame size and maximum "
|
|
||||||
"pitch period.");
|
|
||||||
static_assert(kFftFrameSize > kNumInvertedLags12kHz + kConvolutionLength,
|
|
||||||
"The FFT length is not sufficiently big to avoid cyclic "
|
|
||||||
"convolution errors.");
|
|
||||||
auto tmp = tmp_->GetView();
|
|
||||||
|
|
||||||
// Compute the FFT for the reversed reference frame - i.e.,
|
|
||||||
// pitch_buf[-kConvolutionLength:].
|
|
||||||
std::reverse_copy(pitch_buf.end() - kConvolutionLength, pitch_buf.end(),
|
|
||||||
tmp.begin());
|
|
||||||
std::fill(tmp.begin() + kConvolutionLength, tmp.end(), 0.f);
|
|
||||||
fft_.ForwardTransform(*tmp_, H_.get(), /*ordered=*/false);
|
|
||||||
|
|
||||||
// Compute the FFT for the sliding frames chunk. The sliding frames are
|
|
||||||
// defined as pitch_buf[i:i+kConvolutionLength] where i in
|
|
||||||
// [0, kNumInvertedLags12kHz). The chunk includes all of them, hence it is
|
|
||||||
// defined as pitch_buf[:kNumInvertedLags12kHz+kConvolutionLength].
|
|
||||||
std::copy(pitch_buf.begin(),
|
|
||||||
pitch_buf.begin() + kConvolutionLength + kNumInvertedLags12kHz,
|
|
||||||
tmp.begin());
|
|
||||||
std::fill(tmp.begin() + kNumInvertedLags12kHz + kConvolutionLength, tmp.end(),
|
|
||||||
0.f);
|
|
||||||
fft_.ForwardTransform(*tmp_, X_.get(), /*ordered=*/false);
|
|
||||||
|
|
||||||
// Convolve in the frequency domain.
|
|
||||||
constexpr float kScalingFactor = 1.f / static_cast<float>(kFftFrameSize);
|
|
||||||
std::fill(tmp.begin(), tmp.end(), 0.f);
|
|
||||||
fft_.FrequencyDomainConvolve(*X_, *H_, tmp_.get(), kScalingFactor);
|
|
||||||
fft_.BackwardTransform(*tmp_, tmp_.get(), /*ordered=*/false);
|
|
||||||
|
|
||||||
// Extract the auto-correlation coefficients.
|
|
||||||
std::copy(tmp.begin() + kConvolutionLength - 1,
|
|
||||||
tmp.begin() + kConvolutionLength + kNumInvertedLags12kHz - 1,
|
|
||||||
auto_corr.begin());
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace rnn_vad
|
|
||||||
} // namespace webrtc
|
|
@ -1,49 +0,0 @@
|
|||||||
/*
|
|
||||||
* Copyright (c) 2019 The WebRTC project authors. All Rights Reserved.
|
|
||||||
*
|
|
||||||
* Use of this source code is governed by a BSD-style license
|
|
||||||
* that can be found in the LICENSE file in the root of the source
|
|
||||||
* tree. An additional intellectual property rights grant can be found
|
|
||||||
* in the file PATENTS. All contributing project authors may
|
|
||||||
* be found in the AUTHORS file in the root of the source tree.
|
|
||||||
*/
|
|
||||||
|
|
||||||
#ifndef MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_AUTO_CORRELATION_H_
|
|
||||||
#define MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_AUTO_CORRELATION_H_
|
|
||||||
|
|
||||||
#include <memory>
|
|
||||||
|
|
||||||
#include "api/array_view.h"
|
|
||||||
#include "modules/audio_processing/agc2/rnn_vad/common.h"
|
|
||||||
#include "modules/audio_processing/utility/pffft_wrapper.h"
|
|
||||||
|
|
||||||
namespace webrtc {
|
|
||||||
namespace rnn_vad {
|
|
||||||
|
|
||||||
// Class to compute the auto correlation on the pitch buffer for a target pitch
|
|
||||||
// interval.
|
|
||||||
class AutoCorrelationCalculator {
|
|
||||||
public:
|
|
||||||
AutoCorrelationCalculator();
|
|
||||||
AutoCorrelationCalculator(const AutoCorrelationCalculator&) = delete;
|
|
||||||
AutoCorrelationCalculator& operator=(const AutoCorrelationCalculator&) =
|
|
||||||
delete;
|
|
||||||
~AutoCorrelationCalculator();
|
|
||||||
|
|
||||||
// Computes the auto-correlation coefficients for a target pitch interval.
|
|
||||||
// |auto_corr| indexes are inverted lags.
|
|
||||||
void ComputeOnPitchBuffer(
|
|
||||||
rtc::ArrayView<const float, kBufSize12kHz> pitch_buf,
|
|
||||||
rtc::ArrayView<float, kNumInvertedLags12kHz> auto_corr);
|
|
||||||
|
|
||||||
private:
|
|
||||||
Pffft fft_;
|
|
||||||
std::unique_ptr<Pffft::FloatBuffer> tmp_;
|
|
||||||
std::unique_ptr<Pffft::FloatBuffer> X_;
|
|
||||||
std::unique_ptr<Pffft::FloatBuffer> H_;
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace rnn_vad
|
|
||||||
} // namespace webrtc
|
|
||||||
|
|
||||||
#endif // MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_AUTO_CORRELATION_H_
|
|
@ -1,62 +0,0 @@
|
|||||||
/*
|
|
||||||
* Copyright (c) 2019 The WebRTC project authors. All Rights Reserved.
|
|
||||||
*
|
|
||||||
* Use of this source code is governed by a BSD-style license
|
|
||||||
* that can be found in the LICENSE file in the root of the source
|
|
||||||
* tree. An additional intellectual property rights grant can be found
|
|
||||||
* in the file PATENTS. All contributing project authors may
|
|
||||||
* be found in the AUTHORS file in the root of the source tree.
|
|
||||||
*/
|
|
||||||
|
|
||||||
#include "modules/audio_processing/agc2/rnn_vad/auto_correlation.h"
|
|
||||||
|
|
||||||
#include "modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h"
|
|
||||||
#include "modules/audio_processing/agc2/rnn_vad/test_utils.h"
|
|
||||||
#include "test/gtest.h"
|
|
||||||
|
|
||||||
namespace webrtc {
|
|
||||||
namespace rnn_vad {
|
|
||||||
namespace test {
|
|
||||||
|
|
||||||
TEST(RnnVadTest, PitchBufferAutoCorrelationWithinTolerance) {
|
|
||||||
PitchTestData test_data;
|
|
||||||
std::array<float, kBufSize12kHz> pitch_buf_decimated;
|
|
||||||
Decimate2x(test_data.GetPitchBufView(), pitch_buf_decimated);
|
|
||||||
std::array<float, kNumPitchBufAutoCorrCoeffs> computed_output;
|
|
||||||
{
|
|
||||||
// TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
|
|
||||||
// FloatingPointExceptionObserver fpe_observer;
|
|
||||||
AutoCorrelationCalculator auto_corr_calculator;
|
|
||||||
auto_corr_calculator.ComputeOnPitchBuffer(pitch_buf_decimated,
|
|
||||||
computed_output);
|
|
||||||
}
|
|
||||||
auto auto_corr_view = test_data.GetPitchBufAutoCorrCoeffsView();
|
|
||||||
ExpectNearAbsolute({auto_corr_view.data(), auto_corr_view.size()},
|
|
||||||
computed_output, 3e-3f);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check that the auto correlation function computes the right thing for a
|
|
||||||
// simple use case.
|
|
||||||
TEST(RnnVadTest, CheckAutoCorrelationOnConstantPitchBuffer) {
|
|
||||||
// Create constant signal with no pitch.
|
|
||||||
std::array<float, kBufSize12kHz> pitch_buf_decimated;
|
|
||||||
std::fill(pitch_buf_decimated.begin(), pitch_buf_decimated.end(), 1.f);
|
|
||||||
std::array<float, kNumPitchBufAutoCorrCoeffs> computed_output;
|
|
||||||
{
|
|
||||||
// TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
|
|
||||||
// FloatingPointExceptionObserver fpe_observer;
|
|
||||||
AutoCorrelationCalculator auto_corr_calculator;
|
|
||||||
auto_corr_calculator.ComputeOnPitchBuffer(pitch_buf_decimated,
|
|
||||||
computed_output);
|
|
||||||
}
|
|
||||||
// The expected output is constantly the length of the fixed 'x'
|
|
||||||
// array in ComputePitchAutoCorrelation.
|
|
||||||
std::array<float, kNumPitchBufAutoCorrCoeffs> expected_output;
|
|
||||||
std::fill(expected_output.begin(), expected_output.end(),
|
|
||||||
kBufSize12kHz - kMaxPitch12kHz);
|
|
||||||
ExpectNearAbsolute(expected_output, computed_output, 4e-5f);
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace test
|
|
||||||
} // namespace rnn_vad
|
|
||||||
} // namespace webrtc
|
|
@ -20,21 +20,18 @@ constexpr size_t kSampleRate24kHz = 24000;
|
|||||||
constexpr size_t kFrameSize10ms24kHz = kSampleRate24kHz / 100;
|
constexpr size_t kFrameSize10ms24kHz = kSampleRate24kHz / 100;
|
||||||
constexpr size_t kFrameSize20ms24kHz = kFrameSize10ms24kHz * 2;
|
constexpr size_t kFrameSize20ms24kHz = kFrameSize10ms24kHz * 2;
|
||||||
|
|
||||||
// Pitch buffer.
|
// Pitch analysis params.
|
||||||
constexpr size_t kMinPitch24kHz = kSampleRate24kHz / 800; // 0.00125 s.
|
constexpr size_t kMinPitch24kHz = kSampleRate24kHz / 800; // 0.00125 s.
|
||||||
constexpr size_t kMaxPitch24kHz = kSampleRate24kHz / 62.5; // 0.016 s.
|
constexpr size_t kMaxPitch24kHz = kSampleRate24kHz / 62.5; // 0.016 s.
|
||||||
constexpr size_t kBufSize24kHz = kMaxPitch24kHz + kFrameSize20ms24kHz;
|
constexpr size_t kBufSize24kHz = kMaxPitch24kHz + kFrameSize20ms24kHz;
|
||||||
static_assert((kBufSize24kHz & 1) == 0, "The buffer size must be even.");
|
static_assert((kBufSize24kHz & 1) == 0, "The buffer size must be even.");
|
||||||
|
|
||||||
// 24 kHz analysis.
|
|
||||||
// Define a higher minimum pitch period for the initial search. This is used to
|
// Define a higher minimum pitch period for the initial search. This is used to
|
||||||
// avoid searching for very short periods, for which a refinement step is
|
// avoid searching for very short periods, for which a refinement step is
|
||||||
// responsible.
|
// responsible.
|
||||||
constexpr size_t kInitialMinPitch24kHz = 3 * kMinPitch24kHz;
|
constexpr size_t kInitialMinPitch24kHz = 3 * kMinPitch24kHz;
|
||||||
static_assert(kMinPitch24kHz < kInitialMinPitch24kHz, "");
|
static_assert(kMinPitch24kHz < kInitialMinPitch24kHz, "");
|
||||||
static_assert(kInitialMinPitch24kHz < kMaxPitch24kHz, "");
|
static_assert(kInitialMinPitch24kHz < kMaxPitch24kHz, "");
|
||||||
static_assert(kMaxPitch24kHz > kInitialMinPitch24kHz, "");
|
|
||||||
constexpr size_t kNumInvertedLags24kHz = kMaxPitch24kHz - kInitialMinPitch24kHz;
|
|
||||||
|
|
||||||
// 12 kHz analysis.
|
// 12 kHz analysis.
|
||||||
constexpr size_t kSampleRate12kHz = 12000;
|
constexpr size_t kSampleRate12kHz = 12000;
|
||||||
@ -43,10 +40,6 @@ constexpr size_t kFrameSize20ms12kHz = kFrameSize10ms12kHz * 2;
|
|||||||
constexpr size_t kBufSize12kHz = kBufSize24kHz / 2;
|
constexpr size_t kBufSize12kHz = kBufSize24kHz / 2;
|
||||||
constexpr size_t kInitialMinPitch12kHz = kInitialMinPitch24kHz / 2;
|
constexpr size_t kInitialMinPitch12kHz = kInitialMinPitch24kHz / 2;
|
||||||
constexpr size_t kMaxPitch12kHz = kMaxPitch24kHz / 2;
|
constexpr size_t kMaxPitch12kHz = kMaxPitch24kHz / 2;
|
||||||
static_assert(kMaxPitch12kHz > kInitialMinPitch12kHz, "");
|
|
||||||
// The inverted lags for the pitch interval [|kInitialMinPitch12kHz|,
|
|
||||||
// |kMaxPitch12kHz|] are in the range [0, |kNumInvertedLags12kHz|].
|
|
||||||
constexpr size_t kNumInvertedLags12kHz = kMaxPitch12kHz - kInitialMinPitch12kHz;
|
|
||||||
|
|
||||||
// 48 kHz constants.
|
// 48 kHz constants.
|
||||||
constexpr size_t kMinPitch48kHz = kMinPitch24kHz * 2;
|
constexpr size_t kMinPitch48kHz = kMinPitch24kHz * 2;
|
||||||
|
@ -19,7 +19,8 @@ namespace webrtc {
|
|||||||
namespace rnn_vad {
|
namespace rnn_vad {
|
||||||
|
|
||||||
PitchEstimator::PitchEstimator()
|
PitchEstimator::PitchEstimator()
|
||||||
: pitch_buf_decimated_(kBufSize12kHz),
|
: fft_(RealFourier::Create(kAutoCorrelationFftOrder)),
|
||||||
|
pitch_buf_decimated_(kBufSize12kHz),
|
||||||
pitch_buf_decimated_view_(pitch_buf_decimated_.data(), kBufSize12kHz),
|
pitch_buf_decimated_view_(pitch_buf_decimated_.data(), kBufSize12kHz),
|
||||||
auto_corr_(kNumInvertedLags12kHz),
|
auto_corr_(kNumInvertedLags12kHz),
|
||||||
auto_corr_view_(auto_corr_.data(), kNumInvertedLags12kHz) {
|
auto_corr_view_(auto_corr_.data(), kNumInvertedLags12kHz) {
|
||||||
@ -33,16 +34,20 @@ PitchInfo PitchEstimator::Estimate(
|
|||||||
rtc::ArrayView<const float, kBufSize24kHz> pitch_buf) {
|
rtc::ArrayView<const float, kBufSize24kHz> pitch_buf) {
|
||||||
// Perform the initial pitch search at 12 kHz.
|
// Perform the initial pitch search at 12 kHz.
|
||||||
Decimate2x(pitch_buf, pitch_buf_decimated_view_);
|
Decimate2x(pitch_buf, pitch_buf_decimated_view_);
|
||||||
auto_corr_calculator_.ComputeOnPitchBuffer(pitch_buf_decimated_view_,
|
// Compute auto-correlation terms.
|
||||||
auto_corr_view_);
|
ComputePitchAutoCorrelation(pitch_buf_decimated_view_, kMaxPitch12kHz,
|
||||||
|
auto_corr_view_, fft_.get());
|
||||||
|
|
||||||
|
// Search for pitch at 12 kHz.
|
||||||
std::array<size_t, 2> pitch_candidates_inv_lags = FindBestPitchPeriods(
|
std::array<size_t, 2> pitch_candidates_inv_lags = FindBestPitchPeriods(
|
||||||
auto_corr_view_, pitch_buf_decimated_view_, kMaxPitch12kHz);
|
auto_corr_view_, pitch_buf_decimated_view_, kMaxPitch12kHz);
|
||||||
|
|
||||||
// Refine the pitch period estimation.
|
// Refine the pitch period estimation.
|
||||||
// The refinement is done using the pitch buffer that contains 24 kHz samples.
|
// The refinement is done using the pitch buffer that contains 24 kHz samples.
|
||||||
// Therefore, adapt the inverted lags in |pitch_candidates_inv_lags| from 12
|
// Therefore, adapt the inverted lags in |pitch_candidates_inv_lags| from 12
|
||||||
// to 24 kHz.
|
// to 24 kHz.
|
||||||
pitch_candidates_inv_lags[0] *= 2;
|
for (size_t i = 0; i < pitch_candidates_inv_lags.size(); ++i)
|
||||||
pitch_candidates_inv_lags[1] *= 2;
|
pitch_candidates_inv_lags[i] *= 2;
|
||||||
size_t pitch_inv_lag_48kHz =
|
size_t pitch_inv_lag_48kHz =
|
||||||
RefinePitchPeriod48kHz(pitch_buf, pitch_candidates_inv_lags);
|
RefinePitchPeriod48kHz(pitch_buf, pitch_candidates_inv_lags);
|
||||||
// Look for stronger harmonics to find the final pitch period and its gain.
|
// Look for stronger harmonics to find the final pitch period and its gain.
|
||||||
|
@ -15,7 +15,7 @@
|
|||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "api/array_view.h"
|
#include "api/array_view.h"
|
||||||
#include "modules/audio_processing/agc2/rnn_vad/auto_correlation.h"
|
#include "common_audio/real_fourier.h"
|
||||||
#include "modules/audio_processing/agc2/rnn_vad/common.h"
|
#include "modules/audio_processing/agc2/rnn_vad/common.h"
|
||||||
#include "modules/audio_processing/agc2/rnn_vad/pitch_info.h"
|
#include "modules/audio_processing/agc2/rnn_vad/pitch_info.h"
|
||||||
#include "modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h"
|
#include "modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h"
|
||||||
@ -36,7 +36,7 @@ class PitchEstimator {
|
|||||||
|
|
||||||
private:
|
private:
|
||||||
PitchInfo last_pitch_48kHz_;
|
PitchInfo last_pitch_48kHz_;
|
||||||
AutoCorrelationCalculator auto_corr_calculator_;
|
std::unique_ptr<RealFourier> fft_;
|
||||||
std::vector<float> pitch_buf_decimated_;
|
std::vector<float> pitch_buf_decimated_;
|
||||||
rtc::ArrayView<float, kBufSize12kHz> pitch_buf_decimated_view_;
|
rtc::ArrayView<float, kBufSize12kHz> pitch_buf_decimated_view_;
|
||||||
std::vector<float> auto_corr_;
|
std::vector<float> auto_corr_;
|
||||||
|
@ -13,6 +13,7 @@
|
|||||||
#include <stdlib.h>
|
#include <stdlib.h>
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
|
#include <complex>
|
||||||
#include <cstddef>
|
#include <cstddef>
|
||||||
#include <numeric>
|
#include <numeric>
|
||||||
|
|
||||||
@ -212,6 +213,64 @@ void ComputeSlidingFrameSquareEnergies(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ComputePitchAutoCorrelation(
|
||||||
|
rtc::ArrayView<const float, kBufSize12kHz> pitch_buf,
|
||||||
|
size_t max_pitch_period,
|
||||||
|
rtc::ArrayView<float, kNumInvertedLags12kHz> auto_corr,
|
||||||
|
webrtc::RealFourier* fft) {
|
||||||
|
RTC_DCHECK_GT(max_pitch_period, auto_corr.size());
|
||||||
|
RTC_DCHECK_LT(max_pitch_period, pitch_buf.size());
|
||||||
|
RTC_DCHECK(fft);
|
||||||
|
|
||||||
|
constexpr size_t time_domain_fft_length = 1 << kAutoCorrelationFftOrder;
|
||||||
|
constexpr size_t freq_domain_fft_length = time_domain_fft_length / 2 + 1;
|
||||||
|
|
||||||
|
RTC_DCHECK_EQ(RealFourier::FftLength(fft->order()), time_domain_fft_length);
|
||||||
|
RTC_DCHECK_EQ(RealFourier::ComplexLength(fft->order()),
|
||||||
|
freq_domain_fft_length);
|
||||||
|
|
||||||
|
// Cross-correlation of y_i=pitch_buf[i:i+convolution_length] and
|
||||||
|
// x=pitch_buf[-convolution_length:] is equivalent to convolution of
|
||||||
|
// y_i and reversed(x). New notation: h=reversed(x), x=y.
|
||||||
|
std::array<float, time_domain_fft_length> h{};
|
||||||
|
std::array<float, time_domain_fft_length> x{};
|
||||||
|
|
||||||
|
const size_t convolution_length = kBufSize12kHz - max_pitch_period;
|
||||||
|
// Check that the FFT-length is big enough to avoid cyclic
|
||||||
|
// convolution errors.
|
||||||
|
RTC_DCHECK_GT(time_domain_fft_length,
|
||||||
|
kNumInvertedLags12kHz + convolution_length);
|
||||||
|
|
||||||
|
// h[0:convolution_length] is reversed pitch_buf[-convolution_length:].
|
||||||
|
std::reverse_copy(pitch_buf.end() - convolution_length, pitch_buf.end(),
|
||||||
|
h.begin());
|
||||||
|
|
||||||
|
// x is pitch_buf[:kNumInvertedLags12kHz + convolution_length].
|
||||||
|
std::copy(pitch_buf.begin(),
|
||||||
|
pitch_buf.begin() + kNumInvertedLags12kHz + convolution_length,
|
||||||
|
x.begin());
|
||||||
|
|
||||||
|
// Shift to frequency domain.
|
||||||
|
std::array<std::complex<float>, freq_domain_fft_length> X{};
|
||||||
|
std::array<std::complex<float>, freq_domain_fft_length> H{};
|
||||||
|
fft->Forward(&x[0], &X[0]);
|
||||||
|
fft->Forward(&h[0], &H[0]);
|
||||||
|
|
||||||
|
// Convolve in frequency domain.
|
||||||
|
for (size_t i = 0; i < X.size(); ++i) {
|
||||||
|
X[i] *= H[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
// Shift back to time domain.
|
||||||
|
std::array<float, time_domain_fft_length> x_conv_h;
|
||||||
|
fft->Inverse(&X[0], &x_conv_h[0]);
|
||||||
|
|
||||||
|
// Collect the result.
|
||||||
|
std::copy(x_conv_h.begin() + convolution_length - 1,
|
||||||
|
x_conv_h.begin() + convolution_length + kNumInvertedLags12kHz - 1,
|
||||||
|
auto_corr.begin());
|
||||||
|
}
|
||||||
|
|
||||||
std::array<size_t, 2> FindBestPitchPeriods(
|
std::array<size_t, 2> FindBestPitchPeriods(
|
||||||
rtc::ArrayView<const float> auto_corr,
|
rtc::ArrayView<const float> auto_corr,
|
||||||
rtc::ArrayView<const float> pitch_buf,
|
rtc::ArrayView<const float> pitch_buf,
|
||||||
|
@ -15,12 +15,25 @@
|
|||||||
#include <array>
|
#include <array>
|
||||||
|
|
||||||
#include "api/array_view.h"
|
#include "api/array_view.h"
|
||||||
|
#include "common_audio/real_fourier.h"
|
||||||
#include "modules/audio_processing/agc2/rnn_vad/common.h"
|
#include "modules/audio_processing/agc2/rnn_vad/common.h"
|
||||||
#include "modules/audio_processing/agc2/rnn_vad/pitch_info.h"
|
#include "modules/audio_processing/agc2/rnn_vad/pitch_info.h"
|
||||||
|
|
||||||
namespace webrtc {
|
namespace webrtc {
|
||||||
namespace rnn_vad {
|
namespace rnn_vad {
|
||||||
|
|
||||||
|
// The inverted lags for the pitch interval [|kInitialMinPitch12kHz|,
|
||||||
|
// |kMaxPitch12kHz|] are in the range [0, |kNumInvertedLags|].
|
||||||
|
static_assert(kMaxPitch12kHz > kInitialMinPitch12kHz, "");
|
||||||
|
static_assert(kMaxPitch24kHz > kInitialMinPitch24kHz, "");
|
||||||
|
constexpr size_t kNumInvertedLags12kHz = kMaxPitch12kHz - kInitialMinPitch12kHz;
|
||||||
|
constexpr size_t kNumInvertedLags24kHz = kMaxPitch24kHz - kInitialMinPitch24kHz;
|
||||||
|
constexpr int kAutoCorrelationFftOrder = 9; // Length-512 FFT.
|
||||||
|
|
||||||
|
static_assert(1 << kAutoCorrelationFftOrder >
|
||||||
|
kNumInvertedLags12kHz + kBufSize12kHz - kMaxPitch12kHz,
|
||||||
|
"");
|
||||||
|
|
||||||
// Performs 2x decimation without any anti-aliasing filter.
|
// Performs 2x decimation without any anti-aliasing filter.
|
||||||
void Decimate2x(rtc::ArrayView<const float, kBufSize24kHz> src,
|
void Decimate2x(rtc::ArrayView<const float, kBufSize24kHz> src,
|
||||||
rtc::ArrayView<float, kBufSize12kHz> dst);
|
rtc::ArrayView<float, kBufSize12kHz> dst);
|
||||||
@ -48,6 +61,25 @@ void ComputeSlidingFrameSquareEnergies(
|
|||||||
rtc::ArrayView<const float, kBufSize24kHz> pitch_buf,
|
rtc::ArrayView<const float, kBufSize24kHz> pitch_buf,
|
||||||
rtc::ArrayView<float, kMaxPitch24kHz + 1> yy_values);
|
rtc::ArrayView<float, kMaxPitch24kHz + 1> yy_values);
|
||||||
|
|
||||||
|
// Computes the auto-correlation coefficients for a given pitch interval.
|
||||||
|
// |auto_corr| indexes are inverted lags.
|
||||||
|
//
|
||||||
|
// The auto-correlations coefficients are computed as follows:
|
||||||
|
// |.........|...........| <- pitch buffer
|
||||||
|
// [ x (fixed) ]
|
||||||
|
// [ y_0 ]
|
||||||
|
// [ y_{m-1} ]
|
||||||
|
// x and y are sub-array of equal length; x is never moved, whereas y slides.
|
||||||
|
// The cross-correlation between y_0 and x corresponds to the auto-correlation
|
||||||
|
// for the maximum pitch period. Hence, the first value in |auto_corr| has an
|
||||||
|
// inverted lag equal to 0 that corresponds to a lag equal to the maximum pitch
|
||||||
|
// period.
|
||||||
|
void ComputePitchAutoCorrelation(
|
||||||
|
rtc::ArrayView<const float, kBufSize12kHz> pitch_buf,
|
||||||
|
size_t max_pitch_period,
|
||||||
|
rtc::ArrayView<float, kNumInvertedLags12kHz> auto_corr,
|
||||||
|
webrtc::RealFourier* fft);
|
||||||
|
|
||||||
// Given the auto-correlation coefficients stored according to
|
// Given the auto-correlation coefficients stored according to
|
||||||
// ComputePitchAutoCorrelation() (i.e., using inverted lags), returns the best
|
// ComputePitchAutoCorrelation() (i.e., using inverted lags), returns the best
|
||||||
// and the second best pitch periods.
|
// and the second best pitch periods.
|
||||||
|
@ -9,6 +9,7 @@
|
|||||||
*/
|
*/
|
||||||
|
|
||||||
#include "modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h"
|
#include "modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h"
|
||||||
|
#include "common_audio/real_fourier.h"
|
||||||
|
|
||||||
#include <array>
|
#include <array>
|
||||||
#include <tuple>
|
#include <tuple>
|
||||||
@ -29,6 +30,34 @@ constexpr std::array<int, 2> kTestPitchPeriods = {
|
|||||||
};
|
};
|
||||||
constexpr std::array<float, 2> kTestPitchGains = {0.35f, 0.75f};
|
constexpr std::array<float, 2> kTestPitchGains = {0.35f, 0.75f};
|
||||||
|
|
||||||
|
constexpr size_t kNumPitchBufSquareEnergies = 385;
|
||||||
|
constexpr size_t kNumPitchBufAutoCorrCoeffs = 147;
|
||||||
|
constexpr size_t kTestDataSize =
|
||||||
|
kBufSize24kHz + kNumPitchBufSquareEnergies + kNumPitchBufAutoCorrCoeffs;
|
||||||
|
|
||||||
|
class TestData {
|
||||||
|
public:
|
||||||
|
TestData() {
|
||||||
|
auto test_data_reader = CreatePitchSearchTestDataReader();
|
||||||
|
test_data_reader->ReadChunk(test_data_);
|
||||||
|
}
|
||||||
|
rtc::ArrayView<const float, kBufSize24kHz> GetPitchBufView() {
|
||||||
|
return {test_data_.data(), kBufSize24kHz};
|
||||||
|
}
|
||||||
|
rtc::ArrayView<const float, kNumPitchBufSquareEnergies>
|
||||||
|
GetPitchBufSquareEnergiesView() {
|
||||||
|
return {test_data_.data() + kBufSize24kHz, kNumPitchBufSquareEnergies};
|
||||||
|
}
|
||||||
|
rtc::ArrayView<const float, kNumPitchBufAutoCorrCoeffs>
|
||||||
|
GetPitchBufAutoCorrCoeffsView() {
|
||||||
|
return {test_data_.data() + kBufSize24kHz + kNumPitchBufSquareEnergies,
|
||||||
|
kNumPitchBufAutoCorrCoeffs};
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::array<float, kTestDataSize> test_data_;
|
||||||
|
};
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
class ComputePitchGainThresholdTest
|
class ComputePitchGainThresholdTest
|
||||||
@ -78,7 +107,7 @@ INSTANTIATE_TEST_SUITE_P(
|
|||||||
std::make_tuple(78, 2, 156, 0.72750503f, 153, 0.85069299f, 0.618379f)));
|
std::make_tuple(78, 2, 156, 0.72750503f, 153, 0.85069299f, 0.618379f)));
|
||||||
|
|
||||||
TEST(RnnVadTest, ComputeSlidingFrameSquareEnergiesBitExactness) {
|
TEST(RnnVadTest, ComputeSlidingFrameSquareEnergiesBitExactness) {
|
||||||
PitchTestData test_data;
|
TestData test_data;
|
||||||
std::array<float, kNumPitchBufSquareEnergies> computed_output;
|
std::array<float, kNumPitchBufSquareEnergies> computed_output;
|
||||||
{
|
{
|
||||||
// TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
|
// TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
|
||||||
@ -91,8 +120,51 @@ TEST(RnnVadTest, ComputeSlidingFrameSquareEnergiesBitExactness) {
|
|||||||
computed_output, 3e-2f);
|
computed_output, 3e-2f);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(RnnVadTest, ComputePitchAutoCorrelationBitExactness) {
|
||||||
|
TestData test_data;
|
||||||
|
std::array<float, kBufSize12kHz> pitch_buf_decimated;
|
||||||
|
Decimate2x(test_data.GetPitchBufView(), pitch_buf_decimated);
|
||||||
|
std::array<float, kNumPitchBufAutoCorrCoeffs> computed_output;
|
||||||
|
{
|
||||||
|
// TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
|
||||||
|
// FloatingPointExceptionObserver fpe_observer;
|
||||||
|
std::unique_ptr<RealFourier> fft =
|
||||||
|
RealFourier::Create(kAutoCorrelationFftOrder);
|
||||||
|
ComputePitchAutoCorrelation(pitch_buf_decimated, kMaxPitch12kHz,
|
||||||
|
computed_output, fft.get());
|
||||||
|
}
|
||||||
|
auto auto_corr_view = test_data.GetPitchBufAutoCorrCoeffsView();
|
||||||
|
ExpectNearAbsolute({auto_corr_view.data(), auto_corr_view.size()},
|
||||||
|
computed_output, 3e-3f);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check that the auto correlation function computes the right thing for a
|
||||||
|
// simple use case.
|
||||||
|
TEST(RnnVadTest, ComputePitchAutoCorrelationConstantBuffer) {
|
||||||
|
// Create constant signal with no pitch.
|
||||||
|
std::array<float, kBufSize12kHz> pitch_buf_decimated;
|
||||||
|
std::fill(pitch_buf_decimated.begin(), pitch_buf_decimated.end(), 1.f);
|
||||||
|
|
||||||
|
std::array<float, kNumPitchBufAutoCorrCoeffs> computed_output;
|
||||||
|
{
|
||||||
|
// TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
|
||||||
|
// FloatingPointExceptionObserver fpe_observer;
|
||||||
|
std::unique_ptr<RealFourier> fft =
|
||||||
|
RealFourier::Create(kAutoCorrelationFftOrder);
|
||||||
|
ComputePitchAutoCorrelation(pitch_buf_decimated, kMaxPitch12kHz,
|
||||||
|
computed_output, fft.get());
|
||||||
|
}
|
||||||
|
|
||||||
|
// The expected output is constantly the length of the fixed 'x'
|
||||||
|
// array in ComputePitchAutoCorrelation.
|
||||||
|
std::array<float, kNumPitchBufAutoCorrCoeffs> expected_output;
|
||||||
|
std::fill(expected_output.begin(), expected_output.end(),
|
||||||
|
kBufSize12kHz - kMaxPitch12kHz);
|
||||||
|
ExpectNearAbsolute(expected_output, computed_output, 4e-5f);
|
||||||
|
}
|
||||||
|
|
||||||
TEST(RnnVadTest, FindBestPitchPeriodsBitExactness) {
|
TEST(RnnVadTest, FindBestPitchPeriodsBitExactness) {
|
||||||
PitchTestData test_data;
|
TestData test_data;
|
||||||
std::array<float, kBufSize12kHz> pitch_buf_decimated;
|
std::array<float, kBufSize12kHz> pitch_buf_decimated;
|
||||||
Decimate2x(test_data.GetPitchBufView(), pitch_buf_decimated);
|
Decimate2x(test_data.GetPitchBufView(), pitch_buf_decimated);
|
||||||
std::array<size_t, 2> pitch_candidates_inv_lags;
|
std::array<size_t, 2> pitch_candidates_inv_lags;
|
||||||
@ -109,7 +181,7 @@ TEST(RnnVadTest, FindBestPitchPeriodsBitExactness) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TEST(RnnVadTest, RefinePitchPeriod48kHzBitExactness) {
|
TEST(RnnVadTest, RefinePitchPeriod48kHzBitExactness) {
|
||||||
PitchTestData test_data;
|
TestData test_data;
|
||||||
std::array<float, kBufSize12kHz> pitch_buf_decimated;
|
std::array<float, kBufSize12kHz> pitch_buf_decimated;
|
||||||
Decimate2x(test_data.GetPitchBufView(), pitch_buf_decimated);
|
Decimate2x(test_data.GetPitchBufView(), pitch_buf_decimated);
|
||||||
size_t pitch_inv_lag;
|
size_t pitch_inv_lag;
|
||||||
@ -135,7 +207,7 @@ TEST_P(CheckLowerPitchPeriodsAndComputePitchGainTest, BitExactness) {
|
|||||||
const float prev_pitch_gain = std::get<2>(params);
|
const float prev_pitch_gain = std::get<2>(params);
|
||||||
const int expected_pitch_period = std::get<3>(params);
|
const int expected_pitch_period = std::get<3>(params);
|
||||||
const float expected_pitch_gain = std::get<4>(params);
|
const float expected_pitch_gain = std::get<4>(params);
|
||||||
PitchTestData test_data;
|
TestData test_data;
|
||||||
{
|
{
|
||||||
// TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
|
// TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
|
||||||
// FloatingPointExceptionObserver fpe_observer;
|
// FloatingPointExceptionObserver fpe_observer;
|
||||||
|
@ -111,28 +111,6 @@ ReaderPairType CreateVadProbsReader() {
|
|||||||
return {std::move(ptr), ptr->data_length()};
|
return {std::move(ptr), ptr->data_length()};
|
||||||
}
|
}
|
||||||
|
|
||||||
PitchTestData::PitchTestData() {
|
|
||||||
auto test_data_reader = CreatePitchSearchTestDataReader();
|
|
||||||
test_data_reader->ReadChunk(test_data_);
|
|
||||||
}
|
|
||||||
|
|
||||||
PitchTestData::~PitchTestData() = default;
|
|
||||||
|
|
||||||
rtc::ArrayView<const float, kBufSize24kHz> PitchTestData::GetPitchBufView() {
|
|
||||||
return {test_data_.data(), kBufSize24kHz};
|
|
||||||
}
|
|
||||||
|
|
||||||
rtc::ArrayView<const float, kNumPitchBufSquareEnergies>
|
|
||||||
PitchTestData::GetPitchBufSquareEnergiesView() {
|
|
||||||
return {test_data_.data() + kBufSize24kHz, kNumPitchBufSquareEnergies};
|
|
||||||
}
|
|
||||||
|
|
||||||
rtc::ArrayView<const float, kNumPitchBufAutoCorrCoeffs>
|
|
||||||
PitchTestData::GetPitchBufAutoCorrCoeffsView() {
|
|
||||||
return {test_data_.data() + kBufSize24kHz + kNumPitchBufSquareEnergies,
|
|
||||||
kNumPitchBufAutoCorrCoeffs};
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace test
|
} // namespace test
|
||||||
} // namespace rnn_vad
|
} // namespace rnn_vad
|
||||||
} // namespace webrtc
|
} // namespace webrtc
|
||||||
|
@ -12,7 +12,6 @@
|
|||||||
#define MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_TEST_UTILS_H_
|
#define MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_TEST_UTILS_H_
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <array>
|
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
#include <limits>
|
#include <limits>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
@ -21,7 +20,6 @@
|
|||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "api/array_view.h"
|
#include "api/array_view.h"
|
||||||
#include "modules/audio_processing/agc2/rnn_vad/common.h"
|
|
||||||
#include "rtc_base/checks.h"
|
#include "rtc_base/checks.h"
|
||||||
|
|
||||||
namespace webrtc {
|
namespace webrtc {
|
||||||
@ -120,27 +118,6 @@ CreateSilenceFlagsFeatureMatrixReader();
|
|||||||
std::pair<std::unique_ptr<BinaryFileReader<float>>, const size_t>
|
std::pair<std::unique_ptr<BinaryFileReader<float>>, const size_t>
|
||||||
CreateVadProbsReader();
|
CreateVadProbsReader();
|
||||||
|
|
||||||
constexpr size_t kNumPitchBufAutoCorrCoeffs = 147;
|
|
||||||
constexpr size_t kNumPitchBufSquareEnergies = 385;
|
|
||||||
constexpr size_t kPitchTestDataSize =
|
|
||||||
kBufSize24kHz + kNumPitchBufSquareEnergies + kNumPitchBufAutoCorrCoeffs;
|
|
||||||
|
|
||||||
// Class to retrieve a test pitch buffer content and the expected output for the
|
|
||||||
// analysis steps.
|
|
||||||
class PitchTestData {
|
|
||||||
public:
|
|
||||||
PitchTestData();
|
|
||||||
~PitchTestData();
|
|
||||||
rtc::ArrayView<const float, kBufSize24kHz> GetPitchBufView();
|
|
||||||
rtc::ArrayView<const float, kNumPitchBufSquareEnergies>
|
|
||||||
GetPitchBufSquareEnergiesView();
|
|
||||||
rtc::ArrayView<const float, kNumPitchBufAutoCorrCoeffs>
|
|
||||||
GetPitchBufAutoCorrCoeffsView();
|
|
||||||
|
|
||||||
private:
|
|
||||||
std::array<float, kPitchTestDataSize> test_data_;
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace test
|
} // namespace test
|
||||||
} // namespace rnn_vad
|
} // namespace rnn_vad
|
||||||
} // namespace webrtc
|
} // namespace webrtc
|
||||||
|
Reference in New Issue
Block a user