diff --git a/tensorflow/lite/tools/optimize/BUILD b/tensorflow/lite/tools/optimize/BUILD index 3011c01cdeb..c10d4465e5c 100644 --- a/tensorflow/lite/tools/optimize/BUILD +++ b/tensorflow/lite/tools/optimize/BUILD @@ -296,6 +296,7 @@ tf_cc_test( "//tensorflow/lite/tools/optimize:testdata/split.bin", "//tensorflow/lite/tools/optimize:testdata/svdf_calibrated.bin", "//tensorflow/lite/tools/optimize:testdata/svdf_quantized.bin", + "//tensorflow/lite/tools/optimize:testdata/transpose.bin", "//tensorflow/lite/tools/optimize:testdata/unpack.bin", ], tags = [ diff --git a/tensorflow/lite/tools/optimize/quantize_model_test.cc b/tensorflow/lite/tools/optimize/quantize_model_test.cc index f8f1a9d4113..36b35af0065 100644 --- a/tensorflow/lite/tools/optimize/quantize_model_test.cc +++ b/tensorflow/lite/tools/optimize/quantize_model_test.cc @@ -1454,6 +1454,50 @@ TEST_F(QuantizeUnpackTest, VerifyUnpack) { unpack_output_1->quantization->zero_point[0]); } +class QuantizeTransposeTest : public QuantizeModelTest { + protected: + QuantizeTransposeTest() { + input_model_ = ReadModel(internal::kModelWithTranspose); + readonly_model_ = input_model_->GetModel(); + readonly_model_->UnPackTo(&model_); + } +}; + +TEST_F(QuantizeTransposeTest, VerifyTranspose) { + auto status = QuantizeModel(&builder_, &model_, &error_reporter_); + + ASSERT_EQ(kTfLiteOk, status); + + const auto subgraph = model_.subgraphs[0].get(); + auto op = subgraph->operators[1].get(); + + auto float_graph = readonly_model_->subgraphs()->Get(0); + + ASSERT_EQ(model_.operator_codes[op->opcode_index].get()->builtin_code, + BuiltinOperator_TRANSPOSE); + + // The model should only have one input and one outputs. + EXPECT_EQ(subgraph->inputs.size(), 1); + EXPECT_EQ(subgraph->outputs.size(), 1); + + // Get transpose input and output tensors + auto transpose_input = subgraph->tensors[op->inputs[0]].get(); + auto transpose_output = subgraph->tensors[op->outputs[0]].get(); + + // Verify transpose input is quantized. + ASSERT_EQ(float_graph->tensors()->Get(op->inputs[0])->type(), + TensorType_FLOAT32); + EXPECT_EQ(transpose_input->type, TensorType_INT8); + + // Ensure quantization parameters before and after transpose + // are preserved after quantization for all outputs of + // transpose. + EXPECT_FLOAT_EQ(transpose_input->quantization->scale[0], + transpose_output->quantization->scale[0]); + EXPECT_EQ(transpose_input->quantization->zero_point[0], + transpose_output->quantization->zero_point[0]); +} + } // namespace } // namespace optimize } // namespace tflite diff --git a/tensorflow/lite/tools/optimize/test_util.cc b/tensorflow/lite/tools/optimize/test_util.cc index 7d5e9d65f06..61e82ed3e34 100644 --- a/tensorflow/lite/tools/optimize/test_util.cc +++ b/tensorflow/lite/tools/optimize/test_util.cc @@ -61,6 +61,8 @@ const char* kModelWithMaximumOp = "maximum.bin"; const char* kLstmCalibrated2 = "lstm_calibrated2.bin"; const char* kLstmQuantized2 = "lstm_quantized2.bin"; +const char* kModelWithTranspose = "transpose.bin"; + const char* kSvdfCalibrated = "svdf_calibrated.bin"; const char* kSvdfQuantized = "svdf_quantized.bin"; diff --git a/tensorflow/lite/tools/optimize/test_util.h b/tensorflow/lite/tools/optimize/test_util.h index abcdbc21d36..4d2eadf283f 100644 --- a/tensorflow/lite/tools/optimize/test_util.h +++ b/tensorflow/lite/tools/optimize/test_util.h @@ -98,6 +98,9 @@ extern const char* kModelWithMaximumOp; extern const char* kLstmCalibrated2; extern const char* kLstmQuantized2; +// Test model with a transpose op. +extern const char* kModelWithTranspose; + // Test model with SVDF op. extern const char* kSvdfCalibrated; extern const char* kSvdfQuantized; diff --git a/tensorflow/lite/tools/optimize/testdata/transpose.bin b/tensorflow/lite/tools/optimize/testdata/transpose.bin new file mode 100644 index 00000000000..a76886e5b47 Binary files /dev/null and b/tensorflow/lite/tools/optimize/testdata/transpose.bin differ