diff --git a/tensorflow/lite/kernels/internal/reference/reference_ops.h b/tensorflow/lite/kernels/internal/reference/reference_ops.h index a028ab10580..2d4abf00e97 100644 --- a/tensorflow/lite/kernels/internal/reference/reference_ops.h +++ b/tensorflow/lite/kernels/internal/reference/reference_ops.h @@ -19,6 +19,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -4049,6 +4050,93 @@ inline void TransposeConv( } } +inline void TransposeConv(const ConvParams& params, + const RuntimeShape& input_shape, + const uint8* input_data, + const RuntimeShape& filter_shape, + const uint8* filter_data, + const RuntimeShape& output_shape, uint8* output_data, + const RuntimeShape& im2col_shape, uint8* im2col_data, + int32* scratch_buffer) { + const int stride_width = params.stride_width; + const int stride_height = params.stride_height; + const int pad_width = params.padding_values.width; + const int pad_height = params.padding_values.height; + TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4); + TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4); + TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4); + (void)im2col_data; // only used in optimized code. + (void)im2col_shape; // only used in optimized code. + + const int batches = MatchingDim(input_shape, 0, output_shape, 0); + const int input_depth = MatchingDim(input_shape, 3, filter_shape, 3); + const int output_depth = MatchingDim(filter_shape, 0, output_shape, 3); + const int input_height = input_shape.Dims(1); + const int input_width = input_shape.Dims(2); + const int filter_height = filter_shape.Dims(1); + const int filter_width = filter_shape.Dims(2); + const int output_height = output_shape.Dims(1); + const int output_width = output_shape.Dims(2); + const int32 input_offset = params.input_offset; + const int32 filter_offset = params.weights_offset; + const int32 output_offset = params.output_offset; + const int32 output_multiplier = params.output_multiplier; + const int output_shift = params.output_shift; + const int32 output_activation_min = params.quantized_activation_min; + const int32 output_activation_max = params.quantized_activation_max; + TFLITE_DCHECK_LE(output_activation_min, output_activation_max); + + const int num_elements = output_shape.FlatSize(); + // We need to initialize scratch_buffer to all 0s, as we apply the same + // 'scatter' based trick as in float version. + memset(scratch_buffer, 0, num_elements * sizeof(int32)); + + // Loop through input elements one at a time. + for (int batch = 0; batch < batches; ++batch) { + for (int in_y = 0; in_y < input_height; ++in_y) { + for (int in_x = 0; in_x < input_width; ++in_x) { + for (int in_channel = 0; in_channel < input_depth; ++in_channel) { + // Loop through the output elements it will influence. + const int out_x_origin = (in_x * stride_width) - pad_width; + const int out_y_origin = (in_y * stride_height) - pad_height; + for (int filter_y = 0; filter_y < filter_height; ++filter_y) { + for (int filter_x = 0; filter_x < filter_width; ++filter_x) { + for (int out_channel = 0; out_channel < output_depth; + ++out_channel) { + // Compute output element location. + const int out_x = out_x_origin + filter_x; + const int out_y = out_y_origin + filter_y; + // We cannot accumulate out of bounds. + if ((out_x >= 0) && (out_x < output_width) && (out_y >= 0) && + (out_y < output_height)) { + uint8 input_value = input_data[Offset( + input_shape, batch, in_y, in_x, in_channel)]; + uint8 filter_value = + filter_data[Offset(filter_shape, out_channel, filter_y, + filter_x, in_channel)]; + scratch_buffer[Offset(output_shape, batch, out_y, out_x, + out_channel)] += + (input_value + input_offset) * + (filter_value + filter_offset); + } + } + } + } + } + } + } + } + for (int i = 0; i < num_elements; ++i) { + int32 acc = scratch_buffer[i]; + acc = MultiplyByQuantizedMultiplier(acc, output_multiplier, output_shift); + acc += output_offset; + // Clamp the output before converting back to uint8. + acc = std::max(acc, output_activation_min); + acc = std::min(acc, output_activation_max); + output_data[i] = static_cast(acc); + } +} + template inline bool EqualFn(T lhs, T rhs) { return lhs == rhs; diff --git a/tensorflow/lite/kernels/kernel_util.cc b/tensorflow/lite/kernels/kernel_util.cc index deb484b70f0..7ff61ac51ad 100644 --- a/tensorflow/lite/kernels/kernel_util.cc +++ b/tensorflow/lite/kernels/kernel_util.cc @@ -111,15 +111,22 @@ TfLiteStatus GetQuantizedConvolutionMultipler(TfLiteContext* context, double* multiplier) { const double input_product_scale = input->params.scale * filter->params.scale; const double bias_scale = bias->params.scale; - const double output_scale = output->params.scale; - // TODO(ahentz): The following conditions must be guaranteed by the training // pipeline. TF_LITE_ENSURE(context, std::abs(input_product_scale - bias_scale) <= 1e-6 * std::min(input_product_scale, bias_scale)); - TF_LITE_ENSURE(context, input_product_scale >= 0); + return GetQuantizedConvolutionMultipler(context, input, filter, output, + multiplier); +} - *multiplier = input_product_scale / output_scale; +TfLiteStatus GetQuantizedConvolutionMultipler(TfLiteContext* context, + const TfLiteTensor* input, + const TfLiteTensor* filter, + TfLiteTensor* output, + double* multiplier) { + const double input_product_scale = input->params.scale * filter->params.scale; + TF_LITE_ENSURE(context, input_product_scale >= 0); + *multiplier = input_product_scale / output->params.scale; return kTfLiteOk; } diff --git a/tensorflow/lite/kernels/kernel_util.h b/tensorflow/lite/kernels/kernel_util.h index 94c9842b474..24a3438f88a 100644 --- a/tensorflow/lite/kernels/kernel_util.h +++ b/tensorflow/lite/kernels/kernel_util.h @@ -113,6 +113,12 @@ TfLiteStatus GetQuantizedConvolutionMultipler(TfLiteContext* context, TfLiteTensor* output, double* multiplier); +TfLiteStatus GetQuantizedConvolutionMultipler(TfLiteContext* context, + const TfLiteTensor* input, + const TfLiteTensor* filter, + TfLiteTensor* output, + double* multiplier); + // Calculates the useful quantized range of an activation layer given its // activation tensor. TfLiteStatus CalculateActivationRangeQuantized(TfLiteContext* context, diff --git a/tensorflow/lite/kernels/transpose_conv.cc b/tensorflow/lite/kernels/transpose_conv.cc index 343f2ca59ba..af6df74091e 100644 --- a/tensorflow/lite/kernels/transpose_conv.cc +++ b/tensorflow/lite/kernels/transpose_conv.cc @@ -54,6 +54,19 @@ struct OpData { // im2col is the only temporary currently tracked, therefore always index 0. // If more temporaries are added, they should be properly tracked. int32_t im2col_index = 0; + + TfLitePaddingValues padding; + // The scaling factor from input to output (aka the 'real multiplier') can + // be represented as a fixed point multiplier plus a left shift. + int32_t output_multiplier; + int output_shift; + + // The range of the fused activation layer. For example for kNone and + // uint8_t these would be 0 and 255. + int32_t output_activation_min; + int32_t output_activation_max; + + int scratch_tensor_index; }; void* Init(TfLiteContext* context, const char* buffer, size_t length) { @@ -61,6 +74,9 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) { // Instead, we allocate a new object to use as scratch space for im2col, and // to carry information from Prepare() to Eval(). auto* data = new OpData; + // Populate scratch_tensor_index. + context->AddTensors(context, /*tensors_to_add=*/1, + &data->scratch_tensor_index); eigen_support::IncrementUsageCounter(context); return data; } @@ -70,9 +86,9 @@ void Free(TfLiteContext* context, void* buffer) { delete reinterpret_cast(buffer); } -TfLiteStatus ResizeOutputTensor(TfLiteContext* context, - const TfLiteTensor* shape_tensor, - TfLiteTensor* output) { +TfLiteStatus ResizeTensor(TfLiteContext* context, + const TfLiteTensor* shape_tensor, + TfLiteTensor* tensor_to_resize) { // Currently only support int32 for output shape. if (shape_tensor->type != kTfLiteInt32) { context->ReportError(context, "Output shape is %d, not int32.", @@ -85,7 +101,7 @@ TfLiteStatus ResizeOutputTensor(TfLiteContext* context, shape->data[i] = GetTensorData(shape_tensor)[i]; } - return context->ResizeTensor(context, output, shape); + return context->ResizeTensor(context, tensor_to_resize, shape); } static TfLiteStatus AllocateIm2colTensorIfRequired(TfLiteContext* context, @@ -129,6 +145,8 @@ TfLiteStatus ResizeIm2ColTensor(TfLiteContext* context, } TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + OpData* data = reinterpret_cast(node->user_data); + // Sanity checks on op TF_LITE_ENSURE_EQ(context, NumInputs(node), 3); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); @@ -150,9 +168,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_EQ(context, NumDimensions(output_shape), 1); TF_LITE_ENSURE_EQ(context, NumDimensions(input), 4); TF_LITE_ENSURE_EQ(context, NumDimensions(weights), 4); - TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32); - TF_LITE_ENSURE_EQ(context, weights->type, kTfLiteFloat32); - TF_LITE_ENSURE_EQ(context, output->type, kTfLiteFloat32); + TF_LITE_ENSURE(context, + input->type == kTfLiteFloat32 || input->type == kTfLiteUInt8); + TF_LITE_ENSURE_EQ(context, weights->type, input->type); + TF_LITE_ENSURE_EQ(context, output->type, input->type); // Ensure that weights and inputs have the same channel dimension. // Note: TOCO will reorder weights in the following format: OHWI. TF_LITE_ENSURE_EQ(context, SizeOfDimension(input, 3), @@ -163,13 +182,103 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { SetTensorToDynamic(output); SetTensorToDynamic(im2col); } else { - TF_LITE_ENSURE_STATUS(ResizeOutputTensor(context, output_shape, output)); + TF_LITE_ENSURE_STATUS(ResizeTensor(context, output_shape, output)); TF_LITE_ENSURE_STATUS( ResizeIm2ColTensor(context, output_shape, weights, input, im2col)); } + + if (input->type == kTfLiteUInt8) { + // Set up a scratch buffer tensor. + TfLiteIntArrayFree(node->temporaries); + node->temporaries = TfLiteIntArrayCreate(1); + node->temporaries->data[0] = data->scratch_tensor_index; + TfLiteTensor* scratch_buffer = GetTemporary(context, node, /*index=*/0); + scratch_buffer->type = kTfLiteInt32; + scratch_buffer->allocation_type = kTfLiteArenaRw; + if (!IsConstantTensor(output_shape)) { + SetTensorToDynamic(scratch_buffer); + } else { + TF_LITE_ENSURE_STATUS( + ResizeTensor(context, output_shape, scratch_buffer)); + } + + // Calcuate output multiplier for quantization. + double real_multiplier = 0.0; + TF_LITE_ENSURE_STATUS(GetQuantizedConvolutionMultipler( + context, input, weights, output, &real_multiplier)); + int exponent; + // Populate quantization parameteters with multiplier and shift. + QuantizeMultiplier(real_multiplier, &data->output_multiplier, &exponent); + data->output_shift = -exponent; + // Populate max and min activation range. + CalculateActivationRangeUint8(kTfLiteActNone, output, + &data->output_activation_min, + &data->output_activation_max); + } return kTfLiteOk; } +template +void EvalFloat(const TfLiteTransposeConvParams* params, const OpData* data, + const TfLiteTensor* input, const TfLiteTensor* weights, + TfLiteTensor* im2col, TfLiteTensor* output) { + tflite::ConvParams op_params; + op_params.padding_type = PaddingType::kSame; + op_params.padding_values.width = data->padding.width; + op_params.padding_values.height = data->padding.height; + op_params.stride_width = params->stride_width; + op_params.stride_height = params->stride_height; + switch (kernel_type) { + case kReference: { + reference_ops::TransposeConv( + op_params, GetTensorShape(input), GetTensorData(input), + GetTensorShape(weights), GetTensorData(weights), + GetTensorShape(output), GetTensorData(output), + GetTensorShape(im2col), GetTensorData(im2col)); + break; + } + case kGenericOptimized: { + optimized_ops::TransposeConv( + op_params, GetTensorShape(input), GetTensorData(input), + GetTensorShape(weights), GetTensorData(weights), + GetTensorShape(output), GetTensorData(output), + GetTensorShape(im2col), GetTensorData(im2col)); + break; + } + } +} + +void EvalQuantized(const TfLiteTransposeConvParams* params, OpData* data, + const TfLiteTensor* input, const TfLiteTensor* weights, + TfLiteTensor* im2col, TfLiteTensor* output, + TfLiteTensor* scratch_buffer) { + int32_t input_offset = -input->params.zero_point; + int32_t filter_offset = -weights->params.zero_point; + int32_t output_offset = output->params.zero_point; + + tflite::ConvParams op_params; + op_params.padding_type = PaddingType::kSame; + op_params.padding_values.width = data->padding.width; + op_params.padding_values.height = data->padding.height; + op_params.stride_width = params->stride_width; + op_params.stride_height = params->stride_height; + op_params.input_offset = input_offset; + op_params.output_offset = output_offset; + op_params.weights_offset = filter_offset; + op_params.output_multiplier = data->output_multiplier; + op_params.output_shift = -data->output_shift; + op_params.quantized_activation_min = data->output_activation_min; + op_params.quantized_activation_max = data->output_activation_max; + + // TODO(haoliang): Add optimized implementation later. + reference_ops::TransposeConv( + op_params, GetTensorShape(input), GetTensorData(input), + GetTensorShape(weights), GetTensorData(weights), + GetTensorShape(output), GetTensorData(output), + GetTensorShape(im2col), GetTensorData(im2col), + GetTensorData(scratch_buffer)); +} + template TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { // Retrieve tensors (All should be allocated by now) @@ -178,16 +287,15 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { const TfLiteTensor* weights = GetInput(context, node, kWeightsTensor); const TfLiteTensor* input = GetInput(context, node, kDataInputTensor); TfLiteTensor* output = GetOutput(context, node, kOutputTensor); - OpData* user_data = reinterpret_cast(node->user_data); + OpData* data = reinterpret_cast(node->user_data); TfLiteTensor* im2col = - &context->tensors[node->temporaries->data[user_data->im2col_index]]; + &context->tensors[node->temporaries->data[data->im2col_index]]; const auto* params = reinterpret_cast(node->builtin_data); // Resize any deferred dynamic tensors if (IsDynamicTensor(output)) { - TF_LITE_ENSURE_OK(context, - ResizeOutputTensor(context, output_shape, output)); + TF_LITE_ENSURE_OK(context, ResizeTensor(context, output_shape, output)); } if (IsDynamicTensor(im2col)) { TF_LITE_ENSURE_OK(context, ResizeIm2ColTensor(context, output_shape, @@ -200,45 +308,31 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { const int filter_width = SizeOfDimension(weights, 2); const int filter_height = SizeOfDimension(weights, 1); - const int stride_width = params->stride_width; - const int stride_height = params->stride_height; + data->padding = ComputePaddingHeightWidth( + params->stride_height, params->stride_width, 1, height, width, + filter_height, filter_width, params->padding); - const TfLitePaddingValues& padding_size = - ComputePaddingHeightWidth(stride_height, stride_width, 1, height, width, - filter_height, filter_width, params->padding); - - // Currently only support float32. + // Currently support float32 and uint8. switch (input->type) { case kTfLiteFloat32: { - tflite::ConvParams op_params; - op_params.padding_type = PaddingType::kSame; - op_params.padding_values.width = padding_size.width; - op_params.padding_values.height = padding_size.height; - op_params.stride_width = stride_width; - op_params.stride_height = stride_height; - switch (kernel_type) { - case kReference: { - reference_ops::TransposeConv( - op_params, GetTensorShape(input), GetTensorData(input), - GetTensorShape(weights), GetTensorData(weights), - GetTensorShape(output), GetTensorData(output), - GetTensorShape(im2col), GetTensorData(im2col)); - break; - } - case kGenericOptimized: { - optimized_ops::TransposeConv( - op_params, GetTensorShape(input), GetTensorData(input), - GetTensorShape(weights), GetTensorData(weights), - GetTensorShape(output), GetTensorData(output), - GetTensorShape(im2col), GetTensorData(im2col)); - break; - } + EvalFloat(params, data, input, weights, im2col, output); + break; + } + case kTfLiteUInt8: { + // TODO(haoliang): support optimized implementation for quantized + // TransposeConv. + TfLiteTensor* scratch_buffer = GetTemporary(context, node, /*index*/ 0); + if (IsDynamicTensor(scratch_buffer)) { + TF_LITE_ENSURE_OK(context, + ResizeTensor(context, output_shape, scratch_buffer)); } + EvalQuantized(params, data, input, weights, im2col, output, + scratch_buffer); break; } default: - context->ReportError(context, "Type %d, not currently supported.", - input->type); + context->ReportError(context, "Type '%s' is not currently supported.", + TfLiteTypeGetName(input->type)); return kTfLiteError; } return kTfLiteOk; diff --git a/tensorflow/lite/kernels/transpose_conv_test.cc b/tensorflow/lite/kernels/transpose_conv_test.cc index 05d3451d005..516bae87a1d 100644 --- a/tensorflow/lite/kernels/transpose_conv_test.cc +++ b/tensorflow/lite/kernels/transpose_conv_test.cc @@ -35,12 +35,12 @@ namespace { using ::testing::ElementsAreArray; -class TransposeConvOpModel : public SingleOpModel { +class BaseTransposeConvOpModel : public SingleOpModel { public: - TransposeConvOpModel(TfLiteRegistration* registration, - const TensorData& filter, const TensorData& input, - const TensorData& output, Padding padding, int stride_w, - int stride_h) { + BaseTransposeConvOpModel(TfLiteRegistration* registration, + const TensorData& filter, const TensorData& input, + const TensorData& output, Padding padding, + int stride_w, int stride_h) { // Just to be confusing, transpose_conv has an _input_ named "output_shape" // that sets the shape of the output tensor of the op :). It must always be // an int32 1D four element tensor. @@ -63,18 +63,25 @@ class TransposeConvOpModel : public SingleOpModel { void SetOutputShape(std::initializer_list i) { PopulateTensor(output_shape_, i); } + + std::vector GetOutputShape() { return GetTensorShape(output_); } + + protected: + int output_shape_; + int filter_; + int input_; + int output_; +}; + +class TransposeConvOpModel : public BaseTransposeConvOpModel { + public: + using BaseTransposeConvOpModel::BaseTransposeConvOpModel; + void SetFilter(std::initializer_list f) { PopulateTensor(filter_, f); } void SetInput(std::initializer_list data) { PopulateTensor(input_, data); } std::vector GetOutput() { return ExtractVector(output_); } - std::vector GetOutputShape() { return GetTensorShape(output_); } - - private: - int output_shape_; - int filter_; - int input_; - int output_; }; const auto kKernelMap = new std::map({ @@ -97,19 +104,20 @@ class TransposeConvOpTest : public SingleOpTest { // [1, 1, 1, 1 ], // "SAME") TEST_P(TransposeConvOpTest, SimpleTest) { - TransposeConvOpModel m(GetRegistration(), {TensorType_FLOAT32, {1, 3, 3, 1}}, - {TensorType_FLOAT32, {1, 4, 4, 1}}, - {TensorType_FLOAT32, {}}, Padding_SAME, 1, 1); - m.SetOutputShape({1, 4, 4, 1}); - m.SetFilter({1, 2, 3, 4, 5, 6, 7, 8, 9}); - m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); - m.Invoke(); + TransposeConvOpModel model(GetRegistration(), + {TensorType_FLOAT32, {1, 3, 3, 1}}, + {TensorType_FLOAT32, {1, 4, 4, 1}}, + {TensorType_FLOAT32, {}}, Padding_SAME, 1, 1); + model.SetOutputShape({1, 4, 4, 1}); + model.SetFilter({1, 2, 3, 4, 5, 6, 7, 8, 9}); + model.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); + model.Invoke(); - EXPECT_THAT(m.GetOutput(), + EXPECT_THAT(model.GetOutput(), ElementsAreArray({29, 62, 83, 75, 99, 192, 237, 198, 207, 372, 417, 330, 263, 446, 485, 365})); - // GetOutputShape() should always be same as m.SetOutputShape(...); - EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 4, 1})); + // GetOutputShape() should always be same as model.SetOutputShape(...); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 4, 4, 1})); } // Test case: @@ -125,19 +133,22 @@ TEST_P(TransposeConvOpTest, SimpleTest) { // And filter value is derived by: // filter = tf.reshape(tf.transpose(filter, perm=[3, 0, 1, 2]), shape=[18, 1]) TEST_P(TransposeConvOpTest, TwoFiltersTest) { - TransposeConvOpModel m(GetRegistration(), {TensorType_FLOAT32, {1, 3, 3, 2}}, - {TensorType_FLOAT32, {1, 4, 4, 2}}, - {TensorType_FLOAT32, {}}, Padding_SAME, 1, 1); - m.SetOutputShape({1, 4, 4, 1}); - m.SetFilter({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18}); - m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, - 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32}); - m.Invoke(); + TransposeConvOpModel model(GetRegistration(), + {TensorType_FLOAT32, {1, 3, 3, 2}}, + {TensorType_FLOAT32, {1, 4, 4, 2}}, + {TensorType_FLOAT32, {}}, Padding_SAME, 1, 1); + model.SetOutputShape({1, 4, 4, 1}); + model.SetFilter( + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18}); + model.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, + 23, 24, 25, 26, 27, 28, 29, 30, 31, 32}); + model.Invoke(); - EXPECT_THAT(m.GetOutput(), + EXPECT_THAT(model.GetOutput(), ElementsAreArray({184, 412, 568, 528, 678, 1347, 1689, 1434, 1494, 2715, 3057, 2442, 1968, 3352, 3652, 2760})); - EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 4, 1})); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 4, 4, 1})); } // Test case: @@ -153,22 +164,25 @@ TEST_P(TransposeConvOpTest, TwoFiltersTest) { // And filter value is derived by: // filter = tf.reshape(tf.transpose(filter, perm=[3, 0, 1, 2]), shape=[1, 18]) TEST_P(TransposeConvOpTest, PaddingValidTest) { - TransposeConvOpModel m(GetRegistration(), {TensorType_FLOAT32, {1, 3, 3, 2}}, - {TensorType_FLOAT32, {1, 4, 4, 2}}, - {TensorType_FLOAT32, {}}, Padding_VALID, 1, 1); - m.SetOutputShape({1, 6, 6, 1}); - m.SetFilter({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18}); - m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, - 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32}); - m.Invoke(); + TransposeConvOpModel model(GetRegistration(), + {TensorType_FLOAT32, {1, 3, 3, 2}}, + {TensorType_FLOAT32, {1, 4, 4, 2}}, + {TensorType_FLOAT32, {}}, Padding_VALID, 1, 1); + model.SetOutputShape({1, 6, 6, 1}); + model.SetFilter( + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18}); + model.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, + 23, 24, 25, 26, 27, 28, 29, 30, 31, 32}); + model.Invoke(); EXPECT_THAT( - m.GetOutput(), + model.GetOutput(), ElementsAreArray({5, 22, 59, 101, 114, 83, 52, 184, 412, 568, 528, 344, 237, 678, 1347, 1689, 1434, 879, 597, 1494, 2715, 3057, 2442, 1431, 856, 1968, 3352, 3652, 2760, 1548, 689, 1534, 2543, 2729, 2010, 1103})); - EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 6, 6, 1})); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 6, 6, 1})); } // Test case: @@ -182,19 +196,20 @@ TEST_P(TransposeConvOpTest, PaddingValidTest) { // [1, 2, 2, 1 ], // "VALID") TEST_P(TransposeConvOpTest, StrideValidTest) { - TransposeConvOpModel m(GetRegistration(), {TensorType_FLOAT32, {1, 3, 3, 1}}, - {TensorType_FLOAT32, {1, 2, 2, 1}}, - {TensorType_FLOAT32, {}}, Padding_VALID, 2, 2); - m.SetOutputShape({1, 5, 5, 1}); - m.SetFilter({1, 2, 3, 4, 5, 6, 7, 8, 9}); - m.SetInput({1, 2, 3, 4}); - m.Invoke(); + TransposeConvOpModel model(GetRegistration(), + {TensorType_FLOAT32, {1, 3, 3, 1}}, + {TensorType_FLOAT32, {1, 2, 2, 1}}, + {TensorType_FLOAT32, {}}, Padding_VALID, 2, 2); + model.SetOutputShape({1, 5, 5, 1}); + model.SetFilter({1, 2, 3, 4, 5, 6, 7, 8, 9}); + model.SetInput({1, 2, 3, 4}); + model.Invoke(); EXPECT_THAT( - m.GetOutput(), + model.GetOutput(), ElementsAreArray({1, 2, 5, 4, 6, 4, 5, 14, 10, 12, 10, 14, 36, 24, 30, 12, 15, 34, 20, 24, 21, 24, 55, 32, 36})); - EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 5, 5, 1})); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 5, 5, 1})); } // Test case: @@ -208,21 +223,23 @@ TEST_P(TransposeConvOpTest, StrideValidTest) { // [1, 2, 2, 1 ], // "VALID") TEST_P(TransposeConvOpTest, MultiChannelTest) { - TransposeConvOpModel m(GetRegistration(), {TensorType_FLOAT32, {2, 3, 3, 1}}, - {TensorType_FLOAT32, {1, 2, 2, 1}}, - {TensorType_FLOAT32, {}}, Padding_VALID, 2, 2); - m.SetOutputShape({1, 5, 5, 2}); - m.SetFilter({1, 3, 5, 7, 9, 11, 13, 15, 17, 2, 4, 6, 8, 10, 12, 14, 16, 18}); - m.SetInput({1, 2, 3, 4}); - m.Invoke(); + TransposeConvOpModel model(GetRegistration(), + {TensorType_FLOAT32, {2, 3, 3, 1}}, + {TensorType_FLOAT32, {1, 2, 2, 1}}, + {TensorType_FLOAT32, {}}, Padding_VALID, 2, 2); + model.SetOutputShape({1, 5, 5, 2}); + model.SetFilter( + {1, 3, 5, 7, 9, 11, 13, 15, 17, 2, 4, 6, 8, 10, 12, 14, 16, 18}); + model.SetInput({1, 2, 3, 4}); + model.Invoke(); EXPECT_THAT( - m.GetOutput(), + model.GetOutput(), ElementsAreArray({1, 2, 3, 4, 7, 10, 6, 8, 10, 12, 7, 8, 9, 10, 25, 28, 18, 20, 22, 24, 16, 20, 24, 28, 62, 72, 42, 48, 54, 60, 21, 24, 27, 30, 61, 68, 36, 40, 44, 48, 39, 42, 45, 48, 103, 110, 60, 64, 68, 72})); - EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 5, 5, 2})); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 5, 5, 2})); } // Test case: @@ -238,18 +255,100 @@ TEST_P(TransposeConvOpTest, MultiChannelTest) { // And filter value is derived by: // filter = tf.reshape(tf.transpose(filter, perm=[3, 0, 1, 2]), shape=[-1]) TEST_P(TransposeConvOpTest, AccuracyTest) { - TransposeConvOpModel m(GetRegistration(), {TensorType_FLOAT32, {1, 3, 3, 1}}, - {TensorType_FLOAT32, {1, 1, 2, 1}}, - {TensorType_FLOAT32, {}}, Padding_SAME, 3, 3); - m.SetOutputShape({1, 3, 4, 1}); - m.SetFilter({9, 5, 6, 9, 8, 5, 3, 1, 4}); - m.SetInput({323, 521}); - m.Invoke(); + TransposeConvOpModel model(GetRegistration(), + {TensorType_FLOAT32, {1, 3, 3, 1}}, + {TensorType_FLOAT32, {1, 1, 2, 1}}, + {TensorType_FLOAT32, {}}, Padding_SAME, 3, 3); + model.SetOutputShape({1, 3, 4, 1}); + model.SetFilter({9, 5, 6, 9, 8, 5, 3, 1, 4}); + model.SetInput({323, 521}); + model.Invoke(); - EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear( - {1615., 1938., 4689., 2605., 2584., 1615., + EXPECT_THAT(model.GetOutput(), + ElementsAreArray( + ArrayFloatNear({1615., 1938., 4689., 2605., 2584., 1615., 4689., 4168., 323., 1292., 1563., 521.}))); - EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 3, 4, 1})); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 3, 4, 1})); +} + +class QuantizedTransposeConvOpModel : public BaseTransposeConvOpModel { + public: + using BaseTransposeConvOpModel::BaseTransposeConvOpModel; + + void SetFilter(std::initializer_list f) { + QuantizeAndPopulate(filter_, f); + } + void SetInput(std::initializer_list data) { + QuantizeAndPopulate(input_, data); + } + std::vector GetDequantizedOutput() { + return Dequantize(ExtractVector(output_), + GetScale(output_), GetZeroPoint(output_)); + } +}; + +TEST_P(TransposeConvOpTest, SimpleTestQuantized) { + QuantizedTransposeConvOpModel model( + GetRegistration(), {TensorType_UINT8, {1, 3, 3, 1}, -63.5, 64}, + {TensorType_UINT8, {1, 4, 4, 1}, -63.5, 64}, + {TensorType_UINT8, {}, -508, 512}, Padding_SAME, 1, 1); + model.SetOutputShape({1, 4, 4, 1}); + model.SetFilter({1, 2, 3, 4, 5, 6, 7, 8, 9}); + model.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); + model.Invoke(); + + EXPECT_THAT( + model.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear({28, 64, 84, 76, 100, 192, 236, 200, 208, + 372, 416, 332, 264, 448, 484, 364}, + 1e-5))); + + // GetOutputShape() should always be same as model.SetOutputShape(...); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 4, 4, 1})); +} + +TEST_P(TransposeConvOpTest, TwoFiltersTestQuantized) { + QuantizedTransposeConvOpModel model( + GetRegistration(), {TensorType_UINT8, {1, 3, 3, 2}, -63.5, 64}, + {TensorType_UINT8, {1, 4, 4, 2}, -63.5, 64}, + {TensorType_UINT8, {}, -4064, 4096}, Padding_SAME, 1, 1); + model.SetOutputShape({1, 4, 4, 1}); + model.SetFilter( + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18}); + model.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, + 23, 24, 25, 26, 27, 28, 29, 30, 31, 32}); + model.Invoke(); + + EXPECT_THAT(model.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear( + {192, 416, 576, 544, 672, 1344, 1696, 1440, 1504, 2720, 3072, + 2432, 1984, 3360, 3648, 2752}, + 1e-5))); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 4, 4, 1})); +} + +TEST_P(TransposeConvOpTest, PaddingValidTestQuantized) { + QuantizedTransposeConvOpModel model( + GetRegistration(), {TensorType_UINT8, {1, 3, 3, 2}, -63.5, 64}, + {TensorType_UINT8, {1, 4, 4, 2}, -63.5, 64}, + {TensorType_UINT8, {}, -4064, 4096}, Padding_VALID, 1, 1); + model.SetOutputShape({1, 6, 6, 1}); + model.SetFilter( + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18}); + model.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, + 23, 24, 25, 26, 27, 28, 29, 30, 31, 32}); + model.Invoke(); + + EXPECT_THAT(model.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear( + {0, 32, 64, 96, 128, 96, 64, 192, 416, + 576, 544, 352, 224, 672, 1344, 1696, 1440, 864, + 608, 1504, 2720, 3072, 2432, 1440, 864, 1984, 3360, + 3648, 2752, 1536, 704, 1536, 2528, 2720, 2016, 1088}, + 1e-5))); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 6, 6, 1})); } INSTANTIATE_TEST_SUITE_P(