Implement a Neon optimized function to find the argmax element in an array.

Finding the array element with the largest argmax is a fairly common
operation, so it makes sense to have a Neon optimized version. The
implementation is done by first finding both the min and max value, and
then returning whichever has the largest argmax.

Bug: chromium:12355
Change-Id: I088bd4f7d469b2424a7265de10fffb42764567a1
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/201622
Commit-Queue: Ivo Creusen <ivoc@webrtc.org>
Reviewed-by: Per Åhgren <peah@webrtc.org>
Cr-Commit-Position: refs/heads/master@{#33052}
This commit is contained in:
Ivo Creusen
2021-01-20 16:26:43 +01:00
committed by Commit Bot
parent 03eed7c8d0
commit 6031b74664
6 changed files with 169 additions and 8 deletions

View File

@ -228,6 +228,25 @@ int32_t WebRtcSpl_MinValueW32Neon(const int32_t* vector, size_t length);
int32_t WebRtcSpl_MinValueW32_mips(const int32_t* vector, size_t length);
#endif
// Returns both the minimum and maximum values of a 16-bit vector.
//
// Input:
// - vector : 16-bit input vector.
// - length : Number of samples in vector.
// Ouput:
// - max_val : Maximum sample value in |vector|.
// - min_val : Minimum sample value in |vector|.
void WebRtcSpl_MinMaxW16(const int16_t* vector,
size_t length,
int16_t* min_val,
int16_t* max_val);
#if defined(WEBRTC_HAS_NEON)
void WebRtcSpl_MinMaxW16Neon(const int16_t* vector,
size_t length,
int16_t* min_val,
int16_t* max_val);
#endif
// Returns the vector index to the largest absolute value of a 16-bit vector.
//
// Input:
@ -240,6 +259,17 @@ int32_t WebRtcSpl_MinValueW32_mips(const int32_t* vector, size_t length);
// -32768 presenting an int16 absolute value of 32767).
size_t WebRtcSpl_MaxAbsIndexW16(const int16_t* vector, size_t length);
// Returns the element with the largest absolute value of a 16-bit vector. Note
// that this function can return a negative value.
//
// Input:
// - vector : 16-bit input vector.
// - length : Number of samples in vector.
//
// Return value : The element with the largest absolute value. Note that this
// may be a negative value.
int16_t WebRtcSpl_MaxAbsElementW16(const int16_t* vector, size_t length);
// Returns the vector index to the maximum sample value of a 16-bit vector.
//
// Input:

View File

@ -155,6 +155,15 @@ size_t WebRtcSpl_MaxAbsIndexW16(const int16_t* vector, size_t length) {
return index;
}
int16_t WebRtcSpl_MaxAbsElementW16(const int16_t* vector, size_t length) {
int16_t min_val, max_val;
WebRtcSpl_MinMaxW16(vector, length, &min_val, &max_val);
if (min_val == max_val || min_val < -max_val) {
return min_val;
}
return max_val;
}
// Index of maximum value in a word16 vector.
size_t WebRtcSpl_MaxIndexW16(const int16_t* vector, size_t length) {
size_t i = 0, index = 0;
@ -222,3 +231,26 @@ size_t WebRtcSpl_MinIndexW32(const int32_t* vector, size_t length) {
return index;
}
// Finds both the minimum and maximum elements in an array of 16-bit integers.
void WebRtcSpl_MinMaxW16(const int16_t* vector, size_t length,
int16_t* min_val, int16_t* max_val) {
#if defined(WEBRTC_HAS_NEON)
return WebRtcSpl_MinMaxW16Neon(vector, length, min_val, max_val);
#else
int16_t minimum = WEBRTC_SPL_WORD16_MAX;
int16_t maximum = WEBRTC_SPL_WORD16_MIN;
size_t i = 0;
RTC_DCHECK_GT(length, 0);
for (i = 0; i < length; i++) {
if (vector[i] < minimum)
minimum = vector[i];
if (vector[i] > maximum)
maximum = vector[i];
}
*min_val = minimum;
*max_val = maximum;
#endif
}

View File

@ -281,3 +281,53 @@ int32_t WebRtcSpl_MinValueW32Neon(const int32_t* vector, size_t length) {
return minimum;
}
// Finds both the minimum and maximum elements in an array of 16-bit integers.
void WebRtcSpl_MinMaxW16Neon(const int16_t* vector, size_t length,
int16_t* min_val, int16_t* max_val) {
int16_t minimum = WEBRTC_SPL_WORD16_MAX;
int16_t maximum = WEBRTC_SPL_WORD16_MIN;
size_t i = 0;
size_t residual = length & 0x7;
RTC_DCHECK_GT(length, 0);
const int16_t* p_start = vector;
int16x8_t min16x8 = vdupq_n_s16(WEBRTC_SPL_WORD16_MAX);
int16x8_t max16x8 = vdupq_n_s16(WEBRTC_SPL_WORD16_MIN);
// First part, unroll the loop 8 times.
for (i = 0; i < length - residual; i += 8) {
int16x8_t in16x8 = vld1q_s16(p_start);
min16x8 = vminq_s16(min16x8, in16x8);
max16x8 = vmaxq_s16(max16x8, in16x8);
p_start += 8;
}
#if defined(WEBRTC_ARCH_ARM64)
minimum = vminvq_s16(min16x8);
maximum = vmaxvq_s16(max16x8);
#else
int16x4_t min16x4 = vmin_s16(vget_low_s16(min16x8), vget_high_s16(min16x8));
min16x4 = vpmin_s16(min16x4, min16x4);
min16x4 = vpmin_s16(min16x4, min16x4);
minimum = vget_lane_s16(min16x4, 0);
int16x4_t max16x4 = vmax_s16(vget_low_s16(max16x8), vget_high_s16(max16x8));
max16x4 = vpmax_s16(max16x4, max16x4);
max16x4 = vpmax_s16(max16x4, max16x4);
maximum = vget_lane_s16(max16x4, 0);
#endif
// Second part, do the remaining iterations (if any).
for (i = residual; i > 0; i--) {
if (*p_start < minimum)
minimum = *p_start;
if (*p_start > maximum)
maximum = *p_start;
p_start++;
}
*min_val = minimum;
*max_val = maximum;
}

View File

@ -289,6 +289,12 @@ TEST(SplTest, MinMaxOperationsTest) {
WebRtcSpl_MinValueW32(vector32, kVectorSize));
EXPECT_EQ(kVectorSize - 1, WebRtcSpl_MinIndexW16(vector16, kVectorSize));
EXPECT_EQ(kVectorSize - 1, WebRtcSpl_MinIndexW32(vector32, kVectorSize));
EXPECT_EQ(WEBRTC_SPL_WORD16_MIN,
WebRtcSpl_MaxAbsElementW16(vector16, kVectorSize));
int16_t min_value, max_value;
WebRtcSpl_MinMaxW16(vector16, kVectorSize, &min_value, &max_value);
EXPECT_EQ(WEBRTC_SPL_WORD16_MIN, min_value);
EXPECT_EQ(12334, max_value);
// Test the cases where maximum values have to be caught
// outside of the unrolled loops in ARM-Neon.
@ -306,6 +312,11 @@ TEST(SplTest, MinMaxOperationsTest) {
EXPECT_EQ(kVectorSize - 1, WebRtcSpl_MaxAbsIndexW16(vector16, kVectorSize));
EXPECT_EQ(kVectorSize - 1, WebRtcSpl_MaxIndexW16(vector16, kVectorSize));
EXPECT_EQ(kVectorSize - 1, WebRtcSpl_MaxIndexW32(vector32, kVectorSize));
EXPECT_EQ(WEBRTC_SPL_WORD16_MAX,
WebRtcSpl_MaxAbsElementW16(vector16, kVectorSize));
WebRtcSpl_MinMaxW16(vector16, kVectorSize, &min_value, &max_value);
EXPECT_EQ(-29871, min_value);
EXPECT_EQ(WEBRTC_SPL_WORD16_MAX, max_value);
// Test the cases where multiple maximum and minimum values are present.
vector16[1] = WEBRTC_SPL_WORD16_MAX;
@ -332,6 +343,43 @@ TEST(SplTest, MinMaxOperationsTest) {
EXPECT_EQ(1u, WebRtcSpl_MaxIndexW32(vector32, kVectorSize));
EXPECT_EQ(6u, WebRtcSpl_MinIndexW16(vector16, kVectorSize));
EXPECT_EQ(6u, WebRtcSpl_MinIndexW32(vector32, kVectorSize));
EXPECT_EQ(WEBRTC_SPL_WORD16_MIN,
WebRtcSpl_MaxAbsElementW16(vector16, kVectorSize));
WebRtcSpl_MinMaxW16(vector16, kVectorSize, &min_value, &max_value);
EXPECT_EQ(WEBRTC_SPL_WORD16_MIN, min_value);
EXPECT_EQ(WEBRTC_SPL_WORD16_MAX, max_value);
// Test a one-element vector.
int16_t single_element_vector = 0;
EXPECT_EQ(0, WebRtcSpl_MaxAbsValueW16(&single_element_vector, 1));
EXPECT_EQ(0, WebRtcSpl_MaxValueW16(&single_element_vector, 1));
EXPECT_EQ(0, WebRtcSpl_MinValueW16(&single_element_vector, 1));
EXPECT_EQ(0u, WebRtcSpl_MaxAbsIndexW16(&single_element_vector, 1));
EXPECT_EQ(0u, WebRtcSpl_MaxIndexW16(&single_element_vector, 1));
EXPECT_EQ(0u, WebRtcSpl_MinIndexW16(&single_element_vector, 1));
EXPECT_EQ(0, WebRtcSpl_MaxAbsElementW16(&single_element_vector, 1));
WebRtcSpl_MinMaxW16(&single_element_vector, 1, &min_value, &max_value);
EXPECT_EQ(0, min_value);
EXPECT_EQ(0, max_value);
// Test a two-element vector with the values WEBRTC_SPL_WORD16_MIN and
// WEBRTC_SPL_WORD16_MAX.
int16_t two_element_vector[2] = {WEBRTC_SPL_WORD16_MIN,
WEBRTC_SPL_WORD16_MAX};
EXPECT_EQ(WEBRTC_SPL_WORD16_MAX,
WebRtcSpl_MaxAbsValueW16(two_element_vector, 2));
EXPECT_EQ(WEBRTC_SPL_WORD16_MAX,
WebRtcSpl_MaxValueW16(two_element_vector, 2));
EXPECT_EQ(WEBRTC_SPL_WORD16_MIN,
WebRtcSpl_MinValueW16(two_element_vector, 2));
EXPECT_EQ(0u, WebRtcSpl_MaxAbsIndexW16(two_element_vector, 2));
EXPECT_EQ(1u, WebRtcSpl_MaxIndexW16(two_element_vector, 2));
EXPECT_EQ(0u, WebRtcSpl_MinIndexW16(two_element_vector, 2));
EXPECT_EQ(WEBRTC_SPL_WORD16_MIN,
WebRtcSpl_MaxAbsElementW16(two_element_vector, 2));
WebRtcSpl_MinMaxW16(two_element_vector, 2, &min_value, &max_value);
EXPECT_EQ(WEBRTC_SPL_WORD16_MIN, min_value);
EXPECT_EQ(WEBRTC_SPL_WORD16_MAX, max_value);
}
TEST(SplTest, VectorOperationsTest) {