Optimize int8 transpose_conv

PiperOrigin-RevId: 283881873
Change-Id: I568f5db0ba3663cb17208af41d30c2179e2e485c
This commit is contained in:
Renjie Liu 2019-12-04 18:29:37 -08:00 committed by TensorFlower Gardener
parent 7a33433911
commit 5e94ded1de
4 changed files with 247 additions and 12 deletions

View File

@ -227,6 +227,7 @@ cc_library(
"optimized/integer_ops/mul.h",
"optimized/integer_ops/pooling.h",
"optimized/integer_ops/softmax.h",
"optimized/integer_ops/transpose_conv.h",
"optimized/optimized_ops.h",
],
copts = tflite_copts(),

View File

@ -0,0 +1,105 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_INTEGER_OPS_TRANSPOSE_CONV_H_
#define TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_INTEGER_OPS_TRANSPOSE_CONV_H_
#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
namespace tflite {
namespace optimized_integer_ops {
// TransposeConvV2 expect the weights in HWOI order.
inline void TransposeConvV2(
const ConvParams& params, const int32* output_multiplier,
const int32* output_shift, const RuntimeShape& input_shape,
const int8_t* input_data, const RuntimeShape& hwoi_ordered_filter_shape,
const int8_t* hwoi_ordered_filter_data, const RuntimeShape& output_shape,
int8_t* output_data, const RuntimeShape& col2im_shape, int32_t* col2im_data,
int32_t* scratch_data, CpuBackendContext* cpu_backend_context) {
gemmlowp::ScopedProfilingLabel label("TransposeConvV2/int8");
TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
TFLITE_DCHECK_EQ(hwoi_ordered_filter_shape.DimensionsCount(), 4);
const int batch_size = input_shape.Dims(0);
TFLITE_DCHECK(col2im_data);
TFLITE_DCHECK(hwoi_ordered_filter_data);
const int input_image_size = input_shape.Dims(1) * input_shape.Dims(2);
const int output_height = output_shape.Dims(1);
const int output_width = output_shape.Dims(2);
const int output_image_size = output_height * output_width;
const int input_depth =
MatchingDim(input_shape, 3, hwoi_ordered_filter_shape, 3);
const int output_depth =
MatchingDim(output_shape, 3, hwoi_ordered_filter_shape, 2);
const int input_offset = input_image_size * input_depth;
const int output_offset = output_image_size * output_depth;
const int filter_height = hwoi_ordered_filter_shape.Dims(0);
const int filter_width = hwoi_ordered_filter_shape.Dims(1);
const int padding_top = params.padding_values.height;
const int padding_bottom =
params.padding_values.height + params.padding_values.height_offset;
const int padding_left = params.padding_values.width;
const int padding_right =
params.padding_values.width + params.padding_values.width_offset;
const int stride_height = params.stride_height;
const int stride_width = params.stride_width;
const int hwoi_ordered_filter_total_size =
filter_height * filter_width * output_depth;
cpu_backend_gemm::MatrixParams<int8_t> lhs_params;
lhs_params.order = cpu_backend_gemm::Order::kRowMajor;
lhs_params.rows = hwoi_ordered_filter_total_size;
lhs_params.cols = input_depth;
// Since our weight is symmetric quantized, the zp will always be 0.
lhs_params.zero_point = 0;
int32_t* scratch_data_p = scratch_data;
std::fill_n(scratch_data, output_offset * batch_size, static_cast<int32>(0));
for (int i = 0; i < batch_size; ++i) {
cpu_backend_gemm::MatrixParams<int8_t> rhs_params;
rhs_params.order = cpu_backend_gemm::Order::kColMajor;
rhs_params.rows = input_depth;
rhs_params.cols = input_image_size;
rhs_params.zero_point = -params.input_offset;
cpu_backend_gemm::MatrixParams<int32_t> dst_params;
dst_params.order = cpu_backend_gemm::Order::kColMajor;
dst_params.rows = hwoi_ordered_filter_total_size;
dst_params.cols = input_image_size;
cpu_backend_gemm::GemmParams<int32_t, int32_t> gemm_params;
cpu_backend_gemm::Gemm(lhs_params, hwoi_ordered_filter_data, rhs_params,
input_data + input_offset * i, dst_params,
col2im_data, gemm_params, cpu_backend_context);
optimized_ops::Col2im(
col2im_data, output_depth, output_height, output_width, filter_height,
filter_width, padding_top, padding_left, padding_bottom, padding_right,
stride_height, stride_width, scratch_data_p);
scratch_data_p += output_offset;
}
optimized_ops::Quantize(output_multiplier, output_shift, output_depth,
output_shape.FlatSize(), params.output_offset,
scratch_data, output_data);
}
} // namespace optimized_integer_ops
} // namespace tflite
#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_INTEGER_OPS_TRANSPOSE_CONV_H_

