AGC2 RNN VAD: Spectral features internal API.

This CL adds helper functions to be used for the spectral features
computation. Namely, it includes the following:
- band boundaries (frequency to FFT coeffcient index)
- band energy coefficients
- log band energy coefficients
- fixed size DCT table and computation

Bug: webrtc:9076
Change-Id: I03a8799b226d986bc1e37cefd0c3039f94b5592a
Reviewed-on: https://webrtc-review.googlesource.com/73687
Reviewed-by: Alex Loiko <aleloi@webrtc.org>
Reviewed-by: Minyue Li <minyue@webrtc.org>
Commit-Queue: Alessio Bazzica <alessiob@webrtc.org>
Cr-Commit-Position: refs/heads/master@{#23170}
This commit is contained in:
Alessio Bazzica
2018-05-08 11:10:45 +02:00
committed by Commit Bot
parent 496caa9095
commit 0bd0a3fe4c
9 changed files with 385 additions and 9 deletions

View File

@ -30,11 +30,14 @@ source_set("lib") {
"rnn.cc",
"rnn.h",
"sequence_buffer.h",
"spectral_features_internal.cc",
"spectral_features_internal.h",
"symmetric_matrix_buffer.h",
]
deps = [
"../../../../api:array_view",
"../../../../rtc_base:checks",
"../../../../rtc_base:rtc_base_approved",
"//third_party/rnnoise:kiss_fft",
"//third_party/rnnoise:rnn_vad",
]
@ -57,6 +60,8 @@ if (rtc_include_tests) {
}
unittest_resources = [
"../../../../resources/audio_processing/agc2/rnn_vad/band_energies.dat",
"../../../../resources/audio_processing/agc2/rnn_vad/fft.dat",
"../../../../resources/audio_processing/agc2/rnn_vad/pitch_buf_24k.dat",
"../../../../resources/audio_processing/agc2/rnn_vad/pitch_lp_res.dat",
"../../../../resources/audio_processing/agc2/rnn_vad/sil_features.dat",
@ -83,6 +88,7 @@ if (rtc_include_tests) {
"ring_buffer_unittest.cc",
"rnn_unittest.cc",
"sequence_buffer_unittest.cc",
"spectral_features_internal_unittest.cc",
"symmetric_matrix_buffer_unittest.cc",
]
deps = [

View File

@ -45,6 +45,12 @@ constexpr size_t kMaxPitch12kHz = kMaxPitch24kHz / 2;
constexpr size_t kMinPitch48kHz = kMinPitch24kHz * 2;
constexpr size_t kMaxPitch48kHz = kMaxPitch24kHz * 2;
// Sub-band frequency boundaries.
constexpr size_t kNumBands = 22;
constexpr int kBandFrequencyBoundaries[kNumBands] = {
0, 200, 400, 600, 800, 1000, 1200, 1400, 1600, 2000, 2400,
2800, 3200, 4000, 4800, 5600, 6800, 8000, 9600, 12000, 15600, 20000};
constexpr size_t kFeatureVectorSize = 42;
} // namespace rnn_vad

View File

@ -0,0 +1,134 @@
/*
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#include "modules/audio_processing/agc2/rnn_vad/spectral_features_internal.h"
#include <algorithm>
#include <cmath>
#include "rtc_base/checks.h"
#include "rtc_base/function_view.h"
namespace webrtc {
namespace rnn_vad {
namespace {
// DCT scaling factor.
const float kDctScalingFactor = std::sqrt(2.f / kNumBands);
// Iterates through frequency bands and computes coefficients via |functor| for
// triangular bands with peak response at each band boundary. |functor| returns
// a floating point value for the FFT coefficient having index equal to the
// argument passed to |functor|; that argument is in the range {0, ...
// |max_freq_bin_index| - 1}.
void ComputeBandCoefficients(
rtc::FunctionView<float(size_t)> functor,
rtc::ArrayView<const size_t, kNumBands> band_boundaries,
size_t max_freq_bin_index,
rtc::ArrayView<float, kNumBands> coefficients) {
std::fill(coefficients.begin(), coefficients.end(), 0.f);
for (size_t i = 0; i < coefficients.size() - 1; ++i) {
RTC_DCHECK_EQ(0.f, coefficients[i + 1]);
RTC_DCHECK_GT(band_boundaries[i + 1], band_boundaries[i]);
const size_t first_freq_bin = band_boundaries[i];
const size_t last_freq_bin =
std::min(max_freq_bin_index, first_freq_bin + band_boundaries[i + 1] -
band_boundaries[i] - 1);
// Depending on the sample rate, the highest bands can have no FFT
// coefficients. Stop the iteration when coming across the first empty band.
if (first_freq_bin >= last_freq_bin)
break;
const size_t band_size = last_freq_bin - first_freq_bin + 1;
// Compute the band coefficient using a triangular band with peak response
// at the band boundary.
for (size_t j = first_freq_bin; j <= last_freq_bin; ++j) {
const float w = static_cast<float>(j - first_freq_bin) / band_size;
const float coefficient = functor(j);
coefficients[i] += (1.f - w) * coefficient;
coefficients[i + 1] += w * coefficient;
}
}
// The first and the last bands in the loop above only got half contribution.
coefficients[0] *= 2.f;
coefficients[coefficients.size() - 1] *= 2.f;
// TODO(bugs.webrtc.org/9076): Replace the line above with
// "coefficients[i] *= 2.f" (*) since we now assume that the last band is
// always |kNumBands| - 1.
// (*): "size_t i" must be declared before the main loop.
}
} // namespace
std::array<size_t, kNumBands> ComputeBandBoundaryIndexes(
size_t sample_rate_hz,
size_t frame_size_samples) {
std::array<size_t, kNumBands> indexes;
for (size_t i = 0; i < kNumBands; ++i) {
indexes[i] =
kBandFrequencyBoundaries[i] * frame_size_samples / sample_rate_hz;
}
return indexes;
}
void ComputeBandEnergies(
rtc::ArrayView<const std::complex<float>> fft_coeffs,
rtc::ArrayView<const size_t, kNumBands> band_boundaries,
rtc::ArrayView<float, kNumBands> band_energies) {
RTC_DCHECK_EQ(band_boundaries.size(), band_energies.size());
auto functor = [fft_coeffs](const size_t freq_bin_index) {
return std::norm(fft_coeffs[freq_bin_index]);
};
ComputeBandCoefficients(functor, band_boundaries, fft_coeffs.size() - 1,
band_energies);
}
void ComputeLogBandEnergiesCoefficients(
rtc::ArrayView<const float, kNumBands> band_energy_coeffs,
rtc::ArrayView<float, kNumBands> log_band_energy_coeffs) {
float log_max = -2.f;
float follow = -2.f;
for (size_t i = 0; i < band_energy_coeffs.size(); ++i) {
log_band_energy_coeffs[i] = std::log10(1e-2f + band_energy_coeffs[i]);
// Smoothing across frequency bands.
log_band_energy_coeffs[i] = std::max(
log_max - 7.f, std::max(follow - 1.5f, log_band_energy_coeffs[i]));
log_max = std::max(log_max, log_band_energy_coeffs[i]);
follow = std::max(follow - 1.5f, log_band_energy_coeffs[i]);
}
}
std::array<float, kNumBands * kNumBands> ComputeDctTable() {
std::array<float, kNumBands * kNumBands> dct_table;
const double k = std::sqrt(0.5);
for (size_t i = 0; i < kNumBands; ++i) {
for (size_t j = 0; j < kNumBands; ++j)
dct_table[i * kNumBands + j] = std::cos((i + 0.5) * j * kPi / kNumBands);
dct_table[i * kNumBands] *= k;
}
return dct_table;
}
void ComputeDct(rtc::ArrayView<const float, kNumBands> in,
rtc::ArrayView<const float, kNumBands * kNumBands> dct_table,
rtc::ArrayView<float> out) {
RTC_DCHECK_NE(in.data(), out.data()) << "In-place DCT is not supported.";
RTC_DCHECK_LE(1, out.size());
RTC_DCHECK_LE(out.size(), in.size());
std::fill(out.begin(), out.end(), 0.f);
for (size_t i = 0; i < out.size(); ++i) {
for (size_t j = 0; j < in.size(); ++j) {
out[i] += in[j] * dct_table[j * in.size() + i];
}
out[i] *= kDctScalingFactor;
}
}
} // namespace rnn_vad
} // namespace webrtc

View File

@ -0,0 +1,53 @@
/*
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#ifndef MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_SPECTRAL_FEATURES_INTERNAL_H_
#define MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_SPECTRAL_FEATURES_INTERNAL_H_
#include <array>
#include <complex>
#include "api/array_view.h"
#include "modules/audio_processing/agc2/rnn_vad/common.h"
namespace webrtc {
namespace rnn_vad {
// Computes FFT boundary indexes corresponding to sub-bands.
std::array<size_t, kNumBands> ComputeBandBoundaryIndexes(
size_t sample_rate_hz,
size_t frame_size_samples);
// Given an array of FFT coefficients and a vector of band boundary indexes,
// computes band energy coefficients.
void ComputeBandEnergies(
rtc::ArrayView<const std::complex<float>> fft_coeffs,
rtc::ArrayView<const size_t, kNumBands> band_boundaries,
rtc::ArrayView<float, kNumBands> band_energies);
// Computes log band energy coefficients.
void ComputeLogBandEnergiesCoefficients(
rtc::ArrayView<const float, kNumBands> band_energy_coeffs,
rtc::ArrayView<float, kNumBands> log_band_energy_coeffs);
// Creates a DCT table for arrays having size equal to |kNumBands|.
std::array<float, kNumBands * kNumBands> ComputeDctTable();
// Computes DCT for |in| given a pre-computed DCT table. In-place computation is
// not allowed and |out| can be smaller than |in| in order to only compute the
// first DCT coefficients.
void ComputeDct(rtc::ArrayView<const float, kNumBands> in,
rtc::ArrayView<const float, kNumBands * kNumBands> dct_table,
rtc::ArrayView<float> out);
} // namespace rnn_vad
} // namespace webrtc
#endif // MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_SPECTRAL_FEATURES_INTERNAL_H_

View File

@ -0,0 +1,137 @@
/*
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#include "modules/audio_processing/agc2/rnn_vad/spectral_features_internal.h"
#include "modules/audio_processing/agc2/rnn_vad/test_utils.h"
// TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
// #include "test/fpe_observer.h"
#include "test/gtest.h"
namespace webrtc {
namespace rnn_vad {
namespace test {
namespace {
constexpr size_t kSampleRate48kHz = 48000;
constexpr size_t kFrameSize20ms48kHz = 2 * kSampleRate48kHz / 100;
constexpr size_t kFftNumCoeffs20ms48kHz = kFrameSize20ms48kHz / 2 + 1;
} // namespace
// TODO(bugs.webrtc.org/9076): Remove this test before closing the issue.
// Check that when using precomputed FFT coefficients for frames at 48 kHz, the
// output of ComputeBandEnergies() is bit exact.
TEST(RnnVadTest, ComputeBandEnergies48kHzBitExactness) {
// Initialize input data reader and buffers.
auto fft_coeffs_reader = CreateFftCoeffsReader();
const size_t num_frames = fft_coeffs_reader.second;
ASSERT_EQ(
kFftNumCoeffs20ms48kHz,
rtc::CheckedDivExact(fft_coeffs_reader.first->data_length(), num_frames) /
2);
std::array<float, kFftNumCoeffs20ms48kHz> fft_coeffs_real;
std::array<float, kFftNumCoeffs20ms48kHz> fft_coeffs_imag;
std::array<std::complex<float>, kFftNumCoeffs20ms48kHz> fft_coeffs;
// Init expected output reader and buffer.
auto band_energies_reader = CreateBandEnergyCoeffsReader();
ASSERT_EQ(num_frames, band_energies_reader.second);
std::array<float, kNumBands> expected_band_energies;
// Init band energies coefficients computation.
const auto band_boundary_indexes =
ComputeBandBoundaryIndexes(kSampleRate48kHz, kFrameSize20ms48kHz);
std::array<float, kNumBands> computed_band_energies;
// Check output for every frame.
{
// TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
// FloatingPointExceptionObserver fpe_observer;
for (size_t i = 0; i < num_frames; ++i) {
SCOPED_TRACE(i);
// Read input.
fft_coeffs_reader.first->ReadChunk(
{fft_coeffs_real.data(), fft_coeffs_real.size()});
fft_coeffs_reader.first->ReadChunk(
{fft_coeffs_imag.data(), fft_coeffs_imag.size()});
for (size_t i = 0; i < kFftNumCoeffs20ms48kHz; ++i) {
fft_coeffs[i].real(fft_coeffs_real[i]);
fft_coeffs[i].imag(fft_coeffs_imag[i]);
}
band_energies_reader.first->ReadChunk(
{expected_band_energies.data(), expected_band_energies.size()});
// Compute band energy coefficients and check output.
ComputeBandEnergies(
{fft_coeffs.data(), fft_coeffs.size()},
{band_boundary_indexes.data(), band_boundary_indexes.size()},
{computed_band_energies.data(), computed_band_energies.size()});
ExpectEqualFloatArray(expected_band_energies, computed_band_energies);
}
}
}
TEST(RnnVadTest, ComputeLogBandEnergiesCoefficientsBitExactness) {
constexpr std::array<float, kNumBands> input = {
{86.060539245605f, 275.668334960938f, 43.406528472900f, 6.541896820068f,
17.964015960693f, 8.090919494629f, 1.261920094490f, 1.212702631950f,
1.619154453278f, 0.508935272694f, 0.346316039562f, 0.237035423517f,
0.172424271703f, 0.271657168865f, 0.126088857651f, 0.139967113733f,
0.207200810313f, 0.155893072486f, 0.091090843081f, 0.033391401172f,
0.013879744336f, 0.011973354965f}};
constexpr std::array<float, kNumBands> expected_output = {
{1.934854507446f, 2.440402746201f, 1.637655138969f, 0.816367030144f,
1.254645109177f, 0.908534288406f, 0.104459829628f, 0.087320849299f,
0.211962252855f, -0.284886807203f, -0.448164641857f, -0.607240796089f,
-0.738917350769f, -0.550279200077f, -0.866177439690f, -0.824003994465f,
-0.663138568401f, -0.780171751976f, -0.995288193226f, -1.362596273422f,
-1.621970295906f, -1.658103585243f}};
std::array<float, kNumBands> computed_output;
{
// TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
// FloatingPointExceptionObserver fpe_observer;
ComputeLogBandEnergiesCoefficients(
{input.data(), input.size()},
{computed_output.data(), computed_output.size()});
ExpectNearAbsolute(expected_output, computed_output, 1e-5f);
}
}
TEST(RnnVadTest, ComputeDctBitExactness) {
constexpr std::array<float, kNumBands> input = {
{0.232155621052f, 0.678957760334f, 0.220818966627f, -0.077363930643f,
-0.559227049351f, 0.432545185089f, 0.353900641203f, 0.398993015289f,
0.409774333239f, 0.454977899790f, 0.300520688295f, -0.010286616161f,
0.272525429726f, 0.098067551851f, 0.083649002016f, 0.046226885170f,
-0.033228103071f, 0.144773483276f, -0.117661058903f, -0.005628800020f,
-0.009547689930f, -0.045382082462f}};
constexpr std::array<float, kNumBands> expected_output = {
{0.697072803974f, 0.442710995674f, -0.293156713247f, -0.060711503029f,
0.292050391436f, 0.489301353693f, 0.402255415916f, 0.134404733777f,
-0.086305990815f, -0.199605688453f, -0.234511867166f, -0.413774639368f,
-0.388507157564f, -0.032798115164f, 0.044605545700f, 0.112466648221f,
-0.050096966326f, 0.045971218497f, -0.029815061018f, -0.410366982222f,
-0.209233760834f, -0.128037497401f}};
const auto dct_table = ComputeDctTable();
std::array<float, kNumBands> computed_output;
{
// TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
// FloatingPointExceptionObserver fpe_observer;
ComputeDct({input.data(), input.size()},
{dct_table.data(), dct_table.size()},
{computed_output.data(), computed_output.size()});
ExpectNearAbsolute(expected_output, computed_output, 1e-5f);
}
}
} // namespace test
} // namespace rnn_vad
} // namespace webrtc

View File

@ -27,6 +27,15 @@ using ReaderPairType =
using webrtc::test::ResourcePath;
void ExpectEqualFloatArray(rtc::ArrayView<const float> expected,
rtc::ArrayView<const float> computed) {
ASSERT_EQ(expected.size(), computed.size());
for (size_t i = 0; i < expected.size(); ++i) {
SCOPED_TRACE(i);
EXPECT_FLOAT_EQ(expected[i], computed[i]);
}
}
void ExpectNearAbsolute(rtc::ArrayView<const float> expected,
rtc::ArrayView<const float> computed,
float tolerance) {
@ -38,10 +47,10 @@ void ExpectNearAbsolute(rtc::ArrayView<const float> expected,
}
ReaderPairType CreatePitchBuffer24kHzReader() {
constexpr size_t cols = 864;
auto ptr = rtc::MakeUnique<BinaryFileReader<float>>(
ResourcePath("audio_processing/agc2/rnn_vad/pitch_buf_24k", "dat"), 864);
return {std::move(ptr),
rtc::CheckedDivExact(ptr->data_length(), static_cast<size_t>(864))};
ResourcePath("audio_processing/agc2/rnn_vad/pitch_buf_24k", "dat"), cols);
return {std::move(ptr), rtc::CheckedDivExact(ptr->data_length(), cols)};
}
ReaderPairType CreateLpResidualAndPitchPeriodGainReader() {
@ -53,13 +62,31 @@ ReaderPairType CreateLpResidualAndPitchPeriodGainReader() {
rtc::CheckedDivExact(ptr->data_length(), 2 + num_lp_residual_coeffs)};
}
ReaderPairType CreateFftCoeffsReader() {
constexpr size_t num_fft_points = 481;
constexpr size_t row_size = 2 * num_fft_points; // Real and imaginary values.
auto ptr = rtc::MakeUnique<BinaryFileReader<float>>(
test::ResourcePath("audio_processing/agc2/rnn_vad/fft", "dat"),
num_fft_points);
return {std::move(ptr), rtc::CheckedDivExact(ptr->data_length(), row_size)};
}
ReaderPairType CreateBandEnergyCoeffsReader() {
constexpr size_t num_bands = 22;
auto ptr = rtc::MakeUnique<BinaryFileReader<float>>(
test::ResourcePath("audio_processing/agc2/rnn_vad/band_energies", "dat"),
num_bands);
return {std::move(ptr), rtc::CheckedDivExact(ptr->data_length(), num_bands)};
}
ReaderPairType CreateSilenceFlagsFeatureMatrixReader() {
constexpr size_t feature_vector_size = 42;
auto ptr = rtc::MakeUnique<BinaryFileReader<float>>(
test::ResourcePath("audio_processing/agc2/rnn_vad/sil_features", "dat"),
42);
// Features (42) and silence flag.
feature_vector_size);
// Features and silence flag.
return {std::move(ptr),
rtc::CheckedDivExact(ptr->data_length(), static_cast<size_t>(43))};
rtc::CheckedDivExact(ptr->data_length(), feature_vector_size + 1)};
}
ReaderPairType CreateVadProbsReader() {

View File

@ -28,7 +28,12 @@ namespace test {
constexpr float kFloatMin = std::numeric_limits<float>::min();
// Fail for every pair from two equally sized rtc::ArrayView<float> views such
// Fails for every pair from two equally sized rtc::ArrayView<float> views such
// that the values in the pair do not match.
void ExpectEqualFloatArray(rtc::ArrayView<const float> expected,
rtc::ArrayView<const float> computed);
// Fails for every pair from two equally sized rtc::ArrayView<float> views such
// that their absolute error is above a given threshold.
void ExpectNearAbsolute(rtc::ArrayView<const float> expected,
rtc::ArrayView<const float> computed,
@ -95,10 +100,16 @@ CreatePitchBuffer24kHzReader();
// and gain values.
std::pair<std::unique_ptr<BinaryFileReader<float>>, const size_t>
CreateLpResidualAndPitchPeriodGainReader();
// Instance a reader for the silence flags and the feature matrix.
// Creates a reader for the FFT coefficients.
std::pair<std::unique_ptr<BinaryFileReader<float>>, const size_t>
CreateFftCoeffsReader();
// Instance a reader for the band energy coefficients.
std::pair<std::unique_ptr<BinaryFileReader<float>>, const size_t>
CreateBandEnergyCoeffsReader();
// Creates a reader for the silence flags and the feature matrix.
std::pair<std::unique_ptr<BinaryFileReader<float>>, const size_t>
CreateSilenceFlagsFeatureMatrixReader();
// Instance a reader for the VAD probabilities.
// Creates a reader for the VAD probabilities.
std::pair<std::unique_ptr<BinaryFileReader<float>>, const size_t>
CreateVadProbsReader();

View File

@ -0,0 +1 @@
52fad3366911c238929585daad02c8db11f99eae

View File

@ -0,0 +1 @@
e62364d35abd123663bfc800fa233071d6d7fffd