Add x86 vector code using SSE 4.1 for MatrixBatchVectorMultiplyAccumulate and SparseMatrixBatchVectorMultiplyAccumulate, int8 versions only.
PiperOrigin-RevId: 255269233
This commit is contained in:
parent
7d02afd419
commit
869de82d13
@ -505,6 +505,25 @@ cc_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "sse_tensor_utils",
|
||||||
|
srcs = [
|
||||||
|
"compatibility.h",
|
||||||
|
"optimized/sse_tensor_utils.cc",
|
||||||
|
],
|
||||||
|
hdrs = [
|
||||||
|
"optimized/sse_tensor_utils.h",
|
||||||
|
"optimized/sse_tensor_utils_impl.h",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":cpu_check",
|
||||||
|
":neon_tensor_utils",
|
||||||
|
":portable_tensor_utils",
|
||||||
|
"//tensorflow/lite/c:c_api_internal",
|
||||||
|
"//tensorflow/lite/kernels:op_macros",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "kernel_utils",
|
name = "kernel_utils",
|
||||||
srcs = ["kernel_utils.cc"],
|
srcs = ["kernel_utils.cc"],
|
||||||
@ -572,7 +591,7 @@ cc_library(
|
|||||||
":neon_tensor_utils",
|
":neon_tensor_utils",
|
||||||
],
|
],
|
||||||
":haswell": [
|
":haswell": [
|
||||||
":neon_tensor_utils",
|
":sse_tensor_utils",
|
||||||
],
|
],
|
||||||
":ios_armv7": [
|
":ios_armv7": [
|
||||||
":neon_tensor_utils",
|
":neon_tensor_utils",
|
||||||
@ -581,25 +600,25 @@ cc_library(
|
|||||||
":neon_tensor_utils",
|
":neon_tensor_utils",
|
||||||
],
|
],
|
||||||
":ios_x86_64": [
|
":ios_x86_64": [
|
||||||
":neon_tensor_utils",
|
":sse_tensor_utils",
|
||||||
],
|
],
|
||||||
":x86_64": [
|
":x86_64": [
|
||||||
":neon_tensor_utils",
|
":sse_tensor_utils",
|
||||||
],
|
],
|
||||||
":x86": [
|
":x86": [
|
||||||
":neon_tensor_utils",
|
":sse_tensor_utils",
|
||||||
],
|
],
|
||||||
":k8": [
|
":k8": [
|
||||||
":neon_tensor_utils",
|
":sse_tensor_utils",
|
||||||
],
|
],
|
||||||
":darwin": [
|
":darwin": [
|
||||||
":neon_tensor_utils",
|
":neon_tensor_utils",
|
||||||
],
|
],
|
||||||
":darwin_x86_64": [
|
":darwin_x86_64": [
|
||||||
":neon_tensor_utils",
|
":sse_tensor_utils",
|
||||||
],
|
],
|
||||||
":freebsd": [
|
":freebsd": [
|
||||||
":neon_tensor_utils",
|
":sse_tensor_utils",
|
||||||
],
|
],
|
||||||
"//conditions:default": [
|
"//conditions:default": [
|
||||||
":portable_tensor_utils",
|
":portable_tensor_utils",
|
||||||
@ -808,6 +827,7 @@ cc_library(
|
|||||||
hdrs = [
|
hdrs = [
|
||||||
"optimized/cpu_check.h",
|
"optimized/cpu_check.h",
|
||||||
"optimized/neon_check.h",
|
"optimized/neon_check.h",
|
||||||
|
"optimized/sse_check.h",
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/lite/kernels:cpu_backend_context",
|
"//tensorflow/lite/kernels:cpu_backend_context",
|
||||||
|
34
tensorflow/lite/kernels/internal/optimized/sse_check.h
Normal file
34
tensorflow/lite/kernels/internal/optimized/sse_check.h
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_SSE_CHECK_H_
|
||||||
|
#define TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_SSE_CHECK_H_
|
||||||
|
|
||||||
|
#if defined(__SSE4_1__)
|
||||||
|
// SSE 4.1 available: Use the SSE code.
|
||||||
|
#define SSE_OR_PORTABLE(funcname, ...) Sse##funcname(__VA_ARGS__)
|
||||||
|
|
||||||
|
#else
|
||||||
|
|
||||||
|
#include "tensorflow/lite/kernels/internal/optimized/neon_check.h"
|
||||||
|
|
||||||
|
// No SSE 4.1 available: Fall back to NEON_OR_PORTABLE, potentially used with
|
||||||
|
// NEON_2_SSE translator library. As the library requires SSSE3, the fallback is
|
||||||
|
// generally using Portable code, only a narrow subset of processors supporting
|
||||||
|
// SSSE3 but no SSE4.1 is affected - but that includes the android_x86 ABI (not
|
||||||
|
// android_x86_64).
|
||||||
|
#define SSE_OR_PORTABLE(...) NEON_OR_PORTABLE(__VA_ARGS__)
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_SSE_CHECK_H_
|
155
tensorflow/lite/kernels/internal/optimized/sse_tensor_utils.cc
Normal file
155
tensorflow/lite/kernels/internal/optimized/sse_tensor_utils.cc
Normal file
@ -0,0 +1,155 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#include "tensorflow/lite/kernels/internal/optimized/sse_tensor_utils_impl.h"
|
||||||
|
|
||||||
|
#ifdef __SSE4_1__
|
||||||
|
|
||||||
|
#include <emmintrin.h> // SSE2
|
||||||
|
#include <smmintrin.h> // SSE4.1
|
||||||
|
#include <tmmintrin.h> // SSSE3
|
||||||
|
|
||||||
|
#include "tensorflow/lite/kernels/internal/compatibility.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace tensor_utils {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
// Elementwise multiply two i8x8 vectors to i16x8, add elements pairwise and
|
||||||
|
// accumulate result to a i32x4 accumulator.
|
||||||
|
//
|
||||||
|
// Shared by the inner loop of MatrixBatchVectorMultiplyAccumulate(int8) and
|
||||||
|
// SparseMatrixBatchVectorMultiplyAccumulate(int8).
|
||||||
|
//
|
||||||
|
// x86 SSE has no i8*i8 instruction (only a u8*i8), so we need to do sign
|
||||||
|
// extension to 16 bit and do i16*i16 multiplications. There is an instruction
|
||||||
|
// to sign-extend i8x8 => i16x8 from the lower half of the register (used here),
|
||||||
|
// but there is no direct way to sign-extend the high half, only multiple
|
||||||
|
// instructions (see _mm_cmpgt_epi8 and _mm_unpackhi_epi8). Bottom line is, it
|
||||||
|
// is actually cheaper to only to process 8 elements = 64b at a time.
|
||||||
|
static inline __m128i MatrixBatchVectorMultiplyAccumulateLoopBodySse(
|
||||||
|
__m128i dotprod, __m128i a_8x8, __m128i b_8x8) {
|
||||||
|
// Sign extend i8 => i16
|
||||||
|
__m128i a_16x8 = _mm_cvtepi8_epi16(a_8x8); // SSE4.1
|
||||||
|
__m128i b_16x8 = _mm_cvtepi8_epi16(b_8x8); // SSE4.1
|
||||||
|
// sumprod[i] = a[2*i]*b[2*i] + a[2*i+1]*b[2*i+1] (i = 0..3)
|
||||||
|
__m128i sumprod_32x4 = _mm_madd_epi16(a_16x8, b_16x8); // SSE2
|
||||||
|
// i32x4 + i32x4
|
||||||
|
return _mm_add_epi32(dotprod, sumprod_32x4); // SSE2
|
||||||
|
}
|
||||||
|
|
||||||
|
// Horizontally add 4 int32 values stored in a single XMM register to int32_t.
|
||||||
|
static inline int32_t ReduceInt32x4(__m128i acc) {
|
||||||
|
acc = _mm_hadd_epi32(acc, acc); // SSSE3
|
||||||
|
// This second hadd could be only 64 bit, but 64 and 128 bit hadd has same
|
||||||
|
// latency on most CPUs, and it costs more to move. (Moving can be no-op, but
|
||||||
|
// nevertheless is an extra instruction occupying the decoder and I cache.)
|
||||||
|
acc = _mm_hadd_epi32(acc, acc); // SSSE3
|
||||||
|
// SSE4.1 instrinsic, but actually translated to SSE2 instruction (due to
|
||||||
|
// moving from 0th element).
|
||||||
|
return _mm_extract_epi32(acc, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
void SseMatrixBatchVectorMultiplyAccumulate(
|
||||||
|
const int8_t* __restrict__ matrix, const int m_rows, const int m_cols,
|
||||||
|
const int8_t* __restrict__ vectors, const float* scaling_factors,
|
||||||
|
int n_batch, float* __restrict__ result, int result_stride) {
|
||||||
|
static constexpr int kBlockSize = 8;
|
||||||
|
for (int batch = 0; batch < n_batch; ++batch) {
|
||||||
|
const float batch_scaling_factor = scaling_factors[batch];
|
||||||
|
// Compute dot-product for every column.
|
||||||
|
for (int row = 0; row < m_rows; ++row, result += result_stride) {
|
||||||
|
// Get the address of the first element of the row.
|
||||||
|
const int8_t* row_ptr = matrix + row * m_cols;
|
||||||
|
|
||||||
|
// Initialize the dot product sum for the row to 0.
|
||||||
|
__m128i dotprod_32x4 = _mm_setzero_si128(); // SSE2
|
||||||
|
// For every block of kBlockSize 8-bit elements.
|
||||||
|
int col = 0;
|
||||||
|
for (; col < (m_cols & ~(kBlockSize - 1)); col += kBlockSize) {
|
||||||
|
// See comment at MatrixBatchVectorMultiplyAccumulateLoopBodySse why to
|
||||||
|
// load only 64 bits. _mm_loadl_epi64 requires SSE2.
|
||||||
|
const __m128i vec_8x8 =
|
||||||
|
_mm_loadl_epi64(reinterpret_cast<const __m128i*>(vectors + col));
|
||||||
|
const __m128i row_8x8 =
|
||||||
|
_mm_loadl_epi64(reinterpret_cast<const __m128i*>(row_ptr + col));
|
||||||
|
dotprod_32x4 = MatrixBatchVectorMultiplyAccumulateLoopBodySse(
|
||||||
|
dotprod_32x4, vec_8x8, row_8x8);
|
||||||
|
} // for col
|
||||||
|
// Horizontally add the 4 intermediate sum values to get the final
|
||||||
|
// dot-prod value for this row.
|
||||||
|
int32_t sum = ReduceInt32x4(dotprod_32x4);
|
||||||
|
|
||||||
|
// Postamble loop.
|
||||||
|
for (; col < m_cols; ++col) {
|
||||||
|
sum += row_ptr[col] * vectors[col];
|
||||||
|
} // for col
|
||||||
|
|
||||||
|
*result += sum * batch_scaling_factor;
|
||||||
|
} // for row
|
||||||
|
|
||||||
|
vectors += m_cols;
|
||||||
|
} // for batch
|
||||||
|
}
|
||||||
|
|
||||||
|
void SseSparseMatrixBatchVectorMultiplyAccumulate(
|
||||||
|
const int8_t* __restrict__ matrix, const uint8_t* ledger, const int m_rows,
|
||||||
|
const int m_cols, const int8_t* __restrict__ vectors,
|
||||||
|
const float* scaling_factors, int n_batch, float* __restrict__ result,
|
||||||
|
int result_stride) {
|
||||||
|
static const int kBlockSize = 16;
|
||||||
|
TFLITE_DCHECK_EQ(m_cols % kBlockSize, 0);
|
||||||
|
|
||||||
|
for (int batch = 0; batch < n_batch; ++batch, vectors += m_cols) {
|
||||||
|
const float batch_scaling_factor = scaling_factors[batch];
|
||||||
|
const uint8_t* ledger_ptr = ledger;
|
||||||
|
const int8_t* row_ptr = matrix;
|
||||||
|
for (int row = 0; row < m_rows; ++row, result += result_stride) {
|
||||||
|
// Initialize the dot product sum for the row to 0.
|
||||||
|
__m128i dotprod_32x4 = _mm_setzero_si128();
|
||||||
|
int num_nonzero_blocks = *ledger_ptr++;
|
||||||
|
for (int i = 0; i < num_nonzero_blocks; i++) {
|
||||||
|
const int col_index = *ledger_ptr++ * kBlockSize;
|
||||||
|
// With sparse models, we assume the block size is 16, we can't change
|
||||||
|
// it to 8 here to better fit SSE (see dense version). Instead, do the
|
||||||
|
// int8x8_t computation twice.
|
||||||
|
__m128i vec_8x8 = _mm_loadl_epi64(
|
||||||
|
reinterpret_cast<const __m128i*>(vectors + col_index));
|
||||||
|
__m128i row_8x8 =
|
||||||
|
_mm_loadl_epi64(reinterpret_cast<const __m128i*>(row_ptr));
|
||||||
|
dotprod_32x4 = MatrixBatchVectorMultiplyAccumulateLoopBodySse(
|
||||||
|
dotprod_32x4, vec_8x8, row_8x8);
|
||||||
|
vec_8x8 = _mm_loadl_epi64(
|
||||||
|
reinterpret_cast<const __m128i*>(vectors + col_index + 8));
|
||||||
|
row_8x8 =
|
||||||
|
_mm_loadl_epi64(reinterpret_cast<const __m128i*>(row_ptr + 8));
|
||||||
|
dotprod_32x4 = MatrixBatchVectorMultiplyAccumulateLoopBodySse(
|
||||||
|
dotprod_32x4, vec_8x8, row_8x8);
|
||||||
|
row_ptr += kBlockSize;
|
||||||
|
}
|
||||||
|
// Horizontally add the 4 intermediate sum values to get the final
|
||||||
|
// dot-prod value for this row.
|
||||||
|
int32_t dotprod = ReduceInt32x4(dotprod_32x4);
|
||||||
|
|
||||||
|
*result += dotprod * batch_scaling_factor;
|
||||||
|
} // for row
|
||||||
|
} // for batch
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace tensor_utils
|
||||||
|
} // namespace tflite
|
||||||
|
|
||||||
|
#endif // __SSE4_1__
|
186
tensorflow/lite/kernels/internal/optimized/sse_tensor_utils.h
Normal file
186
tensorflow/lite/kernels/internal/optimized/sse_tensor_utils.h
Normal file
@ -0,0 +1,186 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_SSE_TENSOR_UTILS_H_
|
||||||
|
#define TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_SSE_TENSOR_UTILS_H_
|
||||||
|
|
||||||
|
// Note: This file is a copy-paste version of neon_tensor_utils.h, only
|
||||||
|
// difference is in MatrixBatchVectorMultiplyAccumulate and
|
||||||
|
// SparseMatrixBatchVectorMultiplyAccumulate (other functions do not have SSE
|
||||||
|
// implementation yet).
|
||||||
|
|
||||||
|
// Note: Most of the functions below use NEON_OR_PORTABLE, through the Intel
|
||||||
|
// NEON_2_SSE translator library. If a native SSE version of a function is
|
||||||
|
// implemented, replace the appropriate one to SSE_OR_PORTABLE.
|
||||||
|
|
||||||
|
// TODO(ghodrat): Remove this header file and the dependency to internal data
|
||||||
|
// structure.
|
||||||
|
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||||
|
#include "tensorflow/lite/kernels/internal/optimized/neon_check.h"
|
||||||
|
#include "tensorflow/lite/kernels/internal/optimized/neon_tensor_utils_impl.h"
|
||||||
|
#include "tensorflow/lite/kernels/internal/optimized/sse_check.h"
|
||||||
|
#include "tensorflow/lite/kernels/internal/optimized/sse_tensor_utils_impl.h"
|
||||||
|
#include "tensorflow/lite/kernels/internal/reference/portable_tensor_utils_impl.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace tensor_utils {
|
||||||
|
|
||||||
|
void MatrixBatchVectorMultiplyAccumulate(const float* matrix, int m_rows,
|
||||||
|
int m_cols, const float* vector,
|
||||||
|
int n_batch, float* result,
|
||||||
|
int result_stride) {
|
||||||
|
NEON_OR_PORTABLE(MatrixBatchVectorMultiplyAccumulate, matrix, m_rows, m_cols,
|
||||||
|
vector, n_batch, result, result_stride);
|
||||||
|
}
|
||||||
|
|
||||||
|
void MatrixBatchVectorMultiplyAccumulate(
|
||||||
|
const int8_t* __restrict__ matrix, const int m_rows, const int m_cols,
|
||||||
|
const int8_t* __restrict__ vectors, const float* scaling_factors,
|
||||||
|
int n_batch, float* __restrict__ result, int result_stride) {
|
||||||
|
SSE_OR_PORTABLE(MatrixBatchVectorMultiplyAccumulate, matrix, m_rows, m_cols,
|
||||||
|
vectors, scaling_factors, n_batch, result, result_stride);
|
||||||
|
}
|
||||||
|
|
||||||
|
void SparseMatrixBatchVectorMultiplyAccumulate(
|
||||||
|
const float* __restrict__ matrix, const uint8_t* __restrict__ ledger,
|
||||||
|
int m_rows, int m_cols, const float* __restrict__ vector, int n_batch,
|
||||||
|
float* __restrict__ result, int result_stride) {
|
||||||
|
NEON_OR_PORTABLE(SparseMatrixBatchVectorMultiplyAccumulate, matrix, ledger,
|
||||||
|
m_rows, m_cols, vector, n_batch, result, result_stride);
|
||||||
|
}
|
||||||
|
|
||||||
|
void SparseMatrixBatchVectorMultiplyAccumulate(
|
||||||
|
const int8_t* __restrict__ matrix, const uint8_t* ledger, const int m_rows,
|
||||||
|
const int m_cols, const int8_t* __restrict__ vectors,
|
||||||
|
const float* scaling_factors, int n_batch, float* __restrict__ result,
|
||||||
|
int result_stride) {
|
||||||
|
SSE_OR_PORTABLE(SparseMatrixBatchVectorMultiplyAccumulate, matrix, ledger,
|
||||||
|
m_rows, m_cols, vectors, scaling_factors, n_batch, result,
|
||||||
|
result_stride);
|
||||||
|
}
|
||||||
|
|
||||||
|
void VectorVectorCwiseProduct(const float* vector1, const float* vector2,
|
||||||
|
int v_size, float* result) {
|
||||||
|
NEON_OR_PORTABLE(VectorVectorCwiseProduct, vector1, vector2, v_size, result);
|
||||||
|
}
|
||||||
|
|
||||||
|
void VectorVectorCwiseProductAccumulate(const float* vector1,
|
||||||
|
const float* vector2, int v_size,
|
||||||
|
float* result) {
|
||||||
|
NEON_OR_PORTABLE(VectorVectorCwiseProductAccumulate, vector1, vector2, v_size,
|
||||||
|
result);
|
||||||
|
}
|
||||||
|
|
||||||
|
void VectorBatchVectorCwiseProduct(const float* vector, int v_size,
|
||||||
|
const float* batch_vector, int n_batch,
|
||||||
|
float* result) {
|
||||||
|
NEON_OR_PORTABLE(VectorBatchVectorCwiseProduct, vector, v_size, batch_vector,
|
||||||
|
n_batch, result);
|
||||||
|
}
|
||||||
|
|
||||||
|
void VectorBatchVectorCwiseProductAccumulate(const float* vector, int v_size,
|
||||||
|
const float* batch_vector,
|
||||||
|
int n_batch, float* result) {
|
||||||
|
NEON_OR_PORTABLE(VectorBatchVectorCwiseProductAccumulate, vector, v_size,
|
||||||
|
batch_vector, n_batch, result);
|
||||||
|
}
|
||||||
|
|
||||||
|
float VectorVectorDotProduct(const float* vector1, const float* vector2,
|
||||||
|
int v_size) {
|
||||||
|
return NEON_OR_PORTABLE(VectorVectorDotProduct, vector1, vector2, v_size);
|
||||||
|
}
|
||||||
|
|
||||||
|
void BatchVectorBatchVectorDotProduct(const float* vector1,
|
||||||
|
const float* vector2, int v_size,
|
||||||
|
int n_batch, float* result,
|
||||||
|
int result_stride) {
|
||||||
|
NEON_OR_PORTABLE(BatchVectorBatchVectorDotProduct, vector1, vector2, v_size,
|
||||||
|
n_batch, result, result_stride);
|
||||||
|
}
|
||||||
|
|
||||||
|
void VectorBatchVectorAdd(const float* vector, int v_size, int n_batch,
|
||||||
|
float* batch_vector) {
|
||||||
|
PortableVectorBatchVectorAdd(vector, v_size, n_batch, batch_vector);
|
||||||
|
}
|
||||||
|
|
||||||
|
void VectorBatchVectorAssign(const float* vector, int v_size, int n_batch,
|
||||||
|
float* batch_vector) {
|
||||||
|
PortableVectorBatchVectorAssign(vector, v_size, n_batch, batch_vector);
|
||||||
|
}
|
||||||
|
|
||||||
|
void ApplySigmoidToVector(const float* vector, int v_size, float* result) {
|
||||||
|
PortableApplySigmoidToVector(vector, v_size, result);
|
||||||
|
}
|
||||||
|
|
||||||
|
void ApplyActivationToVector(const float* vector, int v_size,
|
||||||
|
TfLiteFusedActivation activation, float* result) {
|
||||||
|
PortableApplyActivationToVector(vector, v_size, activation, result);
|
||||||
|
}
|
||||||
|
|
||||||
|
void CopyVector(const float* vector, int v_size, float* result) {
|
||||||
|
PortableCopyVector(vector, v_size, result);
|
||||||
|
}
|
||||||
|
|
||||||
|
void Sub1Vector(const float* vector, int v_size, float* result) {
|
||||||
|
NEON_OR_PORTABLE(Sub1Vector, vector, v_size, result);
|
||||||
|
}
|
||||||
|
|
||||||
|
void ZeroVector(float* vector, int v_size) {
|
||||||
|
PortableZeroVector(vector, v_size);
|
||||||
|
}
|
||||||
|
|
||||||
|
float Clip(float f, float abs_limit) { return PortableClip(f, abs_limit); }
|
||||||
|
|
||||||
|
// Check if all entries of a vector are zero.
|
||||||
|
bool IsZeroVector(const float* vector, int v_size) {
|
||||||
|
return NEON_OR_PORTABLE(IsZeroVector, vector, v_size);
|
||||||
|
}
|
||||||
|
|
||||||
|
void VectorScalarMultiply(const int8_t* vector, int v_size, float scale,
|
||||||
|
float* result) {
|
||||||
|
NEON_OR_PORTABLE(VectorScalarMultiply, vector, v_size, scale, result);
|
||||||
|
}
|
||||||
|
void ClipVector(const float* vector, int v_size, float abs_limit,
|
||||||
|
float* result) {
|
||||||
|
NEON_OR_PORTABLE(ClipVector, vector, v_size, abs_limit, result);
|
||||||
|
}
|
||||||
|
|
||||||
|
void SymmetricQuantizeFloats(const float* values, const int size,
|
||||||
|
int8_t* quantized_values, float* min_value,
|
||||||
|
float* max_value, float* scaling_factor) {
|
||||||
|
NEON_OR_PORTABLE(SymmetricQuantizeFloats, values, size, quantized_values,
|
||||||
|
min_value, max_value, scaling_factor);
|
||||||
|
}
|
||||||
|
|
||||||
|
void VectorShiftLeft(float* vector, int v_size, float shift_value) {
|
||||||
|
NEON_OR_PORTABLE(VectorShiftLeft, vector, v_size, shift_value);
|
||||||
|
}
|
||||||
|
|
||||||
|
void ReductionSumVector(const float* input_vector, float* output_vector,
|
||||||
|
int output_size, int reduction_size) {
|
||||||
|
NEON_OR_PORTABLE(ReductionSumVector, input_vector, output_vector, output_size,
|
||||||
|
reduction_size);
|
||||||
|
}
|
||||||
|
|
||||||
|
void MeanStddevNormalization(const float* input_vector, float* output_vector,
|
||||||
|
int v_size, int n_batch,
|
||||||
|
float normalization_epsilon) {
|
||||||
|
PortableMeanStddevNormalization(input_vector, output_vector, v_size, n_batch,
|
||||||
|
normalization_epsilon);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace tensor_utils
|
||||||
|
} // namespace tflite
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_SSE_TENSOR_UTILS_H_
|
@ -0,0 +1,48 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_SSE_TENSOR_UTILS_IMPL_H_
|
||||||
|
#define TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_SSE_TENSOR_UTILS_IMPL_H_
|
||||||
|
|
||||||
|
#include <cstdint>
|
||||||
|
|
||||||
|
#if defined(_MSC_VER)
|
||||||
|
#define __restrict__ __restrict
|
||||||
|
#endif
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace tensor_utils {
|
||||||
|
|
||||||
|
#ifdef __SSE4_1__
|
||||||
|
|
||||||
|
// Matrix multiplication for quantized values using symmetric quantization.
|
||||||
|
void SseMatrixBatchVectorMultiplyAccumulate(
|
||||||
|
const int8_t* __restrict__ matrix, const int m_rows, const int m_cols,
|
||||||
|
const int8_t* __restrict__ vectors, const float* scaling_factors,
|
||||||
|
int n_batch, float* __restrict__ result, int result_stride);
|
||||||
|
|
||||||
|
// Matrix multiplication for quantized values using symmetric quantization.
|
||||||
|
// Sparse version.
|
||||||
|
void SseSparseMatrixBatchVectorMultiplyAccumulate(
|
||||||
|
const int8_t* __restrict__ matrix, const uint8_t* ledger, const int m_rows,
|
||||||
|
const int m_cols, const int8_t* __restrict__ vectors,
|
||||||
|
const float* scaling_factors, int n_batch, float* __restrict__ result,
|
||||||
|
int result_stride);
|
||||||
|
|
||||||
|
#endif // __SSE4_1__
|
||||||
|
|
||||||
|
} // namespace tensor_utils
|
||||||
|
} // namespace tflite
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_SSE_TENSOR_UTILS_IMPL_H_
|
@ -14,10 +14,12 @@ limitations under the License.
|
|||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
#include "tensorflow/lite/kernels/internal/tensor_utils.h"
|
#include "tensorflow/lite/kernels/internal/tensor_utils.h"
|
||||||
|
|
||||||
#include "tensorflow/lite/kernels/internal/optimized/cpu_check.h"
|
#include "tensorflow/lite/kernels/internal/optimized/neon_check.h"
|
||||||
|
|
||||||
#ifdef USE_NEON
|
#if defined(__SSE4_1__)
|
||||||
|
#include "tensorflow/lite/kernels/internal/optimized/sse_tensor_utils.h"
|
||||||
|
#elif defined(USE_NEON)
|
||||||
#include "tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.h"
|
#include "tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.h"
|
||||||
#else
|
#else
|
||||||
#include "tensorflow/lite/kernels/internal/reference/portable_tensor_utils.h"
|
#include "tensorflow/lite/kernels/internal/reference/portable_tensor_utils.h"
|
||||||
#endif // USE_NEON
|
#endif // __SSE4_1__ or USE_NEON
|
||||||
|
Loading…
Reference in New Issue
Block a user