Cleanup: Rename MatrixBatchVectorMultiplyAccumulateLoopBodySse to DotProdInt8x4x4 to better reflect what it does. Also do the final accumulation step outside this function.
PiperOrigin-RevId: 293243254 Change-Id: I33a557cb42e336eb605b69e5c124e4df5b7bb33c
This commit is contained in:
parent
1eb0fdd149
commit
cd5281bd47
@ -25,22 +25,17 @@ namespace tflite {
|
||||
namespace tensor_utils {
|
||||
namespace {
|
||||
|
||||
// Elementwise multiply two i8x16 vectors to i16x8, add elements pairwise and
|
||||
// accumulate result to a i32x4 accumulator.
|
||||
//
|
||||
// Shared by the inner loop of MatrixBatchVectorMultiplyAccumulate(int8) and
|
||||
// SparseMatrixBatchVectorMultiplyAccumulate(int8).
|
||||
static inline __m128i MatrixBatchVectorMultiplyAccumulateLoopBodySse(
|
||||
__m128i dotprod, __m128i a_8x16, __m128i b_8x16) {
|
||||
// Dot product of four int8 vectors of 4 elements packed into a XMM register.
|
||||
// Result is four int32 scalars packed into a XMM register.
|
||||
// int8x4x4 · int8x4x4 => int32x4
|
||||
static inline __m128i DotProdInt8x4x4(__m128i a_8x16, __m128i b_8x16) {
|
||||
// Transfer sign from 'a' to 'b', as _mm_maddubs_epi16 treats 'a' unsigned.
|
||||
b_8x16 = _mm_sign_epi8(b_8x16, a_8x16);
|
||||
a_8x16 = _mm_abs_epi8(a_8x16);
|
||||
// sumprod[i] = a[2*i]*b[2*i] + a[2*i+1]*b[2*i+1] (i = 0..7)
|
||||
__m128i sumprod_16x8 = _mm_maddubs_epi16(a_8x16, b_8x16);
|
||||
// sumprod[i] = sumprod[2*i]*1 + sumprod[2*i+1]*1 (i = 0..3)
|
||||
__m128i sumprod_32x4 = _mm_madd_epi16(sumprod_16x8, _mm_set1_epi16(1));
|
||||
// accumlator += sumprod
|
||||
return _mm_add_epi32(dotprod, sumprod_32x4);
|
||||
return _mm_madd_epi16(sumprod_16x8, _mm_set1_epi16(1));
|
||||
}
|
||||
|
||||
// Horizontally add 4 int32 values stored in a single XMM register to int32_t.
|
||||
@ -76,8 +71,9 @@ void SseMatrixBatchVectorMultiplyAccumulate(
|
||||
_mm_loadu_si128(reinterpret_cast<const __m128i*>(vectors + col));
|
||||
const __m128i row_8x16 =
|
||||
_mm_loadu_si128(reinterpret_cast<const __m128i*>(row_ptr + col));
|
||||
dotprod_32x4 = MatrixBatchVectorMultiplyAccumulateLoopBodySse(
|
||||
dotprod_32x4, vec_8x16, row_8x16);
|
||||
// dotprod += vec · row
|
||||
dotprod_32x4 =
|
||||
_mm_add_epi32(dotprod_32x4, DotProdInt8x4x4(vec_8x16, row_8x16));
|
||||
} // for col
|
||||
// Horizontally add the 4 intermediate sum values to get the final
|
||||
// dot-prod value for this row.
|
||||
@ -113,8 +109,9 @@ void SseMatrixBatchVectorMultiplyAccumulate(
|
||||
_mm_loadu_si128(reinterpret_cast<const __m128i*>(vectors + col));
|
||||
const __m128i row_8x16 =
|
||||
_mm_loadu_si128(reinterpret_cast<const __m128i*>(row_ptr + col));
|
||||
dotprod_32x4 = MatrixBatchVectorMultiplyAccumulateLoopBodySse(
|
||||
dotprod_32x4, vec_8x16, row_8x16);
|
||||
// dotprod += vec · row
|
||||
dotprod_32x4 =
|
||||
_mm_add_epi32(dotprod_32x4, DotProdInt8x4x4(vec_8x16, row_8x16));
|
||||
|
||||
// Pairwise add 16x 8-bit values; equivalently, multipy-add with 1.
|
||||
// Result is 8x 16-bit values.
|
||||
@ -161,8 +158,9 @@ void SseSparseMatrixBatchVectorMultiplyAccumulate(
|
||||
reinterpret_cast<const __m128i*>(vectors + col_index));
|
||||
const __m128i row_8x16 =
|
||||
_mm_loadu_si128(reinterpret_cast<const __m128i*>(row_ptr));
|
||||
dotprod_32x4 = MatrixBatchVectorMultiplyAccumulateLoopBodySse(
|
||||
dotprod_32x4, vec_8x16, row_8x16);
|
||||
// dotprod += vec · row
|
||||
dotprod_32x4 =
|
||||
_mm_add_epi32(dotprod_32x4, DotProdInt8x4x4(vec_8x16, row_8x16));
|
||||
row_ptr += kBlockSize;
|
||||
} // for col
|
||||
// Horizontally add the 4 intermediate sum values to get the final
|
||||
|
Loading…
x
Reference in New Issue
Block a user