Reland "RNN VAD: prepare for SIMD optimization"
This reverts commit 5ab21f8853892205594ae8559a00b431f30a8a06. Reason for revert: downstream fixed Original change's description: > Revert "RNN VAD: prepare for SIMD optimization" > > This reverts commit 7350a902374c796dec8ce583cfaf4b9697f3a525. > > Reason for revert: possibly breaking downstream projects > > Original change's description: > > RNN VAD: prepare for SIMD optimization > > > > This CL adds the boilerplate for SIMD optimization of FC and GRU layers > > in rnn.cc. The same scheme of AEC3 has been used. Unit tests for the > > optimized architectures have been added (the same unoptimized > > implementation will run). > > > > Minor changes: > > - unnecessary const removed in rnn.h > > - FC and GRU test data in the anon namespace as constexpr > > > > Bug: webrtc:10480 > > Change-Id: Ifae4e970326e7e7c603d49aeaf61194b5efdabd3 > > Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/141419 > > Commit-Queue: Alessio Bazzica <alessiob@webrtc.org> > > Reviewed-by: Gustaf Ullberg <gustaf@webrtc.org> > > Cr-Commit-Position: refs/heads/master@{#29696} > > TBR=gustaf@webrtc.org,alessiob@webrtc.org,fhernqvist@webrtc.org > > Change-Id: I9ae82f4bd2d30797646fabfb5ad16bea378208b8 > No-Presubmit: true > No-Tree-Checks: true > No-Try: true > Bug: webrtc:10480 > Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/158893 > Reviewed-by: Alessio Bazzica <alessiob@webrtc.org> > Commit-Queue: Alessio Bazzica <alessiob@webrtc.org> > Cr-Commit-Position: refs/heads/master@{#29699} TBR=gustaf@webrtc.org,alessiob@webrtc.org,fhernqvist@webrtc.org Change-Id: I33edd144f7ac795bf472aae9fa5a79c326000443 No-Presubmit: true No-Tree-Checks: true No-Try: true Bug: webrtc:10480 Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/159001 Reviewed-by: Alessio Bazzica <alessiob@webrtc.org> Commit-Queue: Alessio Bazzica <alessiob@webrtc.org> Cr-Commit-Position: refs/heads/master@{#29708}
This commit is contained in:

committed by
Commit Bot

parent
2f2049af23
commit
43afc09fc5
@ -13,6 +13,7 @@ rtc_library("rnn_vad") {
|
||||
sources = [
|
||||
"auto_correlation.cc",
|
||||
"auto_correlation.h",
|
||||
"common.cc",
|
||||
"common.h",
|
||||
"features_extraction.cc",
|
||||
"features_extraction.h",
|
||||
@ -33,11 +34,20 @@ rtc_library("rnn_vad") {
|
||||
"spectral_features_internal.h",
|
||||
"symmetric_matrix_buffer.h",
|
||||
]
|
||||
|
||||
defines = []
|
||||
if (rtc_build_with_neon && current_cpu != "arm64") {
|
||||
suppressed_configs += [ "//build/config/compiler:compiler_arm_fpu" ]
|
||||
cflags = [ "-mfpu=neon" ]
|
||||
}
|
||||
|
||||
deps = [
|
||||
"..:biquad_filter",
|
||||
"../../../../api:array_view",
|
||||
"../../../../rtc_base:checks",
|
||||
"../../../../rtc_base:rtc_base_approved",
|
||||
"../../../../rtc_base/system:arch",
|
||||
"../../../../system_wrappers:cpu_features_api",
|
||||
"../../utility:pffft_wrapper",
|
||||
"//third_party/rnnoise:rnn_vad",
|
||||
]
|
||||
|
34
modules/audio_processing/agc2/rnn_vad/common.cc
Normal file
34
modules/audio_processing/agc2/rnn_vad/common.cc
Normal file
@ -0,0 +1,34 @@
|
||||
/*
|
||||
* Copyright (c) 2019 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#include "modules/audio_processing/agc2/rnn_vad/common.h"
|
||||
|
||||
#include "rtc_base/system/arch.h"
|
||||
#include "system_wrappers/include/cpu_features_wrapper.h"
|
||||
|
||||
namespace webrtc {
|
||||
namespace rnn_vad {
|
||||
|
||||
Optimization DetectOptimization() {
|
||||
#if defined(WEBRTC_ARCH_X86_FAMILY)
|
||||
if (WebRtc_GetCPUInfo(kSSE2) != 0) {
|
||||
return Optimization::kSse2;
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(WEBRTC_HAS_NEON)
|
||||
return Optimization::kNeon;
|
||||
#endif
|
||||
|
||||
return Optimization::kNone;
|
||||
}
|
||||
|
||||
} // namespace rnn_vad
|
||||
} // namespace webrtc
|
@ -11,6 +11,8 @@
|
||||
#ifndef MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_COMMON_H_
|
||||
#define MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_COMMON_H_
|
||||
|
||||
#include <stddef.h>
|
||||
|
||||
namespace webrtc {
|
||||
namespace rnn_vad {
|
||||
|
||||
@ -63,6 +65,11 @@ static_assert(kCepstralCoeffsHistorySize > 2,
|
||||
|
||||
constexpr size_t kFeatureVectorSize = 42;
|
||||
|
||||
enum class Optimization { kNone, kSse2, kNeon };
|
||||
|
||||
// Detects what kind of optimizations to use for the code.
|
||||
Optimization DetectOptimization();
|
||||
|
||||
} // namespace rnn_vad
|
||||
} // namespace webrtc
|
||||
|
||||
|
@ -10,6 +10,15 @@
|
||||
|
||||
#include "modules/audio_processing/agc2/rnn_vad/rnn.h"
|
||||
|
||||
// Defines WEBRTC_ARCH_X86_FAMILY, used below.
|
||||
#include "rtc_base/system/arch.h"
|
||||
|
||||
#if defined(WEBRTC_HAS_NEON)
|
||||
#include <arm_neon.h>
|
||||
#endif
|
||||
#if defined(WEBRTC_ARCH_X86_FAMILY)
|
||||
#include <emmintrin.h>
|
||||
#endif
|
||||
#include <algorithm>
|
||||
#include <array>
|
||||
#include <cmath>
|
||||
@ -69,12 +78,14 @@ FullyConnectedLayer::FullyConnectedLayer(
|
||||
const size_t output_size,
|
||||
const rtc::ArrayView<const int8_t> bias,
|
||||
const rtc::ArrayView<const int8_t> weights,
|
||||
float (*const activation_function)(float))
|
||||
float (*const activation_function)(float),
|
||||
Optimization optimization)
|
||||
: input_size_(input_size),
|
||||
output_size_(output_size),
|
||||
bias_(GetScaledParams(bias)),
|
||||
weights_(GetScaledParams(weights)),
|
||||
activation_function_(activation_function) {
|
||||
activation_function_(activation_function),
|
||||
optimization_(optimization) {
|
||||
RTC_DCHECK_LE(output_size_, kFullyConnectedLayersMaxUnits)
|
||||
<< "Static over-allocation of fully-connected layers output vectors is "
|
||||
"not sufficient.";
|
||||
@ -91,8 +102,26 @@ rtc::ArrayView<const float> FullyConnectedLayer::GetOutput() const {
|
||||
}
|
||||
|
||||
void FullyConnectedLayer::ComputeOutput(rtc::ArrayView<const float> input) {
|
||||
// TODO(bugs.chromium.org/9076): Optimize using SSE/AVX fused multiply-add
|
||||
// operations.
|
||||
switch (optimization_) {
|
||||
#if defined(WEBRTC_ARCH_X86_FAMILY)
|
||||
case Optimization::kSse2:
|
||||
// TODO(bugs.chromium.org/10480): Handle Optimization::kSse2.
|
||||
ComputeOutput_NONE(input);
|
||||
break;
|
||||
#endif
|
||||
#if defined(WEBRTC_HAS_NEON)
|
||||
case Optimization::kNeon:
|
||||
// TODO(bugs.chromium.org/10480): Handle Optimization::kNeon.
|
||||
ComputeOutput_NONE(input);
|
||||
break;
|
||||
#endif
|
||||
default:
|
||||
ComputeOutput_NONE(input);
|
||||
}
|
||||
}
|
||||
|
||||
void FullyConnectedLayer::ComputeOutput_NONE(
|
||||
rtc::ArrayView<const float> input) {
|
||||
for (size_t o = 0; o < output_size_; ++o) {
|
||||
output_[o] = bias_[o];
|
||||
// TODO(bugs.chromium.org/9076): Benchmark how different layouts for
|
||||
@ -109,12 +138,14 @@ GatedRecurrentLayer::GatedRecurrentLayer(
|
||||
const size_t output_size,
|
||||
const rtc::ArrayView<const int8_t> bias,
|
||||
const rtc::ArrayView<const int8_t> weights,
|
||||
const rtc::ArrayView<const int8_t> recurrent_weights)
|
||||
const rtc::ArrayView<const int8_t> recurrent_weights,
|
||||
Optimization optimization)
|
||||
: input_size_(input_size),
|
||||
output_size_(output_size),
|
||||
bias_(GetScaledParams(bias)),
|
||||
weights_(GetScaledParams(weights)),
|
||||
recurrent_weights_(GetScaledParams(recurrent_weights)) {
|
||||
recurrent_weights_(GetScaledParams(recurrent_weights)),
|
||||
optimization_(optimization) {
|
||||
RTC_DCHECK_LE(output_size_, kRecurrentLayersMaxUnits)
|
||||
<< "Static over-allocation of recurrent layers state vectors is not "
|
||||
<< "sufficient.";
|
||||
@ -139,6 +170,26 @@ void GatedRecurrentLayer::Reset() {
|
||||
}
|
||||
|
||||
void GatedRecurrentLayer::ComputeOutput(rtc::ArrayView<const float> input) {
|
||||
switch (optimization_) {
|
||||
#if defined(WEBRTC_ARCH_X86_FAMILY)
|
||||
case Optimization::kSse2:
|
||||
// TODO(bugs.chromium.org/10480): Handle Optimization::kSse2.
|
||||
ComputeOutput_NONE(input);
|
||||
break;
|
||||
#endif
|
||||
#if defined(WEBRTC_HAS_NEON)
|
||||
case Optimization::kNeon:
|
||||
// TODO(bugs.chromium.org/10480): Handle Optimization::kNeon.
|
||||
ComputeOutput_NONE(input);
|
||||
break;
|
||||
#endif
|
||||
default:
|
||||
ComputeOutput_NONE(input);
|
||||
}
|
||||
}
|
||||
|
||||
void GatedRecurrentLayer::ComputeOutput_NONE(
|
||||
rtc::ArrayView<const float> input) {
|
||||
// TODO(bugs.chromium.org/9076): Optimize using SSE/AVX fused multiply-add
|
||||
// operations.
|
||||
// Stride and offset used to read parameter arrays.
|
||||
@ -203,17 +254,20 @@ RnnBasedVad::RnnBasedVad()
|
||||
kInputLayerOutputSize,
|
||||
kInputDenseBias,
|
||||
kInputDenseWeights,
|
||||
TansigApproximated),
|
||||
TansigApproximated,
|
||||
DetectOptimization()),
|
||||
hidden_layer_(kInputLayerOutputSize,
|
||||
kHiddenLayerOutputSize,
|
||||
kHiddenGruBias,
|
||||
kHiddenGruWeights,
|
||||
kHiddenGruRecurrentWeights),
|
||||
kHiddenGruRecurrentWeights,
|
||||
DetectOptimization()),
|
||||
output_layer_(kHiddenLayerOutputSize,
|
||||
kOutputLayerOutputSize,
|
||||
kOutputDenseBias,
|
||||
kOutputDenseWeights,
|
||||
SigmoidApproximated) {
|
||||
SigmoidApproximated,
|
||||
DetectOptimization()) {
|
||||
// Input-output chaining size checks.
|
||||
RTC_DCHECK_EQ(input_layer_.output_size(), hidden_layer_.input_size())
|
||||
<< "The input and the hidden layers sizes do not match.";
|
||||
|
@ -38,11 +38,12 @@ constexpr size_t kRecurrentLayersMaxUnits = 24;
|
||||
// Fully-connected layer.
|
||||
class FullyConnectedLayer {
|
||||
public:
|
||||
FullyConnectedLayer(const size_t input_size,
|
||||
const size_t output_size,
|
||||
const rtc::ArrayView<const int8_t> bias,
|
||||
const rtc::ArrayView<const int8_t> weights,
|
||||
float (*const activation_function)(float));
|
||||
FullyConnectedLayer(size_t input_size,
|
||||
size_t output_size,
|
||||
rtc::ArrayView<const int8_t> bias,
|
||||
rtc::ArrayView<const int8_t> weights,
|
||||
float (*const activation_function)(float),
|
||||
Optimization optimization);
|
||||
FullyConnectedLayer(const FullyConnectedLayer&) = delete;
|
||||
FullyConnectedLayer& operator=(const FullyConnectedLayer&) = delete;
|
||||
~FullyConnectedLayer();
|
||||
@ -53,11 +54,15 @@ class FullyConnectedLayer {
|
||||
void ComputeOutput(rtc::ArrayView<const float> input);
|
||||
|
||||
private:
|
||||
// No SIMD optimizations.
|
||||
void ComputeOutput_NONE(rtc::ArrayView<const float> input);
|
||||
|
||||
const size_t input_size_;
|
||||
const size_t output_size_;
|
||||
const std::vector<float> bias_;
|
||||
const std::vector<float> weights_;
|
||||
float (*const activation_function_)(float);
|
||||
const Optimization optimization_;
|
||||
// The output vector of a recurrent layer has length equal to |output_size_|.
|
||||
// However, for efficiency, over-allocation is used.
|
||||
std::array<float, kFullyConnectedLayersMaxUnits> output_;
|
||||
@ -67,11 +72,12 @@ class FullyConnectedLayer {
|
||||
// activation functions for the update/reset and output gates respectively.
|
||||
class GatedRecurrentLayer {
|
||||
public:
|
||||
GatedRecurrentLayer(const size_t input_size,
|
||||
const size_t output_size,
|
||||
const rtc::ArrayView<const int8_t> bias,
|
||||
const rtc::ArrayView<const int8_t> weights,
|
||||
const rtc::ArrayView<const int8_t> recurrent_weights);
|
||||
GatedRecurrentLayer(size_t input_size,
|
||||
size_t output_size,
|
||||
rtc::ArrayView<const int8_t> bias,
|
||||
rtc::ArrayView<const int8_t> weights,
|
||||
rtc::ArrayView<const int8_t> recurrent_weights,
|
||||
Optimization optimization);
|
||||
GatedRecurrentLayer(const GatedRecurrentLayer&) = delete;
|
||||
GatedRecurrentLayer& operator=(const GatedRecurrentLayer&) = delete;
|
||||
~GatedRecurrentLayer();
|
||||
@ -83,6 +89,9 @@ class GatedRecurrentLayer {
|
||||
void ComputeOutput(rtc::ArrayView<const float> input);
|
||||
|
||||
private:
|
||||
// No SIMD optimizations.
|
||||
void ComputeOutput_NONE(rtc::ArrayView<const float> input);
|
||||
|
||||
const size_t input_size_;
|
||||
const size_t output_size_;
|
||||
const std::vector<float> bias_;
|
||||
@ -91,6 +100,7 @@ class GatedRecurrentLayer {
|
||||
// The state vector of a recurrent layer has length equal to |output_size_|.
|
||||
// However, to avoid dynamic allocation, over-allocation is used.
|
||||
std::array<float, kRecurrentLayersMaxUnits> state_;
|
||||
const Optimization optimization_;
|
||||
};
|
||||
|
||||
// Recurrent network based VAD.
|
||||
|
@ -14,6 +14,7 @@
|
||||
|
||||
#include "modules/audio_processing/agc2/rnn_vad/test_utils.h"
|
||||
#include "rtc_base/checks.h"
|
||||
#include "rtc_base/logging.h"
|
||||
#include "test/gtest.h"
|
||||
#include "third_party/rnnoise/src/rnn_activations.h"
|
||||
#include "third_party/rnnoise/src/rnn_vad_weights.h"
|
||||
@ -60,86 +61,104 @@ void TestGatedRecurrentLayer(
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// Checks that the output of a fully connected layer is within tolerance given
|
||||
// test input data.
|
||||
TEST(RnnVadTest, CheckFullyConnectedLayerOutput) {
|
||||
const std::array<int8_t, 1> bias = {-50};
|
||||
const std::array<int8_t, 24> weights = {
|
||||
// Fully connected layer test data.
|
||||
constexpr size_t kFullyConnectedInputSize = 24;
|
||||
constexpr size_t kFullyConnectedOutputSize = 1;
|
||||
constexpr std::array<int8_t, 1> kFullyConnectedBias = {-50};
|
||||
constexpr std::array<int8_t, 24> kFullyConnectedWeights = {
|
||||
127, 127, 127, 127, 127, 20, 127, -126, -126, -54, 14, 125,
|
||||
-126, -126, 127, -125, -126, 127, -127, -127, -57, -30, 127, 80};
|
||||
FullyConnectedLayer fc(24, 1, bias, weights, SigmoidApproximated);
|
||||
// Test on different inputs.
|
||||
{
|
||||
const std::array<float, 24> input_vector = {
|
||||
0.f, 0.f, 0.f, 0.f, 0.f,
|
||||
0.f, 0.215833917f, 0.290601075f, 0.238759011f, 0.244751841f,
|
||||
0.f, 0.0461241305f, 0.106401242f, 0.223070428f, 0.630603909f,
|
||||
0.690453172f, 0.f, 0.387645692f, 0.166913897f, 0.f,
|
||||
0.0327451192f, 0.f, 0.136149868f, 0.446351469f};
|
||||
TestFullyConnectedLayer(&fc, input_vector, 0.436567038f);
|
||||
}
|
||||
{
|
||||
const std::array<float, 24> input_vector = {
|
||||
0.592162728f, 0.529089332f, 1.18205106f,
|
||||
1.21736848f, 0.f, 0.470851123f,
|
||||
0.130675942f, 0.320903003f, 0.305496395f,
|
||||
0.0571633279f, 1.57001138f, 0.0182026215f,
|
||||
0.0977443159f, 0.347477973f, 0.493206412f,
|
||||
0.9688586f, 0.0320267938f, 0.244722098f,
|
||||
0.312745273f, 0.f, 0.00650715502f,
|
||||
0.312553257f, 1.62619662f, 0.782880902f};
|
||||
TestFullyConnectedLayer(&fc, input_vector, 0.874741316f);
|
||||
}
|
||||
{
|
||||
const std::array<float, 24> input_vector = {
|
||||
0.395022154f, 0.333681047f, 0.76302278f,
|
||||
0.965480626f, 0.f, 0.941198349f,
|
||||
0.0892967582f, 0.745046318f, 0.635769248f,
|
||||
0.238564298f, 0.970656633f, 0.014159563f,
|
||||
0.094203949f, 0.446816623f, 0.640755892f,
|
||||
1.20532358f, 0.0254284926f, 0.283327013f,
|
||||
0.726210058f, 0.0550272502f, 0.000344108557f,
|
||||
constexpr std::array<float, 24 * 3> kFullyConnectedInputVectors = {
|
||||
// Input 1.
|
||||
0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.215833917f, 0.290601075f, 0.238759011f,
|
||||
0.244751841f, 0.f, 0.0461241305f, 0.106401242f, 0.223070428f, 0.630603909f,
|
||||
0.690453172f, 0.f, 0.387645692f, 0.166913897f, 0.f, 0.0327451192f, 0.f,
|
||||
0.136149868f, 0.446351469f,
|
||||
// Input 2.
|
||||
0.592162728f, 0.529089332f, 1.18205106f, 1.21736848f, 0.f, 0.470851123f,
|
||||
0.130675942f, 0.320903003f, 0.305496395f, 0.0571633279f, 1.57001138f,
|
||||
0.0182026215f, 0.0977443159f, 0.347477973f, 0.493206412f, 0.9688586f,
|
||||
0.0320267938f, 0.244722098f, 0.312745273f, 0.f, 0.00650715502f,
|
||||
0.312553257f, 1.62619662f, 0.782880902f,
|
||||
// Input 3.
|
||||
0.395022154f, 0.333681047f, 0.76302278f, 0.965480626f, 0.f, 0.941198349f,
|
||||
0.0892967582f, 0.745046318f, 0.635769248f, 0.238564298f, 0.970656633f,
|
||||
0.014159563f, 0.094203949f, 0.446816623f, 0.640755892f, 1.20532358f,
|
||||
0.0254284926f, 0.283327013f, 0.726210058f, 0.0550272502f, 0.000344108557f,
|
||||
0.369803518f, 1.56680179f, 0.997883797f};
|
||||
TestFullyConnectedLayer(&fc, input_vector, 0.672785878f);
|
||||
}
|
||||
}
|
||||
constexpr std::array<float, 3> kFullyConnectedExpectedOutputs = {
|
||||
0.436567038f, 0.874741316f, 0.672785878f};
|
||||
|
||||
// Checks that the output of a GRU layer is within tolerance given test input
|
||||
// data.
|
||||
TEST(RnnVadTest, CheckGatedRecurrentLayer) {
|
||||
const std::array<int8_t, 12> bias = {96, -99, -81, -114, 49, 119,
|
||||
// Gated recurrent units layer test data.
|
||||
constexpr size_t kGruInputSize = 5;
|
||||
constexpr size_t kGruOutputSize = 4;
|
||||
constexpr std::array<int8_t, 12> kGruBias = {96, -99, -81, -114, 49, 119,
|
||||
-118, 68, -76, 91, 121, 125};
|
||||
const std::array<int8_t, 60> weights = {
|
||||
constexpr std::array<int8_t, 60> kGruWeights = {
|
||||
124, 9, 1, 116, -66, -21, -118, -110, 104, 75, -23, -51,
|
||||
-72, -111, 47, 93, 77, -98, 41, -8, 40, -23, -43, -107,
|
||||
9, -73, 30, -32, -2, 64, -26, 91, -48, -24, -28, -104,
|
||||
74, -46, 116, 15, 32, 52, -126, -38, -121, 12, -16, 110,
|
||||
-95, 66, -103, -35, -38, 3, -126, -61, 28, 98, -117, -43};
|
||||
const std::array<int8_t, 60> recurrent_weights = {
|
||||
constexpr std::array<int8_t, 60> kGruRecurrentWeights = {
|
||||
-3, 87, 50, 51, -22, 27, -39, 62, 31, -83, -52, -48,
|
||||
-6, 83, -19, 104, 105, 48, 23, 68, 23, 40, 7, -120,
|
||||
64, -62, 117, 85, -51, -43, 54, -105, 120, 56, -128, -107,
|
||||
39, 50, -17, -47, -117, 14, 108, 12, -7, -72, 103, -87,
|
||||
-66, 82, 84, 100, -98, 102, -49, 44, 122, 106, -20, -69};
|
||||
GatedRecurrentLayer gru(5, 4, bias, weights, recurrent_weights);
|
||||
// Test on different inputs.
|
||||
{
|
||||
const std::array<float, 20> input_sequence = {
|
||||
constexpr std::array<float, 20> kGruInputSequence = {
|
||||
0.89395463f, 0.93224651f, 0.55788344f, 0.32341808f, 0.93355054f,
|
||||
0.13475326f, 0.97370994f, 0.14253306f, 0.93710381f, 0.76093364f,
|
||||
0.65780413f, 0.41657975f, 0.49403164f, 0.46843281f, 0.75138855f,
|
||||
0.24517593f, 0.47657707f, 0.57064998f, 0.435184f, 0.19319285f};
|
||||
const std::array<float, 16> expected_output_sequence = {
|
||||
constexpr std::array<float, 16> kGruExpectedOutputSequence = {
|
||||
0.0239123f, 0.5773077f, 0.f, 0.f,
|
||||
0.01282811f, 0.64330572f, 0.f, 0.04863098f,
|
||||
0.00781069f, 0.75267816f, 0.f, 0.02579715f,
|
||||
0.00471378f, 0.59162533f, 0.11087593f, 0.01334511f};
|
||||
TestGatedRecurrentLayer(&gru, input_sequence, expected_output_sequence);
|
||||
|
||||
} // namespace
|
||||
|
||||
class OptimizationTest : public ::testing::Test,
|
||||
public ::testing::WithParamInterface<Optimization> {};
|
||||
|
||||
// Checks that the output of a fully connected layer is within tolerance given
|
||||
// test input data.
|
||||
TEST_P(OptimizationTest, CheckFullyConnectedLayerOutput) {
|
||||
const Optimization optimization = GetParam();
|
||||
RTC_LOG(LS_VERBOSE) << optimization;
|
||||
FullyConnectedLayer fc(kFullyConnectedInputSize, kFullyConnectedOutputSize,
|
||||
kFullyConnectedBias, kFullyConnectedWeights,
|
||||
SigmoidApproximated, optimization);
|
||||
// Test on different inputs.
|
||||
static_assert(
|
||||
kFullyConnectedInputVectors.size() % kFullyConnectedInputSize == 0, "");
|
||||
constexpr size_t kNumInputVectors =
|
||||
kFullyConnectedInputVectors.size() / kFullyConnectedInputSize;
|
||||
static_assert(kFullyConnectedExpectedOutputs.size() == kNumInputVectors, "");
|
||||
for (size_t i = 0; i < kNumInputVectors; ++i) {
|
||||
rtc::ArrayView<const float> input(
|
||||
kFullyConnectedInputVectors.data() + kFullyConnectedInputSize * i,
|
||||
kFullyConnectedInputSize);
|
||||
TestFullyConnectedLayer(&fc, input, kFullyConnectedExpectedOutputs[i]);
|
||||
}
|
||||
}
|
||||
|
||||
// Checks that the output of a GRU layer is within tolerance given test input
|
||||
// data.
|
||||
TEST_P(OptimizationTest, CheckGatedRecurrentLayer) {
|
||||
const Optimization optimization = GetParam();
|
||||
RTC_LOG(LS_VERBOSE) << optimization;
|
||||
GatedRecurrentLayer gru(kGruInputSize, kGruOutputSize, kGruBias, kGruWeights,
|
||||
kGruRecurrentWeights, optimization);
|
||||
TestGatedRecurrentLayer(&gru, kGruInputSequence, kGruExpectedOutputSequence);
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(RnnVadTest,
|
||||
OptimizationTest,
|
||||
::testing::Values(Optimization::kNone,
|
||||
DetectOptimization()));
|
||||
|
||||
} // namespace test
|
||||
} // namespace rnn_vad
|
||||
} // namespace webrtc
|
||||
|
Reference in New Issue
Block a user