Cleanup: Rename MatrixBatchVectorMultiplyAccumulateLoopBodySse to DotProdInt8x4x4 to better reflect what it does. Also do the final accumulation step outside this function.

PiperOrigin-RevId: 293243254
Change-Id: I33a557cb42e336eb605b69e5c124e4df5b7bb33c
This commit is contained in:
Robert David 2020-02-04 15:30:21 -08:00 committed by TensorFlower Gardener
parent 1eb0fdd149
commit cd5281bd47

View File

@ -25,22 +25,17 @@ namespace tflite {
namespace tensor_utils {
namespace {
// Elementwise multiply two i8x16 vectors to i16x8, add elements pairwise and
// accumulate result to a i32x4 accumulator.
//
// Shared by the inner loop of MatrixBatchVectorMultiplyAccumulate(int8) and
// SparseMatrixBatchVectorMultiplyAccumulate(int8).
static inline __m128i MatrixBatchVectorMultiplyAccumulateLoopBodySse(
__m128i dotprod, __m128i a_8x16, __m128i b_8x16) {
// Dot product of four int8 vectors of 4 elements packed into a XMM register.
// Result is four int32 scalars packed into a XMM register.
// int8x4x4 · int8x4x4 => int32x4
static inline __m128i DotProdInt8x4x4(__m128i a_8x16, __m128i b_8x16) {
// Transfer sign from 'a' to 'b', as _mm_maddubs_epi16 treats 'a' unsigned.
b_8x16 = _mm_sign_epi8(b_8x16, a_8x16);
a_8x16 = _mm_abs_epi8(a_8x16);
// sumprod[i] = a[2*i]*b[2*i] + a[2*i+1]*b[2*i+1] (i = 0..7)
__m128i sumprod_16x8 = _mm_maddubs_epi16(a_8x16, b_8x16);
// sumprod[i] = sumprod[2*i]*1 + sumprod[2*i+1]*1 (i = 0..3)
__m128i sumprod_32x4 = _mm_madd_epi16(sumprod_16x8, _mm_set1_epi16(1));
// accumlator += sumprod
return _mm_add_epi32(dotprod, sumprod_32x4);
return _mm_madd_epi16(sumprod_16x8, _mm_set1_epi16(1));
}
// Horizontally add 4 int32 values stored in a single XMM register to int32_t.
@ -76,8 +71,9 @@ void SseMatrixBatchVectorMultiplyAccumulate(
_mm_loadu_si128(reinterpret_cast<const __m128i*>(vectors + col));
const __m128i row_8x16 =
_mm_loadu_si128(reinterpret_cast<const __m128i*>(row_ptr + col));
dotprod_32x4 = MatrixBatchVectorMultiplyAccumulateLoopBodySse(
dotprod_32x4, vec_8x16, row_8x16);
// dotprod += vec · row
dotprod_32x4 =
_mm_add_epi32(dotprod_32x4, DotProdInt8x4x4(vec_8x16, row_8x16));
} // for col
// Horizontally add the 4 intermediate sum values to get the final
// dot-prod value for this row.
@ -113,8 +109,9 @@ void SseMatrixBatchVectorMultiplyAccumulate(
_mm_loadu_si128(reinterpret_cast<const __m128i*>(vectors + col));
const __m128i row_8x16 =
_mm_loadu_si128(reinterpret_cast<const __m128i*>(row_ptr + col));
dotprod_32x4 = MatrixBatchVectorMultiplyAccumulateLoopBodySse(
dotprod_32x4, vec_8x16, row_8x16);
// dotprod += vec · row
dotprod_32x4 =
_mm_add_epi32(dotprod_32x4, DotProdInt8x4x4(vec_8x16, row_8x16));
// Pairwise add 16x 8-bit values; equivalently, multipy-add with 1.
// Result is 8x 16-bit values.
@ -161,8 +158,9 @@ void SseSparseMatrixBatchVectorMultiplyAccumulate(
reinterpret_cast<const __m128i*>(vectors + col_index));
const __m128i row_8x16 =
_mm_loadu_si128(reinterpret_cast<const __m128i*>(row_ptr));
dotprod_32x4 = MatrixBatchVectorMultiplyAccumulateLoopBodySse(
dotprod_32x4, vec_8x16, row_8x16);
// dotprod += vec · row
dotprod_32x4 =
_mm_add_epi32(dotprod_32x4, DotProdInt8x4x4(vec_8x16, row_8x16));
row_ptr += kBlockSize;
} // for col
// Horizontally add the 4 intermediate sum values to get the final