diff --git a/modules/audio_processing/agc2/rnn_vad/rnn.cc b/modules/audio_processing/agc2/rnn_vad/rnn.cc index a5f7b4b4ab..1cd8ae7dbc 100644 --- a/modules/audio_processing/agc2/rnn_vad/rnn.cc +++ b/modules/audio_processing/agc2/rnn_vad/rnn.cc @@ -25,6 +25,7 @@ #include #include "rtc_base/checks.h" +#include "rtc_base/logging.h" #include "third_party/rnnoise/src/rnn_activations.h" #include "third_party/rnnoise/src/rnn_vad_weights.h" @@ -71,9 +72,12 @@ std::vector GetScaledParams(rtc::ArrayView params) { return scaled_params; } +// TODO(bugs.chromium.org/10480): Hard-code optimized layout and remove this +// function to improve setup time. // Casts and scales |weights| and re-arranges the layout. -std::vector GetPreprocessedWeights(rtc::ArrayView weights, - const size_t output_size) { +std::vector GetPreprocessedFcWeights( + rtc::ArrayView weights, + size_t output_size) { if (output_size == 1) { return GetScaledParams(weights); } @@ -89,6 +93,117 @@ std::vector GetPreprocessedWeights(rtc::ArrayView weights, return w; } +constexpr size_t kNumGruGates = 3; // Update, reset, output. + +// TODO(bugs.chromium.org/10480): Hard-coded optimized layout and remove this +// function to improve setup time. +// Casts and scales |tensor_src| for a GRU layer and re-arranges the layout. +// It works both for weights, recurrent weights and bias. +std::vector GetPreprocessedGruTensor( + rtc::ArrayView tensor_src, + size_t output_size) { + // Transpose, cast and scale. + // |n| is the size of the first dimension of the 3-dim tensor |weights|. + const size_t n = + rtc::CheckedDivExact(tensor_src.size(), output_size * kNumGruGates); + const size_t stride_src = kNumGruGates * output_size; + const size_t stride_dst = n * output_size; + std::vector tensor_dst(tensor_src.size()); + for (size_t g = 0; g < kNumGruGates; ++g) { + for (size_t o = 0; o < output_size; ++o) { + for (size_t i = 0; i < n; ++i) { + tensor_dst[g * stride_dst + o * n + i] = + rnnoise::kWeightsScale * + static_cast( + tensor_src[i * stride_src + g * output_size + o]); + } + } + } + return tensor_dst; +} + +void ComputeGruUpdateResetGates(size_t input_size, + size_t output_size, + rtc::ArrayView weights, + rtc::ArrayView recurrent_weights, + rtc::ArrayView bias, + rtc::ArrayView input, + rtc::ArrayView state, + rtc::ArrayView gate) { + for (size_t o = 0; o < output_size; ++o) { + gate[o] = bias[o]; + for (size_t i = 0; i < input_size; ++i) { + gate[o] += input[i] * weights[o * input_size + i]; + } + for (size_t s = 0; s < output_size; ++s) { + gate[o] += state[s] * recurrent_weights[o * output_size + s]; + } + gate[o] = SigmoidApproximated(gate[o]); + } +} + +void ComputeGruOutputGate(size_t input_size, + size_t output_size, + rtc::ArrayView weights, + rtc::ArrayView recurrent_weights, + rtc::ArrayView bias, + rtc::ArrayView input, + rtc::ArrayView state, + rtc::ArrayView reset, + rtc::ArrayView gate) { + for (size_t o = 0; o < output_size; ++o) { + gate[o] = bias[o]; + for (size_t i = 0; i < input_size; ++i) { + gate[o] += input[i] * weights[o * input_size + i]; + } + for (size_t s = 0; s < output_size; ++s) { + gate[o] += state[s] * recurrent_weights[o * output_size + s] * reset[s]; + } + gate[o] = RectifiedLinearUnit(gate[o]); + } +} + +// Gated recurrent unit (GRU) layer un-optimized implementation. +void ComputeGruLayerOutput(size_t input_size, + size_t output_size, + rtc::ArrayView input, + rtc::ArrayView weights, + rtc::ArrayView recurrent_weights, + rtc::ArrayView bias, + rtc::ArrayView state) { + RTC_DCHECK_EQ(input_size, input.size()); + // Stride and offset used to read parameter arrays. + const size_t stride_in = input_size * output_size; + const size_t stride_out = output_size * output_size; + + // Update gate. + std::array update; + ComputeGruUpdateResetGates( + input_size, output_size, weights.subview(0, stride_in), + recurrent_weights.subview(0, stride_out), bias.subview(0, output_size), + input, state, update); + + // Reset gate. + std::array reset; + ComputeGruUpdateResetGates( + input_size, output_size, weights.subview(stride_in, stride_in), + recurrent_weights.subview(stride_out, stride_out), + bias.subview(output_size, output_size), input, state, reset); + + // Output gate. + std::array output; + ComputeGruOutputGate( + input_size, output_size, weights.subview(2 * stride_in, stride_in), + recurrent_weights.subview(2 * stride_out, stride_out), + bias.subview(2 * output_size, output_size), input, state, reset, output); + + // Update output through the update gates and update the state. + for (size_t o = 0; o < output_size; ++o) { + output[o] = update[o] * state[o] + (1.f - update[o]) * output[o]; + state[o] = output[o]; + } +} + // Fully connected layer un-optimized implementation. void ComputeFullyConnectedLayerOutput( size_t input_size, @@ -160,7 +275,7 @@ FullyConnectedLayer::FullyConnectedLayer( : input_size_(input_size), output_size_(output_size), bias_(GetScaledParams(bias)), - weights_(GetPreprocessedWeights(weights, output_size)), + weights_(GetPreprocessedFcWeights(weights, output_size)), activation_function_(activation_function), optimization_(optimization) { RTC_DCHECK_LE(output_size_, kFullyConnectedLayersMaxUnits) @@ -209,18 +324,20 @@ GatedRecurrentLayer::GatedRecurrentLayer( Optimization optimization) : input_size_(input_size), output_size_(output_size), - bias_(GetScaledParams(bias)), - weights_(GetScaledParams(weights)), - recurrent_weights_(GetScaledParams(recurrent_weights)), + bias_(GetPreprocessedGruTensor(bias, output_size)), + weights_(GetPreprocessedGruTensor(weights, output_size)), + recurrent_weights_( + GetPreprocessedGruTensor(recurrent_weights, output_size)), optimization_(optimization) { RTC_DCHECK_LE(output_size_, kRecurrentLayersMaxUnits) << "Static over-allocation of recurrent layers state vectors is not " << "sufficient."; - RTC_DCHECK_EQ(3 * output_size_, bias_.size()) + RTC_DCHECK_EQ(kNumGruGates * output_size_, bias_.size()) << "Mismatching output size and bias terms array size."; - RTC_DCHECK_EQ(3 * input_size_ * output_size_, weights_.size()) + RTC_DCHECK_EQ(kNumGruGates * input_size_ * output_size_, weights_.size()) << "Mismatching input-output size and weight coefficients array size."; - RTC_DCHECK_EQ(3 * input_size_ * output_size_, recurrent_weights_.size()) + RTC_DCHECK_EQ(kNumGruGates * output_size_ * output_size_, + recurrent_weights_.size()) << "Mismatching input-output size and recurrent weight coefficients array" << " size."; Reset(); @@ -241,81 +358,23 @@ void GatedRecurrentLayer::ComputeOutput(rtc::ArrayView input) { #if defined(WEBRTC_ARCH_X86_FAMILY) case Optimization::kSse2: // TODO(bugs.chromium.org/10480): Handle Optimization::kSse2. - ComputeOutput_NONE(input); + ComputeGruLayerOutput(input_size_, output_size_, input, weights_, + recurrent_weights_, bias_, state_); break; #endif #if defined(WEBRTC_HAS_NEON) case Optimization::kNeon: // TODO(bugs.chromium.org/10480): Handle Optimization::kNeon. - ComputeOutput_NONE(input); + ComputeGruLayerOutput(input_size_, output_size_, input, weights_, + recurrent_weights_, bias_, state_); break; #endif default: - ComputeOutput_NONE(input); + ComputeGruLayerOutput(input_size_, output_size_, input, weights_, + recurrent_weights_, bias_, state_); } } -void GatedRecurrentLayer::ComputeOutput_NONE( - rtc::ArrayView input) { - // TODO(bugs.chromium.org/9076): Optimize using SSE/AVX fused multiply-add - // operations. - // Stride and offset used to read parameter arrays. - const size_t stride = 3 * output_size_; - size_t offset = 0; - - // Compute update gates. - std::array update; - for (size_t o = 0; o < output_size_; ++o) { - update[o] = bias_[o]; - // TODO(bugs.chromium.org/9076): Benchmark how different layouts for - // |weights_| and |recurrent_weights_| change the performance across - // different platforms. - for (size_t i = 0; i < input_size_; ++i) { // Add input. - update[o] += input[i] * weights_[i * stride + o]; - } - for (size_t s = 0; s < output_size_; ++s) { - update[o] += state_[s] * recurrent_weights_[s * stride + o]; - } // Add state. - update[o] = SigmoidApproximated(update[o]); - } - - // Compute reset gates. - offset += output_size_; - std::array reset; - for (size_t o = 0; o < output_size_; ++o) { - reset[o] = bias_[offset + o]; - for (size_t i = 0; i < input_size_; ++i) { // Add input. - reset[o] += input[i] * weights_[offset + i * stride + o]; - } - for (size_t s = 0; s < output_size_; ++s) { // Add state. - reset[o] += state_[s] * recurrent_weights_[offset + s * stride + o]; - } - reset[o] = SigmoidApproximated(reset[o]); - } - - // Compute output. - offset += output_size_; - std::array output; - for (size_t o = 0; o < output_size_; ++o) { - output[o] = bias_[offset + o]; - for (size_t i = 0; i < input_size_; ++i) { // Add input. - output[o] += input[i] * weights_[offset + i * stride + o]; - } - for (size_t s = 0; s < output_size_; - ++s) { // Add state through reset gates. - output[o] += - state_[s] * recurrent_weights_[offset + s * stride + o] * reset[s]; - } - output[o] = RectifiedLinearUnit(output[o]); - // Update output through the update gates. - output[o] = update[o] * state_[o] + (1.f - update[o]) * output[o]; - } - - // Update the state. Not done in the previous loop since that would pollute - // the current state and lead to incorrect output values. - std::copy(output.begin(), output.end(), state_.begin()); -} - RnnBasedVad::RnnBasedVad() : input_layer_(kInputLayerInputSize, kInputLayerOutputSize, diff --git a/modules/audio_processing/agc2/rnn_vad/rnn.h b/modules/audio_processing/agc2/rnn_vad/rnn.h index 29ee20744b..58274b2e1e 100644 --- a/modules/audio_processing/agc2/rnn_vad/rnn.h +++ b/modules/audio_processing/agc2/rnn_vad/rnn.h @@ -90,9 +90,6 @@ class GatedRecurrentLayer { void ComputeOutput(rtc::ArrayView input); private: - // No SIMD optimizations. - void ComputeOutput_NONE(rtc::ArrayView input); - const size_t input_size_; const size_t output_size_; const std::vector bias_; diff --git a/modules/audio_processing/agc2/rnn_vad/rnn_unittest.cc b/modules/audio_processing/agc2/rnn_vad/rnn_unittest.cc index 74974164a1..6e9f6f3690 100644 --- a/modules/audio_processing/agc2/rnn_vad/rnn_unittest.cc +++ b/modules/audio_processing/agc2/rnn_vad/rnn_unittest.cc @@ -82,17 +82,45 @@ constexpr size_t kGruOutputSize = 4; constexpr std::array kGruBias = {96, -99, -81, -114, 49, 119, -118, 68, -76, 91, 121, 125}; constexpr std::array kGruWeights = { - 124, 9, 1, 116, -66, -21, -118, -110, 104, 75, -23, -51, - -72, -111, 47, 93, 77, -98, 41, -8, 40, -23, -43, -107, - 9, -73, 30, -32, -2, 64, -26, 91, -48, -24, -28, -104, - 74, -46, 116, 15, 32, 52, -126, -38, -121, 12, -16, 110, - -95, 66, -103, -35, -38, 3, -126, -61, 28, 98, -117, -43}; -constexpr std::array kGruRecurrentWeights = { - -3, 87, 50, 51, -22, 27, -39, 62, 31, -83, -52, -48, - -6, 83, -19, 104, 105, 48, 23, 68, 23, 40, 7, -120, - 64, -62, 117, 85, -51, -43, 54, -105, 120, 56, -128, -107, - 39, 50, -17, -47, -117, 14, 108, 12, -7, -72, 103, -87, - -66, 82, 84, 100, -98, 102, -49, 44, 122, 106, -20, -69}; + // Input 0. + 124, 9, 1, 116, // Update. + -66, -21, -118, -110, // Reset. + 104, 75, -23, -51, // Output. + // Input 1. + -72, -111, 47, 93, // Update. + 77, -98, 41, -8, // Reset. + 40, -23, -43, -107, // Output. + // Input 2. + 9, -73, 30, -32, // Update. + -2, 64, -26, 91, // Reset. + -48, -24, -28, -104, // Output. + // Input 3. + 74, -46, 116, 15, // Update. + 32, 52, -126, -38, // Reset. + -121, 12, -16, 110, // Output. + // Input 4. + -95, 66, -103, -35, // Update. + -38, 3, -126, -61, // Reset. + 28, 98, -117, -43 // Output. +}; +constexpr std::array kGruRecurrentWeights = { + // Output 0. + -3, 87, 50, 51, // Update. + -22, 27, -39, 62, // Reset. + 31, -83, -52, -48, // Output. + // Output 1. + -6, 83, -19, 104, // Update. + 105, 48, 23, 68, // Reset. + 23, 40, 7, -120, // Output. + // Output 2. + 64, -62, 117, 85, // Update. + 51, -43, 54, -105, // Reset. + 120, 56, -128, -107, // Output. + // Output 3. + 39, 50, -17, -47, // Update. + -117, 14, 108, 12, // Reset. + -7, -72, 103, -87, // Output. +}; constexpr std::array kGruInputSequence = { 0.89395463f, 0.93224651f, 0.55788344f, 0.32341808f, 0.93355054f, 0.13475326f, 0.97370994f, 0.14253306f, 0.93710381f, 0.76093364f, @@ -115,6 +143,12 @@ std::string GetOptimizationName(Optimization optimization) { } } +struct Result { + Optimization optimization; + double average_us; + double std_dev_us; +}; + } // namespace // Checks that the output of a fully connected layer is within tolerance given @@ -152,6 +186,17 @@ TEST(RnnVadTest, CheckFullyConnectedLayerOutputSse2) { kFullyConnectedExpectedOutput); } +// Like CheckGatedRecurrentLayer, but testing the SSE2 implementation. +TEST(RnnVadTest, CheckGatedRecurrentLayerSse2) { + if (!IsOptimizationAvailable(Optimization::kSse2)) { + return; + } + + GatedRecurrentLayer gru(kGruInputSize, kGruOutputSize, kGruBias, kGruWeights, + kGruRecurrentWeights, Optimization::kSse2); + TestGatedRecurrentLayer(&gru, kGruInputSequence, kGruExpectedOutputSequence); +} + #endif // WEBRTC_ARCH_X86_FAMILY TEST(RnnVadTest, DISABLED_BenchmarkFullyConnectedLayer) { @@ -167,13 +212,7 @@ TEST(RnnVadTest, DISABLED_BenchmarkFullyConnectedLayer) { rnnoise::TansigApproximated, Optimization::kSse2)); } - struct Result { - Optimization optimization; - double average_us; - double std_dev_us; - }; std::vector results; - constexpr size_t number_of_tests = 10000; for (auto& fc : implementations) { ::webrtc::test::PerformanceTimer perf_timer(number_of_tests); @@ -193,6 +232,41 @@ TEST(RnnVadTest, DISABLED_BenchmarkFullyConnectedLayer) { } } +TEST(RnnVadTest, DISABLED_BenchmarkGatedRecurrentLayer) { + std::vector> implementations; + implementations.emplace_back(std::make_unique( + kGruInputSize, kGruOutputSize, kGruBias, kGruWeights, + kGruRecurrentWeights, Optimization::kNone)); + + rtc::ArrayView input_sequence(kGruInputSequence); + static_assert(kGruInputSequence.size() % kGruInputSize == 0, ""); + constexpr size_t input_sequence_length = + kGruInputSequence.size() / kGruInputSize; + + std::vector results; + constexpr size_t number_of_tests = 10000; + for (auto& gru : implementations) { + ::webrtc::test::PerformanceTimer perf_timer(number_of_tests); + gru->Reset(); + for (size_t k = 0; k < number_of_tests; ++k) { + perf_timer.StartTimer(); + for (size_t i = 0; i < input_sequence_length; ++i) { + gru->ComputeOutput( + input_sequence.subview(i * gru->input_size(), gru->input_size())); + } + perf_timer.StopTimer(); + } + results.push_back({gru->optimization(), perf_timer.GetDurationAverage(), + perf_timer.GetDurationStandardDeviation()}); + } + + for (const auto& result : results) { + RTC_LOG(LS_INFO) << GetOptimizationName(result.optimization) << ": " + << (result.average_us / 1e3) << " +/- " + << (result.std_dev_us / 1e3) << " ms"; + } +} + } // namespace test } // namespace rnn_vad } // namespace webrtc