diff --git a/tensorflow/lite/kernels/internal/optimized/sse_tensor_utils.cc b/tensorflow/lite/kernels/internal/optimized/sse_tensor_utils.cc index 79e3b8b9a63..6db74b0b69a 100644 --- a/tensorflow/lite/kernels/internal/optimized/sse_tensor_utils.cc +++ b/tensorflow/lite/kernels/internal/optimized/sse_tensor_utils.cc @@ -25,22 +25,17 @@ namespace tflite { namespace tensor_utils { namespace { -// Elementwise multiply two i8x16 vectors to i16x8, add elements pairwise and -// accumulate result to a i32x4 accumulator. -// -// Shared by the inner loop of MatrixBatchVectorMultiplyAccumulate(int8) and -// SparseMatrixBatchVectorMultiplyAccumulate(int8). -static inline __m128i MatrixBatchVectorMultiplyAccumulateLoopBodySse( - __m128i dotprod, __m128i a_8x16, __m128i b_8x16) { +// Dot product of four int8 vectors of 4 elements packed into a XMM register. +// Result is four int32 scalars packed into a XMM register. +// int8x4x4 · int8x4x4 => int32x4 +static inline __m128i DotProdInt8x4x4(__m128i a_8x16, __m128i b_8x16) { // Transfer sign from 'a' to 'b', as _mm_maddubs_epi16 treats 'a' unsigned. b_8x16 = _mm_sign_epi8(b_8x16, a_8x16); a_8x16 = _mm_abs_epi8(a_8x16); // sumprod[i] = a[2*i]*b[2*i] + a[2*i+1]*b[2*i+1] (i = 0..7) __m128i sumprod_16x8 = _mm_maddubs_epi16(a_8x16, b_8x16); // sumprod[i] = sumprod[2*i]*1 + sumprod[2*i+1]*1 (i = 0..3) - __m128i sumprod_32x4 = _mm_madd_epi16(sumprod_16x8, _mm_set1_epi16(1)); - // accumlator += sumprod - return _mm_add_epi32(dotprod, sumprod_32x4); + return _mm_madd_epi16(sumprod_16x8, _mm_set1_epi16(1)); } // Horizontally add 4 int32 values stored in a single XMM register to int32_t. @@ -76,8 +71,9 @@ void SseMatrixBatchVectorMultiplyAccumulate( _mm_loadu_si128(reinterpret_cast(vectors + col)); const __m128i row_8x16 = _mm_loadu_si128(reinterpret_cast(row_ptr + col)); - dotprod_32x4 = MatrixBatchVectorMultiplyAccumulateLoopBodySse( - dotprod_32x4, vec_8x16, row_8x16); + // dotprod += vec · row + dotprod_32x4 = + _mm_add_epi32(dotprod_32x4, DotProdInt8x4x4(vec_8x16, row_8x16)); } // for col // Horizontally add the 4 intermediate sum values to get the final // dot-prod value for this row. @@ -113,8 +109,9 @@ void SseMatrixBatchVectorMultiplyAccumulate( _mm_loadu_si128(reinterpret_cast(vectors + col)); const __m128i row_8x16 = _mm_loadu_si128(reinterpret_cast(row_ptr + col)); - dotprod_32x4 = MatrixBatchVectorMultiplyAccumulateLoopBodySse( - dotprod_32x4, vec_8x16, row_8x16); + // dotprod += vec · row + dotprod_32x4 = + _mm_add_epi32(dotprod_32x4, DotProdInt8x4x4(vec_8x16, row_8x16)); // Pairwise add 16x 8-bit values; equivalently, multipy-add with 1. // Result is 8x 16-bit values. @@ -161,8 +158,9 @@ void SseSparseMatrixBatchVectorMultiplyAccumulate( reinterpret_cast(vectors + col_index)); const __m128i row_8x16 = _mm_loadu_si128(reinterpret_cast(row_ptr)); - dotprod_32x4 = MatrixBatchVectorMultiplyAccumulateLoopBodySse( - dotprod_32x4, vec_8x16, row_8x16); + // dotprod += vec · row + dotprod_32x4 = + _mm_add_epi32(dotprod_32x4, DotProdInt8x4x4(vec_8x16, row_8x16)); row_ptr += kBlockSize; } // for col // Horizontally add the 4 intermediate sum values to get the final