Add x86 vector code using SSE 4.1 for MatrixBatchVectorMultiplyAccumulate and SparseMatrixBatchVectorMultiplyAccumulate, int8 versions only.

PiperOrigin-RevId: 255269233
This commit is contained in:
A. Unique TensorFlower 2019-06-26 14:49:46 -07:00 committed by TensorFlower Gardener
parent 7d02afd419
commit 869de82d13
6 changed files with 455 additions and 10 deletions

View File

@ -505,6 +505,25 @@ cc_library(
],
)
cc_library(
name = "sse_tensor_utils",
srcs = [
"compatibility.h",
"optimized/sse_tensor_utils.cc",
],
hdrs = [
"optimized/sse_tensor_utils.h",
"optimized/sse_tensor_utils_impl.h",
],
deps = [
":cpu_check",
":neon_tensor_utils",
":portable_tensor_utils",
"//tensorflow/lite/c:c_api_internal",
"//tensorflow/lite/kernels:op_macros",
],
)
cc_library(
name = "kernel_utils",
srcs = ["kernel_utils.cc"],
@ -572,7 +591,7 @@ cc_library(
":neon_tensor_utils",
],
":haswell": [
":neon_tensor_utils",
":sse_tensor_utils",
],
":ios_armv7": [
":neon_tensor_utils",
@ -581,25 +600,25 @@ cc_library(
":neon_tensor_utils",
],
":ios_x86_64": [
":neon_tensor_utils",
":sse_tensor_utils",
],
":x86_64": [
":neon_tensor_utils",
":sse_tensor_utils",
],
":x86": [
":neon_tensor_utils",
":sse_tensor_utils",
],
":k8": [
":neon_tensor_utils",
":sse_tensor_utils",
],
":darwin": [
":neon_tensor_utils",
],
":darwin_x86_64": [
":neon_tensor_utils",
":sse_tensor_utils",
],
":freebsd": [
":neon_tensor_utils",
":sse_tensor_utils",
],
"//conditions:default": [
":portable_tensor_utils",
@ -808,6 +827,7 @@ cc_library(
hdrs = [
"optimized/cpu_check.h",
"optimized/neon_check.h",
"optimized/sse_check.h",
],
deps = [
"//tensorflow/lite/kernels:cpu_backend_context",

View File

@ -0,0 +1,34 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_SSE_CHECK_H_
#define TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_SSE_CHECK_H_
#if defined(__SSE4_1__)
// SSE 4.1 available: Use the SSE code.
#define SSE_OR_PORTABLE(funcname, ...) Sse##funcname(__VA_ARGS__)
#else
#include "tensorflow/lite/kernels/internal/optimized/neon_check.h"
// No SSE 4.1 available: Fall back to NEON_OR_PORTABLE, potentially used with
// NEON_2_SSE translator library. As the library requires SSSE3, the fallback is
// generally using Portable code, only a narrow subset of processors supporting
// SSSE3 but no SSE4.1 is affected - but that includes the android_x86 ABI (not
// android_x86_64).
#define SSE_OR_PORTABLE(...) NEON_OR_PORTABLE(__VA_ARGS__)
#endif
#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_SSE_CHECK_H_

View File

@ -0,0 +1,155 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/kernels/internal/optimized/sse_tensor_utils_impl.h"
#ifdef __SSE4_1__
#include <emmintrin.h> // SSE2
#include <smmintrin.h> // SSE4.1
#include <tmmintrin.h> // SSSE3
#include "tensorflow/lite/kernels/internal/compatibility.h"
namespace tflite {
namespace tensor_utils {
namespace {
// Elementwise multiply two i8x8 vectors to i16x8, add elements pairwise and
// accumulate result to a i32x4 accumulator.
//
// Shared by the inner loop of MatrixBatchVectorMultiplyAccumulate(int8) and
// SparseMatrixBatchVectorMultiplyAccumulate(int8).
//
// x86 SSE has no i8*i8 instruction (only a u8*i8), so we need to do sign
// extension to 16 bit and do i16*i16 multiplications. There is an instruction
// to sign-extend i8x8 => i16x8 from the lower half of the register (used here),
// but there is no direct way to sign-extend the high half, only multiple
// instructions (see _mm_cmpgt_epi8 and _mm_unpackhi_epi8). Bottom line is, it
// is actually cheaper to only to process 8 elements = 64b at a time.
static inline __m128i MatrixBatchVectorMultiplyAccumulateLoopBodySse(
__m128i dotprod, __m128i a_8x8, __m128i b_8x8) {
// Sign extend i8 => i16
__m128i a_16x8 = _mm_cvtepi8_epi16(a_8x8); // SSE4.1
__m128i b_16x8 = _mm_cvtepi8_epi16(b_8x8); // SSE4.1
// sumprod[i] = a[2*i]*b[2*i] + a[2*i+1]*b[2*i+1] (i = 0..3)
__m128i sumprod_32x4 = _mm_madd_epi16(a_16x8, b_16x8); // SSE2
// i32x4 + i32x4
return _mm_add_epi32(dotprod, sumprod_32x4); // SSE2
}
// Horizontally add 4 int32 values stored in a single XMM register to int32_t.
static inline int32_t ReduceInt32x4(__m128i acc) {
acc = _mm_hadd_epi32(acc, acc); // SSSE3
// This second hadd could be only 64 bit, but 64 and 128 bit hadd has same
// latency on most CPUs, and it costs more to move. (Moving can be no-op, but
// nevertheless is an extra instruction occupying the decoder and I cache.)
acc = _mm_hadd_epi32(acc, acc); // SSSE3
// SSE4.1 instrinsic, but actually translated to SSE2 instruction (due to
// moving from 0th element).
return _mm_extract_epi32(acc, 0);
}
} // namespace
void SseMatrixBatchVectorMultiplyAccumulate(
const int8_t* __restrict__ matrix, const int m_rows, const int m_cols,
const int8_t* __restrict__ vectors, const float* scaling_factors,
int n_batch, float* __restrict__ result, int result_stride) {
static constexpr int kBlockSize = 8;
for (int batch = 0; batch < n_batch; ++batch) {
const float batch_scaling_factor = scaling_factors[batch];
// Compute dot-product for every column.
for (int row = 0; row < m_rows; ++row, result += result_stride) {
// Get the address of the first element of the row.
const int8_t* row_ptr = matrix + row * m_cols;
// Initialize the dot product sum for the row to 0.
__m128i dotprod_32x4 = _mm_setzero_si128(); // SSE2
// For every block of kBlockSize 8-bit elements.
int col = 0;
for (; col < (m_cols & ~(kBlockSize - 1)); col += kBlockSize) {
// See comment at MatrixBatchVectorMultiplyAccumulateLoopBodySse why to
// load only 64 bits. _mm_loadl_epi64 requires SSE2.
const __m128i vec_8x8 =
_mm_loadl_epi64(reinterpret_cast<const __m128i*>(vectors + col));
const __m128i row_8x8 =
_mm_loadl_epi64(reinterpret_cast<const __m128i*>(row_ptr + col));
dotprod_32x4 = MatrixBatchVectorMultiplyAccumulateLoopBodySse(
dotprod_32x4, vec_8x8, row_8x8);
} // for col
// Horizontally add the 4 intermediate sum values to get the final
// dot-prod value for this row.
int32_t sum = ReduceInt32x4(dotprod_32x4);
// Postamble loop.
for (; col < m_cols; ++col) {
sum += row_ptr[col] * vectors[col];
} // for col
*result += sum * batch_scaling_factor;
} // for row
vectors += m_cols;
} // for batch
}
void SseSparseMatrixBatchVectorMultiplyAccumulate(
const int8_t* __restrict__ matrix, const uint8_t* ledger, const int m_rows,
const int m_cols, const int8_t* __restrict__ vectors,
const float* scaling_factors, int n_batch, float* __restrict__ result,
int result_stride) {
static const int kBlockSize = 16;
TFLITE_DCHECK_EQ(m_cols % kBlockSize, 0);
for (int batch = 0; batch < n_batch; ++batch, vectors += m_cols) {
const float batch_scaling_factor = scaling_factors[batch];
const uint8_t* ledger_ptr = ledger;
const int8_t* row_ptr = matrix;
for (int row = 0; row < m_rows; ++row, result += result_stride) {
// Initialize the dot product sum for the row to 0.
__m128i dotprod_32x4 = _mm_setzero_si128();
int num_nonzero_blocks = *ledger_ptr++;
for (int i = 0; i < num_nonzero_blocks; i++) {
const int col_index = *ledger_ptr++ * kBlockSize;
// With sparse models, we assume the block size is 16, we can't change
// it to 8 here to better fit SSE (see dense version). Instead, do the
// int8x8_t computation twice.
__m128i vec_8x8 = _mm_loadl_epi64(
reinterpret_cast<const __m128i*>(vectors + col_index));
__m128i row_8x8 =
_mm_loadl_epi64(reinterpret_cast<const __m128i*>(row_ptr));
dotprod_32x4 = MatrixBatchVectorMultiplyAccumulateLoopBodySse(
dotprod_32x4, vec_8x8, row_8x8);
vec_8x8 = _mm_loadl_epi64(
reinterpret_cast<const __m128i*>(vectors + col_index + 8));
row_8x8 =
_mm_loadl_epi64(reinterpret_cast<const __m128i*>(row_ptr + 8));
dotprod_32x4 = MatrixBatchVectorMultiplyAccumulateLoopBodySse(
dotprod_32x4, vec_8x8, row_8x8);
row_ptr += kBlockSize;
}
// Horizontally add the 4 intermediate sum values to get the final
// dot-prod value for this row.
int32_t dotprod = ReduceInt32x4(dotprod_32x4);
*result += dotprod * batch_scaling_factor;
} // for row
} // for batch
}
} // namespace tensor_utils
} // namespace tflite
#endif // __SSE4_1__

View File

@ -0,0 +1,186 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_SSE_TENSOR_UTILS_H_
#define TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_SSE_TENSOR_UTILS_H_
// Note: This file is a copy-paste version of neon_tensor_utils.h, only
// difference is in MatrixBatchVectorMultiplyAccumulate and
// SparseMatrixBatchVectorMultiplyAccumulate (other functions do not have SSE
// implementation yet).
// Note: Most of the functions below use NEON_OR_PORTABLE, through the Intel
// NEON_2_SSE translator library. If a native SSE version of a function is
// implemented, replace the appropriate one to SSE_OR_PORTABLE.
// TODO(ghodrat): Remove this header file and the dependency to internal data
// structure.
#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/kernels/internal/optimized/neon_check.h"
#include "tensorflow/lite/kernels/internal/optimized/neon_tensor_utils_impl.h"
#include "tensorflow/lite/kernels/internal/optimized/sse_check.h"
#include "tensorflow/lite/kernels/internal/optimized/sse_tensor_utils_impl.h"
#include "tensorflow/lite/kernels/internal/reference/portable_tensor_utils_impl.h"
namespace tflite {
namespace tensor_utils {
void MatrixBatchVectorMultiplyAccumulate(const float* matrix, int m_rows,
int m_cols, const float* vector,
int n_batch, float* result,
int result_stride) {
NEON_OR_PORTABLE(MatrixBatchVectorMultiplyAccumulate, matrix, m_rows, m_cols,
vector, n_batch, result, result_stride);
}
void MatrixBatchVectorMultiplyAccumulate(
const int8_t* __restrict__ matrix, const int m_rows, const int m_cols,
const int8_t* __restrict__ vectors, const float* scaling_factors,
int n_batch, float* __restrict__ result, int result_stride) {
SSE_OR_PORTABLE(MatrixBatchVectorMultiplyAccumulate, matrix, m_rows, m_cols,
vectors, scaling_factors, n_batch, result, result_stride);
}
void SparseMatrixBatchVectorMultiplyAccumulate(
const float* __restrict__ matrix, const uint8_t* __restrict__ ledger,
int m_rows, int m_cols, const float* __restrict__ vector, int n_batch,
float* __restrict__ result, int result_stride) {
NEON_OR_PORTABLE(SparseMatrixBatchVectorMultiplyAccumulate, matrix, ledger,
m_rows, m_cols, vector, n_batch, result, result_stride);
}
void SparseMatrixBatchVectorMultiplyAccumulate(
const int8_t* __restrict__ matrix, const uint8_t* ledger, const int m_rows,
const int m_cols, const int8_t* __restrict__ vectors,
const float* scaling_factors, int n_batch, float* __restrict__ result,
int result_stride) {
SSE_OR_PORTABLE(SparseMatrixBatchVectorMultiplyAccumulate, matrix, ledger,
m_rows, m_cols, vectors, scaling_factors, n_batch, result,
result_stride);
}
void VectorVectorCwiseProduct(const float* vector1, const float* vector2,
int v_size, float* result) {
NEON_OR_PORTABLE(VectorVectorCwiseProduct, vector1, vector2, v_size, result);
}
void VectorVectorCwiseProductAccumulate(const float* vector1,
const float* vector2, int v_size,
float* result) {
NEON_OR_PORTABLE(VectorVectorCwiseProductAccumulate, vector1, vector2, v_size,
result);
}
void VectorBatchVectorCwiseProduct(const float* vector, int v_size,
const float* batch_vector, int n_batch,
float* result) {
NEON_OR_PORTABLE(VectorBatchVectorCwiseProduct, vector, v_size, batch_vector,
n_batch, result);
}
void VectorBatchVectorCwiseProductAccumulate(const float* vector, int v_size,
const float* batch_vector,
int n_batch, float* result) {
NEON_OR_PORTABLE(VectorBatchVectorCwiseProductAccumulate, vector, v_size,
batch_vector, n_batch, result);
}
float VectorVectorDotProduct(const float* vector1, const float* vector2,
int v_size) {
return NEON_OR_PORTABLE(VectorVectorDotProduct, vector1, vector2, v_size);
}
void BatchVectorBatchVectorDotProduct(const float* vector1,
const float* vector2, int v_size,
int n_batch, float* result,
int result_stride) {
NEON_OR_PORTABLE(BatchVectorBatchVectorDotProduct, vector1, vector2, v_size,
n_batch, result, result_stride);
}
void VectorBatchVectorAdd(const float* vector, int v_size, int n_batch,
float* batch_vector) {
PortableVectorBatchVectorAdd(vector, v_size, n_batch, batch_vector);
}
void VectorBatchVectorAssign(const float* vector, int v_size, int n_batch,
float* batch_vector) {
PortableVectorBatchVectorAssign(vector, v_size, n_batch, batch_vector);
}
void ApplySigmoidToVector(const float* vector, int v_size, float* result) {
PortableApplySigmoidToVector(vector, v_size, result);
}
void ApplyActivationToVector(const float* vector, int v_size,
TfLiteFusedActivation activation, float* result) {
PortableApplyActivationToVector(vector, v_size, activation, result);
}
void CopyVector(const float* vector, int v_size, float* result) {
PortableCopyVector(vector, v_size, result);
}
void Sub1Vector(const float* vector, int v_size, float* result) {
NEON_OR_PORTABLE(Sub1Vector, vector, v_size, result);
}
void ZeroVector(float* vector, int v_size) {
PortableZeroVector(vector, v_size);
}
float Clip(float f, float abs_limit) { return PortableClip(f, abs_limit); }
// Check if all entries of a vector are zero.
bool IsZeroVector(const float* vector, int v_size) {
return NEON_OR_PORTABLE(IsZeroVector, vector, v_size);
}
void VectorScalarMultiply(const int8_t* vector, int v_size, float scale,
float* result) {
NEON_OR_PORTABLE(VectorScalarMultiply, vector, v_size, scale, result);
}
void ClipVector(const float* vector, int v_size, float abs_limit,
float* result) {
NEON_OR_PORTABLE(ClipVector, vector, v_size, abs_limit, result);
}
void SymmetricQuantizeFloats(const float* values, const int size,
int8_t* quantized_values, float* min_value,
float* max_value, float* scaling_factor) {
NEON_OR_PORTABLE(SymmetricQuantizeFloats, values, size, quantized_values,
min_value, max_value, scaling_factor);
}
void VectorShiftLeft(float* vector, int v_size, float shift_value) {
NEON_OR_PORTABLE(VectorShiftLeft, vector, v_size, shift_value);
}
void ReductionSumVector(const float* input_vector, float* output_vector,
int output_size, int reduction_size) {
NEON_OR_PORTABLE(ReductionSumVector, input_vector, output_vector, output_size,
reduction_size);
}
void MeanStddevNormalization(const float* input_vector, float* output_vector,
int v_size, int n_batch,
float normalization_epsilon) {
PortableMeanStddevNormalization(input_vector, output_vector, v_size, n_batch,
normalization_epsilon);
}
} // namespace tensor_utils
} // namespace tflite
#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_SSE_TENSOR_UTILS_H_

View File

@ -0,0 +1,48 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_SSE_TENSOR_UTILS_IMPL_H_
#define TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_SSE_TENSOR_UTILS_IMPL_H_
#include <cstdint>
#if defined(_MSC_VER)
#define __restrict__ __restrict
#endif
namespace tflite {
namespace tensor_utils {
#ifdef __SSE4_1__
// Matrix multiplication for quantized values using symmetric quantization.
void SseMatrixBatchVectorMultiplyAccumulate(
const int8_t* __restrict__ matrix, const int m_rows, const int m_cols,
const int8_t* __restrict__ vectors, const float* scaling_factors,
int n_batch, float* __restrict__ result, int result_stride);
// Matrix multiplication for quantized values using symmetric quantization.
// Sparse version.
void SseSparseMatrixBatchVectorMultiplyAccumulate(
const int8_t* __restrict__ matrix, const uint8_t* ledger, const int m_rows,
const int m_cols, const int8_t* __restrict__ vectors,
const float* scaling_factors, int n_batch, float* __restrict__ result,
int result_stride);
#endif // __SSE4_1__
} // namespace tensor_utils
} // namespace tflite
#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_SSE_TENSOR_UTILS_IMPL_H_

View File

@ -14,10 +14,12 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/lite/kernels/internal/tensor_utils.h"
#include "tensorflow/lite/kernels/internal/optimized/cpu_check.h"
#include "tensorflow/lite/kernels/internal/optimized/neon_check.h"
#ifdef USE_NEON
#if defined(__SSE4_1__)
#include "tensorflow/lite/kernels/internal/optimized/sse_tensor_utils.h"
#elif defined(USE_NEON)
#include "tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.h"
#else
#include "tensorflow/lite/kernels/internal/reference/portable_tensor_utils.h"
#endif // USE_NEON
#endif // __SSE4_1__ or USE_NEON