Optimize int8 transpose_conv
PiperOrigin-RevId: 283881873 Change-Id: I568f5db0ba3663cb17208af41d30c2179e2e485c
This commit is contained in:
parent
7a33433911
commit
5e94ded1de
@ -227,6 +227,7 @@ cc_library(
|
||||
"optimized/integer_ops/mul.h",
|
||||
"optimized/integer_ops/pooling.h",
|
||||
"optimized/integer_ops/softmax.h",
|
||||
"optimized/integer_ops/transpose_conv.h",
|
||||
"optimized/optimized_ops.h",
|
||||
],
|
||||
copts = tflite_copts(),
|
||||
|
@ -0,0 +1,105 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_INTEGER_OPS_TRANSPOSE_CONV_H_
|
||||
#define TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_INTEGER_OPS_TRANSPOSE_CONV_H_
|
||||
|
||||
#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace optimized_integer_ops {
|
||||
|
||||
// TransposeConvV2 expect the weights in HWOI order.
|
||||
inline void TransposeConvV2(
|
||||
const ConvParams& params, const int32* output_multiplier,
|
||||
const int32* output_shift, const RuntimeShape& input_shape,
|
||||
const int8_t* input_data, const RuntimeShape& hwoi_ordered_filter_shape,
|
||||
const int8_t* hwoi_ordered_filter_data, const RuntimeShape& output_shape,
|
||||
int8_t* output_data, const RuntimeShape& col2im_shape, int32_t* col2im_data,
|
||||
int32_t* scratch_data, CpuBackendContext* cpu_backend_context) {
|
||||
gemmlowp::ScopedProfilingLabel label("TransposeConvV2/int8");
|
||||
TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
|
||||
TFLITE_DCHECK_EQ(hwoi_ordered_filter_shape.DimensionsCount(), 4);
|
||||
const int batch_size = input_shape.Dims(0);
|
||||
TFLITE_DCHECK(col2im_data);
|
||||
TFLITE_DCHECK(hwoi_ordered_filter_data);
|
||||
|
||||
const int input_image_size = input_shape.Dims(1) * input_shape.Dims(2);
|
||||
const int output_height = output_shape.Dims(1);
|
||||
const int output_width = output_shape.Dims(2);
|
||||
const int output_image_size = output_height * output_width;
|
||||
const int input_depth =
|
||||
MatchingDim(input_shape, 3, hwoi_ordered_filter_shape, 3);
|
||||
const int output_depth =
|
||||
MatchingDim(output_shape, 3, hwoi_ordered_filter_shape, 2);
|
||||
const int input_offset = input_image_size * input_depth;
|
||||
const int output_offset = output_image_size * output_depth;
|
||||
|
||||
const int filter_height = hwoi_ordered_filter_shape.Dims(0);
|
||||
const int filter_width = hwoi_ordered_filter_shape.Dims(1);
|
||||
const int padding_top = params.padding_values.height;
|
||||
const int padding_bottom =
|
||||
params.padding_values.height + params.padding_values.height_offset;
|
||||
const int padding_left = params.padding_values.width;
|
||||
const int padding_right =
|
||||
params.padding_values.width + params.padding_values.width_offset;
|
||||
const int stride_height = params.stride_height;
|
||||
const int stride_width = params.stride_width;
|
||||
|
||||
const int hwoi_ordered_filter_total_size =
|
||||
filter_height * filter_width * output_depth;
|
||||
|
||||
cpu_backend_gemm::MatrixParams<int8_t> lhs_params;
|
||||
lhs_params.order = cpu_backend_gemm::Order::kRowMajor;
|
||||
lhs_params.rows = hwoi_ordered_filter_total_size;
|
||||
lhs_params.cols = input_depth;
|
||||
// Since our weight is symmetric quantized, the zp will always be 0.
|
||||
lhs_params.zero_point = 0;
|
||||
|
||||
int32_t* scratch_data_p = scratch_data;
|
||||
std::fill_n(scratch_data, output_offset * batch_size, static_cast<int32>(0));
|
||||
for (int i = 0; i < batch_size; ++i) {
|
||||
cpu_backend_gemm::MatrixParams<int8_t> rhs_params;
|
||||
rhs_params.order = cpu_backend_gemm::Order::kColMajor;
|
||||
rhs_params.rows = input_depth;
|
||||
rhs_params.cols = input_image_size;
|
||||
rhs_params.zero_point = -params.input_offset;
|
||||
|
||||
cpu_backend_gemm::MatrixParams<int32_t> dst_params;
|
||||
dst_params.order = cpu_backend_gemm::Order::kColMajor;
|
||||
dst_params.rows = hwoi_ordered_filter_total_size;
|
||||
dst_params.cols = input_image_size;
|
||||
|
||||
cpu_backend_gemm::GemmParams<int32_t, int32_t> gemm_params;
|
||||
cpu_backend_gemm::Gemm(lhs_params, hwoi_ordered_filter_data, rhs_params,
|
||||
input_data + input_offset * i, dst_params,
|
||||
col2im_data, gemm_params, cpu_backend_context);
|
||||
|
||||
optimized_ops::Col2im(
|
||||
col2im_data, output_depth, output_height, output_width, filter_height,
|
||||
filter_width, padding_top, padding_left, padding_bottom, padding_right,
|
||||
stride_height, stride_width, scratch_data_p);
|
||||
|
||||
scratch_data_p += output_offset;
|
||||
}
|
||||
|
||||
optimized_ops::Quantize(output_multiplier, output_shift, output_depth,
|
||||
output_shape.FlatSize(), params.output_offset,
|
||||
scratch_data, output_data);
|
||||
}
|
||||
|
||||
} // namespace optimized_integer_ops
|
||||
} // namespace tflite
|
||||
|
||||
#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_INTEGER_OPS_TRANSPOSE_CONV_H_
|
@ -5617,6 +5617,117 @@ inline void Quantize(int32_t multiplier, int32_t shift, int32_t total_size,
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(b/145632530): Refactor other quantize per-channel to use this one.
|
||||
inline void Quantize(const int32_t* multiplier, const int32_t* shift,
|
||||
int32_t channel_size, int32_t total_size,
|
||||
int32_t output_zp, int32_t* scratch, int8_t* output) {
|
||||
gemmlowp::ScopedProfilingLabel label("Quantize/int8");
|
||||
|
||||
const int32_t output_min = std::numeric_limits<int8_t>::min();
|
||||
const int32_t output_max = std::numeric_limits<int8_t>::max();
|
||||
|
||||
// Here we're trying to quantize the raw accumulators:
|
||||
// output_channels
|
||||
// data data data data data
|
||||
// rows data data data data data
|
||||
// data data data data data
|
||||
// ....
|
||||
//
|
||||
// In order to minimize the reload of the multipliers & shifts, once we load
|
||||
// the multipliers & shifts, we load & quantize the raw accumualtrs for every
|
||||
// row.
|
||||
#ifdef USE_NEON
|
||||
const int32x4_t output_offset_vec = vdupq_n_s32(output_zp);
|
||||
const int32x4_t output_activation_min_vec = vdupq_n_s32(output_min);
|
||||
const int32x4_t output_activation_max_vec = vdupq_n_s32(output_max);
|
||||
const int32x4_t ones = vdupq_n_s32(1);
|
||||
const int32x4_t minus_ones = vdupq_n_s32(-1);
|
||||
const int32x4_t zeros = vdupq_n_s32(0);
|
||||
#endif
|
||||
|
||||
TFLITE_DCHECK_EQ(total_size % channel_size, 0);
|
||||
const int32_t rows = total_size / channel_size;
|
||||
|
||||
int c = 0;
|
||||
|
||||
while (c < channel_size) {
|
||||
int target_output_depth = channel_size;
|
||||
#ifdef USE_NEON
|
||||
using gemmlowp::RoundingDivideByPOT;
|
||||
for (; c <= channel_size - 4; c += 4) {
|
||||
int32x4_t out_shift = vld1q_s32(shift + c);
|
||||
const bool out_shift_all_less_than_zero =
|
||||
(vgetq_lane_s32(out_shift, 0) < 0) &&
|
||||
(vgetq_lane_s32(out_shift, 1) < 0) &&
|
||||
(vgetq_lane_s32(out_shift, 2) < 0) &&
|
||||
(vgetq_lane_s32(out_shift, 3) < 0);
|
||||
const bool out_shift_all_greater_equal_than_zero =
|
||||
(vgetq_lane_s32(out_shift, 0) >= 0) &&
|
||||
(vgetq_lane_s32(out_shift, 1) >= 0) &&
|
||||
(vgetq_lane_s32(out_shift, 2) >= 0) &&
|
||||
(vgetq_lane_s32(out_shift, 3) >= 0);
|
||||
if (!out_shift_all_less_than_zero &&
|
||||
!out_shift_all_greater_equal_than_zero) {
|
||||
// Fallback to general path.
|
||||
// Then go ahead for next 4.
|
||||
target_output_depth = c + 4;
|
||||
break;
|
||||
}
|
||||
int32x4_t out_mul = vld1q_s32(multiplier + c);
|
||||
for (int n = 0; n < rows; ++n) {
|
||||
int loc = n * channel_size + c;
|
||||
int32x4_t acc = vld1q_s32(scratch + loc);
|
||||
if (out_shift_all_less_than_zero) { // output_shift all < 0 case.
|
||||
acc = vqrdmulhq_s32(acc, out_mul);
|
||||
int32x4_t negative_out_shift = vmulq_n_s32(out_shift, -1);
|
||||
int32x4_t mask =
|
||||
vaddq_s32(vshlq_s32(ones, negative_out_shift), minus_ones);
|
||||
int32x4_t remainder = vandq_s32(acc, mask);
|
||||
int32x4_t shifted_right_mask = vshlq_s32(mask, minus_ones);
|
||||
int32x4_t temp =
|
||||
vandq_s32(vreinterpretq_s32_u32(vcltq_s32(acc, zeros)), ones);
|
||||
int32x4_t threshold = vaddq_s32(shifted_right_mask, temp);
|
||||
temp = vandq_s32(
|
||||
vreinterpretq_s32_u32(vcgtq_s32(remainder, threshold)), ones);
|
||||
int32x4_t shifted_right_acc = vshlq_s32(acc, out_shift);
|
||||
acc = vaddq_s32(shifted_right_acc, temp);
|
||||
} else { // output_shift all > 0 case.
|
||||
int32x4_t multiplier_power_of_two = vshlq_s32(ones, out_shift);
|
||||
acc = vmulq_s32(acc, multiplier_power_of_two);
|
||||
acc = vqrdmulhq_s32(acc, out_mul);
|
||||
}
|
||||
// Add the output offset.
|
||||
acc = vaddq_s32(acc, output_offset_vec);
|
||||
// Apply the activation function.
|
||||
acc = vmaxq_s32(acc, output_activation_min_vec);
|
||||
acc = vminq_s32(acc, output_activation_max_vec);
|
||||
// Saturating cast to int8 and store to destination.
|
||||
const int16x4_t acc_s16 = vqmovn_s32(acc);
|
||||
const int16x8_t res_s16 = vcombine_s16(acc_s16, acc_s16);
|
||||
const int8x8_t res_s8 = vqmovn_s16(res_s16);
|
||||
vst1_lane_s8(output + loc + 0, res_s8, 0);
|
||||
vst1_lane_s8(output + loc + 1, res_s8, 1);
|
||||
vst1_lane_s8(output + loc + 2, res_s8, 2);
|
||||
vst1_lane_s8(output + loc + 3, res_s8, 3);
|
||||
}
|
||||
}
|
||||
|
||||
#endif // USE_NEON
|
||||
// Handle leftover values, one by one. This is very slow.
|
||||
for (; c < target_output_depth; c++) {
|
||||
for (int n = 0; n < rows; ++n) {
|
||||
int loc = n * channel_size + c;
|
||||
int32 acc = scratch[loc];
|
||||
acc = MultiplyByQuantizedMultiplier(acc, multiplier[c], shift[c]);
|
||||
acc += output_zp;
|
||||
acc = std::max(acc, output_min);
|
||||
acc = std::min(acc, output_max);
|
||||
output[loc] = static_cast<int8>(acc);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TransposeConvV2 expect the weights in HWOI order.
|
||||
inline void TransposeConvV2(
|
||||
const ConvParams& params, const RuntimeShape& input_shape,
|
||||
|
@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <cassert>
|
||||
#include <cmath>
|
||||
#include <cstdio>
|
||||
@ -23,6 +24,8 @@ limitations under the License.
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/kernels/cpu_backend_context.h"
|
||||
#include "tensorflow/lite/kernels/eigen_support.h"
|
||||
// NOLINTNEXTLINE - This header file should't go to the top.
|
||||
#include "tensorflow/lite/kernels/internal/optimized/integer_ops/transpose_conv.h"
|
||||
#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
|
||||
// NOLINTNEXTLINE - This header file should't go to the top.
|
||||
#include "tensorflow/lite/kernels/internal/reference/integer_ops/transpose_conv.h"
|
||||
@ -422,6 +425,7 @@ void EvalQuantized(TfLiteContext* context,
|
||||
}
|
||||
}
|
||||
|
||||
template <KernelType kernel_type>
|
||||
void EvalQuantizedPerChannel(TfLiteContext* context,
|
||||
const TfLiteTransposeConvParams* params,
|
||||
OpData* data, const TfLiteTensor* input,
|
||||
@ -444,15 +448,29 @@ void EvalQuantizedPerChannel(TfLiteContext* context,
|
||||
op_params.quantized_activation_min = data->output_activation_min;
|
||||
op_params.quantized_activation_max = data->output_activation_max;
|
||||
|
||||
// TODO(b/143380105): Need to add optimized kernel for int8 quantized
|
||||
// transpose conv.
|
||||
reference_integer_ops::TransposeConv(
|
||||
op_params, data->per_channel_output_multiplier.data(),
|
||||
data->per_channel_output_shift.data(), GetTensorShape(input),
|
||||
GetTensorData<int8>(input), GetTensorShape(weights),
|
||||
GetTensorData<int8>(weights), GetTensorShape(output),
|
||||
GetTensorData<int8>(output), GetTensorShape(col2im),
|
||||
GetTensorData<int8>(col2im), GetTensorData<int32_t>(scratch_buffer));
|
||||
switch (kernel_type) {
|
||||
case kReference: {
|
||||
reference_integer_ops::TransposeConv(
|
||||
op_params, data->per_channel_output_multiplier.data(),
|
||||
data->per_channel_output_shift.data(), GetTensorShape(input),
|
||||
GetTensorData<int8>(input), GetTensorShape(weights),
|
||||
GetTensorData<int8>(weights), GetTensorShape(output),
|
||||
GetTensorData<int8>(output), GetTensorShape(col2im),
|
||||
GetTensorData<int8>(col2im), GetTensorData<int32_t>(scratch_buffer));
|
||||
break;
|
||||
}
|
||||
case kGenericOptimized: {
|
||||
optimized_integer_ops::TransposeConvV2(
|
||||
op_params, data->per_channel_output_multiplier.data(),
|
||||
data->per_channel_output_shift.data(), GetTensorShape(input),
|
||||
GetTensorData<int8>(input), GetTensorShape(transposed_weights),
|
||||
GetTensorData<int8>(transposed_weights), GetTensorShape(output),
|
||||
GetTensorData<int8>(output), GetTensorShape(col2im),
|
||||
GetTensorData<int32>(col2im), GetTensorData<int32>(scratch_buffer),
|
||||
CpuBackendContext::GetFromContext(context));
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <KernelType kernel_type>
|
||||
@ -535,9 +553,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
if (data->weights_are_transposed && !IsConstantTensor(weights)) {
|
||||
ResizeAndTransposeWeights(context, weights, transposed_weights);
|
||||
}
|
||||
EvalQuantizedPerChannel(context, params, data, input, weights,
|
||||
transposed_weights, col2im, output,
|
||||
scratch_buffer);
|
||||
EvalQuantizedPerChannel<kernel_type>(context, params, data, input,
|
||||
weights, transposed_weights, col2im,
|
||||
output, scratch_buffer);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
|
Loading…
x
Reference in New Issue
Block a user