Add INT8 type support for reverse op

* Added QI8 type support in tfl_ops.td
* Added ReverseV2 op version 3 with int8 type support

PiperOrigin-RevId: 348726608
Change-Id: I7003d7eff031e8ac12b55747fa5afaf9e3ab2a52
This commit is contained in:
Taehee Jeong 2020-12-22 18:49:04 -08:00 committed by TensorFlower Gardener
parent 9c99eca7e6
commit 9be3ad213b
8 changed files with 47 additions and 7 deletions

View File

@ -71,6 +71,7 @@
* TFLiteConverter exports models with SignatureDef
* Interpreter supports getting a list of signatures and getting callable
function for a given signaturedef.
* Add int8 support for `ReshapeV2`.
* TF Core:
* Corrected higher-order gradients of control flow constructs (`tf.cond`,
`tf.while_loop`, and compositions like `tf.foldl`) computed with

View File

@ -2720,7 +2720,7 @@ def TFL_ReverseV2Op: TFL_Op<"reverse_v2", [
Args:
tensor: A Tensor. Must be one of the following types:
uint8, int16, int32, int64, float32, bool Up to 8-D.
uint8, int8, int16, int32, int64, float32, bool Up to 8-D.
axis: A Tensor. Must be one of the following types: int32, int64.
with only 1 element which is the axis index.
@ -2729,12 +2729,12 @@ def TFL_ReverseV2Op: TFL_Op<"reverse_v2", [
let arguments = (
ins
TFL_TensorOf<[F32, UI8, I16, I32, I64, QI16, QUI8, TFL_Quint8, I1]>:$input,
TFL_TensorOf<[F32, UI8, I16, I32, I64, QI16, QUI8, QI8, TFL_Quint8, I1]>:$input,
TFL_I32Tensor:$axis
);
let results = (outs
TFL_TensorOf<[F32, UI8, I16, I32, I64, QI16, QUI8, TFL_Quint8, I1]>:$output);
TFL_TensorOf<[F32, UI8, I16, I32, I64, QI16, QUI8, QI8, TFL_Quint8, I1]>:$output);
}
// Select has many instances in TF models where one or more of its operands

View File

@ -277,7 +277,7 @@ BuiltinOpResolver::BuiltinOpResolver() {
AddBuiltin(BuiltinOperator_UNIQUE, Register_UNIQUE());
AddBuiltin(BuiltinOperator_REVERSE_V2, Register_REVERSE_V2(),
/* min_version = */ 1,
/* max_version = */ 2);
/* max_version = */ 3);
AddBuiltin(BuiltinOperator_ADD_N, Register_ADD_N());
AddBuiltin(BuiltinOperator_GATHER_ND, Register_GATHER_ND(),
/* min_version = */ 1,

View File

@ -43,8 +43,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE(context, NumDimensions(input) >= NumElements(axis));
if (input->type != kTfLiteInt32 && input->type != kTfLiteFloat32 &&
input->type != kTfLiteUInt8 && input->type != kTfLiteInt16 &&
input->type != kTfLiteInt64 && input->type != kTfLiteBool) {
input->type != kTfLiteUInt8 && input->type != kTfLiteInt8 &&
input->type != kTfLiteInt16 && input->type != kTfLiteInt64 &&
input->type != kTfLiteBool) {
context->ReportError(context, "Type '%s' is not supported by reverse.",
TfLiteTypeGetName(input->type));
return kTfLiteError;
@ -94,7 +95,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
GetTensorShape(output), GetTensorData<float>(output));
break;
}
case kTfLiteUInt8: {
case kTfLiteUInt8:
case kTfLiteInt8: {
reference_ops::Reverse<uint8_t>(
axis, GetTensorShape(input), GetTensorData<uint8_t>(input),
GetTensorShape(output), GetTensorData<uint8_t>(output));

View File

@ -164,6 +164,33 @@ TEST(ReverseOpTest, Uint8MultiDimensions) {
17, 18, 15, 16, 13, 14, 23, 24, 21, 22, 19, 20}));
}
// int8 tests
TEST(ReverseOpTest, Int8OneDimension) {
ReverseOpModel<int8_t> model({TensorType_INT8, {4}}, {TensorType_INT32, {1}});
model.PopulateTensor<int8_t>(model.input(), {1, 2, -1, -2});
model.PopulateTensor<int32_t>(model.axis(), {0});
model.Invoke();
EXPECT_THAT(model.GetOutputShape(), ElementsAre(4));
EXPECT_THAT(model.GetOutput(), ElementsAreArray({-2, -1, 2, 1}));
}
TEST(ReverseOpTest, Int8MultiDimensions) {
ReverseOpModel<int8_t> model({TensorType_INT8, {4, 3, 2}},
{TensorType_INT32, {1}});
model.PopulateTensor<int8_t>(
model.input(), {-1, -2, -3, -4, 5, 6, 7, 8, 9, 10, 11, 12,
13, 14, 15, 16, 17, 18, 19, 20, -21, -22, -23, -24});
model.PopulateTensor<int32_t>(model.axis(), {1});
model.Invoke();
EXPECT_THAT(model.GetOutputShape(), ElementsAre(4, 3, 2));
EXPECT_THAT(
model.GetOutput(),
ElementsAreArray({5, 6, -3, -4, -1, -2, 11, 12, 9, 10, 7, 8,
17, 18, 15, 16, 13, 14, -23, -24, -21, -22, 19, 20}));
}
// int16 tests
TEST(ReverseOpTest, Int16OneDimension) {
ReverseOpModel<int16_t> model({TensorType_INT16, {4}},

View File

@ -882,6 +882,12 @@ OperatorProperty GetOperatorProperty(OpVariant op_variant) {
property.restrict_same_input_output_scale = true;
property.version = 2;
break;
case BuiltinOperator_REVERSE_V2:
property.inputs = {{0, {}}};
property.outputs = {{0, {}}};
property.restrict_same_input_output_scale = true;
property.version = 3;
break;
case BuiltinOperator_SHAPE:
property.inputs = {{0, {}}};
// Shape has no quantizable output.

View File

@ -402,6 +402,9 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) {
}
return 1;
case BuiltinOperator_REVERSE_V2:
if (op_sig.input_types.at(0) == TensorType_INT8) {
return 3;
}
if (op_sig.input_types.at(0) == TensorType_BOOL) {
return 2;
}

View File

@ -332,6 +332,7 @@ std::string FindMinimumRuntimeVersionForOp(tflite::BuiltinOperator op_code,
{{BuiltinOperator_FILL, 2}, "2.3.0"},
{{BuiltinOperator_REVERSE_V2, 1}, "1.14.0"},
{{BuiltinOperator_REVERSE_V2, 2}, "2.2.0"},
{{BuiltinOperator_REVERSE_V2, 3}, kPendingReleaseVersion},
{{BuiltinOperator_RANK, 1}, "1.14.0"},
{{BuiltinOperator_WHILE, 1}, "1.15.0"},
{{BuiltinOperator_CUMSUM, 1}, "2.4.0"},