diff --git a/tensorflow/lite/kernels/transpose_conv.cc b/tensorflow/lite/kernels/transpose_conv.cc index 7a2b1a8dceb..497edac5762 100644 --- a/tensorflow/lite/kernels/transpose_conv.cc +++ b/tensorflow/lite/kernels/transpose_conv.cc @@ -80,8 +80,6 @@ struct OpData { int output_shift; // Per channel output multiplier and shift. - // TODO(b/144846950): Add channel dimension index for the kernel to be more - // flexible. std::vector per_channel_output_multiplier; std::vector per_channel_output_shift; @@ -374,17 +372,20 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { const auto* affine_quantization = reinterpret_cast( weights->quantization.params); + const int channels_out = weights->dims->data[0]; TF_LITE_ENSURE(context, affine_quantization); TF_LITE_ENSURE(context, affine_quantization->scale); - const int number_channel = affine_quantization->scale->size; - data->per_channel_output_multiplier.resize(number_channel); - data->per_channel_output_shift.resize(number_channel); + TF_LITE_ENSURE(context, (affine_quantization->scale->size == 1 || + affine_quantization->scale->size == channels_out)); + + data->per_channel_output_multiplier.resize(channels_out); + data->per_channel_output_shift.resize(channels_out); TF_LITE_ENSURE_STATUS(tflite::PopulateConvolutionQuantizationParams( context, input, weights, bias, output, kTfLiteActNone, &data->output_multiplier, &data->output_shift, &data->output_activation_min, &data->output_activation_max, data->per_channel_output_multiplier.data(), - data->per_channel_output_shift.data())); + data->per_channel_output_shift.data(), channels_out)); } return kTfLiteOk; diff --git a/tensorflow/lite/kernels/transpose_conv_test.cc b/tensorflow/lite/kernels/transpose_conv_test.cc index c4d79036ca5..08b516fa8cf 100644 --- a/tensorflow/lite/kernels/transpose_conv_test.cc +++ b/tensorflow/lite/kernels/transpose_conv_test.cc @@ -462,6 +462,37 @@ TEST_P(TransposeConvOpTest, TestQuantizedPerChannelMultiChannel) { EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 5, 5, 2})); } +// Test data copied from the float multi-channel test above. +TEST_P(TransposeConvOpTest, TestQuantizedPerTensorMultiChannel) { + const std::initializer_list filter_data = { + 1, 3, 5, 7, 9, 11, 13, 15, 17, 2, 4, 6, 8, 10, 12, 14, 16, 18}; + const std::initializer_list const_filter_data = { + 7, 21, 35, 49, 64, 78, 92, 106, 120, + 14, 28, 42, 56, 71, 85, 99, 113, 127}; + PerChannelQuantizedTransposeConvOpModel model( + GetRegistration(), {1, 5, 5, 2}, + {TensorType_INT8, {2, 3, 3, 1}, 0, 0, 0, 0, true, {18.0 / 127}, {0}, 0}, + const_filter_data, {TensorType_INT8, {1, 2, 2, 1}, 0, 0, 4.0 / 255, -128}, + {TensorType_INT8, {}, 0, 0, 1, -128}, Padding_VALID, 2, 2, GetTestType(), + /* version */ 2); + model.SetInput({1, 2, 3, 4}); + if (GetTestType() == TestType::kDynamic) { + model.SetFilter(filter_data); + } + model.Invoke(); + + EXPECT_THAT( + model.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear( + {1, 2, 3, 4, 7, 10, 6, 8, 10, 12, 7, 8, 9, 10, 25, 28, 18, + 20, 22, 24, 16, 20, 24, 28, 62, 72, 42, 48, 54, 60, 21, 24, 27, 30, + 61, 68, 36, 40, 44, 48, 39, 42, 45, 48, 103, 110, 60, 64, 68, 72}, + 1e-5))); + + // GetOutputShape() should always be same as model.SetOutputShape(...); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 5, 5, 2})); +} + class PerChannelQuantizedTransposeConvOpModel16x8 : public BaseTransposeConvOpModel { public: