From 9be3ad213ba6826f9c067e4140d36d176ac4af85 Mon Sep 17 00:00:00 2001 From: Taehee Jeong Date: Tue, 22 Dec 2020 18:49:04 -0800 Subject: [PATCH] Add INT8 type support for reverse op * Added QI8 type support in tfl_ops.td * Added ReverseV2 op version 3 with int8 type support PiperOrigin-RevId: 348726608 Change-Id: I7003d7eff031e8ac12b55747fa5afaf9e3ab2a52 --- RELEASE.md | 1 + tensorflow/compiler/mlir/lite/ir/tfl_ops.td | 6 ++--- tensorflow/lite/kernels/register.cc | 2 +- tensorflow/lite/kernels/reverse.cc | 8 +++--- tensorflow/lite/kernels/reverse_test.cc | 27 +++++++++++++++++++ .../lite/tools/optimize/operator_property.cc | 6 +++++ .../lite/tools/versioning/op_version.cc | 3 +++ .../lite/tools/versioning/runtime_version.cc | 1 + 8 files changed, 47 insertions(+), 7 deletions(-) diff --git a/RELEASE.md b/RELEASE.md index b1847b7d587..46ae1ff396f 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -71,6 +71,7 @@ * TFLiteConverter exports models with SignatureDef * Interpreter supports getting a list of signatures and getting callable function for a given signaturedef. + * Add int8 support for `ReshapeV2`. * TF Core: * Corrected higher-order gradients of control flow constructs (`tf.cond`, `tf.while_loop`, and compositions like `tf.foldl`) computed with diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td index c3930fcf463..161c2e03ebd 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td @@ -2720,7 +2720,7 @@ def TFL_ReverseV2Op: TFL_Op<"reverse_v2", [ Args: tensor: A Tensor. Must be one of the following types: - uint8, int16, int32, int64, float32, bool Up to 8-D. + uint8, int8, int16, int32, int64, float32, bool Up to 8-D. axis: A Tensor. Must be one of the following types: int32, int64. with only 1 element which is the axis index. @@ -2729,12 +2729,12 @@ def TFL_ReverseV2Op: TFL_Op<"reverse_v2", [ let arguments = ( ins - TFL_TensorOf<[F32, UI8, I16, I32, I64, QI16, QUI8, TFL_Quint8, I1]>:$input, + TFL_TensorOf<[F32, UI8, I16, I32, I64, QI16, QUI8, QI8, TFL_Quint8, I1]>:$input, TFL_I32Tensor:$axis ); let results = (outs - TFL_TensorOf<[F32, UI8, I16, I32, I64, QI16, QUI8, TFL_Quint8, I1]>:$output); + TFL_TensorOf<[F32, UI8, I16, I32, I64, QI16, QUI8, QI8, TFL_Quint8, I1]>:$output); } // Select has many instances in TF models where one or more of its operands diff --git a/tensorflow/lite/kernels/register.cc b/tensorflow/lite/kernels/register.cc index a57f358f1ab..1c7964d87f7 100644 --- a/tensorflow/lite/kernels/register.cc +++ b/tensorflow/lite/kernels/register.cc @@ -277,7 +277,7 @@ BuiltinOpResolver::BuiltinOpResolver() { AddBuiltin(BuiltinOperator_UNIQUE, Register_UNIQUE()); AddBuiltin(BuiltinOperator_REVERSE_V2, Register_REVERSE_V2(), /* min_version = */ 1, - /* max_version = */ 2); + /* max_version = */ 3); AddBuiltin(BuiltinOperator_ADD_N, Register_ADD_N()); AddBuiltin(BuiltinOperator_GATHER_ND, Register_GATHER_ND(), /* min_version = */ 1, diff --git a/tensorflow/lite/kernels/reverse.cc b/tensorflow/lite/kernels/reverse.cc index a7ef54dae12..ff701272dd0 100644 --- a/tensorflow/lite/kernels/reverse.cc +++ b/tensorflow/lite/kernels/reverse.cc @@ -43,8 +43,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE(context, NumDimensions(input) >= NumElements(axis)); if (input->type != kTfLiteInt32 && input->type != kTfLiteFloat32 && - input->type != kTfLiteUInt8 && input->type != kTfLiteInt16 && - input->type != kTfLiteInt64 && input->type != kTfLiteBool) { + input->type != kTfLiteUInt8 && input->type != kTfLiteInt8 && + input->type != kTfLiteInt16 && input->type != kTfLiteInt64 && + input->type != kTfLiteBool) { context->ReportError(context, "Type '%s' is not supported by reverse.", TfLiteTypeGetName(input->type)); return kTfLiteError; @@ -94,7 +95,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { GetTensorShape(output), GetTensorData(output)); break; } - case kTfLiteUInt8: { + case kTfLiteUInt8: + case kTfLiteInt8: { reference_ops::Reverse( axis, GetTensorShape(input), GetTensorData(input), GetTensorShape(output), GetTensorData(output)); diff --git a/tensorflow/lite/kernels/reverse_test.cc b/tensorflow/lite/kernels/reverse_test.cc index f1fcf67fd42..69000f7f9c0 100644 --- a/tensorflow/lite/kernels/reverse_test.cc +++ b/tensorflow/lite/kernels/reverse_test.cc @@ -164,6 +164,33 @@ TEST(ReverseOpTest, Uint8MultiDimensions) { 17, 18, 15, 16, 13, 14, 23, 24, 21, 22, 19, 20})); } +// int8 tests +TEST(ReverseOpTest, Int8OneDimension) { + ReverseOpModel model({TensorType_INT8, {4}}, {TensorType_INT32, {1}}); + model.PopulateTensor(model.input(), {1, 2, -1, -2}); + model.PopulateTensor(model.axis(), {0}); + model.Invoke(); + + EXPECT_THAT(model.GetOutputShape(), ElementsAre(4)); + EXPECT_THAT(model.GetOutput(), ElementsAreArray({-2, -1, 2, 1})); +} + +TEST(ReverseOpTest, Int8MultiDimensions) { + ReverseOpModel model({TensorType_INT8, {4, 3, 2}}, + {TensorType_INT32, {1}}); + model.PopulateTensor( + model.input(), {-1, -2, -3, -4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, -21, -22, -23, -24}); + model.PopulateTensor(model.axis(), {1}); + model.Invoke(); + + EXPECT_THAT(model.GetOutputShape(), ElementsAre(4, 3, 2)); + EXPECT_THAT( + model.GetOutput(), + ElementsAreArray({5, 6, -3, -4, -1, -2, 11, 12, 9, 10, 7, 8, + 17, 18, 15, 16, 13, 14, -23, -24, -21, -22, 19, 20})); +} + // int16 tests TEST(ReverseOpTest, Int16OneDimension) { ReverseOpModel model({TensorType_INT16, {4}}, diff --git a/tensorflow/lite/tools/optimize/operator_property.cc b/tensorflow/lite/tools/optimize/operator_property.cc index d5ef9451659..73a74c9c697 100644 --- a/tensorflow/lite/tools/optimize/operator_property.cc +++ b/tensorflow/lite/tools/optimize/operator_property.cc @@ -882,6 +882,12 @@ OperatorProperty GetOperatorProperty(OpVariant op_variant) { property.restrict_same_input_output_scale = true; property.version = 2; break; + case BuiltinOperator_REVERSE_V2: + property.inputs = {{0, {}}}; + property.outputs = {{0, {}}}; + property.restrict_same_input_output_scale = true; + property.version = 3; + break; case BuiltinOperator_SHAPE: property.inputs = {{0, {}}}; // Shape has no quantizable output. diff --git a/tensorflow/lite/tools/versioning/op_version.cc b/tensorflow/lite/tools/versioning/op_version.cc index c7c08c81c66..8b58f0afbb9 100644 --- a/tensorflow/lite/tools/versioning/op_version.cc +++ b/tensorflow/lite/tools/versioning/op_version.cc @@ -402,6 +402,9 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) { } return 1; case BuiltinOperator_REVERSE_V2: + if (op_sig.input_types.at(0) == TensorType_INT8) { + return 3; + } if (op_sig.input_types.at(0) == TensorType_BOOL) { return 2; } diff --git a/tensorflow/lite/tools/versioning/runtime_version.cc b/tensorflow/lite/tools/versioning/runtime_version.cc index 3b418d9f526..46cd22d9f1b 100644 --- a/tensorflow/lite/tools/versioning/runtime_version.cc +++ b/tensorflow/lite/tools/versioning/runtime_version.cc @@ -332,6 +332,7 @@ std::string FindMinimumRuntimeVersionForOp(tflite::BuiltinOperator op_code, {{BuiltinOperator_FILL, 2}, "2.3.0"}, {{BuiltinOperator_REVERSE_V2, 1}, "1.14.0"}, {{BuiltinOperator_REVERSE_V2, 2}, "2.2.0"}, + {{BuiltinOperator_REVERSE_V2, 3}, kPendingReleaseVersion}, {{BuiltinOperator_RANK, 1}, "1.14.0"}, {{BuiltinOperator_WHILE, 1}, "1.15.0"}, {{BuiltinOperator_CUMSUM, 1}, "2.4.0"},