From 2abdad6ab165c8dbc459d9e2ba645127e8adc56e Mon Sep 17 00:00:00 2001 From: David Rim Date: Thu, 18 Feb 2021 20:25:36 -0800 Subject: [PATCH] Disable padded MatrixBatchVectorMultiply with sdot PiperOrigin-RevId: 358324561 Change-Id: I2ba23bf11c7b200e49cee1cdff096c3521f12e51 --- .../kernels/internal/optimized/neon_tensor_utils.cc | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.cc b/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.cc index 87b4d8d80f1..386ef4e0433 100644 --- a/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.cc +++ b/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.cc @@ -60,6 +60,10 @@ limitations under the License. #define TFLITE_UNLIKELY(x) (x) #endif +// TODO(b/180650471): Add back padded version of +// MatrixBatchVectorMultiplyAccumulate with sdot instruction. +#define ENABLE_PADDED_DOT_PROD false + namespace tflite { namespace tensor_utils { namespace { @@ -68,7 +72,6 @@ constexpr int kFloatValuesPerNeonVector = 4; constexpr int kInt16ValuesPerNeonVector = 8; constexpr int kInt8ValuesPerNeonVector = 16; constexpr int kNeonVectorAlignment = 4; - template inline int RoundDownVectors(int size) { return size & ~(PerNeonSize - 1); @@ -1054,7 +1057,8 @@ void NeonMatrixBatchVectorMultiplyAccumulate(const int8_t* __restrict__ matrix, DotprodMatrixBatchFourVectorMultiplyAccumulate( matrix, m_rows, m_cols, vectors, scaling_factors, n_batch, result); return; - } else if (n_batch >= 2 && m_rows * m_cols >= 128 * 128) { + } else if (ENABLE_PADDED_DOT_PROD && n_batch >= 2 && + m_rows * m_cols >= 128 * 128) { DotprodMatrixBatchPaddedFourVectorMultiplyAccumulate( matrix, m_rows, m_cols, vectors, scaling_factors, n_batch, result); return; @@ -1252,7 +1256,8 @@ void NeonMatrixBatchVectorMultiplyAccumulateImpl( matrix, m_rows, m_cols, vectors, scaling_factors, n_batch, result, per_channel_scale, input_offset, row_sums); return; - } else if (n_batch >= 2 && m_rows * m_cols >= 128 * 128) { + } else if (ENABLE_PADDED_DOT_PROD && n_batch >= 2 && + m_rows * m_cols >= 128 * 128) { DotprodMatrixBatchPaddedFourVectorMultiplyAccumulate( matrix, m_rows, m_cols, vectors, scaling_factors, n_batch, result, per_channel_scale, input_offset, row_sums);