diff --git a/modules/audio_processing/agc2/rnn_vad/BUILD.gn b/modules/audio_processing/agc2/rnn_vad/BUILD.gn index 852abd88bf..f4613b19e3 100644 --- a/modules/audio_processing/agc2/rnn_vad/BUILD.gn +++ b/modules/audio_processing/agc2/rnn_vad/BUILD.gn @@ -44,6 +44,7 @@ rtc_library("rnn_vad") { deps = [ "..:biquad_filter", "../../../../api:array_view", + "../../../../api:function_view", "../../../../rtc_base:checks", "../../../../rtc_base:rtc_base_approved", "../../../../rtc_base/system:arch", @@ -65,6 +66,8 @@ if (rtc_include_tests) { "../../../../api:array_view", "../../../../api:scoped_refptr", "../../../../rtc_base:checks", + "../../../../rtc_base/system:arch", + "../../../../system_wrappers:cpu_features_api", "../../../../test:fileutils", "../../../../test:test_support", ] @@ -113,8 +116,10 @@ if (rtc_include_tests) { "../../../../common_audio/", "../../../../rtc_base:checks", "../../../../rtc_base:logging", + "../../../../rtc_base/system:arch", "../../../../test:test_support", "../../utility:pffft_wrapper", + "//third_party/abseil-cpp/absl/memory", "//third_party/rnnoise:rnn_vad", ] data = unittest_resources diff --git a/modules/audio_processing/agc2/rnn_vad/rnn.cc b/modules/audio_processing/agc2/rnn_vad/rnn.cc index e6ef2f3a41..a5f7b4b4ab 100644 --- a/modules/audio_processing/agc2/rnn_vad/rnn.cc +++ b/modules/audio_processing/agc2/rnn_vad/rnn.cc @@ -22,6 +22,7 @@ #include #include #include +#include #include "rtc_base/checks.h" #include "third_party/rnnoise/src/rnn_activations.h" @@ -29,6 +30,7 @@ namespace webrtc { namespace rnn_vad { +namespace { using rnnoise::kWeightsScale; @@ -56,8 +58,6 @@ static_assert(kOutputLayerOutputSize <= kFullyConnectedLayersMaxUnits, using rnnoise::SigmoidApproximated; using rnnoise::TansigApproximated; -namespace { - inline float RectifiedLinearUnit(float x) { return x < 0.f ? 0.f : x; } @@ -71,6 +71,83 @@ std::vector GetScaledParams(rtc::ArrayView params) { return scaled_params; } +// Casts and scales |weights| and re-arranges the layout. +std::vector GetPreprocessedWeights(rtc::ArrayView weights, + const size_t output_size) { + if (output_size == 1) { + return GetScaledParams(weights); + } + // Transpose, scale and cast. + const size_t input_size = rtc::CheckedDivExact(weights.size(), output_size); + std::vector w(weights.size()); + for (size_t o = 0; o < output_size; ++o) { + for (size_t i = 0; i < input_size; ++i) { + w[o * input_size + i] = rnnoise::kWeightsScale * + static_cast(weights[i * output_size + o]); + } + } + return w; +} + +// Fully connected layer un-optimized implementation. +void ComputeFullyConnectedLayerOutput( + size_t input_size, + size_t output_size, + rtc::ArrayView input, + rtc::ArrayView bias, + rtc::ArrayView weights, + rtc::FunctionView activation_function, + rtc::ArrayView output) { + RTC_DCHECK_EQ(input.size(), input_size); + RTC_DCHECK_EQ(bias.size(), output_size); + RTC_DCHECK_EQ(weights.size(), input_size * output_size); + for (size_t o = 0; o < output_size; ++o) { + output[o] = bias[o]; + // TODO(bugs.chromium.org/9076): Benchmark how different layouts for + // |weights_| change the performance across different platforms. + for (size_t i = 0; i < input_size; ++i) { + output[o] += input[i] * weights[o * input_size + i]; + } + output[o] = activation_function(output[o]); + } +} + +#if defined(WEBRTC_ARCH_X86_FAMILY) +// Fully connected layer SSE2 implementation. +void ComputeFullyConnectedLayerOutputSse2( + size_t input_size, + size_t output_size, + rtc::ArrayView input, + rtc::ArrayView bias, + rtc::ArrayView weights, + rtc::FunctionView activation_function, + rtc::ArrayView output) { + RTC_DCHECK_EQ(input.size(), input_size); + RTC_DCHECK_EQ(bias.size(), output_size); + RTC_DCHECK_EQ(weights.size(), input_size * output_size); + const size_t input_size_by_4 = input_size >> 2; + const size_t offset = input_size & ~3; + __m128 sum_wx_128; + const float* v = reinterpret_cast(&sum_wx_128); + for (size_t o = 0; o < output_size; ++o) { + // Perform 128 bit vector operations. + sum_wx_128 = _mm_set1_ps(0); + const float* x_p = input.data(); + const float* w_p = weights.data() + o * input_size; + for (size_t i = 0; i < input_size_by_4; ++i, x_p += 4, w_p += 4) { + sum_wx_128 = _mm_add_ps(sum_wx_128, + _mm_mul_ps(_mm_loadu_ps(x_p), _mm_loadu_ps(w_p))); + } + // Perform non-vector operations for any remaining items, sum up bias term + // and results from the vectorized code, and apply the activation function. + output[o] = activation_function( + std::inner_product(input.begin() + offset, input.end(), + weights.begin() + o * input_size + offset, + bias[o] + v[0] + v[1] + v[2] + v[3])); + } +} +#endif + } // namespace FullyConnectedLayer::FullyConnectedLayer( @@ -78,12 +155,12 @@ FullyConnectedLayer::FullyConnectedLayer( const size_t output_size, const rtc::ArrayView bias, const rtc::ArrayView weights, - float (*const activation_function)(float), + rtc::FunctionView activation_function, Optimization optimization) : input_size_(input_size), output_size_(output_size), bias_(GetScaledParams(bias)), - weights_(GetScaledParams(weights)), + weights_(GetPreprocessedWeights(weights, output_size)), activation_function_(activation_function), optimization_(optimization) { RTC_DCHECK_LE(output_size_, kFullyConnectedLayersMaxUnits) @@ -105,31 +182,21 @@ void FullyConnectedLayer::ComputeOutput(rtc::ArrayView input) { switch (optimization_) { #if defined(WEBRTC_ARCH_X86_FAMILY) case Optimization::kSse2: - // TODO(bugs.chromium.org/10480): Handle Optimization::kSse2. - ComputeOutput_NONE(input); + ComputeFullyConnectedLayerOutputSse2(input_size_, output_size_, input, + bias_, weights_, + activation_function_, output_); break; #endif #if defined(WEBRTC_HAS_NEON) case Optimization::kNeon: // TODO(bugs.chromium.org/10480): Handle Optimization::kNeon. - ComputeOutput_NONE(input); + ComputeFullyConnectedLayerOutput(input_size_, output_size_, input, bias_, + weights_, activation_function_, output_); break; #endif default: - ComputeOutput_NONE(input); - } -} - -void FullyConnectedLayer::ComputeOutput_NONE( - rtc::ArrayView input) { - for (size_t o = 0; o < output_size_; ++o) { - output_[o] = bias_[o]; - // TODO(bugs.chromium.org/9076): Benchmark how different layouts for - // |weights_| change the performance across different platforms. - for (size_t i = 0; i < input_size_; ++i) { - output_[o] += input[i] * weights_[i * output_size_ + o]; - } - output_[o] = (*activation_function_)(output_[o]); + ComputeFullyConnectedLayerOutput(input_size_, output_size_, input, bias_, + weights_, activation_function_, output_); } } diff --git a/modules/audio_processing/agc2/rnn_vad/rnn.h b/modules/audio_processing/agc2/rnn_vad/rnn.h index f53a09379d..29ee20744b 100644 --- a/modules/audio_processing/agc2/rnn_vad/rnn.h +++ b/modules/audio_processing/agc2/rnn_vad/rnn.h @@ -18,7 +18,9 @@ #include #include "api/array_view.h" +#include "api/function_view.h" #include "modules/audio_processing/agc2/rnn_vad/common.h" +#include "rtc_base/system/arch.h" namespace webrtc { namespace rnn_vad { @@ -42,30 +44,28 @@ class FullyConnectedLayer { size_t output_size, rtc::ArrayView bias, rtc::ArrayView weights, - float (*const activation_function)(float), + rtc::FunctionView activation_function, Optimization optimization); FullyConnectedLayer(const FullyConnectedLayer&) = delete; FullyConnectedLayer& operator=(const FullyConnectedLayer&) = delete; ~FullyConnectedLayer(); size_t input_size() const { return input_size_; } size_t output_size() const { return output_size_; } + Optimization optimization() const { return optimization_; } rtc::ArrayView GetOutput() const; // Computes the fully-connected layer output. void ComputeOutput(rtc::ArrayView input); private: - // No SIMD optimizations. - void ComputeOutput_NONE(rtc::ArrayView input); - const size_t input_size_; const size_t output_size_; const std::vector bias_; const std::vector weights_; - float (*const activation_function_)(float); - const Optimization optimization_; + rtc::FunctionView activation_function_; // The output vector of a recurrent layer has length equal to |output_size_|. // However, for efficiency, over-allocation is used. std::array output_; + const Optimization optimization_; }; // Recurrent layer with gated recurrent units (GRUs) with sigmoid and ReLU as @@ -83,6 +83,7 @@ class GatedRecurrentLayer { ~GatedRecurrentLayer(); size_t input_size() const { return input_size_; } size_t output_size() const { return output_size_; } + Optimization optimization() const { return optimization_; } rtc::ArrayView GetOutput() const; void Reset(); // Computes the recurrent layer output and updates the status. diff --git a/modules/audio_processing/agc2/rnn_vad/rnn_unittest.cc b/modules/audio_processing/agc2/rnn_vad/rnn_unittest.cc index 97ede1811a..74974164a1 100644 --- a/modules/audio_processing/agc2/rnn_vad/rnn_unittest.cc +++ b/modules/audio_processing/agc2/rnn_vad/rnn_unittest.cc @@ -11,10 +11,14 @@ #include "modules/audio_processing/agc2/rnn_vad/rnn.h" #include +#include +#include #include "modules/audio_processing/agc2/rnn_vad/test_utils.h" +#include "modules/audio_processing/test/performance_timer.h" #include "rtc_base/checks.h" #include "rtc_base/logging.h" +#include "rtc_base/system/arch.h" #include "test/gtest.h" #include "third_party/rnnoise/src/rnn_activations.h" #include "third_party/rnnoise/src/rnn_vad_weights.h" @@ -23,18 +27,14 @@ namespace webrtc { namespace rnn_vad { namespace test { -using rnnoise::RectifiedLinearUnit; -using rnnoise::SigmoidApproximated; - namespace { void TestFullyConnectedLayer(FullyConnectedLayer* fc, rtc::ArrayView input_vector, - const float expected_output) { + rtc::ArrayView expected_output) { RTC_CHECK(fc); fc->ComputeOutput(input_vector); - const auto output = fc->GetOutput(); - EXPECT_NEAR(expected_output, output[0], 3e-6f); + ExpectNearAbsolute(expected_output, fc->GetOutput(), 1e-5f); } void TestGatedRecurrentLayer( @@ -62,32 +62,19 @@ void TestGatedRecurrentLayer( } // Fully connected layer test data. -constexpr size_t kFullyConnectedInputSize = 24; -constexpr size_t kFullyConnectedOutputSize = 1; -constexpr std::array kFullyConnectedBias = {-50}; -constexpr std::array kFullyConnectedWeights = { - 127, 127, 127, 127, 127, 20, 127, -126, -126, -54, 14, 125, - -126, -126, 127, -125, -126, 127, -127, -127, -57, -30, 127, 80}; -constexpr std::array kFullyConnectedInputVectors = { - // Input 1. - 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.215833917f, 0.290601075f, 0.238759011f, - 0.244751841f, 0.f, 0.0461241305f, 0.106401242f, 0.223070428f, 0.630603909f, - 0.690453172f, 0.f, 0.387645692f, 0.166913897f, 0.f, 0.0327451192f, 0.f, - 0.136149868f, 0.446351469f, - // Input 2. - 0.592162728f, 0.529089332f, 1.18205106f, 1.21736848f, 0.f, 0.470851123f, - 0.130675942f, 0.320903003f, 0.305496395f, 0.0571633279f, 1.57001138f, - 0.0182026215f, 0.0977443159f, 0.347477973f, 0.493206412f, 0.9688586f, - 0.0320267938f, 0.244722098f, 0.312745273f, 0.f, 0.00650715502f, - 0.312553257f, 1.62619662f, 0.782880902f, - // Input 3. - 0.395022154f, 0.333681047f, 0.76302278f, 0.965480626f, 0.f, 0.941198349f, - 0.0892967582f, 0.745046318f, 0.635769248f, 0.238564298f, 0.970656633f, - 0.014159563f, 0.094203949f, 0.446816623f, 0.640755892f, 1.20532358f, - 0.0254284926f, 0.283327013f, 0.726210058f, 0.0550272502f, 0.000344108557f, - 0.369803518f, 1.56680179f, 0.997883797f}; -constexpr std::array kFullyConnectedExpectedOutputs = { - 0.436567038f, 0.874741316f, 0.672785878f}; +constexpr std::array kFullyConnectedInputVector = { + -1.00131f, -0.627069f, -7.81097f, 7.86285f, -2.87145f, 3.32365f, + -0.653161f, 0.529839f, -0.425307f, 0.25583f, 0.235094f, 0.230527f, + -0.144687f, 0.182785f, 0.57102f, 0.125039f, 0.479482f, -0.0255439f, + -0.0073141f, -0.147346f, -0.217106f, -0.0846906f, -8.34943f, 3.09065f, + 1.42628f, -0.85235f, -0.220207f, -0.811163f, 2.09032f, -2.01425f, + -0.690268f, -0.925327f, -0.541354f, 0.58455f, -0.606726f, -0.0372358f, + 0.565991f, 0.435854f, 0.420812f, 0.162198f, -2.13f, 10.0089f}; +constexpr std::array kFullyConnectedExpectedOutput = { + -0.623293f, -0.988299f, 0.999378f, 0.967168f, 0.103087f, -0.978545f, + -0.856347f, 0.346675f, 1.f, -0.717442f, -0.544176f, 0.960363f, + 0.983443f, 0.999991f, -0.824335f, 0.984742f, 0.990208f, 0.938179f, + 0.875092f, 0.999846f, 0.997707f, -0.999382f, 0.973153f, -0.966605f}; // Gated recurrent units layer test data. constexpr size_t kGruInputSize = 5; @@ -117,47 +104,94 @@ constexpr std::array kGruExpectedOutputSequence = { 0.00781069f, 0.75267816f, 0.f, 0.02579715f, 0.00471378f, 0.59162533f, 0.11087593f, 0.01334511f}; -} // namespace +std::string GetOptimizationName(Optimization optimization) { + switch (optimization) { + case Optimization::kSse2: + return "SSE2"; + case Optimization::kNeon: + return "NEON"; + case Optimization::kNone: + return "none"; + } +} -class OptimizationTest : public ::testing::Test, - public ::testing::WithParamInterface {}; +} // namespace // Checks that the output of a fully connected layer is within tolerance given // test input data. -TEST_P(OptimizationTest, CheckFullyConnectedLayerOutput) { - const Optimization optimization = GetParam(); - RTC_LOG(LS_VERBOSE) << optimization; - FullyConnectedLayer fc(kFullyConnectedInputSize, kFullyConnectedOutputSize, - kFullyConnectedBias, kFullyConnectedWeights, - SigmoidApproximated, optimization); - // Test on different inputs. - static_assert( - kFullyConnectedInputVectors.size() % kFullyConnectedInputSize == 0, ""); - constexpr size_t kNumInputVectors = - kFullyConnectedInputVectors.size() / kFullyConnectedInputSize; - static_assert(kFullyConnectedExpectedOutputs.size() == kNumInputVectors, ""); - for (size_t i = 0; i < kNumInputVectors; ++i) { - rtc::ArrayView input( - kFullyConnectedInputVectors.data() + kFullyConnectedInputSize * i, - kFullyConnectedInputSize); - TestFullyConnectedLayer(&fc, input, kFullyConnectedExpectedOutputs[i]); - } +TEST(RnnVadTest, CheckFullyConnectedLayerOutput) { + FullyConnectedLayer fc(rnnoise::kInputLayerInputSize, + rnnoise::kInputLayerOutputSize, + rnnoise::kInputDenseBias, rnnoise::kInputDenseWeights, + rnnoise::TansigApproximated, Optimization::kNone); + TestFullyConnectedLayer(&fc, kFullyConnectedInputVector, + kFullyConnectedExpectedOutput); } // Checks that the output of a GRU layer is within tolerance given test input // data. -TEST_P(OptimizationTest, CheckGatedRecurrentLayer) { - const Optimization optimization = GetParam(); - RTC_LOG(LS_VERBOSE) << optimization; +TEST(RnnVadTest, CheckGatedRecurrentLayer) { GatedRecurrentLayer gru(kGruInputSize, kGruOutputSize, kGruBias, kGruWeights, - kGruRecurrentWeights, optimization); + kGruRecurrentWeights, Optimization::kNone); TestGatedRecurrentLayer(&gru, kGruInputSequence, kGruExpectedOutputSequence); } -INSTANTIATE_TEST_SUITE_P(RnnVadTest, - OptimizationTest, - ::testing::Values(Optimization::kNone, - DetectOptimization())); +#if defined(WEBRTC_ARCH_X86_FAMILY) + +// Like CheckFullyConnectedLayerOutput, but testing the SSE2 implementation. +TEST(RnnVadTest, CheckFullyConnectedLayerOutputSse2) { + if (!IsOptimizationAvailable(Optimization::kSse2)) { + return; + } + + FullyConnectedLayer fc(rnnoise::kInputLayerInputSize, + rnnoise::kInputLayerOutputSize, + rnnoise::kInputDenseBias, rnnoise::kInputDenseWeights, + rnnoise::TansigApproximated, Optimization::kSse2); + TestFullyConnectedLayer(&fc, kFullyConnectedInputVector, + kFullyConnectedExpectedOutput); +} + +#endif // WEBRTC_ARCH_X86_FAMILY + +TEST(RnnVadTest, DISABLED_BenchmarkFullyConnectedLayer) { + std::vector> implementations; + implementations.emplace_back(std::make_unique( + rnnoise::kInputLayerInputSize, rnnoise::kInputLayerOutputSize, + rnnoise::kInputDenseBias, rnnoise::kInputDenseWeights, + rnnoise::TansigApproximated, Optimization::kNone)); + if (IsOptimizationAvailable(Optimization::kSse2)) { + implementations.emplace_back(std::make_unique( + rnnoise::kInputLayerInputSize, rnnoise::kInputLayerOutputSize, + rnnoise::kInputDenseBias, rnnoise::kInputDenseWeights, + rnnoise::TansigApproximated, Optimization::kSse2)); + } + + struct Result { + Optimization optimization; + double average_us; + double std_dev_us; + }; + std::vector results; + + constexpr size_t number_of_tests = 10000; + for (auto& fc : implementations) { + ::webrtc::test::PerformanceTimer perf_timer(number_of_tests); + for (size_t k = 0; k < number_of_tests; ++k) { + perf_timer.StartTimer(); + fc->ComputeOutput(kFullyConnectedInputVector); + perf_timer.StopTimer(); + } + results.push_back({fc->optimization(), perf_timer.GetDurationAverage(), + perf_timer.GetDurationStandardDeviation()}); + } + + for (const auto& result : results) { + RTC_LOG(LS_INFO) << GetOptimizationName(result.optimization) << ": " + << (result.average_us / 1e3) << " +/- " + << (result.std_dev_us / 1e3) << " ms"; + } +} } // namespace test } // namespace rnn_vad diff --git a/modules/audio_processing/agc2/rnn_vad/test_utils.cc b/modules/audio_processing/agc2/rnn_vad/test_utils.cc index 6e0eb5b122..1a8e1a2eeb 100644 --- a/modules/audio_processing/agc2/rnn_vad/test_utils.cc +++ b/modules/audio_processing/agc2/rnn_vad/test_utils.cc @@ -13,6 +13,8 @@ #include #include "rtc_base/checks.h" +#include "rtc_base/system/arch.h" +#include "system_wrappers/include/cpu_features_wrapper.h" #include "test/gtest.h" #include "test/testsupport/file_utils.h" @@ -103,6 +105,25 @@ PitchTestData::GetPitchBufAutoCorrCoeffsView() const { kNumPitchBufAutoCorrCoeffs}; } +bool IsOptimizationAvailable(Optimization optimization) { + switch (optimization) { + case Optimization::kSse2: +#if defined(WEBRTC_ARCH_X86_FAMILY) + return WebRtc_GetCPUInfo(kSSE2) != 0; +#else + return false; +#endif + case Optimization::kNeon: +#if defined(WEBRTC_HAS_NEON) + return true; +#else + return false; +#endif + case Optimization::kNone: + return true; + } +} + } // namespace test } // namespace rnn_vad } // namespace webrtc diff --git a/modules/audio_processing/agc2/rnn_vad/test_utils.h b/modules/audio_processing/agc2/rnn_vad/test_utils.h index fbb270faf8..db155e6a75 100644 --- a/modules/audio_processing/agc2/rnn_vad/test_utils.h +++ b/modules/audio_processing/agc2/rnn_vad/test_utils.h @@ -151,6 +151,9 @@ class PitchTestData { std::array test_data_; }; +// Returns true if the given optimization is available. +bool IsOptimizationAvailable(Optimization optimization); + } // namespace test } // namespace rnn_vad } // namespace webrtc