Add INT8 type support for reverse op
* Added QI8 type support in tfl_ops.td * Added ReverseV2 op version 3 with int8 type support PiperOrigin-RevId: 348726608 Change-Id: I7003d7eff031e8ac12b55747fa5afaf9e3ab2a52
This commit is contained in:
parent
9c99eca7e6
commit
9be3ad213b
@ -71,6 +71,7 @@
|
||||
* TFLiteConverter exports models with SignatureDef
|
||||
* Interpreter supports getting a list of signatures and getting callable
|
||||
function for a given signaturedef.
|
||||
* Add int8 support for `ReshapeV2`.
|
||||
* TF Core:
|
||||
* Corrected higher-order gradients of control flow constructs (`tf.cond`,
|
||||
`tf.while_loop`, and compositions like `tf.foldl`) computed with
|
||||
|
@ -2720,7 +2720,7 @@ def TFL_ReverseV2Op: TFL_Op<"reverse_v2", [
|
||||
|
||||
Args:
|
||||
tensor: A Tensor. Must be one of the following types:
|
||||
uint8, int16, int32, int64, float32, bool Up to 8-D.
|
||||
uint8, int8, int16, int32, int64, float32, bool Up to 8-D.
|
||||
|
||||
axis: A Tensor. Must be one of the following types: int32, int64.
|
||||
with only 1 element which is the axis index.
|
||||
@ -2729,12 +2729,12 @@ def TFL_ReverseV2Op: TFL_Op<"reverse_v2", [
|
||||
|
||||
let arguments = (
|
||||
ins
|
||||
TFL_TensorOf<[F32, UI8, I16, I32, I64, QI16, QUI8, TFL_Quint8, I1]>:$input,
|
||||
TFL_TensorOf<[F32, UI8, I16, I32, I64, QI16, QUI8, QI8, TFL_Quint8, I1]>:$input,
|
||||
TFL_I32Tensor:$axis
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TFL_TensorOf<[F32, UI8, I16, I32, I64, QI16, QUI8, TFL_Quint8, I1]>:$output);
|
||||
TFL_TensorOf<[F32, UI8, I16, I32, I64, QI16, QUI8, QI8, TFL_Quint8, I1]>:$output);
|
||||
}
|
||||
|
||||
// Select has many instances in TF models where one or more of its operands
|
||||
|
@ -277,7 +277,7 @@ BuiltinOpResolver::BuiltinOpResolver() {
|
||||
AddBuiltin(BuiltinOperator_UNIQUE, Register_UNIQUE());
|
||||
AddBuiltin(BuiltinOperator_REVERSE_V2, Register_REVERSE_V2(),
|
||||
/* min_version = */ 1,
|
||||
/* max_version = */ 2);
|
||||
/* max_version = */ 3);
|
||||
AddBuiltin(BuiltinOperator_ADD_N, Register_ADD_N());
|
||||
AddBuiltin(BuiltinOperator_GATHER_ND, Register_GATHER_ND(),
|
||||
/* min_version = */ 1,
|
||||
|
@ -43,8 +43,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE(context, NumDimensions(input) >= NumElements(axis));
|
||||
|
||||
if (input->type != kTfLiteInt32 && input->type != kTfLiteFloat32 &&
|
||||
input->type != kTfLiteUInt8 && input->type != kTfLiteInt16 &&
|
||||
input->type != kTfLiteInt64 && input->type != kTfLiteBool) {
|
||||
input->type != kTfLiteUInt8 && input->type != kTfLiteInt8 &&
|
||||
input->type != kTfLiteInt16 && input->type != kTfLiteInt64 &&
|
||||
input->type != kTfLiteBool) {
|
||||
context->ReportError(context, "Type '%s' is not supported by reverse.",
|
||||
TfLiteTypeGetName(input->type));
|
||||
return kTfLiteError;
|
||||
@ -94,7 +95,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
GetTensorShape(output), GetTensorData<float>(output));
|
||||
break;
|
||||
}
|
||||
case kTfLiteUInt8: {
|
||||
case kTfLiteUInt8:
|
||||
case kTfLiteInt8: {
|
||||
reference_ops::Reverse<uint8_t>(
|
||||
axis, GetTensorShape(input), GetTensorData<uint8_t>(input),
|
||||
GetTensorShape(output), GetTensorData<uint8_t>(output));
|
||||
|
@ -164,6 +164,33 @@ TEST(ReverseOpTest, Uint8MultiDimensions) {
|
||||
17, 18, 15, 16, 13, 14, 23, 24, 21, 22, 19, 20}));
|
||||
}
|
||||
|
||||
// int8 tests
|
||||
TEST(ReverseOpTest, Int8OneDimension) {
|
||||
ReverseOpModel<int8_t> model({TensorType_INT8, {4}}, {TensorType_INT32, {1}});
|
||||
model.PopulateTensor<int8_t>(model.input(), {1, 2, -1, -2});
|
||||
model.PopulateTensor<int32_t>(model.axis(), {0});
|
||||
model.Invoke();
|
||||
|
||||
EXPECT_THAT(model.GetOutputShape(), ElementsAre(4));
|
||||
EXPECT_THAT(model.GetOutput(), ElementsAreArray({-2, -1, 2, 1}));
|
||||
}
|
||||
|
||||
TEST(ReverseOpTest, Int8MultiDimensions) {
|
||||
ReverseOpModel<int8_t> model({TensorType_INT8, {4, 3, 2}},
|
||||
{TensorType_INT32, {1}});
|
||||
model.PopulateTensor<int8_t>(
|
||||
model.input(), {-1, -2, -3, -4, 5, 6, 7, 8, 9, 10, 11, 12,
|
||||
13, 14, 15, 16, 17, 18, 19, 20, -21, -22, -23, -24});
|
||||
model.PopulateTensor<int32_t>(model.axis(), {1});
|
||||
model.Invoke();
|
||||
|
||||
EXPECT_THAT(model.GetOutputShape(), ElementsAre(4, 3, 2));
|
||||
EXPECT_THAT(
|
||||
model.GetOutput(),
|
||||
ElementsAreArray({5, 6, -3, -4, -1, -2, 11, 12, 9, 10, 7, 8,
|
||||
17, 18, 15, 16, 13, 14, -23, -24, -21, -22, 19, 20}));
|
||||
}
|
||||
|
||||
// int16 tests
|
||||
TEST(ReverseOpTest, Int16OneDimension) {
|
||||
ReverseOpModel<int16_t> model({TensorType_INT16, {4}},
|
||||
|
@ -882,6 +882,12 @@ OperatorProperty GetOperatorProperty(OpVariant op_variant) {
|
||||
property.restrict_same_input_output_scale = true;
|
||||
property.version = 2;
|
||||
break;
|
||||
case BuiltinOperator_REVERSE_V2:
|
||||
property.inputs = {{0, {}}};
|
||||
property.outputs = {{0, {}}};
|
||||
property.restrict_same_input_output_scale = true;
|
||||
property.version = 3;
|
||||
break;
|
||||
case BuiltinOperator_SHAPE:
|
||||
property.inputs = {{0, {}}};
|
||||
// Shape has no quantizable output.
|
||||
|
@ -402,6 +402,9 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) {
|
||||
}
|
||||
return 1;
|
||||
case BuiltinOperator_REVERSE_V2:
|
||||
if (op_sig.input_types.at(0) == TensorType_INT8) {
|
||||
return 3;
|
||||
}
|
||||
if (op_sig.input_types.at(0) == TensorType_BOOL) {
|
||||
return 2;
|
||||
}
|
||||
|
@ -332,6 +332,7 @@ std::string FindMinimumRuntimeVersionForOp(tflite::BuiltinOperator op_code,
|
||||
{{BuiltinOperator_FILL, 2}, "2.3.0"},
|
||||
{{BuiltinOperator_REVERSE_V2, 1}, "1.14.0"},
|
||||
{{BuiltinOperator_REVERSE_V2, 2}, "2.2.0"},
|
||||
{{BuiltinOperator_REVERSE_V2, 3}, kPendingReleaseVersion},
|
||||
{{BuiltinOperator_RANK, 1}, "1.14.0"},
|
||||
{{BuiltinOperator_WHILE, 1}, "1.15.0"},
|
||||
{{BuiltinOperator_CUMSUM, 1}, "2.4.0"},
|
||||
|
Loading…
x
Reference in New Issue
Block a user