RNN VAD: FC layer with SSE2 impl
This CL adds the SSE2 optimized implementation for fully connected (FC) layers. The change includes a weights re-alignment op done once at construction time. It is required in order to optimize the load op to fill 128 bit registers. This CL also includes unit test adaptations and a benchmark test (disabled by default). Bug: webrtc:10480 Change-Id: I5ed87f0a629faaaf4c8bffbce1cea5557518f8c8 Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/141862 Commit-Queue: Alessio Bazzica <alessiob@webrtc.org> Reviewed-by: Gustaf Ullberg <gustaf@webrtc.org> Cr-Commit-Position: refs/heads/master@{#29712}
This commit is contained in:
committed by
Commit Bot
parent
33cff37c60
commit
d58fdbedcf
@ -44,6 +44,7 @@ rtc_library("rnn_vad") {
|
|||||||
deps = [
|
deps = [
|
||||||
"..:biquad_filter",
|
"..:biquad_filter",
|
||||||
"../../../../api:array_view",
|
"../../../../api:array_view",
|
||||||
|
"../../../../api:function_view",
|
||||||
"../../../../rtc_base:checks",
|
"../../../../rtc_base:checks",
|
||||||
"../../../../rtc_base:rtc_base_approved",
|
"../../../../rtc_base:rtc_base_approved",
|
||||||
"../../../../rtc_base/system:arch",
|
"../../../../rtc_base/system:arch",
|
||||||
@ -65,6 +66,8 @@ if (rtc_include_tests) {
|
|||||||
"../../../../api:array_view",
|
"../../../../api:array_view",
|
||||||
"../../../../api:scoped_refptr",
|
"../../../../api:scoped_refptr",
|
||||||
"../../../../rtc_base:checks",
|
"../../../../rtc_base:checks",
|
||||||
|
"../../../../rtc_base/system:arch",
|
||||||
|
"../../../../system_wrappers:cpu_features_api",
|
||||||
"../../../../test:fileutils",
|
"../../../../test:fileutils",
|
||||||
"../../../../test:test_support",
|
"../../../../test:test_support",
|
||||||
]
|
]
|
||||||
@ -113,8 +116,10 @@ if (rtc_include_tests) {
|
|||||||
"../../../../common_audio/",
|
"../../../../common_audio/",
|
||||||
"../../../../rtc_base:checks",
|
"../../../../rtc_base:checks",
|
||||||
"../../../../rtc_base:logging",
|
"../../../../rtc_base:logging",
|
||||||
|
"../../../../rtc_base/system:arch",
|
||||||
"../../../../test:test_support",
|
"../../../../test:test_support",
|
||||||
"../../utility:pffft_wrapper",
|
"../../utility:pffft_wrapper",
|
||||||
|
"//third_party/abseil-cpp/absl/memory",
|
||||||
"//third_party/rnnoise:rnn_vad",
|
"//third_party/rnnoise:rnn_vad",
|
||||||
]
|
]
|
||||||
data = unittest_resources
|
data = unittest_resources
|
||||||
|
|||||||
@ -22,6 +22,7 @@
|
|||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <array>
|
#include <array>
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
|
#include <numeric>
|
||||||
|
|
||||||
#include "rtc_base/checks.h"
|
#include "rtc_base/checks.h"
|
||||||
#include "third_party/rnnoise/src/rnn_activations.h"
|
#include "third_party/rnnoise/src/rnn_activations.h"
|
||||||
@ -29,6 +30,7 @@
|
|||||||
|
|
||||||
namespace webrtc {
|
namespace webrtc {
|
||||||
namespace rnn_vad {
|
namespace rnn_vad {
|
||||||
|
namespace {
|
||||||
|
|
||||||
using rnnoise::kWeightsScale;
|
using rnnoise::kWeightsScale;
|
||||||
|
|
||||||
@ -56,8 +58,6 @@ static_assert(kOutputLayerOutputSize <= kFullyConnectedLayersMaxUnits,
|
|||||||
using rnnoise::SigmoidApproximated;
|
using rnnoise::SigmoidApproximated;
|
||||||
using rnnoise::TansigApproximated;
|
using rnnoise::TansigApproximated;
|
||||||
|
|
||||||
namespace {
|
|
||||||
|
|
||||||
inline float RectifiedLinearUnit(float x) {
|
inline float RectifiedLinearUnit(float x) {
|
||||||
return x < 0.f ? 0.f : x;
|
return x < 0.f ? 0.f : x;
|
||||||
}
|
}
|
||||||
@ -71,6 +71,83 @@ std::vector<float> GetScaledParams(rtc::ArrayView<const int8_t> params) {
|
|||||||
return scaled_params;
|
return scaled_params;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Casts and scales |weights| and re-arranges the layout.
|
||||||
|
std::vector<float> GetPreprocessedWeights(rtc::ArrayView<const int8_t> weights,
|
||||||
|
const size_t output_size) {
|
||||||
|
if (output_size == 1) {
|
||||||
|
return GetScaledParams(weights);
|
||||||
|
}
|
||||||
|
// Transpose, scale and cast.
|
||||||
|
const size_t input_size = rtc::CheckedDivExact(weights.size(), output_size);
|
||||||
|
std::vector<float> w(weights.size());
|
||||||
|
for (size_t o = 0; o < output_size; ++o) {
|
||||||
|
for (size_t i = 0; i < input_size; ++i) {
|
||||||
|
w[o * input_size + i] = rnnoise::kWeightsScale *
|
||||||
|
static_cast<float>(weights[i * output_size + o]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return w;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fully connected layer un-optimized implementation.
|
||||||
|
void ComputeFullyConnectedLayerOutput(
|
||||||
|
size_t input_size,
|
||||||
|
size_t output_size,
|
||||||
|
rtc::ArrayView<const float> input,
|
||||||
|
rtc::ArrayView<const float> bias,
|
||||||
|
rtc::ArrayView<const float> weights,
|
||||||
|
rtc::FunctionView<float(float)> activation_function,
|
||||||
|
rtc::ArrayView<float> output) {
|
||||||
|
RTC_DCHECK_EQ(input.size(), input_size);
|
||||||
|
RTC_DCHECK_EQ(bias.size(), output_size);
|
||||||
|
RTC_DCHECK_EQ(weights.size(), input_size * output_size);
|
||||||
|
for (size_t o = 0; o < output_size; ++o) {
|
||||||
|
output[o] = bias[o];
|
||||||
|
// TODO(bugs.chromium.org/9076): Benchmark how different layouts for
|
||||||
|
// |weights_| change the performance across different platforms.
|
||||||
|
for (size_t i = 0; i < input_size; ++i) {
|
||||||
|
output[o] += input[i] * weights[o * input_size + i];
|
||||||
|
}
|
||||||
|
output[o] = activation_function(output[o]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#if defined(WEBRTC_ARCH_X86_FAMILY)
|
||||||
|
// Fully connected layer SSE2 implementation.
|
||||||
|
void ComputeFullyConnectedLayerOutputSse2(
|
||||||
|
size_t input_size,
|
||||||
|
size_t output_size,
|
||||||
|
rtc::ArrayView<const float> input,
|
||||||
|
rtc::ArrayView<const float> bias,
|
||||||
|
rtc::ArrayView<const float> weights,
|
||||||
|
rtc::FunctionView<float(float)> activation_function,
|
||||||
|
rtc::ArrayView<float> output) {
|
||||||
|
RTC_DCHECK_EQ(input.size(), input_size);
|
||||||
|
RTC_DCHECK_EQ(bias.size(), output_size);
|
||||||
|
RTC_DCHECK_EQ(weights.size(), input_size * output_size);
|
||||||
|
const size_t input_size_by_4 = input_size >> 2;
|
||||||
|
const size_t offset = input_size & ~3;
|
||||||
|
__m128 sum_wx_128;
|
||||||
|
const float* v = reinterpret_cast<const float*>(&sum_wx_128);
|
||||||
|
for (size_t o = 0; o < output_size; ++o) {
|
||||||
|
// Perform 128 bit vector operations.
|
||||||
|
sum_wx_128 = _mm_set1_ps(0);
|
||||||
|
const float* x_p = input.data();
|
||||||
|
const float* w_p = weights.data() + o * input_size;
|
||||||
|
for (size_t i = 0; i < input_size_by_4; ++i, x_p += 4, w_p += 4) {
|
||||||
|
sum_wx_128 = _mm_add_ps(sum_wx_128,
|
||||||
|
_mm_mul_ps(_mm_loadu_ps(x_p), _mm_loadu_ps(w_p)));
|
||||||
|
}
|
||||||
|
// Perform non-vector operations for any remaining items, sum up bias term
|
||||||
|
// and results from the vectorized code, and apply the activation function.
|
||||||
|
output[o] = activation_function(
|
||||||
|
std::inner_product(input.begin() + offset, input.end(),
|
||||||
|
weights.begin() + o * input_size + offset,
|
||||||
|
bias[o] + v[0] + v[1] + v[2] + v[3]));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
FullyConnectedLayer::FullyConnectedLayer(
|
FullyConnectedLayer::FullyConnectedLayer(
|
||||||
@ -78,12 +155,12 @@ FullyConnectedLayer::FullyConnectedLayer(
|
|||||||
const size_t output_size,
|
const size_t output_size,
|
||||||
const rtc::ArrayView<const int8_t> bias,
|
const rtc::ArrayView<const int8_t> bias,
|
||||||
const rtc::ArrayView<const int8_t> weights,
|
const rtc::ArrayView<const int8_t> weights,
|
||||||
float (*const activation_function)(float),
|
rtc::FunctionView<float(float)> activation_function,
|
||||||
Optimization optimization)
|
Optimization optimization)
|
||||||
: input_size_(input_size),
|
: input_size_(input_size),
|
||||||
output_size_(output_size),
|
output_size_(output_size),
|
||||||
bias_(GetScaledParams(bias)),
|
bias_(GetScaledParams(bias)),
|
||||||
weights_(GetScaledParams(weights)),
|
weights_(GetPreprocessedWeights(weights, output_size)),
|
||||||
activation_function_(activation_function),
|
activation_function_(activation_function),
|
||||||
optimization_(optimization) {
|
optimization_(optimization) {
|
||||||
RTC_DCHECK_LE(output_size_, kFullyConnectedLayersMaxUnits)
|
RTC_DCHECK_LE(output_size_, kFullyConnectedLayersMaxUnits)
|
||||||
@ -105,31 +182,21 @@ void FullyConnectedLayer::ComputeOutput(rtc::ArrayView<const float> input) {
|
|||||||
switch (optimization_) {
|
switch (optimization_) {
|
||||||
#if defined(WEBRTC_ARCH_X86_FAMILY)
|
#if defined(WEBRTC_ARCH_X86_FAMILY)
|
||||||
case Optimization::kSse2:
|
case Optimization::kSse2:
|
||||||
// TODO(bugs.chromium.org/10480): Handle Optimization::kSse2.
|
ComputeFullyConnectedLayerOutputSse2(input_size_, output_size_, input,
|
||||||
ComputeOutput_NONE(input);
|
bias_, weights_,
|
||||||
|
activation_function_, output_);
|
||||||
break;
|
break;
|
||||||
#endif
|
#endif
|
||||||
#if defined(WEBRTC_HAS_NEON)
|
#if defined(WEBRTC_HAS_NEON)
|
||||||
case Optimization::kNeon:
|
case Optimization::kNeon:
|
||||||
// TODO(bugs.chromium.org/10480): Handle Optimization::kNeon.
|
// TODO(bugs.chromium.org/10480): Handle Optimization::kNeon.
|
||||||
ComputeOutput_NONE(input);
|
ComputeFullyConnectedLayerOutput(input_size_, output_size_, input, bias_,
|
||||||
|
weights_, activation_function_, output_);
|
||||||
break;
|
break;
|
||||||
#endif
|
#endif
|
||||||
default:
|
default:
|
||||||
ComputeOutput_NONE(input);
|
ComputeFullyConnectedLayerOutput(input_size_, output_size_, input, bias_,
|
||||||
}
|
weights_, activation_function_, output_);
|
||||||
}
|
|
||||||
|
|
||||||
void FullyConnectedLayer::ComputeOutput_NONE(
|
|
||||||
rtc::ArrayView<const float> input) {
|
|
||||||
for (size_t o = 0; o < output_size_; ++o) {
|
|
||||||
output_[o] = bias_[o];
|
|
||||||
// TODO(bugs.chromium.org/9076): Benchmark how different layouts for
|
|
||||||
// |weights_| change the performance across different platforms.
|
|
||||||
for (size_t i = 0; i < input_size_; ++i) {
|
|
||||||
output_[o] += input[i] * weights_[i * output_size_ + o];
|
|
||||||
}
|
|
||||||
output_[o] = (*activation_function_)(output_[o]);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -18,7 +18,9 @@
|
|||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "api/array_view.h"
|
#include "api/array_view.h"
|
||||||
|
#include "api/function_view.h"
|
||||||
#include "modules/audio_processing/agc2/rnn_vad/common.h"
|
#include "modules/audio_processing/agc2/rnn_vad/common.h"
|
||||||
|
#include "rtc_base/system/arch.h"
|
||||||
|
|
||||||
namespace webrtc {
|
namespace webrtc {
|
||||||
namespace rnn_vad {
|
namespace rnn_vad {
|
||||||
@ -42,30 +44,28 @@ class FullyConnectedLayer {
|
|||||||
size_t output_size,
|
size_t output_size,
|
||||||
rtc::ArrayView<const int8_t> bias,
|
rtc::ArrayView<const int8_t> bias,
|
||||||
rtc::ArrayView<const int8_t> weights,
|
rtc::ArrayView<const int8_t> weights,
|
||||||
float (*const activation_function)(float),
|
rtc::FunctionView<float(float)> activation_function,
|
||||||
Optimization optimization);
|
Optimization optimization);
|
||||||
FullyConnectedLayer(const FullyConnectedLayer&) = delete;
|
FullyConnectedLayer(const FullyConnectedLayer&) = delete;
|
||||||
FullyConnectedLayer& operator=(const FullyConnectedLayer&) = delete;
|
FullyConnectedLayer& operator=(const FullyConnectedLayer&) = delete;
|
||||||
~FullyConnectedLayer();
|
~FullyConnectedLayer();
|
||||||
size_t input_size() const { return input_size_; }
|
size_t input_size() const { return input_size_; }
|
||||||
size_t output_size() const { return output_size_; }
|
size_t output_size() const { return output_size_; }
|
||||||
|
Optimization optimization() const { return optimization_; }
|
||||||
rtc::ArrayView<const float> GetOutput() const;
|
rtc::ArrayView<const float> GetOutput() const;
|
||||||
// Computes the fully-connected layer output.
|
// Computes the fully-connected layer output.
|
||||||
void ComputeOutput(rtc::ArrayView<const float> input);
|
void ComputeOutput(rtc::ArrayView<const float> input);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// No SIMD optimizations.
|
|
||||||
void ComputeOutput_NONE(rtc::ArrayView<const float> input);
|
|
||||||
|
|
||||||
const size_t input_size_;
|
const size_t input_size_;
|
||||||
const size_t output_size_;
|
const size_t output_size_;
|
||||||
const std::vector<float> bias_;
|
const std::vector<float> bias_;
|
||||||
const std::vector<float> weights_;
|
const std::vector<float> weights_;
|
||||||
float (*const activation_function_)(float);
|
rtc::FunctionView<float(float)> activation_function_;
|
||||||
const Optimization optimization_;
|
|
||||||
// The output vector of a recurrent layer has length equal to |output_size_|.
|
// The output vector of a recurrent layer has length equal to |output_size_|.
|
||||||
// However, for efficiency, over-allocation is used.
|
// However, for efficiency, over-allocation is used.
|
||||||
std::array<float, kFullyConnectedLayersMaxUnits> output_;
|
std::array<float, kFullyConnectedLayersMaxUnits> output_;
|
||||||
|
const Optimization optimization_;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Recurrent layer with gated recurrent units (GRUs) with sigmoid and ReLU as
|
// Recurrent layer with gated recurrent units (GRUs) with sigmoid and ReLU as
|
||||||
@ -83,6 +83,7 @@ class GatedRecurrentLayer {
|
|||||||
~GatedRecurrentLayer();
|
~GatedRecurrentLayer();
|
||||||
size_t input_size() const { return input_size_; }
|
size_t input_size() const { return input_size_; }
|
||||||
size_t output_size() const { return output_size_; }
|
size_t output_size() const { return output_size_; }
|
||||||
|
Optimization optimization() const { return optimization_; }
|
||||||
rtc::ArrayView<const float> GetOutput() const;
|
rtc::ArrayView<const float> GetOutput() const;
|
||||||
void Reset();
|
void Reset();
|
||||||
// Computes the recurrent layer output and updates the status.
|
// Computes the recurrent layer output and updates the status.
|
||||||
|
|||||||
@ -11,10 +11,14 @@
|
|||||||
#include "modules/audio_processing/agc2/rnn_vad/rnn.h"
|
#include "modules/audio_processing/agc2/rnn_vad/rnn.h"
|
||||||
|
|
||||||
#include <array>
|
#include <array>
|
||||||
|
#include <memory>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
#include "modules/audio_processing/agc2/rnn_vad/test_utils.h"
|
#include "modules/audio_processing/agc2/rnn_vad/test_utils.h"
|
||||||
|
#include "modules/audio_processing/test/performance_timer.h"
|
||||||
#include "rtc_base/checks.h"
|
#include "rtc_base/checks.h"
|
||||||
#include "rtc_base/logging.h"
|
#include "rtc_base/logging.h"
|
||||||
|
#include "rtc_base/system/arch.h"
|
||||||
#include "test/gtest.h"
|
#include "test/gtest.h"
|
||||||
#include "third_party/rnnoise/src/rnn_activations.h"
|
#include "third_party/rnnoise/src/rnn_activations.h"
|
||||||
#include "third_party/rnnoise/src/rnn_vad_weights.h"
|
#include "third_party/rnnoise/src/rnn_vad_weights.h"
|
||||||
@ -23,18 +27,14 @@ namespace webrtc {
|
|||||||
namespace rnn_vad {
|
namespace rnn_vad {
|
||||||
namespace test {
|
namespace test {
|
||||||
|
|
||||||
using rnnoise::RectifiedLinearUnit;
|
|
||||||
using rnnoise::SigmoidApproximated;
|
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
void TestFullyConnectedLayer(FullyConnectedLayer* fc,
|
void TestFullyConnectedLayer(FullyConnectedLayer* fc,
|
||||||
rtc::ArrayView<const float> input_vector,
|
rtc::ArrayView<const float> input_vector,
|
||||||
const float expected_output) {
|
rtc::ArrayView<const float> expected_output) {
|
||||||
RTC_CHECK(fc);
|
RTC_CHECK(fc);
|
||||||
fc->ComputeOutput(input_vector);
|
fc->ComputeOutput(input_vector);
|
||||||
const auto output = fc->GetOutput();
|
ExpectNearAbsolute(expected_output, fc->GetOutput(), 1e-5f);
|
||||||
EXPECT_NEAR(expected_output, output[0], 3e-6f);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void TestGatedRecurrentLayer(
|
void TestGatedRecurrentLayer(
|
||||||
@ -62,32 +62,19 @@ void TestGatedRecurrentLayer(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Fully connected layer test data.
|
// Fully connected layer test data.
|
||||||
constexpr size_t kFullyConnectedInputSize = 24;
|
constexpr std::array<float, 42> kFullyConnectedInputVector = {
|
||||||
constexpr size_t kFullyConnectedOutputSize = 1;
|
-1.00131f, -0.627069f, -7.81097f, 7.86285f, -2.87145f, 3.32365f,
|
||||||
constexpr std::array<int8_t, 1> kFullyConnectedBias = {-50};
|
-0.653161f, 0.529839f, -0.425307f, 0.25583f, 0.235094f, 0.230527f,
|
||||||
constexpr std::array<int8_t, 24> kFullyConnectedWeights = {
|
-0.144687f, 0.182785f, 0.57102f, 0.125039f, 0.479482f, -0.0255439f,
|
||||||
127, 127, 127, 127, 127, 20, 127, -126, -126, -54, 14, 125,
|
-0.0073141f, -0.147346f, -0.217106f, -0.0846906f, -8.34943f, 3.09065f,
|
||||||
-126, -126, 127, -125, -126, 127, -127, -127, -57, -30, 127, 80};
|
1.42628f, -0.85235f, -0.220207f, -0.811163f, 2.09032f, -2.01425f,
|
||||||
constexpr std::array<float, 24 * 3> kFullyConnectedInputVectors = {
|
-0.690268f, -0.925327f, -0.541354f, 0.58455f, -0.606726f, -0.0372358f,
|
||||||
// Input 1.
|
0.565991f, 0.435854f, 0.420812f, 0.162198f, -2.13f, 10.0089f};
|
||||||
0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.215833917f, 0.290601075f, 0.238759011f,
|
constexpr std::array<float, 24> kFullyConnectedExpectedOutput = {
|
||||||
0.244751841f, 0.f, 0.0461241305f, 0.106401242f, 0.223070428f, 0.630603909f,
|
-0.623293f, -0.988299f, 0.999378f, 0.967168f, 0.103087f, -0.978545f,
|
||||||
0.690453172f, 0.f, 0.387645692f, 0.166913897f, 0.f, 0.0327451192f, 0.f,
|
-0.856347f, 0.346675f, 1.f, -0.717442f, -0.544176f, 0.960363f,
|
||||||
0.136149868f, 0.446351469f,
|
0.983443f, 0.999991f, -0.824335f, 0.984742f, 0.990208f, 0.938179f,
|
||||||
// Input 2.
|
0.875092f, 0.999846f, 0.997707f, -0.999382f, 0.973153f, -0.966605f};
|
||||||
0.592162728f, 0.529089332f, 1.18205106f, 1.21736848f, 0.f, 0.470851123f,
|
|
||||||
0.130675942f, 0.320903003f, 0.305496395f, 0.0571633279f, 1.57001138f,
|
|
||||||
0.0182026215f, 0.0977443159f, 0.347477973f, 0.493206412f, 0.9688586f,
|
|
||||||
0.0320267938f, 0.244722098f, 0.312745273f, 0.f, 0.00650715502f,
|
|
||||||
0.312553257f, 1.62619662f, 0.782880902f,
|
|
||||||
// Input 3.
|
|
||||||
0.395022154f, 0.333681047f, 0.76302278f, 0.965480626f, 0.f, 0.941198349f,
|
|
||||||
0.0892967582f, 0.745046318f, 0.635769248f, 0.238564298f, 0.970656633f,
|
|
||||||
0.014159563f, 0.094203949f, 0.446816623f, 0.640755892f, 1.20532358f,
|
|
||||||
0.0254284926f, 0.283327013f, 0.726210058f, 0.0550272502f, 0.000344108557f,
|
|
||||||
0.369803518f, 1.56680179f, 0.997883797f};
|
|
||||||
constexpr std::array<float, 3> kFullyConnectedExpectedOutputs = {
|
|
||||||
0.436567038f, 0.874741316f, 0.672785878f};
|
|
||||||
|
|
||||||
// Gated recurrent units layer test data.
|
// Gated recurrent units layer test data.
|
||||||
constexpr size_t kGruInputSize = 5;
|
constexpr size_t kGruInputSize = 5;
|
||||||
@ -117,47 +104,94 @@ constexpr std::array<float, 16> kGruExpectedOutputSequence = {
|
|||||||
0.00781069f, 0.75267816f, 0.f, 0.02579715f,
|
0.00781069f, 0.75267816f, 0.f, 0.02579715f,
|
||||||
0.00471378f, 0.59162533f, 0.11087593f, 0.01334511f};
|
0.00471378f, 0.59162533f, 0.11087593f, 0.01334511f};
|
||||||
|
|
||||||
} // namespace
|
std::string GetOptimizationName(Optimization optimization) {
|
||||||
|
switch (optimization) {
|
||||||
|
case Optimization::kSse2:
|
||||||
|
return "SSE2";
|
||||||
|
case Optimization::kNeon:
|
||||||
|
return "NEON";
|
||||||
|
case Optimization::kNone:
|
||||||
|
return "none";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
class OptimizationTest : public ::testing::Test,
|
} // namespace
|
||||||
public ::testing::WithParamInterface<Optimization> {};
|
|
||||||
|
|
||||||
// Checks that the output of a fully connected layer is within tolerance given
|
// Checks that the output of a fully connected layer is within tolerance given
|
||||||
// test input data.
|
// test input data.
|
||||||
TEST_P(OptimizationTest, CheckFullyConnectedLayerOutput) {
|
TEST(RnnVadTest, CheckFullyConnectedLayerOutput) {
|
||||||
const Optimization optimization = GetParam();
|
FullyConnectedLayer fc(rnnoise::kInputLayerInputSize,
|
||||||
RTC_LOG(LS_VERBOSE) << optimization;
|
rnnoise::kInputLayerOutputSize,
|
||||||
FullyConnectedLayer fc(kFullyConnectedInputSize, kFullyConnectedOutputSize,
|
rnnoise::kInputDenseBias, rnnoise::kInputDenseWeights,
|
||||||
kFullyConnectedBias, kFullyConnectedWeights,
|
rnnoise::TansigApproximated, Optimization::kNone);
|
||||||
SigmoidApproximated, optimization);
|
TestFullyConnectedLayer(&fc, kFullyConnectedInputVector,
|
||||||
// Test on different inputs.
|
kFullyConnectedExpectedOutput);
|
||||||
static_assert(
|
|
||||||
kFullyConnectedInputVectors.size() % kFullyConnectedInputSize == 0, "");
|
|
||||||
constexpr size_t kNumInputVectors =
|
|
||||||
kFullyConnectedInputVectors.size() / kFullyConnectedInputSize;
|
|
||||||
static_assert(kFullyConnectedExpectedOutputs.size() == kNumInputVectors, "");
|
|
||||||
for (size_t i = 0; i < kNumInputVectors; ++i) {
|
|
||||||
rtc::ArrayView<const float> input(
|
|
||||||
kFullyConnectedInputVectors.data() + kFullyConnectedInputSize * i,
|
|
||||||
kFullyConnectedInputSize);
|
|
||||||
TestFullyConnectedLayer(&fc, input, kFullyConnectedExpectedOutputs[i]);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Checks that the output of a GRU layer is within tolerance given test input
|
// Checks that the output of a GRU layer is within tolerance given test input
|
||||||
// data.
|
// data.
|
||||||
TEST_P(OptimizationTest, CheckGatedRecurrentLayer) {
|
TEST(RnnVadTest, CheckGatedRecurrentLayer) {
|
||||||
const Optimization optimization = GetParam();
|
|
||||||
RTC_LOG(LS_VERBOSE) << optimization;
|
|
||||||
GatedRecurrentLayer gru(kGruInputSize, kGruOutputSize, kGruBias, kGruWeights,
|
GatedRecurrentLayer gru(kGruInputSize, kGruOutputSize, kGruBias, kGruWeights,
|
||||||
kGruRecurrentWeights, optimization);
|
kGruRecurrentWeights, Optimization::kNone);
|
||||||
TestGatedRecurrentLayer(&gru, kGruInputSequence, kGruExpectedOutputSequence);
|
TestGatedRecurrentLayer(&gru, kGruInputSequence, kGruExpectedOutputSequence);
|
||||||
}
|
}
|
||||||
|
|
||||||
INSTANTIATE_TEST_SUITE_P(RnnVadTest,
|
#if defined(WEBRTC_ARCH_X86_FAMILY)
|
||||||
OptimizationTest,
|
|
||||||
::testing::Values(Optimization::kNone,
|
// Like CheckFullyConnectedLayerOutput, but testing the SSE2 implementation.
|
||||||
DetectOptimization()));
|
TEST(RnnVadTest, CheckFullyConnectedLayerOutputSse2) {
|
||||||
|
if (!IsOptimizationAvailable(Optimization::kSse2)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
FullyConnectedLayer fc(rnnoise::kInputLayerInputSize,
|
||||||
|
rnnoise::kInputLayerOutputSize,
|
||||||
|
rnnoise::kInputDenseBias, rnnoise::kInputDenseWeights,
|
||||||
|
rnnoise::TansigApproximated, Optimization::kSse2);
|
||||||
|
TestFullyConnectedLayer(&fc, kFullyConnectedInputVector,
|
||||||
|
kFullyConnectedExpectedOutput);
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif // WEBRTC_ARCH_X86_FAMILY
|
||||||
|
|
||||||
|
TEST(RnnVadTest, DISABLED_BenchmarkFullyConnectedLayer) {
|
||||||
|
std::vector<std::unique_ptr<FullyConnectedLayer>> implementations;
|
||||||
|
implementations.emplace_back(std::make_unique<FullyConnectedLayer>(
|
||||||
|
rnnoise::kInputLayerInputSize, rnnoise::kInputLayerOutputSize,
|
||||||
|
rnnoise::kInputDenseBias, rnnoise::kInputDenseWeights,
|
||||||
|
rnnoise::TansigApproximated, Optimization::kNone));
|
||||||
|
if (IsOptimizationAvailable(Optimization::kSse2)) {
|
||||||
|
implementations.emplace_back(std::make_unique<FullyConnectedLayer>(
|
||||||
|
rnnoise::kInputLayerInputSize, rnnoise::kInputLayerOutputSize,
|
||||||
|
rnnoise::kInputDenseBias, rnnoise::kInputDenseWeights,
|
||||||
|
rnnoise::TansigApproximated, Optimization::kSse2));
|
||||||
|
}
|
||||||
|
|
||||||
|
struct Result {
|
||||||
|
Optimization optimization;
|
||||||
|
double average_us;
|
||||||
|
double std_dev_us;
|
||||||
|
};
|
||||||
|
std::vector<Result> results;
|
||||||
|
|
||||||
|
constexpr size_t number_of_tests = 10000;
|
||||||
|
for (auto& fc : implementations) {
|
||||||
|
::webrtc::test::PerformanceTimer perf_timer(number_of_tests);
|
||||||
|
for (size_t k = 0; k < number_of_tests; ++k) {
|
||||||
|
perf_timer.StartTimer();
|
||||||
|
fc->ComputeOutput(kFullyConnectedInputVector);
|
||||||
|
perf_timer.StopTimer();
|
||||||
|
}
|
||||||
|
results.push_back({fc->optimization(), perf_timer.GetDurationAverage(),
|
||||||
|
perf_timer.GetDurationStandardDeviation()});
|
||||||
|
}
|
||||||
|
|
||||||
|
for (const auto& result : results) {
|
||||||
|
RTC_LOG(LS_INFO) << GetOptimizationName(result.optimization) << ": "
|
||||||
|
<< (result.average_us / 1e3) << " +/- "
|
||||||
|
<< (result.std_dev_us / 1e3) << " ms";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace test
|
} // namespace test
|
||||||
} // namespace rnn_vad
|
} // namespace rnn_vad
|
||||||
|
|||||||
@ -13,6 +13,8 @@
|
|||||||
#include <memory>
|
#include <memory>
|
||||||
|
|
||||||
#include "rtc_base/checks.h"
|
#include "rtc_base/checks.h"
|
||||||
|
#include "rtc_base/system/arch.h"
|
||||||
|
#include "system_wrappers/include/cpu_features_wrapper.h"
|
||||||
#include "test/gtest.h"
|
#include "test/gtest.h"
|
||||||
#include "test/testsupport/file_utils.h"
|
#include "test/testsupport/file_utils.h"
|
||||||
|
|
||||||
@ -103,6 +105,25 @@ PitchTestData::GetPitchBufAutoCorrCoeffsView() const {
|
|||||||
kNumPitchBufAutoCorrCoeffs};
|
kNumPitchBufAutoCorrCoeffs};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool IsOptimizationAvailable(Optimization optimization) {
|
||||||
|
switch (optimization) {
|
||||||
|
case Optimization::kSse2:
|
||||||
|
#if defined(WEBRTC_ARCH_X86_FAMILY)
|
||||||
|
return WebRtc_GetCPUInfo(kSSE2) != 0;
|
||||||
|
#else
|
||||||
|
return false;
|
||||||
|
#endif
|
||||||
|
case Optimization::kNeon:
|
||||||
|
#if defined(WEBRTC_HAS_NEON)
|
||||||
|
return true;
|
||||||
|
#else
|
||||||
|
return false;
|
||||||
|
#endif
|
||||||
|
case Optimization::kNone:
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace test
|
} // namespace test
|
||||||
} // namespace rnn_vad
|
} // namespace rnn_vad
|
||||||
} // namespace webrtc
|
} // namespace webrtc
|
||||||
|
|||||||
@ -151,6 +151,9 @@ class PitchTestData {
|
|||||||
std::array<float, kPitchTestDataSize> test_data_;
|
std::array<float, kPitchTestDataSize> test_data_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Returns true if the given optimization is available.
|
||||||
|
bool IsOptimizationAvailable(Optimization optimization);
|
||||||
|
|
||||||
} // namespace test
|
} // namespace test
|
||||||
} // namespace rnn_vad
|
} // namespace rnn_vad
|
||||||
} // namespace webrtc
|
} // namespace webrtc
|
||||||
|
|||||||
Reference in New Issue
Block a user