RNN VAD: VectorMath::DotProduct
with AVX2 optimization
This CL adds a new library for the RNN VAD that provides (optimized) vector math ops. The scheme is the same of the `VectorMath` class of AEC3 to ensure correct builds across different platforms. Bug: webrtc:10480 Change-Id: I96bcfbf930ca27388ab5f2d52c022ddb73acf8e6 Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/194326 Commit-Queue: Alessio Bazzica <alessiob@webrtc.org> Reviewed-by: Per Åhgren <peah@webrtc.org> Reviewed-by: Gustaf Ullberg <gustaf@webrtc.org> Cr-Commit-Position: refs/heads/master@{#32741}
This commit is contained in:

committed by
Commit Bot

parent
ccfcec402d
commit
01b3e24a83
@ -78,6 +78,35 @@ rtc_library("rnn_vad_lp_residual") {
|
||||
]
|
||||
}
|
||||
|
||||
rtc_source_set("vector_math") {
|
||||
sources = [ "vector_math.h" ]
|
||||
deps = [
|
||||
"..:cpu_features",
|
||||
"../../../../api:array_view",
|
||||
"../../../../rtc_base:checks",
|
||||
"../../../../rtc_base/system:arch",
|
||||
]
|
||||
}
|
||||
|
||||
if (current_cpu == "x86" || current_cpu == "x64") {
|
||||
rtc_library("vector_math_avx2") {
|
||||
sources = [ "vector_math_avx2.cc" ]
|
||||
if (is_win) {
|
||||
cflags = [ "/arch:AVX2" ]
|
||||
} else {
|
||||
cflags = [
|
||||
"-mavx2",
|
||||
"-mfma",
|
||||
]
|
||||
}
|
||||
deps = [
|
||||
":vector_math",
|
||||
"../../../../api:array_view",
|
||||
"../../../../rtc_base:checks",
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
rtc_library("rnn_vad_pitch") {
|
||||
sources = [
|
||||
"pitch_search.cc",
|
||||
@ -88,6 +117,7 @@ rtc_library("rnn_vad_pitch") {
|
||||
deps = [
|
||||
":rnn_vad_auto_correlation",
|
||||
":rnn_vad_common",
|
||||
":vector_math",
|
||||
"..:cpu_features",
|
||||
"../../../../api:array_view",
|
||||
"../../../../rtc_base:checks",
|
||||
@ -95,6 +125,9 @@ rtc_library("rnn_vad_pitch") {
|
||||
"../../../../rtc_base:safe_compare",
|
||||
"../../../../rtc_base:safe_conversions",
|
||||
]
|
||||
if (current_cpu == "x86" || current_cpu == "x64") {
|
||||
deps += [ ":vector_math_avx2" ]
|
||||
}
|
||||
}
|
||||
|
||||
rtc_source_set("rnn_vad_ring_buffer") {
|
||||
@ -191,6 +224,7 @@ if (rtc_include_tests) {
|
||||
"spectral_features_internal_unittest.cc",
|
||||
"spectral_features_unittest.cc",
|
||||
"symmetric_matrix_buffer_unittest.cc",
|
||||
"vector_math_unittest.cc",
|
||||
]
|
||||
deps = [
|
||||
":rnn_vad",
|
||||
@ -203,6 +237,7 @@ if (rtc_include_tests) {
|
||||
":rnn_vad_spectral_features",
|
||||
":rnn_vad_symmetric_matrix_buffer",
|
||||
":test_utils",
|
||||
":vector_math",
|
||||
"..:cpu_features",
|
||||
"../..:audioproc_test_utils",
|
||||
"../../../../api:array_view",
|
||||
@ -216,6 +251,9 @@ if (rtc_include_tests) {
|
||||
"../../utility:pffft_wrapper",
|
||||
"//third_party/rnnoise:rnn_vad",
|
||||
]
|
||||
if (current_cpu == "x86" || current_cpu == "x64") {
|
||||
deps += [ ":vector_math_avx2" ]
|
||||
}
|
||||
absl_deps = [ "//third_party/abseil-cpp/absl/memory" ]
|
||||
data = unittest_resources
|
||||
if (is_ios) {
|
||||
|
55
modules/audio_processing/agc2/rnn_vad/vector_math.h
Normal file
55
modules/audio_processing/agc2/rnn_vad/vector_math.h
Normal file
@ -0,0 +1,55 @@
|
||||
/*
|
||||
* Copyright (c) 2020 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#ifndef MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_VECTOR_MATH_H_
|
||||
#define MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_VECTOR_MATH_H_
|
||||
|
||||
#include <numeric>
|
||||
|
||||
#include "api/array_view.h"
|
||||
#include "modules/audio_processing/agc2/cpu_features.h"
|
||||
#include "rtc_base/checks.h"
|
||||
#include "rtc_base/system/arch.h"
|
||||
|
||||
namespace webrtc {
|
||||
namespace rnn_vad {
|
||||
|
||||
// Provides optimizations for mathematical operations having vectors as
|
||||
// operand(s).
|
||||
class VectorMath {
|
||||
public:
|
||||
explicit VectorMath(AvailableCpuFeatures cpu_features)
|
||||
: cpu_features_(cpu_features) {}
|
||||
|
||||
// Computes the dot product between two equally sized vectors.
|
||||
float DotProduct(rtc::ArrayView<const float> x,
|
||||
rtc::ArrayView<const float> y) const {
|
||||
#if defined(WEBRTC_ARCH_X86_FAMILY)
|
||||
if (cpu_features_.avx2) {
|
||||
return DotProductAvx2(x, y);
|
||||
}
|
||||
// TODO(bugs.webrtc.org/10480): Add SSE2 alternative implementation.
|
||||
#endif
|
||||
// TODO(bugs.webrtc.org/10480): Add NEON alternative implementation.
|
||||
RTC_DCHECK_EQ(x.size(), y.size());
|
||||
return std::inner_product(x.begin(), x.end(), y.begin(), 0.f);
|
||||
}
|
||||
|
||||
private:
|
||||
float DotProductAvx2(rtc::ArrayView<const float> x,
|
||||
rtc::ArrayView<const float> y) const;
|
||||
|
||||
const AvailableCpuFeatures cpu_features_;
|
||||
};
|
||||
|
||||
} // namespace rnn_vad
|
||||
} // namespace webrtc
|
||||
|
||||
#endif // MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_VECTOR_MATH_H_
|
53
modules/audio_processing/agc2/rnn_vad/vector_math_avx2.cc
Normal file
53
modules/audio_processing/agc2/rnn_vad/vector_math_avx2.cc
Normal file
@ -0,0 +1,53 @@
|
||||
/*
|
||||
* Copyright (c) 2020 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#include "modules/audio_processing/agc2/rnn_vad/vector_math.h"
|
||||
|
||||
#include <immintrin.h>
|
||||
|
||||
#include "api/array_view.h"
|
||||
#include "rtc_base/checks.h"
|
||||
|
||||
namespace webrtc {
|
||||
namespace rnn_vad {
|
||||
|
||||
float VectorMath::DotProductAvx2(rtc::ArrayView<const float> x,
|
||||
rtc::ArrayView<const float> y) const {
|
||||
RTC_DCHECK(cpu_features_.avx2);
|
||||
RTC_DCHECK_EQ(x.size(), y.size());
|
||||
__m256 accumulator = _mm256_setzero_ps();
|
||||
constexpr int kBlockSizeLog2 = 3;
|
||||
constexpr int kBlockSize = 1 << kBlockSizeLog2;
|
||||
const int incomplete_block_index = (x.size() >> kBlockSizeLog2)
|
||||
<< kBlockSizeLog2;
|
||||
for (int i = 0; i < incomplete_block_index; i += kBlockSize) {
|
||||
RTC_DCHECK_LE(i + kBlockSize, x.size());
|
||||
const __m256 x_i = _mm256_loadu_ps(&x[i]);
|
||||
const __m256 y_i = _mm256_loadu_ps(&y[i]);
|
||||
accumulator = _mm256_fmadd_ps(x_i, y_i, accumulator);
|
||||
}
|
||||
// Reduce `accumulator` by addition.
|
||||
__m128 high = _mm256_extractf128_ps(accumulator, 1);
|
||||
__m128 low = _mm256_extractf128_ps(accumulator, 0);
|
||||
low = _mm_add_ps(high, low);
|
||||
high = _mm_movehl_ps(high, low);
|
||||
low = _mm_add_ps(high, low);
|
||||
high = _mm_shuffle_ps(low, low, 1);
|
||||
low = _mm_add_ss(high, low);
|
||||
float dot_product = _mm_cvtss_f32(low);
|
||||
// Add the result for the last block if incomplete.
|
||||
for (int i = incomplete_block_index; static_cast<size_t>(i) < x.size(); ++i) {
|
||||
dot_product += x[i] * y[i];
|
||||
}
|
||||
return dot_product;
|
||||
}
|
||||
|
||||
} // namespace rnn_vad
|
||||
} // namespace webrtc
|
@ -0,0 +1,67 @@
|
||||
/*
|
||||
* Copyright (c) 2020 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#include "modules/audio_processing/agc2/rnn_vad/vector_math.h"
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "modules/audio_processing/agc2/cpu_features.h"
|
||||
#include "test/gtest.h"
|
||||
|
||||
namespace webrtc {
|
||||
namespace rnn_vad {
|
||||
namespace {
|
||||
|
||||
constexpr int kSizeOfX = 19;
|
||||
constexpr float kX[kSizeOfX] = {
|
||||
0.31593041f, 0.9350786f, -0.25252445f, -0.86956251f, -0.9673632f,
|
||||
0.54571901f, -0.72504495f, -0.79509912f, -0.25525012f, -0.73340473f,
|
||||
0.15747377f, -0.04370565f, 0.76135145f, -0.57239645f, 0.68616848f,
|
||||
0.3740298f, 0.34710799f, -0.92207423f, 0.10738454f};
|
||||
constexpr int kSizeOfXSubSpan = 16;
|
||||
static_assert(kSizeOfXSubSpan < kSizeOfX, "");
|
||||
constexpr float kEnergyOfX = 7.315563958160327f;
|
||||
constexpr float kEnergyOfXSubspan = 6.333327669592963f;
|
||||
|
||||
class VectorMathParametrization
|
||||
: public ::testing::TestWithParam<AvailableCpuFeatures> {};
|
||||
|
||||
TEST_P(VectorMathParametrization, TestDotProduct) {
|
||||
VectorMath vector_math(/*cpu_features=*/GetParam());
|
||||
EXPECT_FLOAT_EQ(vector_math.DotProduct(kX, kX), kEnergyOfX);
|
||||
EXPECT_FLOAT_EQ(
|
||||
vector_math.DotProduct({kX, kSizeOfXSubSpan}, {kX, kSizeOfXSubSpan}),
|
||||
kEnergyOfXSubspan);
|
||||
}
|
||||
|
||||
// Finds the relevant CPU features combinations to test.
|
||||
std::vector<AvailableCpuFeatures> GetCpuFeaturesToTest() {
|
||||
std::vector<AvailableCpuFeatures> v;
|
||||
v.push_back({/*sse2=*/false, /*avx2=*/false, /*neon=*/false});
|
||||
AvailableCpuFeatures available = GetAvailableCpuFeatures();
|
||||
if (available.avx2) {
|
||||
AvailableCpuFeatures features(
|
||||
{/*sse2=*/false, /*avx2=*/true, /*neon=*/false});
|
||||
v.push_back(features);
|
||||
}
|
||||
return v;
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
RnnVadTest,
|
||||
VectorMathParametrization,
|
||||
::testing::ValuesIn(GetCpuFeaturesToTest()),
|
||||
[](const ::testing::TestParamInfo<AvailableCpuFeatures>& info) {
|
||||
return info.param.ToString();
|
||||
});
|
||||
|
||||
} // namespace
|
||||
} // namespace rnn_vad
|
||||
} // namespace webrtc
|
Reference in New Issue
Block a user