diff --git a/tensorflow/lite/kernels/internal/reference/integer_ops/transpose_conv.h b/tensorflow/lite/kernels/internal/reference/integer_ops/transpose_conv.h index f28b7cbddb7..284c0f21db1 100644 --- a/tensorflow/lite/kernels/internal/reference/integer_ops/transpose_conv.h +++ b/tensorflow/lite/kernels/internal/reference/integer_ops/transpose_conv.h @@ -22,13 +22,13 @@ namespace reference_integer_ops { // Fixed-point per-channel-quantization transpose convolution reference kernel. inline void TransposeConv( - const ConvParams& params, const int32* output_multiplier, - const int32* output_shift, const RuntimeShape& input_shape, - const int8* input_data, const RuntimeShape& filter_shape, - const int8* filter_data, const RuntimeShape& bias_shape, - const int32* bias_data, const RuntimeShape& output_shape, int8* output_data, - const RuntimeShape& im2col_shape, int8* im2col_data, - int32* scratch_buffer) { + const ConvParams& params, const int32_t* output_multiplier, + const int32_t* output_shift, const RuntimeShape& input_shape, + const int8_t* input_data, const RuntimeShape& filter_shape, + const int8_t* filter_data, const RuntimeShape& bias_shape, + const int32_t* bias_data, const RuntimeShape& output_shape, + int8_t* output_data, const RuntimeShape& im2col_shape, int8_t* im2col_data, + int32_t* scratch_buffer) { const int stride_width = params.stride_width; const int stride_height = params.stride_height; const int pad_width = params.padding_values.width; @@ -51,16 +51,16 @@ inline void TransposeConv( const int filter_width = filter_shape.Dims(2); const int output_height = output_shape.Dims(1); const int output_width = output_shape.Dims(2); - const int32 input_offset = params.input_offset; - const int32 output_offset = params.output_offset; - const int32 output_activation_min = std::numeric_limits::min(); - const int32 output_activation_max = std::numeric_limits::max(); + const int32_t input_offset = params.input_offset; + const int32_t output_offset = params.output_offset; + const int32_t output_activation_min = std::numeric_limits::min(); + const int32_t output_activation_max = std::numeric_limits::max(); TFLITE_DCHECK_LE(output_activation_min, output_activation_max); const int num_elements = output_shape.FlatSize(); // We need to initialize scratch_buffer to all 0s, as we apply the same // 'scatter' based trick as in float version. - memset(scratch_buffer, 0, num_elements * sizeof(int32)); + memset(scratch_buffer, 0, num_elements * sizeof(int32_t)); // Loop through input elements one at a time. for (int batch = 0; batch < batches; ++batch) { @@ -80,9 +80,9 @@ inline void TransposeConv( // We cannot accumulate out of bounds. if ((out_x >= 0) && (out_x < output_width) && (out_y >= 0) && (out_y < output_height)) { - const int8 input_value = input_data[Offset( + const int8_t input_value = input_data[Offset( input_shape, batch, in_y, in_x, in_channel)]; - const int8 filter_value = + const int8_t filter_value = filter_data[Offset(filter_shape, out_channel, filter_y, filter_x, in_channel)]; scratch_buffer[Offset(output_shape, batch, out_y, out_x, @@ -101,8 +101,8 @@ inline void TransposeConv( for (int out_y = 0; out_y < output_height; ++out_y) { for (int out_x = 0; out_x < output_width; ++out_x) { for (int out_channel = 0; out_channel < output_depth; ++out_channel) { - int32 acc = scratch_buffer[Offset(output_shape, batch, out_y, out_x, - out_channel)]; + int32_t acc = scratch_buffer[Offset(output_shape, batch, out_y, out_x, + out_channel)]; if (bias_data) { acc += bias_data[out_channel]; } @@ -119,14 +119,14 @@ inline void TransposeConv( } } -// int16 input (zero_point=0), int8 filter, int64 accumulator +// int16_t input (zero_point=0), int8_t filter, int64 accumulator inline void TransposeConv( - const ConvParams& params, const int32* output_multiplier, - const int32* output_shift, const RuntimeShape& input_shape, - const int16* input_data, const RuntimeShape& filter_shape, - const int8* filter_data, const RuntimeShape& bias_shape, + const ConvParams& params, const int32_t* output_multiplier, + const int32_t* output_shift, const RuntimeShape& input_shape, + const int16_t* input_data, const RuntimeShape& filter_shape, + const int8_t* filter_data, const RuntimeShape& bias_shape, const std::int64_t* bias_data, const RuntimeShape& output_shape, - int16* output_data, const RuntimeShape& im2col_shape, int8* im2col_data, + int16_t* output_data, const RuntimeShape& im2col_shape, int8_t* im2col_data, std::int64_t* scratch_buffer) { const int stride_width = params.stride_width; const int stride_height = params.stride_height; @@ -150,8 +150,8 @@ inline void TransposeConv( const int filter_width = filter_shape.Dims(2); const int output_height = output_shape.Dims(1); const int output_width = output_shape.Dims(2); - const int32 output_activation_min = std::numeric_limits::min(); - const int32 output_activation_max = std::numeric_limits::max(); + const int32_t output_activation_min = std::numeric_limits::min(); + const int32_t output_activation_max = std::numeric_limits::max(); TFLITE_DCHECK_LE(output_activation_min, output_activation_max); const int num_elements = output_shape.FlatSize(); @@ -177,9 +177,9 @@ inline void TransposeConv( // We cannot accumulate out of bounds. if ((out_x >= 0) && (out_x < output_width) && (out_y >= 0) && (out_y < output_height)) { - const int32 input_value = input_data[Offset( + const int32_t input_value = input_data[Offset( input_shape, batch, in_y, in_x, in_channel)]; - const int32 filter_value = + const int32_t filter_value = filter_data[Offset(filter_shape, out_channel, filter_y, filter_x, in_channel)]; scratch_buffer[Offset(output_shape, batch, out_y, out_x, @@ -203,7 +203,7 @@ inline void TransposeConv( if (bias_data) { acc += bias_data[out_channel]; } - int32 scaled_acc = MultiplyByQuantizedMultiplier( + int32_t scaled_acc = MultiplyByQuantizedMultiplier( acc, output_multiplier[out_channel], output_shift[out_channel]); scaled_acc = std::max(scaled_acc, output_activation_min); scaled_acc = std::min(scaled_acc, output_activation_max); diff --git a/tensorflow/lite/micro/kernels/BUILD b/tensorflow/lite/micro/kernels/BUILD index 6dba309a2e3..965578da0ed 100644 --- a/tensorflow/lite/micro/kernels/BUILD +++ b/tensorflow/lite/micro/kernels/BUILD @@ -132,6 +132,7 @@ cc_library( "sub.cc", "svdf_common.cc", "tanh.cc", + "transpose_conv.cc", "unpack.cc", ] + select({ "//conditions:default": [ @@ -162,10 +163,11 @@ cc_library( ], deps = [ ":activation_utils", - ":kernel_util", ":fixedpoint_utils", + ":kernel_util", ":micro_utils", ":xtensa", + "@flatbuffers", "//tensorflow/lite/c:common", "//tensorflow/lite/kernels:kernel_util", "//tensorflow/lite/kernels:op_macros", @@ -178,7 +180,6 @@ cc_library( "//tensorflow/lite/kernels/internal:types", "//tensorflow/lite/micro:memory_helpers", "//tensorflow/lite/micro:micro_utils", - "@flatbuffers", ] + select({ "//conditions:default": [], ":xtensa_hifimini": [ @@ -331,12 +332,30 @@ tflite_micro_cc_test( ], ) +cc_library( + name = "conv_test_common", + srcs = [ + "conv_test_common.cc", + ], + hdrs = [ + "conv_test.h", + ], + deps = [ + ":kernel_runner", + ":micro_ops", + "//tensorflow/lite/c:common", + "//tensorflow/lite/micro:test_helpers", + "//tensorflow/lite/micro/testing:micro_test", + ], +) + tflite_micro_cc_test( name = "conv_test", srcs = [ "conv_test.cc", ], deps = [ + ":conv_test_common", ":kernel_runner", "//tensorflow/lite/c:common", "//tensorflow/lite/micro:micro_utils", @@ -807,3 +826,18 @@ cc_test( "//tensorflow/lite/micro/testing:micro_test", ], ) + +tflite_micro_cc_test( + name = "transpose_conv_test", + srcs = [ + "transpose_conv_test.cc", + ], + deps = [ + ":conv_test_common", + ":kernel_runner", + "//tensorflow/lite/c:common", + "//tensorflow/lite/micro:micro_utils", + "//tensorflow/lite/micro:test_helpers", + "//tensorflow/lite/micro/testing:micro_test", + ], +) diff --git a/tensorflow/lite/micro/kernels/conv_test.cc b/tensorflow/lite/micro/kernels/conv_test.cc index d053e524481..d0576e80b8c 100644 --- a/tensorflow/lite/micro/kernels/conv_test.cc +++ b/tensorflow/lite/micro/kernels/conv_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/lite/micro/kernels/conv_test.h" + #include "tensorflow/lite/c/builtin_op_data.h" #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/micro/kernels/kernel_runner.h" @@ -53,182 +55,6 @@ static TfLiteConvParams common_conv_params = { 1, // dilation_height_factor }; -template -TfLiteStatus InvokeConv(TfLiteTensor* tensors, int tensors_size, T* output_data, - int output_length, TfLiteConvParams* conv_params) { - int inputs_array_data[] = {3, 0, 1, 2}; - TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data); - int outputs_array_data[] = {1, 3}; - TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data); - - const TfLiteRegistration registration = Register_CONV_2D(); - micro::KernelRunner runner( - registration, tensors, tensors_size, inputs_array, outputs_array, - reinterpret_cast(conv_params), micro_test::reporter); - - const char* init_data = reinterpret_cast(conv_params); - TfLiteStatus status = runner.InitAndPrepare(init_data); - if (status != kTfLiteOk) { - return status; - } - return runner.Invoke(); -} - -template -TfLiteStatus ValidateConvGoldens(TfLiteTensor* tensors, int tensors_size, - const T* expected_output_data, T* output_data, - int output_length, - TfLiteConvParams* conv_params, - float tolerance = 1e-5) { - TfLiteStatus status = InvokeConv(tensors, tensors_size, output_data, - output_length, conv_params); - if (status != kTfLiteOk) { - return status; - } - for (int i = 0; i < output_length; ++i) { - TF_LITE_MICRO_EXPECT_NEAR(expected_output_data[i], output_data[i], - tolerance); - } - return kTfLiteOk; -} - -#if !defined(XTENSA) // Needed to avoid build errors from unused functions. -void TestConvFloat(const int* input_dims_data, const float* input_data, - const int* filter_dims_data, const float* filter_data, - const int* bias_dims_data, const float* bias_data, - const int* output_dims_data, - const float* expected_output_data, float* output_data, - TfLiteConvParams* conv_params) { - TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data); - TfLiteIntArray* filter_dims = IntArrayFromInts(filter_dims_data); - TfLiteIntArray* bias_dims = IntArrayFromInts(bias_dims_data); - TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_data); - const int output_dims_count = ElementCount(*output_dims); - constexpr int inputs_size = 3; - constexpr int outputs_size = 1; - constexpr int tensors_size = inputs_size + outputs_size; - TfLiteTensor tensors[tensors_size] = { - CreateTensor(input_data, input_dims), - CreateTensor(filter_data, filter_dims), - CreateTensor(bias_data, bias_dims), - CreateTensor(output_data, output_dims), - }; - - TF_LITE_MICRO_EXPECT_EQ( - kTfLiteOk, - ValidateConvGoldens(tensors, tensors_size, expected_output_data, - output_data, output_dims_count, conv_params)); -} - -void TestConvQuantizedPerLayer( - const int* input_dims_data, const float* input_data, - uint8_t* input_quantized, float input_scale, const int* filter_dims_data, - const float* filter_data, uint8_t* filter_quantized, float filter_scale, - const int* bias_dims_data, const float* bias_data, int32_t* bias_quantized, - const int* output_dims_data, const float* expected_output_data, - uint8_t* expected_output_quantized, uint8_t* output_data, - float output_scale, TfLiteConvParams* conv_params) { - TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data); - TfLiteIntArray* filter_dims = IntArrayFromInts(filter_dims_data); - TfLiteIntArray* bias_dims = IntArrayFromInts(bias_dims_data); - TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_data); - const int output_dims_count = ElementCount(*output_dims); - - tflite::Quantize(expected_output_data, expected_output_quantized, - output_dims_count, output_scale, 128); - - constexpr int inputs_size = 3; - constexpr int outputs_size = 1; - constexpr int tensors_size = inputs_size + outputs_size; - TfLiteTensor tensors[tensors_size] = { - CreateQuantizedTensor(input_data, input_quantized, input_dims, - input_scale, 128), - CreateQuantizedTensor(filter_data, filter_quantized, filter_dims, - filter_scale, 128), - CreateQuantizedBiasTensor(bias_data, bias_quantized, bias_dims, - input_scale, filter_scale), - CreateQuantizedTensor(output_data, output_dims, output_scale, 128)}; - - // TODO(njeff): Affine Quantization Params should be set on tensor creation. - float filter_scales[] = {1, filter_scale}; - int filter_zero_points[] = {1, 128}; - TfLiteAffineQuantization filter_quant = {FloatArrayFromFloats(filter_scales), - IntArrayFromInts(filter_zero_points), - 0}; - tensors[1].quantization = {kTfLiteAffineQuantization, &filter_quant}; - - TF_LITE_MICRO_EXPECT_EQ( - kTfLiteOk, - ValidateConvGoldens(tensors, tensors_size, expected_output_quantized, - output_data, output_dims_count, conv_params)); -} - -void TestConvQuantizedPerChannel( - const int* input_dims_data, const float* input_data, - int8_t* input_quantized, float input_scale, int input_zero_point, - const int* filter_dims_data, const float* filter_data, - int8_t* filter_data_quantized, const int* bias_dims_data, - const float* bias_data, int32_t* bias_data_quantized, float* bias_scales, - int* bias_zero_points, const int* output_dims_data, - const float* expected_output_data, int8_t* expected_output_data_quantized, - int8_t* output_data, float output_scale, int output_zero_point, - TfLiteConvParams* conv_params) { - TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data); - TfLiteIntArray* filter_dims = IntArrayFromInts(filter_dims_data); - TfLiteIntArray* bias_dims = IntArrayFromInts(bias_dims_data); - TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_data); - const int output_dims_count = ElementCount(*output_dims); - - int filter_zero_points[5]; - float filter_scales[5]; - TfLiteAffineQuantization filter_quant; - TfLiteAffineQuantization bias_quant; - TfLiteTensor input_tensor = CreateQuantizedTensor( - input_data, input_quantized, input_dims, input_scale, input_zero_point); - TfLiteTensor filter_tensor = CreateSymmetricPerChannelQuantizedTensor( - filter_data, filter_data_quantized, filter_dims, filter_scales, - filter_zero_points, &filter_quant, 0 /* quantized dimension */); - TfLiteTensor bias_tensor = CreatePerChannelQuantizedBiasTensor( - bias_data, bias_data_quantized, bias_dims, input_scale, &filter_scales[1], - bias_scales, bias_zero_points, &bias_quant, 0 /* quantized dimension */); - TfLiteTensor output_tensor = CreateQuantizedTensor( - output_data, output_dims, output_scale, output_zero_point); - - // TODO(njeff): Affine Quantization Params should be set on tensor creation. - float input_scales[] = {1, input_scale}; - int input_zero_points[] = {1, input_zero_point}; - TfLiteAffineQuantization input_quant = {FloatArrayFromFloats(input_scales), - IntArrayFromInts(input_zero_points), - 0}; - input_tensor.quantization = {kTfLiteAffineQuantization, &input_quant}; - - float output_scales[] = {1, output_scale}; - int output_zero_points[] = {1, output_zero_point}; - TfLiteAffineQuantization output_quant = {FloatArrayFromFloats(output_scales), - IntArrayFromInts(output_zero_points), - 0}; - output_tensor.quantization = {kTfLiteAffineQuantization, &output_quant}; - - constexpr int inputs_size = 3; - constexpr int outputs_size = 1; - constexpr int tensors_size = inputs_size + outputs_size; - TfLiteTensor tensors[tensors_size] = { - input_tensor, - filter_tensor, - bias_tensor, - output_tensor, - }; - - tflite::Quantize(expected_output_data, expected_output_data_quantized, - output_dims_count, output_scale, output_zero_point); - TF_LITE_MICRO_EXPECT_EQ( - kTfLiteOk, - ValidateConvGoldens(tensors, tensors_size, expected_output_data_quantized, - output_data, output_dims_count, conv_params, - 1.0 /* tolerance */)); -} -#endif // !defined(XTENSA) - } // namespace } // namespace testing } // namespace tflite @@ -245,8 +71,9 @@ TF_LITE_MICRO_TEST(SimpleTestFloat) { tflite::testing::kInputShape, tflite::testing::kInputData, tflite::testing::kFilterShape, tflite::testing::kFilterData, tflite::testing::kBiasShape, tflite::testing::kBiasData, - tflite::testing::kOutputShape, tflite::testing::kGoldenData, output_data, - &tflite::testing::common_conv_params); + tflite::testing::kOutputShape, tflite::testing::kGoldenData, + &tflite::testing::common_conv_params, tflite::Register_CONV_2D(), + output_data); } TF_LITE_MICRO_TEST(InputAndFilterSameWidthHeight) { @@ -263,7 +90,8 @@ TF_LITE_MICRO_TEST(InputAndFilterSameWidthHeight) { tflite::testing::TestConvFloat( tflite::testing::kInputShape, tflite::testing::kInputData, kFilterShape, filter_values, kBiasShape, bias_values, kOutputShape, expected_output, - output_data, &tflite::testing::common_conv_params); + &tflite::testing::common_conv_params, tflite::Register_CONV_2D(), + output_data); } TF_LITE_MICRO_TEST(SimpleTestQuantized) { @@ -285,8 +113,8 @@ TF_LITE_MICRO_TEST(SimpleTestQuantized) { tflite::testing::kFilterData, filter_quantized, filter_scale, tflite::testing::kBiasShape, tflite::testing::kBiasData, bias_quantized, tflite::testing::kOutputShape, tflite::testing::kGoldenData, - golden_quantized, output_data, output_scale, - &tflite::testing::common_conv_params); + golden_quantized, output_scale, &tflite::testing::common_conv_params, + tflite::Register_CONV_2D(), output_data); } TF_LITE_MICRO_TEST(InputOutputDifferentTypeIsError) { @@ -312,9 +140,10 @@ TF_LITE_MICRO_TEST(InputOutputDifferentTypeIsError) { /*zero_point=*/0), }; TF_LITE_MICRO_EXPECT_EQ( - kTfLiteError, tflite::testing::InvokeConv( - tensors, tensors_size, output_data, output_dims_count, - &tflite::testing::common_conv_params)); + kTfLiteError, + tflite::testing::InvokeConv(tensors, tensors_size, output_dims_count, + &tflite::testing::common_conv_params, + tflite::Register_CONV_2D(), output_data)); } TF_LITE_MICRO_TEST(HybridModeIsError) { @@ -342,9 +171,10 @@ TF_LITE_MICRO_TEST(HybridModeIsError) { CreateTensor(output_data, output_dims), }; TF_LITE_MICRO_EXPECT_EQ( - kTfLiteError, tflite::testing::InvokeConv( - tensors, tensors_size, output_data, output_dims_count, - &tflite::testing::common_conv_params)); + kTfLiteError, + tflite::testing::InvokeConv(tensors, tensors_size, output_dims_count, + &tflite::testing::common_conv_params, + tflite::Register_CONV_2D(), output_data)); } TF_LITE_MICRO_TEST(SimpleTestDilatedQuantized) { @@ -381,7 +211,8 @@ TF_LITE_MICRO_TEST(SimpleTestDilatedQuantized) { tflite::testing::kFilterShape, tflite::testing::kFilterData, filter_quantized, filter_scale, tflite::testing::kBiasShape, tflite::testing::kBiasData, bias_quantized, output_shape, golden_data, - golden_quantized, output_data, output_scale, &conv_params); + golden_quantized, output_scale, &conv_params, tflite::Register_CONV_2D(), + output_data); } TF_LITE_MICRO_TEST(SimpleTestQuantizedPerChannel) { @@ -406,8 +237,9 @@ TF_LITE_MICRO_TEST(SimpleTestQuantizedPerChannel) { tflite::testing::kFilterShape, tflite::testing::kFilterData, filter_quantized, tflite::testing::kBiasShape, tflite::testing::kBiasData, bias_quantized, scales, zero_points, tflite::testing::kOutputShape, - tflite::testing::kGoldenData, golden_quantized, output_data, output_scale, - output_zero_point, &tflite::testing::common_conv_params); + tflite::testing::kGoldenData, golden_quantized, output_scale, + output_zero_point, &tflite::testing::common_conv_params, + tflite::Register_CONV_2D(), output_data); } TF_LITE_MICRO_TEST(SimpleTestDilatedQuantizedPerChannel) { @@ -447,8 +279,8 @@ TF_LITE_MICRO_TEST(SimpleTestDilatedQuantizedPerChannel) { tflite::testing::kFilterShape, tflite::testing::kFilterData, filter_quantized, tflite::testing::kBiasShape, tflite::testing::kBiasData, bias_quantized, scales, zero_points, output_shape, golden_data, - golden_quantized, output_data, output_scale, output_zero_point, - &conv_params); + golden_quantized, output_scale, output_zero_point, &conv_params, + tflite::Register_CONV_2D(), output_data); } TF_LITE_MICRO_TEST(SimpleTestQuantizedPerChannelRelu6) { @@ -476,8 +308,9 @@ TF_LITE_MICRO_TEST(SimpleTestQuantizedPerChannelRelu6) { tflite::testing::kFilterShape, tflite::testing::kFilterData, filter_quantized, tflite::testing::kBiasShape, bias_values, bias_quantized, scales, zero_points, tflite::testing::kOutputShape, - golden_data, golden_quantized, output_data, output_scale, - output_zero_point, &tflite::testing::common_conv_params); + golden_data, golden_quantized, output_scale, output_zero_point, + &tflite::testing::common_conv_params, tflite::Register_CONV_2D(), + output_data); } TF_LITE_MICRO_TEST(Kernel1x1QuantizedPerChannel) { @@ -525,8 +358,8 @@ TF_LITE_MICRO_TEST(Kernel1x1QuantizedPerChannel) { input_shape, input_data, input_quantized, input_scale, input_zero_point, filter_shape, filter_data, filter_quantized, bias_shape, bias_data, bias_quantized, scales, zero_points, output_shape, golden_data, - golden_quantized, output_data, output_scale, output_zero_point, - &conv_params); + golden_quantized, output_scale, output_zero_point, &conv_params, + tflite::Register_CONV_2D(), output_data); } TF_LITE_MICRO_TEST(Kernel1x1QuantizedPerChannelRelu6) { @@ -574,8 +407,8 @@ TF_LITE_MICRO_TEST(Kernel1x1QuantizedPerChannelRelu6) { input_shape, input_data, input_quantized, input_scale, input_zero_point, filter_shape, filter_data, filter_quantized, bias_shape, bias_data, bias_quantized, scales, zero_points, output_shape, golden_data, - golden_quantized, output_data, output_scale, output_zero_point, - &conv_params); + golden_quantized, output_scale, output_zero_point, &conv_params, + tflite::Register_CONV_2D(), output_data); } TF_LITE_MICRO_TEST(BroadcastPerLayerQuantizationToPerChannelShouldMatchGolden) { @@ -660,8 +493,9 @@ TF_LITE_MICRO_TEST(BroadcastPerLayerQuantizationToPerChannelShouldMatchGolden) { TF_LITE_MICRO_EXPECT_EQ( kTfLiteOk, tflite::testing::ValidateConvGoldens( - tensors, tensors_size, golden_quantized, output_data, - output_dims_count, &tflite::testing::common_conv_params)); + tensors, tensors_size, golden_quantized, output_dims_count, + &tflite::testing::common_conv_params, + tflite::Register_CONV_2D(), output_data)); } #endif // !defined(XTENSA) @@ -735,19 +569,19 @@ TF_LITE_MICRO_TEST(FilterDimsNotMatchingAffineQuantization) { // (for broadcast case) nor the quantized dimension size. quant->scale->size = 2; TF_LITE_MICRO_EXPECT_EQ( - kTfLiteError, - tflite::testing::ValidateConvGoldens( - tensors, tensors_size, golden_quantized, output_data, - output_dims_count, &tflite::testing::common_conv_params)); + kTfLiteError, tflite::testing::ValidateConvGoldens( + tensors, tensors_size, golden_quantized, + output_dims_count, &tflite::testing::common_conv_params, + tflite::Register_CONV_2D(), output_data)); // Set scale back to correct dimension, and make zero point array too short. quant->scale->size = tflite::testing::kFilterShape[0]; quant->zero_point->size = 2; TF_LITE_MICRO_EXPECT_EQ( - kTfLiteError, - tflite::testing::ValidateConvGoldens( - tensors, tensors_size, golden_quantized, output_data, - output_dims_count, &tflite::testing::common_conv_params)); + kTfLiteError, tflite::testing::ValidateConvGoldens( + tensors, tensors_size, golden_quantized, + output_dims_count, &tflite::testing::common_conv_params, + tflite::Register_CONV_2D(), output_data)); } TF_LITE_MICRO_TEST(Int8Input32x1Filter32x32ShouldMatchGolden) { @@ -881,8 +715,9 @@ TF_LITE_MICRO_TEST(Int8Input32x1Filter32x32ShouldMatchGolden) { TF_LITE_MICRO_EXPECT_EQ( kTfLiteOk, tflite::testing::ValidateConvGoldens( - tensors, kTensorsSize, golden_quantized, output_quantized, - output_dims_count, &conv_params, kQuantizationTolerance)); + tensors, kTensorsSize, golden_quantized, output_dims_count, + &conv_params, tflite::Register_CONV_2D(), output_quantized, + kQuantizationTolerance)); } TF_LITE_MICRO_TESTS_END diff --git a/tensorflow/lite/micro/kernels/conv_test.h b/tensorflow/lite/micro/kernels/conv_test.h new file mode 100644 index 00000000000..ef04307cdcc --- /dev/null +++ b/tensorflow/lite/micro/kernels/conv_test.h @@ -0,0 +1,94 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_MICRO_KERNELS_CONV_H_ +#define TENSORFLOW_LITE_MICRO_KERNELS_CONV_H_ + +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/micro/kernels/kernel_runner.h" +#include "tensorflow/lite/micro/kernels/micro_ops.h" +#include "tensorflow/lite/micro/test_helpers.h" +#include "tensorflow/lite/micro/testing/micro_test.h" + +namespace tflite { +namespace testing { + +TfLiteStatus InvokeConv(TfLiteTensor* tensors, int tensors_size, + int output_length, TfLiteConvParams* conv_params, + TfLiteRegistration registration, float* output_data); + +TfLiteStatus InvokeConv(TfLiteTensor* tensors, int tensors_size, + int output_length, TfLiteConvParams* conv_params, + TfLiteRegistration registration, int8_t* output_data); + +TfLiteStatus InvokeConv(TfLiteTensor* tensors, int tensors_size, + int output_length, TfLiteConvParams* conv_params, + TfLiteRegistration registration, uint8_t* output_data); + +TfLiteStatus ValidateConvGoldens(TfLiteTensor* tensors, int tensors_size, + const float* expected_output_data, + int output_length, + TfLiteConvParams* conv_params, + TfLiteRegistration registration, + float* output_data, float tolerance = 1e-5); + +TfLiteStatus ValidateConvGoldens(TfLiteTensor* tensors, int tensors_size, + const int8_t* expected_output_data, + int output_length, + TfLiteConvParams* conv_params, + TfLiteRegistration registration, + int8_t* output_data, float tolerance = 1e-5); + +TfLiteStatus ValidateConvGoldens(TfLiteTensor* tensors, int tensors_size, + const uint8_t* expected_output_data, + int output_length, + TfLiteConvParams* conv_params, + TfLiteRegistration registration, + uint8_t* output_data, float tolerance = 1e-5); + +void TestConvFloat(const int* input_dims_data, const float* input_data, + const int* filter_dims_data, const float* filter_data, + const int* bias_dims_data, const float* bias_data, + const int* output_dims_data, + const float* expected_output_data, + TfLiteConvParams* conv_params, + TfLiteRegistration registration, float* output_data); + +void TestConvQuantizedPerLayer( + const int* input_dims_data, const float* input_data, + uint8_t* input_quantized, float input_scale, const int* filter_dims_data, + const float* filter_data, uint8_t* filter_quantized, float filter_scale, + const int* bias_dims_data, const float* bias_data, int32_t* bias_quantized, + const int* output_dims_data, const float* expected_output_data, + uint8_t* expected_output_quantized, float output_scale, + TfLiteConvParams* conv_params, TfLiteRegistration registration, + uint8_t* output_data); + +void TestConvQuantizedPerChannel( + const int* input_dims_data, const float* input_data, + int8_t* input_quantized, float input_scale, int input_zero_point, + const int* filter_dims_data, const float* filter_data, + int8_t* filter_data_quantized, const int* bias_dims_data, + const float* bias_data, int32_t* bias_data_quantized, float* bias_scales, + int* bias_zero_points, const int* output_dims_data, + const float* expected_output_data, int8_t* expected_output_data_quantized, + float output_scale, int output_zero_point, TfLiteConvParams* conv_params, + TfLiteRegistration registration, int8_t* output_data); + +} // namespace testing +} // namespace tflite + +#endif // TENSORFLOW_LITE_MICRO_KERNELS_CONV_H_ diff --git a/tensorflow/lite/micro/kernels/conv_test_common.cc b/tensorflow/lite/micro/kernels/conv_test_common.cc new file mode 100644 index 00000000000..9b81e1961dc --- /dev/null +++ b/tensorflow/lite/micro/kernels/conv_test_common.cc @@ -0,0 +1,252 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/micro/kernels/conv_test.h" + +namespace tflite { +namespace testing { + +template +TfLiteStatus InvokeConv(TfLiteTensor* tensors, int tensors_size, + int output_length, TfLiteConvParams* conv_params, + TfLiteRegistration registration, T* output_data) { + int inputs_array_data[] = {3, 0, 1, 2}; + TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data); + int outputs_array_data[] = {1, 3}; + TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data); + + micro::KernelRunner runner( + registration, tensors, tensors_size, inputs_array, outputs_array, + reinterpret_cast(conv_params), micro_test::reporter); + + const char* init_data = reinterpret_cast(conv_params); + TfLiteStatus status = runner.InitAndPrepare(init_data); + if (status != kTfLiteOk) { + return status; + } + return runner.Invoke(); +} + +template +TfLiteStatus ValidateConvGoldens(TfLiteTensor* tensors, int tensors_size, + const T* expected_output_data, + int output_length, + TfLiteConvParams* conv_params, + TfLiteRegistration registration, + T* output_data, float tolerance) { + TfLiteStatus status = InvokeConv(tensors, tensors_size, output_length, + conv_params, registration, output_data); + if (status != kTfLiteOk) { + return status; + } + for (int i = 0; i < output_length; ++i) { + TF_LITE_MICRO_EXPECT_NEAR(expected_output_data[i], output_data[i], + tolerance); + } + return kTfLiteOk; +} + +TfLiteStatus InvokeConv(TfLiteTensor* tensors, int tensors_size, + int output_length, TfLiteConvParams* conv_params, + TfLiteRegistration registration, float* output_data) { + return InvokeConv(tensors, tensors_size, output_length, conv_params, + registration, output_data); +} + +TfLiteStatus InvokeConv(TfLiteTensor* tensors, int tensors_size, + int output_length, TfLiteConvParams* conv_params, + TfLiteRegistration registration, int8_t* output_data) { + return InvokeConv(tensors, tensors_size, output_length, conv_params, + registration, output_data); +} + +TfLiteStatus InvokeConv(TfLiteTensor* tensors, int tensors_size, + int output_length, TfLiteConvParams* conv_params, + TfLiteRegistration registration, uint8_t* output_data) { + return InvokeConv(tensors, tensors_size, output_length, conv_params, + registration, output_data); +} + +TfLiteStatus ValidateConvGoldens(TfLiteTensor* tensors, int tensors_size, + const float* expected_output_data, + int output_length, + TfLiteConvParams* conv_params, + TfLiteRegistration registration, + float* output_data, float tolerance) { + return ValidateConvGoldens(tensors, tensors_size, expected_output_data, + output_length, conv_params, registration, + output_data, tolerance); +} + +TfLiteStatus ValidateConvGoldens(TfLiteTensor* tensors, int tensors_size, + const int8_t* expected_output_data, + int output_length, + TfLiteConvParams* conv_params, + TfLiteRegistration registration, + int8_t* output_data, float tolerance) { + return ValidateConvGoldens( + tensors, tensors_size, expected_output_data, output_length, conv_params, + registration, output_data, tolerance); +} + +TfLiteStatus ValidateConvGoldens(TfLiteTensor* tensors, int tensors_size, + const uint8_t* expected_output_data, + int output_length, + TfLiteConvParams* conv_params, + TfLiteRegistration registration, + uint8_t* output_data, float tolerance) { + return ValidateConvGoldens( + tensors, tensors_size, expected_output_data, output_length, conv_params, + registration, output_data, tolerance); +} + +void TestConvFloat(const int* input_dims_data, const float* input_data, + const int* filter_dims_data, const float* filter_data, + const int* bias_dims_data, const float* bias_data, + const int* output_dims_data, + const float* expected_output_data, + TfLiteConvParams* conv_params, + TfLiteRegistration registration, float* output_data) { + TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data); + TfLiteIntArray* filter_dims = IntArrayFromInts(filter_dims_data); + TfLiteIntArray* bias_dims = IntArrayFromInts(bias_dims_data); + TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_data); + const int output_dims_count = ElementCount(*output_dims); + constexpr int inputs_size = 3; + constexpr int outputs_size = 1; + constexpr int tensors_size = inputs_size + outputs_size; + TfLiteTensor tensors[tensors_size] = { + CreateTensor(input_data, input_dims), + CreateTensor(filter_data, filter_dims), + CreateTensor(bias_data, bias_dims), + CreateTensor(output_data, output_dims), + }; + + TF_LITE_MICRO_EXPECT_EQ( + kTfLiteOk, ValidateConvGoldens(tensors, tensors_size, + expected_output_data, output_dims_count, + conv_params, registration, output_data)); +} + +void TestConvQuantizedPerLayer( + const int* input_dims_data, const float* input_data, + uint8_t* input_quantized, float input_scale, const int* filter_dims_data, + const float* filter_data, uint8_t* filter_quantized, float filter_scale, + const int* bias_dims_data, const float* bias_data, int32_t* bias_quantized, + const int* output_dims_data, const float* expected_output_data, + uint8_t* expected_output_quantized, float output_scale, + TfLiteConvParams* conv_params, TfLiteRegistration registration, + uint8_t* output_data) { + TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data); + TfLiteIntArray* filter_dims = IntArrayFromInts(filter_dims_data); + TfLiteIntArray* bias_dims = IntArrayFromInts(bias_dims_data); + TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_data); + const int output_dims_count = ElementCount(*output_dims); + + tflite::Quantize(expected_output_data, expected_output_quantized, + output_dims_count, output_scale, 128); + + constexpr int inputs_size = 3; + constexpr int outputs_size = 1; + constexpr int tensors_size = inputs_size + outputs_size; + TfLiteTensor tensors[tensors_size] = { + CreateQuantizedTensor(input_data, input_quantized, input_dims, + input_scale, 128), + CreateQuantizedTensor(filter_data, filter_quantized, filter_dims, + filter_scale, 128), + CreateQuantizedBiasTensor(bias_data, bias_quantized, bias_dims, + input_scale, filter_scale), + CreateQuantizedTensor(output_data, output_dims, output_scale, 128)}; + + float filter_scales[] = {1, filter_scale}; + int filter_zero_points[] = {1, 128}; + TfLiteAffineQuantization filter_quant = {FloatArrayFromFloats(filter_scales), + IntArrayFromInts(filter_zero_points), + 0}; + tensors[1].quantization = {kTfLiteAffineQuantization, &filter_quant}; + + TF_LITE_MICRO_EXPECT_EQ( + kTfLiteOk, + ValidateConvGoldens(tensors, tensors_size, expected_output_quantized, + output_dims_count, conv_params, registration, + output_data)); +} + +void TestConvQuantizedPerChannel( + const int* input_dims_data, const float* input_data, + int8_t* input_quantized, float input_scale, int input_zero_point, + const int* filter_dims_data, const float* filter_data, + int8_t* filter_data_quantized, const int* bias_dims_data, + const float* bias_data, int32_t* bias_data_quantized, float* bias_scales, + int* bias_zero_points, const int* output_dims_data, + const float* expected_output_data, int8_t* expected_output_data_quantized, + float output_scale, int output_zero_point, TfLiteConvParams* conv_params, + TfLiteRegistration registration, int8_t* output_data) { + TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data); + TfLiteIntArray* filter_dims = IntArrayFromInts(filter_dims_data); + TfLiteIntArray* bias_dims = IntArrayFromInts(bias_dims_data); + TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_data); + const int output_dims_count = ElementCount(*output_dims); + + int filter_zero_points[5]; + float filter_scales[5]; + TfLiteAffineQuantization filter_quant; + TfLiteAffineQuantization bias_quant; + TfLiteTensor input_tensor = CreateQuantizedTensor( + input_data, input_quantized, input_dims, input_scale, input_zero_point); + TfLiteTensor filter_tensor = CreateSymmetricPerChannelQuantizedTensor( + filter_data, filter_data_quantized, filter_dims, filter_scales, + filter_zero_points, &filter_quant, 0 /* quantized dimension */); + TfLiteTensor bias_tensor = CreatePerChannelQuantizedBiasTensor( + bias_data, bias_data_quantized, bias_dims, input_scale, &filter_scales[1], + bias_scales, bias_zero_points, &bias_quant, 0 /* quantized dimension */); + TfLiteTensor output_tensor = CreateQuantizedTensor( + output_data, output_dims, output_scale, output_zero_point); + + float input_scales[] = {1, input_scale}; + int input_zero_points[] = {1, input_zero_point}; + TfLiteAffineQuantization input_quant = {FloatArrayFromFloats(input_scales), + IntArrayFromInts(input_zero_points), + 0}; + input_tensor.quantization = {kTfLiteAffineQuantization, &input_quant}; + + float output_scales[] = {1, output_scale}; + int output_zero_points[] = {1, output_zero_point}; + TfLiteAffineQuantization output_quant = {FloatArrayFromFloats(output_scales), + IntArrayFromInts(output_zero_points), + 0}; + output_tensor.quantization = {kTfLiteAffineQuantization, &output_quant}; + + constexpr int inputs_size = 3; + constexpr int outputs_size = 1; + constexpr int tensors_size = inputs_size + outputs_size; + TfLiteTensor tensors[tensors_size] = { + input_tensor, + filter_tensor, + bias_tensor, + output_tensor, + }; + + tflite::Quantize(expected_output_data, expected_output_data_quantized, + output_dims_count, output_scale, output_zero_point); + TF_LITE_MICRO_EXPECT_EQ( + kTfLiteOk, + ValidateConvGoldens(tensors, tensors_size, expected_output_data_quantized, + output_dims_count, conv_params, registration, + output_data, 1.0 /* tolerance */)); +} + +} // namespace testing +} // namespace tflite diff --git a/tensorflow/lite/micro/kernels/micro_ops.h b/tensorflow/lite/micro/kernels/micro_ops.h index a65fc4f6a15..f73b1d59183 100644 --- a/tensorflow/lite/micro/kernels/micro_ops.h +++ b/tensorflow/lite/micro/kernels/micro_ops.h @@ -37,6 +37,7 @@ TfLiteRegistration Register_QUANTIZE(); TfLiteRegistration Register_SHAPE(); TfLiteRegistration Register_SOFTMAX(); TfLiteRegistration Register_SVDF(); +TfLiteRegistration Register_TRANSPOSE_CONV_2D(); namespace ops { namespace micro { diff --git a/tensorflow/lite/micro/kernels/transpose_conv.cc b/tensorflow/lite/micro/kernels/transpose_conv.cc new file mode 100644 index 00000000000..9d981921168 --- /dev/null +++ b/tensorflow/lite/micro/kernels/transpose_conv.cc @@ -0,0 +1,265 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/kernels/internal/reference/transpose_conv.h" + +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/kernels/internal/common.h" +#include "tensorflow/lite/kernels/internal/quantization_util.h" +#include "tensorflow/lite/kernels/internal/reference/integer_ops/transpose_conv.h" +#include "tensorflow/lite/kernels/internal/tensor_ctypes.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/kernels/padding.h" +#include "tensorflow/lite/micro/kernels/kernel_util.h" + +namespace tflite { +namespace { + +constexpr int kInputTensor = 0; +constexpr int kFilterTensor = 1; +constexpr int kBiasTensor = 2; +constexpr int kOutputTensor = 0; + +// Conv is quantized along dimension 0: +// https://www.tensorflow.org/lite/performance/quantization_spec +constexpr int kConvQuantizedDimension = 0; + +struct OpData { + ConvParams params; + + // A scratch buffer is required for quantized implementations. + int scratch_buffer_index; + + // Multiplier and shift arrays are required for the int8 implementation. + int32_t* per_channel_output_multiplier; + int32_t* per_channel_output_shift; +}; + +inline PaddingType RuntimePaddingType(TfLitePadding padding) { + switch (padding) { + case TfLitePadding::kTfLitePaddingSame: + return PaddingType::kSame; + case TfLitePadding::kTfLitePaddingValid: + return PaddingType::kValid; + case TfLitePadding::kTfLitePaddingUnknown: + default: + return PaddingType::kNone; + } +} + +TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node, + const TfLiteConvParams* params, int width, + int height, int filter_width, int filter_height, + int out_width, int out_height, + const TfLiteType data_type, OpData* data) { + bool has_bias = node->inputs->size == 3; + // Check number of inputs/outputs + TF_LITE_ENSURE(context, has_bias || node->inputs->size == 2); + TF_LITE_ENSURE_EQ(context, node->outputs->size, 1); + + // Matching GetWindowedOutputSize in TensorFlow. + auto padding = params->padding; + TfLitePaddingValues padding_values = ComputePaddingHeightWidth( + params->stride_height, params->stride_width, + params->dilation_height_factor, params->dilation_width_factor, height, + width, filter_height, filter_width, padding, &out_height, &out_width); + + data->params.padding_type = RuntimePaddingType(padding); + data->params.padding_values.width = padding_values.width; + data->params.padding_values.height = padding_values.height; + + // Note that quantized inference requires that all tensors have their + // parameters set. This is usually done during quantized training. + if (data_type != kTfLiteFloat32) { + const TfLiteTensor* input = GetInput(context, node, kInputTensor); + TF_LITE_ENSURE(context, input != nullptr); + const TfLiteTensor* filter = GetInput(context, node, kFilterTensor); + TF_LITE_ENSURE(context, filter != nullptr); + const TfLiteTensor* bias = + GetOptionalInputTensor(context, node, kBiasTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + TF_LITE_ENSURE(context, output != nullptr); + int output_channels = filter->dims->data[kConvQuantizedDimension]; + + TF_LITE_ENSURE_STATUS(tflite::PopulateConvolutionQuantizationParams( + context, input, filter, bias, output, params->activation, + &data->params.output_multiplier, &data->params.output_shift, + &data->params.quantized_activation_min, + &data->params.quantized_activation_max, + data->per_channel_output_multiplier, + reinterpret_cast(data->per_channel_output_shift), + output_channels)); + } + return kTfLiteOk; +} + +void* Init(TfLiteContext* context, const char* buffer, size_t length) { + TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr); + return context->AllocatePersistentBuffer(context, sizeof(OpData)); +} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TFLITE_DCHECK(node->user_data != nullptr); + TFLITE_DCHECK(node->builtin_data != nullptr); + + OpData* data = static_cast(node->user_data); + const auto params = static_cast(node->builtin_data); + + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + TF_LITE_ENSURE(context, output != nullptr); + const TfLiteTensor* input = GetInput(context, node, kInputTensor); + TF_LITE_ENSURE(context, input != nullptr); + const TfLiteTensor* filter = GetInput(context, node, kFilterTensor); + TF_LITE_ENSURE(context, filter != nullptr); + + int input_width = input->dims->data[2]; + int input_height = input->dims->data[1]; + int filter_width = filter->dims->data[2]; + int filter_height = filter->dims->data[1]; + int output_width = output->dims->data[2]; + int output_height = output->dims->data[1]; + + // Dynamically allocate per-channel quantization parameters. + const int num_channels = filter->dims->data[kConvQuantizedDimension]; + data->per_channel_output_multiplier = + static_cast(context->AllocatePersistentBuffer( + context, num_channels * sizeof(int32_t))); + data->per_channel_output_shift = + static_cast(context->AllocatePersistentBuffer( + context, num_channels * sizeof(int32_t))); + + // Quantized kernels use an int32 scratch buffer. + if (input->type == kTfLiteUInt8 || input->type == kTfLiteInt8) { + TFLITE_DCHECK(context->RequestScratchBufferInArena != nullptr); + TFLITE_DCHECK(context->RequestScratchBufferInArena( + context, + GetTensorShape(output).FlatSize() * sizeof(int32_t), + &(data->scratch_buffer_index)) == kTfLiteOk); + } + + // All per-channel quantized tensors need valid zero point and scale arrays. + if (input->type == kTfLiteInt8) { + TF_LITE_ENSURE_EQ(context, filter->quantization.type, + kTfLiteAffineQuantization); + + const auto* affine_quantization = + static_cast(filter->quantization.params); + TF_LITE_ENSURE(context, affine_quantization); + TF_LITE_ENSURE(context, affine_quantization->scale); + TF_LITE_ENSURE(context, affine_quantization->zero_point); + + TF_LITE_ENSURE(context, + affine_quantization->scale->size == 1 || + affine_quantization->scale->size == + filter->dims->data[kConvQuantizedDimension]); + TF_LITE_ENSURE_EQ(context, affine_quantization->scale->size, + affine_quantization->zero_point->size); + } + + TF_LITE_ENSURE_STATUS(CalculateOpData( + context, node, params, input_width, input_height, filter_width, + filter_height, output_width, output_height, input->type, data)); + + // Offsets (zero points) + data->params.input_offset = -input->params.zero_point; + data->params.weights_offset = -filter->params.zero_point; + data->params.output_offset = output->params.zero_point; + + // Stride + dilation + data->params.stride_width = params->stride_width; + data->params.stride_height = params->stride_height; + data->params.dilation_width_factor = params->dilation_width_factor; + data->params.dilation_height_factor = params->dilation_height_factor; + + float output_activation_min, output_activation_max; + CalculateActivationRange(params->activation, &output_activation_min, + &output_activation_max); + data->params.float_activation_min = output_activation_min; + data->params.float_activation_max = output_activation_max; + return kTfLiteOk; +} // namespace conv + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + const TfLiteEvalTensor* input = + tflite::micro::GetEvalInput(context, node, kInputTensor); + const TfLiteEvalTensor* filter = + tflite::micro::GetEvalInput(context, node, kFilterTensor); + const TfLiteEvalTensor* bias = + (NumInputs(node) == 3) + ? tflite::micro::GetEvalInput(context, node, kBiasTensor) + : nullptr; + TfLiteEvalTensor* output = + tflite::micro::GetEvalOutput(context, node, kOutputTensor); + + TFLITE_DCHECK(node->user_data != nullptr); + const OpData& data = *(static_cast(node->user_data)); + + TF_LITE_ENSURE_EQ(context, input->type, output->type); + TF_LITE_ENSURE_MSG(context, input->type == filter->type, + "Hybrid models are not supported on TFLite Micro."); + + switch (input->type) { // Already know in/out types are same. + case kTfLiteFloat32: { + reference_ops::TransposeConv( + data.params, tflite::micro::GetTensorShape(input), + tflite::micro::GetTensorData(input), + tflite::micro::GetTensorShape(filter), + tflite::micro::GetTensorData(filter), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetTensorData(bias), + tflite::micro::GetTensorShape(output), + tflite::micro::GetTensorData(output), + tflite::micro::GetTensorShape(nullptr), nullptr); + break; + } + case kTfLiteInt8: { + int32_t* scratch_buffer = static_cast( + context->GetScratchBuffer(context, data.scratch_buffer_index)); + reference_integer_ops::TransposeConv( + data.params, data.per_channel_output_multiplier, + data.per_channel_output_shift, tflite::micro::GetTensorShape(input), + tflite::micro::GetTensorData(input), + tflite::micro::GetTensorShape(filter), + tflite::micro::GetTensorData(filter), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetTensorData(bias), + tflite::micro::GetTensorShape(output), + tflite::micro::GetTensorData(output), + tflite::micro::GetTensorShape(nullptr), nullptr, scratch_buffer); + break; + } + default: + TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.", + TfLiteTypeGetName(input->type), input->type); + return kTfLiteError; + } + return kTfLiteOk; +} + +} // namespace + +TfLiteRegistration Register_TRANSPOSE_CONV_2D() { + return {/*init=*/Init, + /*free=*/nullptr, + /*prepare=*/Prepare, + /*invoke=*/Eval, + /*profiling_string=*/nullptr, + /*builtin_code=*/0, + /*custom_name=*/nullptr, + /*version=*/0}; +} + +} // namespace tflite diff --git a/tensorflow/lite/micro/kernels/transpose_conv_test.cc b/tensorflow/lite/micro/kernels/transpose_conv_test.cc new file mode 100644 index 00000000000..954fdc92ae3 --- /dev/null +++ b/tensorflow/lite/micro/kernels/transpose_conv_test.cc @@ -0,0 +1,162 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/micro/kernels/conv_test.h" +#include "tensorflow/lite/micro/kernels/kernel_runner.h" +#include "tensorflow/lite/micro/micro_utils.h" +#include "tensorflow/lite/micro/test_helpers.h" +#include "tensorflow/lite/micro/testing/micro_test.h" + +namespace tflite { +namespace testing { +namespace { + +// Common inputs and outputs. +constexpr int kInputElements = 32; +static const int kInputShape[] = {4, 1, 4, 4, 2}; +static const float kInputData[kInputElements] = { + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, + 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32}; + +constexpr int kFilterElements = 18; +static const int kFilterShape[] = {4, 1, 3, 3, 2}; +static const float kFilterData[kFilterElements] = { + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18}; + +constexpr int kBiasElements = 1; +static const int kBiasShape[] = {4, 1, 1, 1, 1}; +static const float kBiasData[kBiasElements] = {0}; + +constexpr int kOutputElements = 16; +static const int kOutputShape[] = {4, 1, 4, 4, 1}; +static const float kGoldenData[kOutputElements] = { + 184, 412, 568, 528, 678, 1347, 1689, 1434, + 1494, 2715, 3057, 2442, 1968, 3352, 3652, 2760}; + +// Transpose conv uses TfLiteConvParams. +static TfLiteConvParams common_conv_params = {kTfLitePaddingSame, // padding + 1, // stride_width + 1, // stride_height + kTfLiteActNone, + 1, + 1}; + +} // namespace +} // namespace testing +} // namespace tflite + +TF_LITE_MICRO_TESTS_BEGIN + +TF_LITE_MICRO_TEST(SimpleTestFloat) { + float output_data[tflite::testing::kOutputElements]; + + tflite::testing::TestConvFloat( + tflite::testing::kInputShape, tflite::testing::kInputData, + tflite::testing::kFilterShape, tflite::testing::kFilterData, + tflite::testing::kBiasShape, tflite::testing::kBiasData, + tflite::testing::kOutputShape, tflite::testing::kGoldenData, + &tflite::testing::common_conv_params, + tflite::Register_TRANSPOSE_CONV_2D(), output_data); +} + +TF_LITE_MICRO_TEST(SimpleTestQuantizedPerChannel) { + int8_t output_data[tflite::testing::kOutputElements]; + + const float input_scale = 0.5f; + const float output_scale = 1.0f; + const int input_zero_point = 0; + const int output_zero_point = 0; + + int8_t input_quantized[tflite::testing::kInputElements]; + int8_t filter_quantized[tflite::testing::kFilterElements]; + int32_t bias_quantized[tflite::testing::kBiasElements]; + int8_t golden_quantized[tflite::testing::kOutputElements]; + int zero_points[tflite::testing::kBiasElements + 1]; + float scales[tflite::testing::kBiasElements + 1]; + + tflite::testing::TestConvQuantizedPerChannel( + tflite::testing::kInputShape, tflite::testing::kInputData, + input_quantized, input_scale, input_zero_point, + tflite::testing::kFilterShape, tflite::testing::kFilterData, + filter_quantized, tflite::testing::kBiasShape, tflite::testing::kBiasData, + bias_quantized, scales, zero_points, tflite::testing::kOutputShape, + tflite::testing::kGoldenData, golden_quantized, output_scale, + output_zero_point, &tflite::testing::common_conv_params, + tflite::Register_TRANSPOSE_CONV_2D(), output_data); +} + +TF_LITE_MICRO_TEST(InputOutputDifferentTypeIsError) { + using tflite::testing::CreateQuantizedTensor; + using tflite::testing::CreateTensor; + using tflite::testing::IntArrayFromInts; + + TfLiteIntArray* input_dims = IntArrayFromInts(tflite::testing::kInputShape); + TfLiteIntArray* filter_dims = IntArrayFromInts(tflite::testing::kFilterShape); + TfLiteIntArray* bias_dims = IntArrayFromInts(tflite::testing::kBiasShape); + TfLiteIntArray* output_dims = IntArrayFromInts(tflite::testing::kOutputShape); + const int output_dims_count = tflite::ElementCount(*output_dims); + constexpr int inputs_size = 3; + constexpr int outputs_size = 1; + constexpr int tensors_size = inputs_size + outputs_size; + + int8_t output_data[tflite::testing::kOutputElements]; + TfLiteTensor tensors[tensors_size] = { + CreateTensor(tflite::testing::kInputData, input_dims), + CreateTensor(tflite::testing::kFilterData, filter_dims), + CreateTensor(tflite::testing::kBiasData, bias_dims), + CreateQuantizedTensor(output_data, output_dims, /*scale=*/1.0f, + /*zero_point=*/0), + }; + TF_LITE_MICRO_EXPECT_EQ( + kTfLiteError, tflite::testing::InvokeConv( + tensors, tensors_size, output_dims_count, + &tflite::testing::common_conv_params, + tflite::Register_TRANSPOSE_CONV_2D(), output_data)); +} + +TF_LITE_MICRO_TEST(HybridModeIsError) { + using tflite::testing::CreateQuantizedTensor; + using tflite::testing::CreateTensor; + using tflite::testing::IntArrayFromInts; + + TfLiteIntArray* input_dims = IntArrayFromInts(tflite::testing::kInputShape); + TfLiteIntArray* filter_dims = IntArrayFromInts(tflite::testing::kFilterShape); + TfLiteIntArray* bias_dims = IntArrayFromInts(tflite::testing::kBiasShape); + TfLiteIntArray* output_dims = IntArrayFromInts(tflite::testing::kOutputShape); + const int output_dims_count = tflite::ElementCount(*output_dims); + constexpr int inputs_size = 3; + constexpr int outputs_size = 1; + constexpr int tensors_size = inputs_size + outputs_size; + + int8_t filter_data[tflite::testing::kFilterElements] = {}; + float output_data[tflite::testing::kOutputElements]; + TfLiteTensor tensors[tensors_size] = { + CreateTensor(tflite::testing::kInputData, input_dims), + CreateQuantizedTensor(filter_data, filter_dims, + /*scale=*/1.0f, + /*zero_point=*/0), + CreateTensor(tflite::testing::kBiasData, bias_dims), + CreateTensor(output_data, output_dims), + }; + TF_LITE_MICRO_EXPECT_EQ( + kTfLiteError, tflite::testing::InvokeConv( + tensors, tensors_size, output_dims_count, + &tflite::testing::common_conv_params, + tflite::Register_TRANSPOSE_CONV_2D(), output_data)); +} + +TF_LITE_MICRO_TESTS_END diff --git a/tensorflow/lite/micro/tools/make/Makefile b/tensorflow/lite/micro/tools/make/Makefile index 3e5452649ae..0e46139aa19 100644 --- a/tensorflow/lite/micro/tools/make/Makefile +++ b/tensorflow/lite/micro/tools/make/Makefile @@ -305,6 +305,7 @@ tensorflow/lite/micro/kernels/strided_slice_test.cc \ tensorflow/lite/micro/kernels/sub_test.cc \ tensorflow/lite/micro/kernels/svdf_test.cc \ tensorflow/lite/micro/kernels/tanh_test.cc \ +tensorflow/lite/micro/kernels/transpose_conv_test.cc \ tensorflow/lite/micro/kernels/unpack_test.cc \ tensorflow/lite/micro/memory_planner/greedy_memory_planner_test.cc \ tensorflow/lite/micro/memory_planner/linear_memory_planner_test.cc @@ -318,6 +319,7 @@ tensorflow/lite/micro/kernels/circular_buffer.cc \ tensorflow/lite/micro/kernels/comparisons.cc \ tensorflow/lite/micro/kernels/concatenation.cc \ tensorflow/lite/micro/kernels/conv.cc \ +tensorflow/lite/micro/kernels/conv_test_common.cc \ tensorflow/lite/micro/kernels/depthwise_conv.cc \ tensorflow/lite/micro/kernels/dequantize.cc \ tensorflow/lite/micro/kernels/detection_postprocess.cc \ @@ -354,6 +356,7 @@ tensorflow/lite/micro/kernels/sub.cc \ tensorflow/lite/micro/kernels/svdf.cc \ tensorflow/lite/micro/kernels/svdf_common.cc \ tensorflow/lite/micro/kernels/tanh.cc \ +tensorflow/lite/micro/kernels/transpose_conv.cc \ tensorflow/lite/micro/kernels/unpack.cc MICROLITE_TEST_HDRS := \ @@ -418,6 +421,7 @@ tensorflow/lite/kernels/internal/reference/integer_ops/mean.h \ tensorflow/lite/kernels/internal/reference/integer_ops/mul.h \ tensorflow/lite/kernels/internal/reference/integer_ops/pooling.h \ tensorflow/lite/kernels/internal/reference/integer_ops/tanh.h \ +tensorflow/lite/kernels/internal/reference/integer_ops/transpose_conv.h \ tensorflow/lite/kernels/internal/reference/l2normalization.h \ tensorflow/lite/kernels/internal/reference/maximum_minimum.h \ tensorflow/lite/kernels/internal/reference/mul.h \ @@ -436,6 +440,7 @@ tensorflow/lite/kernels/internal/reference/sub.h \ tensorflow/lite/kernels/internal/reference/logistic.h \ tensorflow/lite/kernels/internal/reference/strided_slice.h \ tensorflow/lite/kernels/internal/reference/tanh.h \ +tensorflow/lite/kernels/internal/reference/transpose_conv.h \ tensorflow/lite/kernels/internal/cppmath.h \ tensorflow/lite/kernels/internal/max.h \ tensorflow/lite/kernels/internal/min.h \