Hybrid convolutions go through CpuBackendGEMM on x86

PiperOrigin-RevId: 321618575
Change-Id: I7122ab059604557eb7040ee8c6f41a1b2d709ca1
This commit is contained in:
T.J. Alumbaugh 2020-07-16 12:15:14 -07:00 committed by TensorFlower Gardener
parent c3f2d3d571
commit 446c04c268
6 changed files with 148 additions and 10 deletions

View File

@ -888,13 +888,16 @@ void EvalHybrid(TfLiteContext* context, TfLiteNode* node,
GetTemporary(context, node, data->scaling_factors_index));
// Per-batch input quantization for higher accuracy.
for (int b = 0; b < batch_size; ++b) {
float unused_min, unused_max;
const int offset = b * input_size;
tensor_utils::SymmetricQuantizeFloats(
input_ptr + offset, input_size, quantized_input_ptr_batch + offset,
&unused_min, &unused_max, &scaling_factors_ptr[b]);
scaling_factors_ptr[b] *= filter->params.scale;
{
ruy::profiler::ScopeLabel label("ConvHybridQuantizeInputs");
for (int b = 0; b < batch_size; ++b) {
float unused_min, unused_max;
const int offset = b * input_size;
tensor_utils::SymmetricQuantizeFloats(
input_ptr + offset, input_size, quantized_input_ptr_batch + offset,
&unused_min, &unused_max, &scaling_factors_ptr[b]);
scaling_factors_ptr[b] *= filter->params.scale;
}
}
switch (kernel_type) {
@ -902,8 +905,7 @@ void EvalHybrid(TfLiteContext* context, TfLiteNode* node,
case kGenericOptimized:
case kMultithreadOptimized:
case kCblasOptimized: {
// There is only one implementation for hybrid kernel. Note
// this does not make use of gemmlowp nor supports multithreading.
// There is only one implementation for hybrid kernel.
ConvParams op_params;
op_params.padding_type = PaddingType::kSame;
op_params.padding_values.width = data->padding.width;

View File

@ -1295,6 +1295,47 @@ TEST_P(ConvolutionOpTest, SimpleTestHybridInt8) {
0.16)));
}
TEST_P(ConvolutionOpTest, SimpleTestHybridInt8Big) {
// A bigger variant of the simple hybrid test to ensure coverage on
// optimized paths that are only enabled at larger matrix sizes.
HybridConvolutionOpModel m(
GetRegistration(), {TensorType_FLOAT32, {2, 2, 4, 1}},
{TensorType_INT8, {8, 2, 2, 1}, 0, 0, 4.0 / 127.0, 0},
{TensorType_FLOAT32, {}});
m.SetInput({
// First batch
1, 1, 1, 1, // row = 1
2, 2, 2, 2, // row = 2
// Second batch
1, 2, 3, 4, // row = 1
1, 2, 3, 4, // row = 2
});
m.SetSignedFilter({
1, 2, 3, 4, // first 2x2 filter
-1, 1, -1, 1, // second 2x2 filter
-1, -1, 1, 1, // third 2x2 filter
1, 1, 3, 3, // fourth 2x2 filter
-1, -1, 3, 3, // fifth 2x2 filter
4, 3, 2, 1, // sixth 2x2 filter
2, 1, 1, 2, // seventh 2x2 filter
1, -1, 2, -2, // eighth 2x2 filter
});
m.SetBias({1, 2, 3, 4, 5, 6, 7, 8});
m.Invoke();
EXPECT_THAT(m.GetOutput(),
ElementsAreArray(ArrayFloatNear(
{
18, 2, 5, 18, 15, 19, 16, 8, // first batch, left
18, 2, 5, 18, 15, 19, 16, 8, // first batch, right
17, 4, 3, 16, 11, 20, 16, 5, // second batch, left
37, 4, 3, 32, 19, 40, 28, 5 // second batch, right
},
0.17)));
}
// This test's output is equivalent to the SimpleTestHybrid
// because we break each input into two channels, each with half of the value,
// while keeping the filters for each channel equivalent.

View File

@ -682,7 +682,9 @@ cc_library(
":portable_tensor_utils",
"//tensorflow/lite/c:common",
"//tensorflow/lite/kernels:cpu_backend_context",
"//tensorflow/lite/kernels:cpu_backend_gemm",
"//tensorflow/lite/kernels:op_macros",
"@ruy//ruy/profiler:instrumentation",
],
)

View File

@ -24,7 +24,10 @@ limitations under the License.
#include <cstdint>
#include "ruy/profiler/instrumentation.h" // from @ruy
#include "tensorflow/lite/kernels/cpu_backend_context.h"
#include "tensorflow/lite/kernels/cpu_backend_gemm.h"
#include "tensorflow/lite/kernels/cpu_backend_gemm_params.h"
#include "tensorflow/lite/kernels/internal/compatibility.h"
namespace tflite {
@ -170,6 +173,38 @@ void SseMatrixBatchVectorMultiplyAccumulateImpl(
} // for batch
}
void SseCpuBackendGemm(const int8_t* input, const int32_t* bias,
const int8_t* input_to_gate_weights, int32_t n_batch,
int32_t n_input, int32_t n_output, int32_t output_zp,
int32_t* scratch, CpuBackendContext* context) {
using ::tflite::cpu_backend_gemm::Gemm;
using ::tflite::cpu_backend_gemm::GemmParams;
using ::tflite::cpu_backend_gemm::MatrixParams;
MatrixParams<int8_t> lhs_params;
lhs_params.order = cpu_backend_gemm::Order::kRowMajor;
lhs_params.rows = n_output;
lhs_params.cols = n_input;
lhs_params.cache_policy = cpu_backend_gemm::CachePolicy::kCacheIfLargeSpeedup;
MatrixParams<int8_t> rhs_params;
rhs_params.order = cpu_backend_gemm::Order::kColMajor;
rhs_params.rows = n_input;
rhs_params.cols = n_batch;
MatrixParams<int32_t> dst_params;
dst_params.order = cpu_backend_gemm::Order::kColMajor;
dst_params.rows = n_output;
dst_params.cols = n_batch;
GemmParams<int32, int32> gemm_params;
if (bias) {
gemm_params.bias = bias;
}
cpu_backend_gemm::Gemm(lhs_params, input_to_gate_weights, rhs_params, input,
dst_params, scratch, gemm_params, context);
}
void SseMatrixBatchVectorMultiplyAccumulate(
const int8_t* __restrict__ matrix, const int m_rows, const int m_cols,
const int8_t* __restrict__ vectors,
@ -181,6 +216,56 @@ void SseMatrixBatchVectorMultiplyAccumulate(
/*row_sums=*/nullptr);
}
void SseMatrixBatchVectorMultiplyAccumulate(
const int8_t* __restrict__ matrix, const int m_rows, const int m_cols,
const int8_t* __restrict__ vectors,
const float* __restrict__ scaling_factors, int n_batch, int32_t* scratch,
float* __restrict__ result, CpuBackendContext* context) {
if (m_rows % 4 == 0) {
const int32_t* bias = static_cast<const int32_t*>(nullptr);
SseCpuBackendGemm(vectors, bias, matrix, n_batch, m_cols, m_rows,
/*output_zp=*/0, scratch, context);
{
ruy::profiler::ScopeLabel label("HybridMultiplyScalingFactor");
// Multiply by float scaling factors and write to result
const int total_size = n_batch * m_rows;
int i = 0;
for (; i <= total_size - 8; i += 8, result += 8) {
const float batch_scaling_factor0 = scaling_factors[i / m_rows];
const float batch_scaling_factor1 = scaling_factors[(i + 4) / m_rows];
const __m128 scaling_factor0 = _mm_set1_ps(batch_scaling_factor0);
const __m128 scaling_factor1 = _mm_set1_ps(batch_scaling_factor1);
const __m128i scratch_val0 =
_mm_loadu_si128(reinterpret_cast<const __m128i*>(scratch + i));
const __m128i scratch_val1 =
_mm_loadu_si128(reinterpret_cast<const __m128i*>(scratch + i + 4));
const __m128 float_val0 = _mm_cvtepi32_ps(scratch_val0);
const __m128 float_val1 = _mm_cvtepi32_ps(scratch_val1);
const __m128 prod0 = _mm_mul_ps(float_val0, scaling_factor0);
const __m128 result0 = _mm_add_ps(_mm_load1_ps(result), prod0);
const __m128 prod1 = _mm_mul_ps(float_val1, scaling_factor1);
const __m128 result1 = _mm_add_ps(_mm_load1_ps(result + 4), prod1);
_mm_store_ps(result, result0);
_mm_store_ps(result + 4, result1);
}
scratch += i;
for (; i < total_size; i++) {
const float batch_scaling_factor = scaling_factors[i / m_rows];
int32_t x = *(scratch++);
*result += x * batch_scaling_factor;
++result;
}
}
return;
}
SseMatrixBatchVectorMultiplyAccumulateImpl(
matrix, m_rows, m_cols, vectors, scaling_factors, n_batch, result,
/*per_channel_scale=*/nullptr, /*input_offset=*/nullptr,
/*row_sums=*/nullptr);
}
void SseMatrixBatchVectorMultiplyAccumulate(
const int8_t* __restrict__ matrix, const int m_rows, const int m_cols,
const int8_t* __restrict__ vectors,

View File

@ -71,7 +71,7 @@ void MatrixBatchVectorMultiplyAccumulate(
int32_t* __restrict__ scratch, float* __restrict__ result,
CpuBackendContext* __restrict__ context) {
SSE_OR_PORTABLE(MatrixBatchVectorMultiplyAccumulate, matrix, m_rows, m_cols,
vectors, scaling_factors, n_batch, result);
vectors, scaling_factors, n_batch, scratch, result, context);
}
void SparseMatrixBatchVectorMultiplyAccumulate1x4(

View File

@ -35,6 +35,14 @@ void SseMatrixBatchVectorMultiplyAccumulate(
const float* __restrict__ scaling_factors, int n_batch,
float* __restrict__ result);
// Matrix multiplication for quantized values using symmetric quantization
// with additional scratch memory for GEMM operation prior to scaling.
void SseMatrixBatchVectorMultiplyAccumulate(
const int8_t* __restrict__ matrix, const int m_rows, const int m_cols,
const int8_t* __restrict__ vectors,
const float* __restrict__ scaling_factors, int n_batch, int32_t* scratch,
float* __restrict__ result, CpuBackendContext* context);
// Matrix multiplication for quantized values using asymmetric quantization.
void SseMatrixBatchVectorMultiplyAccumulate(
const int8_t* __restrict__ matrix, const int m_rows, const int m_cols,