View File

@ -5617,6 +5617,117 @@ inline void Quantize(int32_t multiplier, int32_t shift, int32_t total_size,
}
}
// TODO(b/145632530): Refactor other quantize per-channel to use this one.
inline void Quantize(const int32_t* multiplier, const int32_t* shift,
int32_t channel_size, int32_t total_size,
int32_t output_zp, int32_t* scratch, int8_t* output) {
gemmlowp::ScopedProfilingLabel label("Quantize/int8");
const int32_t output_min = std::numeric_limits<int8_t>::min();
const int32_t output_max = std::numeric_limits<int8_t>::max();
// Here we're trying to quantize the raw accumulators:
// output_channels
// data data data data data
// rows data data data data data
// data data data data data
// ....
//
// In order to minimize the reload of the multipliers & shifts, once we load
// the multipliers & shifts, we load & quantize the raw accumualtrs for every
// row.
#ifdef USE_NEON
const int32x4_t output_offset_vec = vdupq_n_s32(output_zp);
const int32x4_t output_activation_min_vec = vdupq_n_s32(output_min);
const int32x4_t output_activation_max_vec = vdupq_n_s32(output_max);
const int32x4_t ones = vdupq_n_s32(1);
const int32x4_t minus_ones = vdupq_n_s32(-1);
const int32x4_t zeros = vdupq_n_s32(0);
#endif
TFLITE_DCHECK_EQ(total_size % channel_size, 0);
const int32_t rows = total_size / channel_size;
int c = 0;
while (c < channel_size) {
int target_output_depth = channel_size;
#ifdef USE_NEON
using gemmlowp::RoundingDivideByPOT;
for (; c <= channel_size - 4; c += 4) {
int32x4_t out_shift = vld1q_s32(shift + c);
const bool out_shift_all_less_than_zero =
(vgetq_lane_s32(out_shift, 0) < 0) &&
(vgetq_lane_s32(out_shift, 1) < 0) &&
(vgetq_lane_s32(out_shift, 2) < 0) &&
(vgetq_lane_s32(out_shift, 3) < 0);
const bool out_shift_all_greater_equal_than_zero =
(vgetq_lane_s32(out_shift, 0) >= 0) &&
(vgetq_lane_s32(out_shift, 1) >= 0) &&
(vgetq_lane_s32(out_shift, 2) >= 0) &&
(vgetq_lane_s32(out_shift, 3) >= 0);
if (!out_shift_all_less_than_zero &&
!out_shift_all_greater_equal_than_zero) {
// Fallback to general path.
// Then go ahead for next 4.
target_output_depth = c + 4;
break;
}
int32x4_t out_mul = vld1q_s32(multiplier + c);
for (int n = 0; n < rows; ++n) {
int loc = n * channel_size + c;
int32x4_t acc = vld1q_s32(scratch + loc);
if (out_shift_all_less_than_zero) { // output_shift all < 0 case.
acc = vqrdmulhq_s32(acc, out_mul);
int32x4_t negative_out_shift = vmulq_n_s32(out_shift, -1);
int32x4_t mask =
vaddq_s32(vshlq_s32(ones, negative_out_shift), minus_ones);
int32x4_t remainder = vandq_s32(acc, mask);
int32x4_t shifted_right_mask = vshlq_s32(mask, minus_ones);
int32x4_t temp =
vandq_s32(vreinterpretq_s32_u32(vcltq_s32(acc, zeros)), ones);
int32x4_t threshold = vaddq_s32(shifted_right_mask, temp);
temp = vandq_s32(
vreinterpretq_s32_u32(vcgtq_s32(remainder, threshold)), ones);
int32x4_t shifted_right_acc = vshlq_s32(acc, out_shift);
acc = vaddq_s32(shifted_right_acc, temp);
} else { // output_shift all > 0 case.
int32x4_t multiplier_power_of_two = vshlq_s32(ones, out_shift);
acc = vmulq_s32(acc, multiplier_power_of_two);
acc = vqrdmulhq_s32(acc, out_mul);
}
// Add the output offset.
acc = vaddq_s32(acc, output_offset_vec);
// Apply the activation function.
acc = vmaxq_s32(acc, output_activation_min_vec);
acc = vminq_s32(acc, output_activation_max_vec);
// Saturating cast to int8 and store to destination.
const int16x4_t acc_s16 = vqmovn_s32(acc);
const int16x8_t res_s16 = vcombine_s16(acc_s16, acc_s16);
const int8x8_t res_s8 = vqmovn_s16(res_s16);
vst1_lane_s8(output + loc + 0, res_s8, 0);
vst1_lane_s8(output + loc + 1, res_s8, 1);
vst1_lane_s8(output + loc + 2, res_s8, 2);
vst1_lane_s8(output + loc + 3, res_s8, 3);
}
}
#endif // USE_NEON
// Handle leftover values, one by one. This is very slow.
for (; c < target_output_depth; c++) {
for (int n = 0; n < rows; ++n) {
int loc = n * channel_size + c;
int32 acc = scratch[loc];
acc = MultiplyByQuantizedMultiplier(acc, multiplier[c], shift[c]);
acc += output_zp;
acc = std::max(acc, output_min);
acc = std::min(acc, output_max);
output[loc] = static_cast<int8>(acc);
}
}
}
}
// TransposeConvV2 expect the weights in HWOI order.
inline void TransposeConvV2(
const ConvParams& params, const RuntimeShape& input_shape,

View File

@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <cassert>
#include <cmath>
#include <cstdio>
@ -23,6 +24,8 @@ limitations under the License.
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/kernels/cpu_backend_context.h"
#include "tensorflow/lite/kernels/eigen_support.h"
// NOLINTNEXTLINE - This header file should't go to the top.
#include "tensorflow/lite/kernels/internal/optimized/integer_ops/transpose_conv.h"
#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
// NOLINTNEXTLINE - This header file should't go to the top.
#include "tensorflow/lite/kernels/internal/reference/integer_ops/transpose_conv.h"
@ -422,6 +425,7 @@ void EvalQuantized(TfLiteContext* context,
}
}
template <KernelType kernel_type>
void EvalQuantizedPerChannel(TfLiteContext* context,
const TfLiteTransposeConvParams* params,
OpData* data, const TfLiteTensor* input,
@ -444,15 +448,29 @@ void EvalQuantizedPerChannel(TfLiteContext* context,
op_params.quantized_activation_min = data->output_activation_min;
op_params.quantized_activation_max = data->output_activation_max;
// TODO(b/143380105): Need to add optimized kernel for int8 quantized
// transpose conv.
reference_integer_ops::TransposeConv(
op_params, data->per_channel_output_multiplier.data(),
data->per_channel_output_shift.data(), GetTensorShape(input),
GetTensorData<int8>(input), GetTensorShape(weights),
GetTensorData<int8>(weights), GetTensorShape(output),
GetTensorData<int8>(output), GetTensorShape(col2im),
GetTensorData<int8>(col2im), GetTensorData<int32_t>(scratch_buffer));
switch (kernel_type) {
case kReference: {
reference_integer_ops::TransposeConv(
op_params, data->per_channel_output_multiplier.data(),
data->per_channel_output_shift.data(), GetTensorShape(input),
GetTensorData<int8>(input), GetTensorShape(weights),
GetTensorData<int8>(weights), GetTensorShape(output),
GetTensorData<int8>(output), GetTensorShape(col2im),
GetTensorData<int8>(col2im), GetTensorData<int32_t>(scratch_buffer));
break;
}
case kGenericOptimized: {
optimized_integer_ops::TransposeConvV2(
op_params, data->per_channel_output_multiplier.data(),
data->per_channel_output_shift.data(), GetTensorShape(input),
GetTensorData<int8>(input), GetTensorShape(transposed_weights),
GetTensorData<int8>(transposed_weights), GetTensorShape(output),
GetTensorData<int8>(output), GetTensorShape(col2im),
GetTensorData<int32>(col2im), GetTensorData<int32>(scratch_buffer),
CpuBackendContext::GetFromContext(context));
break;
}
}
}
template <KernelType kernel_type>
@ -535,9 +553,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
if (data->weights_are_transposed && !IsConstantTensor(weights)) {
ResizeAndTransposeWeights(context, weights, transposed_weights);
}
EvalQuantizedPerChannel(context, params, data, input, weights,
transposed_weights, col2im, output,
scratch_buffer);
EvalQuantizedPerChannel<kernel_type>(context, params, data, input,
weights, transposed_weights, col2im,
output, scratch_buffer);
break;
}
default: