diff --git a/tensorflow/lite/kernels/internal/BUILD b/tensorflow/lite/kernels/internal/BUILD index d71b36547f2..e82d3c16b31 100644 --- a/tensorflow/lite/kernels/internal/BUILD +++ b/tensorflow/lite/kernels/internal/BUILD @@ -227,6 +227,7 @@ cc_library( "optimized/integer_ops/mul.h", "optimized/integer_ops/pooling.h", "optimized/integer_ops/softmax.h", + "optimized/integer_ops/transpose_conv.h", "optimized/optimized_ops.h", ], copts = tflite_copts(), diff --git a/tensorflow/lite/kernels/internal/optimized/integer_ops/transpose_conv.h b/tensorflow/lite/kernels/internal/optimized/integer_ops/transpose_conv.h new file mode 100644 index 00000000000..4d24ff65250 --- /dev/null +++ b/tensorflow/lite/kernels/internal/optimized/integer_ops/transpose_conv.h @@ -0,0 +1,105 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_INTEGER_OPS_TRANSPOSE_CONV_H_ +#define TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_INTEGER_OPS_TRANSPOSE_CONV_H_ + +#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h" + +namespace tflite { +namespace optimized_integer_ops { + +// TransposeConvV2 expect the weights in HWOI order. +inline void TransposeConvV2( + const ConvParams& params, const int32* output_multiplier, + const int32* output_shift, const RuntimeShape& input_shape, + const int8_t* input_data, const RuntimeShape& hwoi_ordered_filter_shape, + const int8_t* hwoi_ordered_filter_data, const RuntimeShape& output_shape, + int8_t* output_data, const RuntimeShape& col2im_shape, int32_t* col2im_data, + int32_t* scratch_data, CpuBackendContext* cpu_backend_context) { + gemmlowp::ScopedProfilingLabel label("TransposeConvV2/int8"); + TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4); + TFLITE_DCHECK_EQ(hwoi_ordered_filter_shape.DimensionsCount(), 4); + const int batch_size = input_shape.Dims(0); + TFLITE_DCHECK(col2im_data); + TFLITE_DCHECK(hwoi_ordered_filter_data); + + const int input_image_size = input_shape.Dims(1) * input_shape.Dims(2); + const int output_height = output_shape.Dims(1); + const int output_width = output_shape.Dims(2); + const int output_image_size = output_height * output_width; + const int input_depth = + MatchingDim(input_shape, 3, hwoi_ordered_filter_shape, 3); + const int output_depth = + MatchingDim(output_shape, 3, hwoi_ordered_filter_shape, 2); + const int input_offset = input_image_size * input_depth; + const int output_offset = output_image_size * output_depth; + + const int filter_height = hwoi_ordered_filter_shape.Dims(0); + const int filter_width = hwoi_ordered_filter_shape.Dims(1); + const int padding_top = params.padding_values.height; + const int padding_bottom = + params.padding_values.height + params.padding_values.height_offset; + const int padding_left = params.padding_values.width; + const int padding_right = + params.padding_values.width + params.padding_values.width_offset; + const int stride_height = params.stride_height; + const int stride_width = params.stride_width; + + const int hwoi_ordered_filter_total_size = + filter_height * filter_width * output_depth; + + cpu_backend_gemm::MatrixParams lhs_params; + lhs_params.order = cpu_backend_gemm::Order::kRowMajor; + lhs_params.rows = hwoi_ordered_filter_total_size; + lhs_params.cols = input_depth; + // Since our weight is symmetric quantized, the zp will always be 0. + lhs_params.zero_point = 0; + + int32_t* scratch_data_p = scratch_data; + std::fill_n(scratch_data, output_offset * batch_size, static_cast(0)); + for (int i = 0; i < batch_size; ++i) { + cpu_backend_gemm::MatrixParams rhs_params; + rhs_params.order = cpu_backend_gemm::Order::kColMajor; + rhs_params.rows = input_depth; + rhs_params.cols = input_image_size; + rhs_params.zero_point = -params.input_offset; + + cpu_backend_gemm::MatrixParams dst_params; + dst_params.order = cpu_backend_gemm::Order::kColMajor; + dst_params.rows = hwoi_ordered_filter_total_size; + dst_params.cols = input_image_size; + + cpu_backend_gemm::GemmParams gemm_params; + cpu_backend_gemm::Gemm(lhs_params, hwoi_ordered_filter_data, rhs_params, + input_data + input_offset * i, dst_params, + col2im_data, gemm_params, cpu_backend_context); + + optimized_ops::Col2im( + col2im_data, output_depth, output_height, output_width, filter_height, + filter_width, padding_top, padding_left, padding_bottom, padding_right, + stride_height, stride_width, scratch_data_p); + + scratch_data_p += output_offset; + } + + optimized_ops::Quantize(output_multiplier, output_shift, output_depth, + output_shape.FlatSize(), params.output_offset, + scratch_data, output_data); +} + +} // namespace optimized_integer_ops +} // namespace tflite + +#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_INTEGER_OPS_TRANSPOSE_CONV_H_ diff --git a/tensorflow/lite/kernels/internal/optimized/optimized_ops.h b/tensorflow/lite/kernels/internal/optimized/optimized_ops.h index 26005e069a7..b5ee08dd7f2 100644 --- a/tensorflow/lite/kernels/internal/optimized/optimized_ops.h +++ b/tensorflow/lite/kernels/internal/optimized/optimized_ops.h @@ -5617,6 +5617,117 @@ inline void Quantize(int32_t multiplier, int32_t shift, int32_t total_size, } } +// TODO(b/145632530): Refactor other quantize per-channel to use this one. +inline void Quantize(const int32_t* multiplier, const int32_t* shift, + int32_t channel_size, int32_t total_size, + int32_t output_zp, int32_t* scratch, int8_t* output) { + gemmlowp::ScopedProfilingLabel label("Quantize/int8"); + + const int32_t output_min = std::numeric_limits::min(); + const int32_t output_max = std::numeric_limits::max(); + + // Here we're trying to quantize the raw accumulators: + // output_channels + // data data data data data + // rows data data data data data + // data data data data data + // .... + // + // In order to minimize the reload of the multipliers & shifts, once we load + // the multipliers & shifts, we load & quantize the raw accumualtrs for every + // row. +#ifdef USE_NEON + const int32x4_t output_offset_vec = vdupq_n_s32(output_zp); + const int32x4_t output_activation_min_vec = vdupq_n_s32(output_min); + const int32x4_t output_activation_max_vec = vdupq_n_s32(output_max); + const int32x4_t ones = vdupq_n_s32(1); + const int32x4_t minus_ones = vdupq_n_s32(-1); + const int32x4_t zeros = vdupq_n_s32(0); +#endif + + TFLITE_DCHECK_EQ(total_size % channel_size, 0); + const int32_t rows = total_size / channel_size; + + int c = 0; + + while (c < channel_size) { + int target_output_depth = channel_size; +#ifdef USE_NEON + using gemmlowp::RoundingDivideByPOT; + for (; c <= channel_size - 4; c += 4) { + int32x4_t out_shift = vld1q_s32(shift + c); + const bool out_shift_all_less_than_zero = + (vgetq_lane_s32(out_shift, 0) < 0) && + (vgetq_lane_s32(out_shift, 1) < 0) && + (vgetq_lane_s32(out_shift, 2) < 0) && + (vgetq_lane_s32(out_shift, 3) < 0); + const bool out_shift_all_greater_equal_than_zero = + (vgetq_lane_s32(out_shift, 0) >= 0) && + (vgetq_lane_s32(out_shift, 1) >= 0) && + (vgetq_lane_s32(out_shift, 2) >= 0) && + (vgetq_lane_s32(out_shift, 3) >= 0); + if (!out_shift_all_less_than_zero && + !out_shift_all_greater_equal_than_zero) { + // Fallback to general path. + // Then go ahead for next 4. + target_output_depth = c + 4; + break; + } + int32x4_t out_mul = vld1q_s32(multiplier + c); + for (int n = 0; n < rows; ++n) { + int loc = n * channel_size + c; + int32x4_t acc = vld1q_s32(scratch + loc); + if (out_shift_all_less_than_zero) { // output_shift all < 0 case. + acc = vqrdmulhq_s32(acc, out_mul); + int32x4_t negative_out_shift = vmulq_n_s32(out_shift, -1); + int32x4_t mask = + vaddq_s32(vshlq_s32(ones, negative_out_shift), minus_ones); + int32x4_t remainder = vandq_s32(acc, mask); + int32x4_t shifted_right_mask = vshlq_s32(mask, minus_ones); + int32x4_t temp = + vandq_s32(vreinterpretq_s32_u32(vcltq_s32(acc, zeros)), ones); + int32x4_t threshold = vaddq_s32(shifted_right_mask, temp); + temp = vandq_s32( + vreinterpretq_s32_u32(vcgtq_s32(remainder, threshold)), ones); + int32x4_t shifted_right_acc = vshlq_s32(acc, out_shift); + acc = vaddq_s32(shifted_right_acc, temp); + } else { // output_shift all > 0 case. + int32x4_t multiplier_power_of_two = vshlq_s32(ones, out_shift); + acc = vmulq_s32(acc, multiplier_power_of_two); + acc = vqrdmulhq_s32(acc, out_mul); + } + // Add the output offset. + acc = vaddq_s32(acc, output_offset_vec); + // Apply the activation function. + acc = vmaxq_s32(acc, output_activation_min_vec); + acc = vminq_s32(acc, output_activation_max_vec); + // Saturating cast to int8 and store to destination. + const int16x4_t acc_s16 = vqmovn_s32(acc); + const int16x8_t res_s16 = vcombine_s16(acc_s16, acc_s16); + const int8x8_t res_s8 = vqmovn_s16(res_s16); + vst1_lane_s8(output + loc + 0, res_s8, 0); + vst1_lane_s8(output + loc + 1, res_s8, 1); + vst1_lane_s8(output + loc + 2, res_s8, 2); + vst1_lane_s8(output + loc + 3, res_s8, 3); + } + } + +#endif // USE_NEON + // Handle leftover values, one by one. This is very slow. + for (; c < target_output_depth; c++) { + for (int n = 0; n < rows; ++n) { + int loc = n * channel_size + c; + int32 acc = scratch[loc]; + acc = MultiplyByQuantizedMultiplier(acc, multiplier[c], shift[c]); + acc += output_zp; + acc = std::max(acc, output_min); + acc = std::min(acc, output_max); + output[loc] = static_cast(acc); + } + } + } +} + // TransposeConvV2 expect the weights in HWOI order. inline void TransposeConvV2( const ConvParams& params, const RuntimeShape& input_shape, diff --git a/tensorflow/lite/kernels/transpose_conv.cc b/tensorflow/lite/kernels/transpose_conv.cc index 0c62c305c0f..114b9ae48f4 100644 --- a/tensorflow/lite/kernels/transpose_conv.cc +++ b/tensorflow/lite/kernels/transpose_conv.cc @@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ + #include #include #include @@ -23,6 +24,8 @@ limitations under the License. #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/kernels/cpu_backend_context.h" #include "tensorflow/lite/kernels/eigen_support.h" +// NOLINTNEXTLINE - This header file should't go to the top. +#include "tensorflow/lite/kernels/internal/optimized/integer_ops/transpose_conv.h" #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h" // NOLINTNEXTLINE - This header file should't go to the top. #include "tensorflow/lite/kernels/internal/reference/integer_ops/transpose_conv.h" @@ -422,6 +425,7 @@ void EvalQuantized(TfLiteContext* context, } } +template void EvalQuantizedPerChannel(TfLiteContext* context, const TfLiteTransposeConvParams* params, OpData* data, const TfLiteTensor* input, @@ -444,15 +448,29 @@ void EvalQuantizedPerChannel(TfLiteContext* context, op_params.quantized_activation_min = data->output_activation_min; op_params.quantized_activation_max = data->output_activation_max; - // TODO(b/143380105): Need to add optimized kernel for int8 quantized - // transpose conv. - reference_integer_ops::TransposeConv( - op_params, data->per_channel_output_multiplier.data(), - data->per_channel_output_shift.data(), GetTensorShape(input), - GetTensorData(input), GetTensorShape(weights), - GetTensorData(weights), GetTensorShape(output), - GetTensorData(output), GetTensorShape(col2im), - GetTensorData(col2im), GetTensorData(scratch_buffer)); + switch (kernel_type) { + case kReference: { + reference_integer_ops::TransposeConv( + op_params, data->per_channel_output_multiplier.data(), + data->per_channel_output_shift.data(), GetTensorShape(input), + GetTensorData(input), GetTensorShape(weights), + GetTensorData(weights), GetTensorShape(output), + GetTensorData(output), GetTensorShape(col2im), + GetTensorData(col2im), GetTensorData(scratch_buffer)); + break; + } + case kGenericOptimized: { + optimized_integer_ops::TransposeConvV2( + op_params, data->per_channel_output_multiplier.data(), + data->per_channel_output_shift.data(), GetTensorShape(input), + GetTensorData(input), GetTensorShape(transposed_weights), + GetTensorData(transposed_weights), GetTensorShape(output), + GetTensorData(output), GetTensorShape(col2im), + GetTensorData(col2im), GetTensorData(scratch_buffer), + CpuBackendContext::GetFromContext(context)); + break; + } + } } template @@ -535,9 +553,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { if (data->weights_are_transposed && !IsConstantTensor(weights)) { ResizeAndTransposeWeights(context, weights, transposed_weights); } - EvalQuantizedPerChannel(context, params, data, input, weights, - transposed_weights, col2im, output, - scratch_buffer); + EvalQuantizedPerChannel(context, params, data, input, + weights, transposed_weights, col2im, + output, scratch_buffer); break; } default: