diff --git a/tensorflow/lite/delegates/nnapi/acceleration_test_list.cc b/tensorflow/lite/delegates/nnapi/acceleration_test_list.cc index 7800e984c7f..b20628016f0 100644 --- a/tensorflow/lite/delegates/nnapi/acceleration_test_list.cc +++ b/tensorflow/lite/delegates/nnapi/acceleration_test_list.cc @@ -381,6 +381,7 @@ TransposeTest/.+ # transpose_conv_test -TransposeConvOpTest/TransposeConvOpTest.SimpleTestQuantizedPerChannelSingleChannel/0 +-TransposeConvOpTest/TransposeConvOpTest.SimpleTestQuantizedPerChannel16x8/0 -TransposeConvOpTest/TransposeConvOpTest.TestQuantizedPerChannelMultiChannel/0 # Const tensor only TransposeConvOpTest/TransposeConvOpTest/.+/0,29 diff --git a/tensorflow/lite/kernels/internal/reference/integer_ops/transpose_conv.h b/tensorflow/lite/kernels/internal/reference/integer_ops/transpose_conv.h index 422adc2a333..f28b7cbddb7 100644 --- a/tensorflow/lite/kernels/internal/reference/integer_ops/transpose_conv.h +++ b/tensorflow/lite/kernels/internal/reference/integer_ops/transpose_conv.h @@ -119,6 +119,102 @@ inline void TransposeConv( } } +// int16 input (zero_point=0), int8 filter, int64 accumulator +inline void TransposeConv( + const ConvParams& params, const int32* output_multiplier, + const int32* output_shift, const RuntimeShape& input_shape, + const int16* input_data, const RuntimeShape& filter_shape, + const int8* filter_data, const RuntimeShape& bias_shape, + const std::int64_t* bias_data, const RuntimeShape& output_shape, + int16* output_data, const RuntimeShape& im2col_shape, int8* im2col_data, + std::int64_t* scratch_buffer) { + const int stride_width = params.stride_width; + const int stride_height = params.stride_height; + const int pad_width = params.padding_values.width; + const int pad_height = params.padding_values.height; + TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4); + TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4); + TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4); + (void)im2col_data; // only used in optimized code. + (void)im2col_shape; // only used in optimized code. + + const int batches = MatchingDim(input_shape, 0, output_shape, 0); + const int input_depth = MatchingDim(input_shape, 3, filter_shape, 3); + const int output_depth = MatchingDim(filter_shape, 0, output_shape, 3); + if (bias_data) { + TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_depth); + } + const int input_height = input_shape.Dims(1); + const int input_width = input_shape.Dims(2); + const int filter_height = filter_shape.Dims(1); + const int filter_width = filter_shape.Dims(2); + const int output_height = output_shape.Dims(1); + const int output_width = output_shape.Dims(2); + const int32 output_activation_min = std::numeric_limits::min(); + const int32 output_activation_max = std::numeric_limits::max(); + TFLITE_DCHECK_LE(output_activation_min, output_activation_max); + + const int num_elements = output_shape.FlatSize(); + // We need to initialize scratch_buffer to all 0s, as we apply the same + // 'scatter' based trick as in float version. + memset(scratch_buffer, 0, num_elements * sizeof(std::int64_t)); + + // Loop through input elements one at a time. + for (int batch = 0; batch < batches; ++batch) { + for (int in_y = 0; in_y < input_height; ++in_y) { + for (int in_x = 0; in_x < input_width; ++in_x) { + for (int in_channel = 0; in_channel < input_depth; ++in_channel) { + // Loop through the output elements it will influence. + const int out_x_origin = (in_x * stride_width) - pad_width; + const int out_y_origin = (in_y * stride_height) - pad_height; + for (int filter_y = 0; filter_y < filter_height; ++filter_y) { + for (int filter_x = 0; filter_x < filter_width; ++filter_x) { + for (int out_channel = 0; out_channel < output_depth; + ++out_channel) { + // Compute output element location. + const int out_x = out_x_origin + filter_x; + const int out_y = out_y_origin + filter_y; + // We cannot accumulate out of bounds. + if ((out_x >= 0) && (out_x < output_width) && (out_y >= 0) && + (out_y < output_height)) { + const int32 input_value = input_data[Offset( + input_shape, batch, in_y, in_x, in_channel)]; + const int32 filter_value = + filter_data[Offset(filter_shape, out_channel, filter_y, + filter_x, in_channel)]; + scratch_buffer[Offset(output_shape, batch, out_y, out_x, + out_channel)] += + input_value * filter_value; + } + } + } + } + } + } + } + } + + for (int batch = 0; batch < batches; ++batch) { + for (int out_y = 0; out_y < output_height; ++out_y) { + for (int out_x = 0; out_x < output_width; ++out_x) { + for (int out_channel = 0; out_channel < output_depth; ++out_channel) { + std::int64_t acc = scratch_buffer[Offset(output_shape, batch, out_y, + out_x, out_channel)]; + if (bias_data) { + acc += bias_data[out_channel]; + } + int32 scaled_acc = MultiplyByQuantizedMultiplier( + acc, output_multiplier[out_channel], output_shift[out_channel]); + scaled_acc = std::max(scaled_acc, output_activation_min); + scaled_acc = std::min(scaled_acc, output_activation_max); + output_data[Offset(output_shape, batch, out_y, out_x, out_channel)] = + static_cast(scaled_acc); + } + } + } + } +} + } // namespace reference_integer_ops } // namespace tflite diff --git a/tensorflow/lite/kernels/transpose_conv.cc b/tensorflow/lite/kernels/transpose_conv.cc index 9b2767f15a9..494433159d4 100644 --- a/tensorflow/lite/kernels/transpose_conv.cc +++ b/tensorflow/lite/kernels/transpose_conv.cc @@ -155,8 +155,9 @@ static TfLiteStatus AllocateTemporaryTensorsIfRequired(TfLiteContext* context, ++temporaries_count; } - // Allocate scratch buffer tensor for UInt8 inputs. - if (input_type == kTfLiteUInt8 || input_type == kTfLiteInt8) { + // Allocate scratch buffer tensor + if (input_type == kTfLiteUInt8 || input_type == kTfLiteInt8 || + input_type == kTfLiteInt16) { if (data->scratch_tensor_id == kTensorNotAllocated) { context->AddTensors(context, 1, &data->scratch_tensor_id); } @@ -227,13 +228,16 @@ TfLiteStatus ResizeAndTransposeWeights(TfLiteContext* context, GetTensorShape(transposed_weights), GetTensorData(transposed_weights)); } else if (weights->type == kTfLiteInt8) { + // int16 transpose_conv also with int8 weights optimized_ops::Transpose(transpose_params, input_shape, GetTensorData(weights), GetTensorShape(transposed_weights), GetTensorData(transposed_weights)); } else { - context->ReportError( - context, "Transpose conv only support float & uint8 & int8 right now."); + TF_LITE_KERNEL_LOG( + context, + "Only float32, uint8, int8, int16 is supported currently, got %s.", + TfLiteTypeGetName(weights->type)); return kTfLiteError; } @@ -263,9 +267,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_EQ(context, NumDimensions(output_shape), 1); TF_LITE_ENSURE_EQ(context, NumDimensions(input), 4); TF_LITE_ENSURE_EQ(context, NumDimensions(weights), 4); - TF_LITE_ENSURE(context, input->type == kTfLiteFloat32 || - input->type == kTfLiteUInt8 || - input->type == kTfLiteInt8); + TF_LITE_ENSURE(context, + input->type == kTfLiteFloat32 || input->type == kTfLiteUInt8 || + input->type == kTfLiteInt8 || input->type == kTfLiteInt16); if (has_bias) { bias = GetOptionalInputTensor(context, node, kBiasTensor); @@ -275,6 +279,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { if (input->type == kTfLiteInt8) { TF_LITE_ENSURE_EQ(context, bias->params.zero_point, 0); } + } else if (input->type == kTfLiteInt16) { + TF_LITE_ENSURE_EQ(context, bias->type, kTfLiteInt64); + TF_LITE_ENSURE_EQ(context, bias->params.zero_point, 0); } else { TF_LITE_ENSURE_EQ(context, bias->type, input->type); } @@ -283,6 +290,13 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { } } + if (input->type == kTfLiteInt16) { + TF_LITE_ENSURE_EQ(context, weights->type, kTfLiteInt8); + TF_LITE_ENSURE_EQ(context, input->params.zero_point, 0); + TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0); + } else { + TF_LITE_ENSURE_EQ(context, weights->type, input->type); + } TF_LITE_ENSURE_EQ(context, output->type, input->type); // Ensure that weights and inputs have the same channel dimension. // Note: TOCO will reorder weights in the following format: OHWI. @@ -326,12 +340,18 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { } } - if (input->type == kTfLiteUInt8 || input->type == kTfLiteInt8) { + if (input->type == kTfLiteUInt8 || input->type == kTfLiteInt8 || + input->type == kTfLiteInt16) { node->temporaries->data[data->scratch_tensor_index] = data->scratch_tensor_id; TfLiteTensor* scratch_buffer = GetTemporary(context, node, data->scratch_tensor_index); - scratch_buffer->type = kTfLiteInt32; + if (input->type == kTfLiteInt16) { + scratch_buffer->type = kTfLiteInt64; + } else { + scratch_buffer->type = kTfLiteInt32; + } + scratch_buffer->allocation_type = kTfLiteDynamic; if (!IsConstantTensor(output_shape)) { SetTensorToDynamic(scratch_buffer); @@ -500,6 +520,37 @@ void EvalQuantizedPerChannel( } } +void EvalQuantizedPerChannel16x8( + TfLiteContext* context, const TfLiteTransposeConvParams* params, + OpData* data, const TfLiteTensor* input, const TfLiteTensor* weights, + const TfLiteTensor* transposed_weights, const TfLiteTensor* bias, + TfLiteTensor* col2im, TfLiteTensor* output, TfLiteTensor* scratch_buffer) { + tflite::ConvParams op_params; + op_params.padding_type = PaddingType::kSame; + op_params.padding_values.width = data->padding.width; + op_params.padding_values.height = data->padding.height; + op_params.padding_values.width_offset = data->padding.width_offset; + op_params.padding_values.height_offset = data->padding.height_offset; + op_params.stride_width = params->stride_width; + op_params.stride_height = params->stride_height; + // Need to flip the sign of input offset to add it directly to the quantized + // buffer. + op_params.input_offset = -input->params.zero_point; + op_params.output_offset = output->params.zero_point; + op_params.quantized_activation_min = data->output_activation_min; + op_params.quantized_activation_max = data->output_activation_max; + + // Need to add optimized kernel + reference_integer_ops::TransposeConv( + op_params, data->per_channel_output_multiplier.data(), + data->per_channel_output_shift.data(), GetTensorShape(input), + GetTensorData(input), GetTensorShape(weights), + GetTensorData(weights), GetTensorShape(bias), + GetTensorData(bias), GetTensorShape(output), + GetTensorData(output), GetTensorShape(col2im), + GetTensorData(col2im), GetTensorData(scratch_buffer)); +} + template TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { // Retrieve tensors (All should be allocated by now) @@ -544,7 +595,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { filter_height, filter_width, params->padding, &unused_output_height, &unused_output_width); - // Currently support float32 and uint8. + // Currently support float32, uint8, int8, int16. switch (input->type) { case kTfLiteFloat32: { // Only for GenericOptimized path, we use transposed weights. @@ -589,6 +640,21 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { col2im, output, scratch_buffer); break; } + case kTfLiteInt16: { + TfLiteTensor* scratch_buffer = + GetTemporary(context, node, data->scratch_tensor_index); + if (IsDynamicTensor(scratch_buffer)) { + TF_LITE_ENSURE_OK(context, + ResizeTensor(context, output_shape, scratch_buffer)); + } + if (data->weights_are_transposed && !IsConstantTensor(weights)) { + ResizeAndTransposeWeights(context, weights, transposed_weights); + } + EvalQuantizedPerChannel16x8(context, params, data, input, weights, + transposed_weights, bias, col2im, output, + scratch_buffer); + break; + } default: context->ReportError(context, "Type '%s' is not currently supported.", TfLiteTypeGetName(input->type)); diff --git a/tensorflow/lite/kernels/transpose_conv_test.cc b/tensorflow/lite/kernels/transpose_conv_test.cc index 77dc22b13e8..25c55c95412 100644 --- a/tensorflow/lite/kernels/transpose_conv_test.cc +++ b/tensorflow/lite/kernels/transpose_conv_test.cc @@ -76,7 +76,10 @@ class BaseTransposeConvOpModel : public SingleOpModel { if (test_type == TestType::kDynamic) { PopulateTensor(output_shape_, output_shape_data); - PopulateTensor(filter_, filter_data); + if (!std::is_same::value && + !std::is_same::value) { + PopulateTensor(filter_, filter_data); + } } } @@ -85,6 +88,8 @@ class BaseTransposeConvOpModel : public SingleOpModel { QuantizeAndPopulate(input_, data); } else if (std::is_same::value) { QuantizeAndPopulate(input_, data); + } else if (std::is_same::value) { + QuantizeAndPopulate(input_, data); } else { PopulateTensor(input_, data); } @@ -315,6 +320,56 @@ TEST_P(TransposeConvOpTest, SimpleTestQuantized) { EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 4, 4, 1})); } +TEST_P(TransposeConvOpTest, TwoFiltersTestQuantized) { + // Float would be {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, + // 18} + std::initializer_list filter_data = {129, 131, 133, 135, 137, 139, + 141, 143, 145, 147, 149, 151, + 153, 155, 157, 159, 161, 163}; + QuantizedTransposeConvOpModel model( + GetRegistration(), {1, 4, 4, 1}, + {TensorType_UINT8, {1, 3, 3, 2}, -63.5, 64}, filter_data, + {TensorType_UINT8, {1, 4, 4, 2}, -63.5, 64}, + {TensorType_UINT8, {}, -4064, 4096}, Padding_SAME, 1, 1, GetTestType()); + model.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, + 23, 24, 25, 26, 27, 28, 29, 30, 31, 32}); + model.Invoke(); + + EXPECT_THAT(model.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear( + {192, 416, 576, 544, 672, 1344, 1696, 1440, 1504, 2720, 3072, + 2432, 1984, 3360, 3648, 2752}, + 1e-5))); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 4, 4, 1})); +} + +TEST_P(TransposeConvOpTest, PaddingValidTestQuantized) { + // Float would be {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, + // 18} + std::initializer_list filter_data = {129, 131, 133, 135, 137, 139, + 141, 143, 145, 147, 149, 151, + 153, 155, 157, 159, 161, 163}; + QuantizedTransposeConvOpModel model( + GetRegistration(), {1, 6, 6, 1}, + {TensorType_UINT8, {1, 3, 3, 2}, -63.5, 64}, filter_data, + {TensorType_UINT8, {1, 4, 4, 2}, -63.5, 64}, + {TensorType_UINT8, {}, -4064, 4096}, Padding_VALID, 1, 1, GetTestType()); + model.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, + 23, 24, 25, 26, 27, 28, 29, 30, 31, 32}); + model.Invoke(); + + EXPECT_THAT(model.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear( + {0, 32, 64, 96, 128, 96, 64, 192, 416, + 576, 544, 352, 224, 672, 1344, 1696, 1440, 864, + 608, 1504, 2720, 3072, 2432, 1440, 864, 1984, 3360, + 3648, 2752, 1536, 704, 1536, 2528, 2720, 2016, 1088}, + 1e-5))); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 6, 6, 1})); +} + class PerChannelQuantizedTransposeConvOpModel : public BaseTransposeConvOpModel { public: @@ -325,10 +380,6 @@ class PerChannelQuantizedTransposeConvOpModel GetZeroPoint(output_)); } - void SetInput(const std::initializer_list& data) { - QuantizeAndPopulate(input_, data); - } - void SetFilter(const std::initializer_list& data) { PerChannelSymmetricQuantizeAndPopulate(filter_, data); } @@ -391,54 +442,78 @@ TEST_P(TransposeConvOpTest, TestQuantizedPerChannelMultiChannel) { EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 5, 5, 2})); } -TEST_P(TransposeConvOpTest, TwoFiltersTestQuantized) { - // Float would be {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, - // 18} - std::initializer_list filter_data = {129, 131, 133, 135, 137, 139, - 141, 143, 145, 147, 149, 151, - 153, 155, 157, 159, 161, 163}; - QuantizedTransposeConvOpModel model( - GetRegistration(), {1, 4, 4, 1}, - {TensorType_UINT8, {1, 3, 3, 2}, -63.5, 64}, filter_data, - {TensorType_UINT8, {1, 4, 4, 2}, -63.5, 64}, - {TensorType_UINT8, {}, -4064, 4096}, Padding_SAME, 1, 1, GetTestType()); - model.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, - 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, - 23, 24, 25, 26, 27, 28, 29, 30, 31, 32}); +class PerChannelQuantizedTransposeConvOpModel16x8 + : public BaseTransposeConvOpModel { + public: + using BaseTransposeConvOpModel::BaseTransposeConvOpModel; + + std::vector GetDequantizedOutput() { + return Dequantize(ExtractVector(output_), + GetScale(output_), GetZeroPoint(output_)); + } + + void SetFilter(const std::initializer_list& data) { + PerChannelSymmetricQuantizeAndPopulate(filter_, data); + } +}; + +TEST_P(TransposeConvOpTest, SimpleTestQuantizedPerChannel16x8) { + const std::initializer_list filter_data = { + // [2 * 2 * 2 * 2] as [output_channel, y, x, input_channel] + 1, 2, // out channel = 0, y = 0, x = 0 + 3, 4, // out channel = 0, y = 0, x = 1 + 3, 4, // out channel = 0, y = 1, x = 0 + 5, 6, // out channel = 0, y = 1, x = 1 + 7, 8, // out channel = 1, y = 0, x = 0 + 5, 6, // out channel = 1, y = 0, x = 1 + 3, 4, // out channel = 1, y = 1, x = 0 + 1, 2, // out channel = 1, y = 1, x = 1 + }; + PerChannelQuantizedTransposeConvOpModel16x8 model( + GetRegistration(), + /*output_shape_data=*/{1, 2, 3, 2}, + /*filter=*/ + {TensorType_INT8, + /*shape=*/{2, 2, 2, 2}, + /*min=*/-64, /*max=*/64, + /*scale=*/0, /*zero_point=*/0, + /*per_channel_quantization=*/true, + /*per_channel_quantization_scales=*/{7.0 / 127, 8.0 / 127}, + /*per_channel_quantization_offsets=*/{0, 0}, + /*channel_index=*/0}, + /*filter_data=*/{}, + /*input=*/ + {TensorType_INT16, + /*shape=*/{1, 2, 3, 2}, + /*min=*/0, /*max=*/0, + /*scale=*/4.0 / 127, + /*zero_point=*/0}, + /*output=*/ + {TensorType_INT16, + /*shape=*/{}, + /*min=*/0, /*max=*/0, + /*scale=*/1.0, + /*zero_point=*/0}, + /*padding=*/Padding_SAME, + /*stride_w=*/1, /*stride_h=*/1, GetTestType()); + model.SetInput({ + // [1 * 2 * 3 * 2] as [batch, y, x, input_channel] + 3, 2, // batch = 0, y = 0, x = 0 + 1, -1, // batch = 0, y = 0, x = 1 + -2, -3, // batch = 0, y = 0, x = 2 + 4, 3, // batch = 0, y = 1, x = 0 + 2, -2, // batch = 0, y = 1, x = 1 + -3, -4, // batch = 0, y = 1, x = 2 + }); + model.SetFilter(filter_data); model.Invoke(); EXPECT_THAT(model.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear( - {192, 416, 576, 544, 672, 1344, 1696, 1440, 1504, 2720, 3072, - 2432, 1984, 3360, 3648, 2752}, - 1e-5))); - EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 4, 4, 1})); -} + {7, 37, 16, 26, -9, -39, 27, 69, 48, 42, -32, -74}, 1e-5))); -TEST_P(TransposeConvOpTest, PaddingValidTestQuantized) { - // Float would be {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, - // 18} - std::initializer_list filter_data = {129, 131, 133, 135, 137, 139, - 141, 143, 145, 147, 149, 151, - 153, 155, 157, 159, 161, 163}; - QuantizedTransposeConvOpModel model( - GetRegistration(), {1, 6, 6, 1}, - {TensorType_UINT8, {1, 3, 3, 2}, -63.5, 64}, filter_data, - {TensorType_UINT8, {1, 4, 4, 2}, -63.5, 64}, - {TensorType_UINT8, {}, -4064, 4096}, Padding_VALID, 1, 1, GetTestType()); - model.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, - 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, - 23, 24, 25, 26, 27, 28, 29, 30, 31, 32}); - model.Invoke(); - - EXPECT_THAT(model.GetDequantizedOutput(), - ElementsAreArray(ArrayFloatNear( - {0, 32, 64, 96, 128, 96, 64, 192, 416, - 576, 544, 352, 224, 672, 1344, 1696, 1440, 864, - 608, 1504, 2720, 3072, 2432, 1440, 864, 1984, 3360, - 3648, 2752, 1536, 704, 1536, 2528, 2720, 2016, 1088}, - 1e-5))); - EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 6, 6, 1})); + // GetOutputShape() should always be same as model.SetOutputShape(...); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 2, 3, 2})); } template