Add INT8 type support for reverse op
* Added QI8 type support in tfl_ops.td * Added ReverseV2 op version 3 with int8 type support PiperOrigin-RevId: 348726608 Change-Id: I7003d7eff031e8ac12b55747fa5afaf9e3ab2a52
This commit is contained in:
parent
9c99eca7e6
commit
9be3ad213b
@ -71,6 +71,7 @@
|
|||||||
* TFLiteConverter exports models with SignatureDef
|
* TFLiteConverter exports models with SignatureDef
|
||||||
* Interpreter supports getting a list of signatures and getting callable
|
* Interpreter supports getting a list of signatures and getting callable
|
||||||
function for a given signaturedef.
|
function for a given signaturedef.
|
||||||
|
* Add int8 support for `ReshapeV2`.
|
||||||
* TF Core:
|
* TF Core:
|
||||||
* Corrected higher-order gradients of control flow constructs (`tf.cond`,
|
* Corrected higher-order gradients of control flow constructs (`tf.cond`,
|
||||||
`tf.while_loop`, and compositions like `tf.foldl`) computed with
|
`tf.while_loop`, and compositions like `tf.foldl`) computed with
|
||||||
|
@ -2720,7 +2720,7 @@ def TFL_ReverseV2Op: TFL_Op<"reverse_v2", [
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
tensor: A Tensor. Must be one of the following types:
|
tensor: A Tensor. Must be one of the following types:
|
||||||
uint8, int16, int32, int64, float32, bool Up to 8-D.
|
uint8, int8, int16, int32, int64, float32, bool Up to 8-D.
|
||||||
|
|
||||||
axis: A Tensor. Must be one of the following types: int32, int64.
|
axis: A Tensor. Must be one of the following types: int32, int64.
|
||||||
with only 1 element which is the axis index.
|
with only 1 element which is the axis index.
|
||||||
@ -2729,12 +2729,12 @@ def TFL_ReverseV2Op: TFL_Op<"reverse_v2", [
|
|||||||
|
|
||||||
let arguments = (
|
let arguments = (
|
||||||
ins
|
ins
|
||||||
TFL_TensorOf<[F32, UI8, I16, I32, I64, QI16, QUI8, TFL_Quint8, I1]>:$input,
|
TFL_TensorOf<[F32, UI8, I16, I32, I64, QI16, QUI8, QI8, TFL_Quint8, I1]>:$input,
|
||||||
TFL_I32Tensor:$axis
|
TFL_I32Tensor:$axis
|
||||||
);
|
);
|
||||||
|
|
||||||
let results = (outs
|
let results = (outs
|
||||||
TFL_TensorOf<[F32, UI8, I16, I32, I64, QI16, QUI8, TFL_Quint8, I1]>:$output);
|
TFL_TensorOf<[F32, UI8, I16, I32, I64, QI16, QUI8, QI8, TFL_Quint8, I1]>:$output);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Select has many instances in TF models where one or more of its operands
|
// Select has many instances in TF models where one or more of its operands
|
||||||
|
@ -277,7 +277,7 @@ BuiltinOpResolver::BuiltinOpResolver() {
|
|||||||
AddBuiltin(BuiltinOperator_UNIQUE, Register_UNIQUE());
|
AddBuiltin(BuiltinOperator_UNIQUE, Register_UNIQUE());
|
||||||
AddBuiltin(BuiltinOperator_REVERSE_V2, Register_REVERSE_V2(),
|
AddBuiltin(BuiltinOperator_REVERSE_V2, Register_REVERSE_V2(),
|
||||||
/* min_version = */ 1,
|
/* min_version = */ 1,
|
||||||
/* max_version = */ 2);
|
/* max_version = */ 3);
|
||||||
AddBuiltin(BuiltinOperator_ADD_N, Register_ADD_N());
|
AddBuiltin(BuiltinOperator_ADD_N, Register_ADD_N());
|
||||||
AddBuiltin(BuiltinOperator_GATHER_ND, Register_GATHER_ND(),
|
AddBuiltin(BuiltinOperator_GATHER_ND, Register_GATHER_ND(),
|
||||||
/* min_version = */ 1,
|
/* min_version = */ 1,
|
||||||
|
@ -43,8 +43,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
TF_LITE_ENSURE(context, NumDimensions(input) >= NumElements(axis));
|
TF_LITE_ENSURE(context, NumDimensions(input) >= NumElements(axis));
|
||||||
|
|
||||||
if (input->type != kTfLiteInt32 && input->type != kTfLiteFloat32 &&
|
if (input->type != kTfLiteInt32 && input->type != kTfLiteFloat32 &&
|
||||||
input->type != kTfLiteUInt8 && input->type != kTfLiteInt16 &&
|
input->type != kTfLiteUInt8 && input->type != kTfLiteInt8 &&
|
||||||
input->type != kTfLiteInt64 && input->type != kTfLiteBool) {
|
input->type != kTfLiteInt16 && input->type != kTfLiteInt64 &&
|
||||||
|
input->type != kTfLiteBool) {
|
||||||
context->ReportError(context, "Type '%s' is not supported by reverse.",
|
context->ReportError(context, "Type '%s' is not supported by reverse.",
|
||||||
TfLiteTypeGetName(input->type));
|
TfLiteTypeGetName(input->type));
|
||||||
return kTfLiteError;
|
return kTfLiteError;
|
||||||
@ -94,7 +95,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
GetTensorShape(output), GetTensorData<float>(output));
|
GetTensorShape(output), GetTensorData<float>(output));
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case kTfLiteUInt8: {
|
case kTfLiteUInt8:
|
||||||
|
case kTfLiteInt8: {
|
||||||
reference_ops::Reverse<uint8_t>(
|
reference_ops::Reverse<uint8_t>(
|
||||||
axis, GetTensorShape(input), GetTensorData<uint8_t>(input),
|
axis, GetTensorShape(input), GetTensorData<uint8_t>(input),
|
||||||
GetTensorShape(output), GetTensorData<uint8_t>(output));
|
GetTensorShape(output), GetTensorData<uint8_t>(output));
|
||||||
|
@ -164,6 +164,33 @@ TEST(ReverseOpTest, Uint8MultiDimensions) {
|
|||||||
17, 18, 15, 16, 13, 14, 23, 24, 21, 22, 19, 20}));
|
17, 18, 15, 16, 13, 14, 23, 24, 21, 22, 19, 20}));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// int8 tests
|
||||||
|
TEST(ReverseOpTest, Int8OneDimension) {
|
||||||
|
ReverseOpModel<int8_t> model({TensorType_INT8, {4}}, {TensorType_INT32, {1}});
|
||||||
|
model.PopulateTensor<int8_t>(model.input(), {1, 2, -1, -2});
|
||||||
|
model.PopulateTensor<int32_t>(model.axis(), {0});
|
||||||
|
model.Invoke();
|
||||||
|
|
||||||
|
EXPECT_THAT(model.GetOutputShape(), ElementsAre(4));
|
||||||
|
EXPECT_THAT(model.GetOutput(), ElementsAreArray({-2, -1, 2, 1}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(ReverseOpTest, Int8MultiDimensions) {
|
||||||
|
ReverseOpModel<int8_t> model({TensorType_INT8, {4, 3, 2}},
|
||||||
|
{TensorType_INT32, {1}});
|
||||||
|
model.PopulateTensor<int8_t>(
|
||||||
|
model.input(), {-1, -2, -3, -4, 5, 6, 7, 8, 9, 10, 11, 12,
|
||||||
|
13, 14, 15, 16, 17, 18, 19, 20, -21, -22, -23, -24});
|
||||||
|
model.PopulateTensor<int32_t>(model.axis(), {1});
|
||||||
|
model.Invoke();
|
||||||
|
|
||||||
|
EXPECT_THAT(model.GetOutputShape(), ElementsAre(4, 3, 2));
|
||||||
|
EXPECT_THAT(
|
||||||
|
model.GetOutput(),
|
||||||
|
ElementsAreArray({5, 6, -3, -4, -1, -2, 11, 12, 9, 10, 7, 8,
|
||||||
|
17, 18, 15, 16, 13, 14, -23, -24, -21, -22, 19, 20}));
|
||||||
|
}
|
||||||
|
|
||||||
// int16 tests
|
// int16 tests
|
||||||
TEST(ReverseOpTest, Int16OneDimension) {
|
TEST(ReverseOpTest, Int16OneDimension) {
|
||||||
ReverseOpModel<int16_t> model({TensorType_INT16, {4}},
|
ReverseOpModel<int16_t> model({TensorType_INT16, {4}},
|
||||||
|
@ -882,6 +882,12 @@ OperatorProperty GetOperatorProperty(OpVariant op_variant) {
|
|||||||
property.restrict_same_input_output_scale = true;
|
property.restrict_same_input_output_scale = true;
|
||||||
property.version = 2;
|
property.version = 2;
|
||||||
break;
|
break;
|
||||||
|
case BuiltinOperator_REVERSE_V2:
|
||||||
|
property.inputs = {{0, {}}};
|
||||||
|
property.outputs = {{0, {}}};
|
||||||
|
property.restrict_same_input_output_scale = true;
|
||||||
|
property.version = 3;
|
||||||
|
break;
|
||||||
case BuiltinOperator_SHAPE:
|
case BuiltinOperator_SHAPE:
|
||||||
property.inputs = {{0, {}}};
|
property.inputs = {{0, {}}};
|
||||||
// Shape has no quantizable output.
|
// Shape has no quantizable output.
|
||||||
|
@ -402,6 +402,9 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) {
|
|||||||
}
|
}
|
||||||
return 1;
|
return 1;
|
||||||
case BuiltinOperator_REVERSE_V2:
|
case BuiltinOperator_REVERSE_V2:
|
||||||
|
if (op_sig.input_types.at(0) == TensorType_INT8) {
|
||||||
|
return 3;
|
||||||
|
}
|
||||||
if (op_sig.input_types.at(0) == TensorType_BOOL) {
|
if (op_sig.input_types.at(0) == TensorType_BOOL) {
|
||||||
return 2;
|
return 2;
|
||||||
}
|
}
|
||||||
|
@ -332,6 +332,7 @@ std::string FindMinimumRuntimeVersionForOp(tflite::BuiltinOperator op_code,
|
|||||||
{{BuiltinOperator_FILL, 2}, "2.3.0"},
|
{{BuiltinOperator_FILL, 2}, "2.3.0"},
|
||||||
{{BuiltinOperator_REVERSE_V2, 1}, "1.14.0"},
|
{{BuiltinOperator_REVERSE_V2, 1}, "1.14.0"},
|
||||||
{{BuiltinOperator_REVERSE_V2, 2}, "2.2.0"},
|
{{BuiltinOperator_REVERSE_V2, 2}, "2.2.0"},
|
||||||
|
{{BuiltinOperator_REVERSE_V2, 3}, kPendingReleaseVersion},
|
||||||
{{BuiltinOperator_RANK, 1}, "1.14.0"},
|
{{BuiltinOperator_RANK, 1}, "1.14.0"},
|
||||||
{{BuiltinOperator_WHILE, 1}, "1.15.0"},
|
{{BuiltinOperator_WHILE, 1}, "1.15.0"},
|
||||||
{{BuiltinOperator_CUMSUM, 1}, "2.4.0"},
|
{{BuiltinOperator_CUMSUM, 1}, "2.4.0"},
|
||||||
|
Loading…
x
Reference in New Issue
Block a user