Add int8 and int16x8 support for GATHER_ND operator
This commit is contained in:
parent
38a922c59f
commit
e8d3a48d09
@ -1068,12 +1068,12 @@ def TFL_GatherNdOp : TFL_Op<"gather_nd", [
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TFL_TensorOf<[F32, I8, I64, I32, UI8, TFL_Str]>:$params,
|
||||
TFL_TensorOf<[F32, I8, I16, I64, I32, UI8, TFL_Str]>:$params,
|
||||
TFL_I32OrI64Tensor:$indices
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TFL_TensorOf<[F32, I8, I64, I32, UI8, TFL_Str]>:$output
|
||||
TFL_TensorOf<[F32, I8, I16, I64, I32, UI8, TFL_Str]>:$output
|
||||
);
|
||||
}
|
||||
|
||||
|
@ -45,6 +45,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
case kTfLiteFloat32:
|
||||
case kTfLiteUInt8:
|
||||
case kTfLiteInt8:
|
||||
case kTfLiteInt16:
|
||||
case kTfLiteInt64:
|
||||
case kTfLiteInt32:
|
||||
case kTfLiteString:
|
||||
@ -129,6 +130,8 @@ TfLiteStatus EvalGatherNd(TfLiteContext* context, const TfLiteTensor* params,
|
||||
return GatherNd<uint8_t, IndicesT>(params, indices, output);
|
||||
case kTfLiteInt8:
|
||||
return GatherNd<int8_t, IndicesT>(params, indices, output);
|
||||
case kTfLiteInt16:
|
||||
return GatherNd<int16_t, IndicesT>(params, indices, output);
|
||||
case kTfLiteInt32:
|
||||
return GatherNd<int32_t, IndicesT>(params, indices, output);
|
||||
case kTfLiteInt64:
|
||||
|
@ -294,6 +294,28 @@ TEST(GatherNdOpTest, Int8Int64) {
|
||||
EXPECT_THAT(m.GetOutput<int8_t>(), ElementsAreArray({-2, 2, 2, 3, 3, -3}));
|
||||
}
|
||||
|
||||
TEST(GatherNdOpTest, Int16Int32) {
|
||||
GatherNdOpModel m({TensorType_INT16, {3, 2, 3}}, {TensorType_INT32, {2, 2}});
|
||||
m.SetInput<int16_t>({1, -1, 1, -2, 2, 2, //
|
||||
3, 3, -3, -4, -4, 4, //
|
||||
5, -5, 5, 6, -6, 6});
|
||||
m.SetPositions<int32_t>({0, 1, 1, 0});
|
||||
m.Invoke();
|
||||
|
||||
EXPECT_THAT(m.GetOutput<int16_t>(), ElementsAreArray({-2, 2, 2, 3, 3, -3}));
|
||||
}
|
||||
|
||||
TEST(GatherNdOpTest, Int16Int64) {
|
||||
GatherNdOpModel m({TensorType_INT16, {3, 2, 3}}, {TensorType_INT64, {2, 2}});
|
||||
m.SetInput<int16_t>({1, -1, 1, -2, 2, 2, //
|
||||
3, 3, -3, -4, -4, 4, //
|
||||
5, -5, 5, 6, -6, 6});
|
||||
m.SetPositions<int64_t>({0LL, 1LL, 1LL, 0LL});
|
||||
m.Invoke();
|
||||
|
||||
EXPECT_THAT(m.GetOutput<int16_t>(), ElementsAreArray({-2, 2, 2, 3, 3, -3}));
|
||||
}
|
||||
|
||||
TEST(GatherNdOpTest, Int64Int32) {
|
||||
GatherNdOpModel m({TensorType_INT64, {3, 2, 3}}, {TensorType_INT32, {2, 2}});
|
||||
m.SetInput<int64_t>({1LL, -1LL, 1LL, -2LL, 2LL, 2LL, //
|
||||
|
@ -281,7 +281,7 @@ BuiltinOpResolver::BuiltinOpResolver() {
|
||||
AddBuiltin(BuiltinOperator_ADD_N, Register_ADD_N());
|
||||
AddBuiltin(BuiltinOperator_GATHER_ND, Register_GATHER_ND(),
|
||||
/* min_version = */ 1,
|
||||
/* max_version = */ 2);
|
||||
/* max_version = */ 3);
|
||||
AddBuiltin(BuiltinOperator_WHERE, Register_WHERE());
|
||||
AddBuiltin(BuiltinOperator_ELU, Register_ELU());
|
||||
AddBuiltin(BuiltinOperator_REVERSE_SEQUENCE, Register_REVERSE_SEQUENCE());
|
||||
|
@ -443,7 +443,7 @@ BuiltinRefOpResolver::BuiltinRefOpResolver() {
|
||||
AddBuiltin(BuiltinOperator_ADD_N, Register_ADD_N());
|
||||
AddBuiltin(BuiltinOperator_GATHER_ND, Register_GATHER_ND(),
|
||||
/* min_version = */ 1,
|
||||
/* max_version = */ 2);
|
||||
/* max_version = */ 3);
|
||||
AddBuiltin(BuiltinOperator_WHERE, Register_WHERE());
|
||||
AddBuiltin(BuiltinOperator_REVERSE_SEQUENCE, Register_REVERSE_SEQUENCE());
|
||||
AddBuiltin(BuiltinOperator_MATRIX_DIAG, Register_MATRIX_DIAG());
|
||||
|
@ -322,6 +322,7 @@ tf_cc_test(
|
||||
"//tensorflow/lite/tools/optimize:testdata/concat.bin",
|
||||
"//tensorflow/lite/tools/optimize:testdata/fc.bin",
|
||||
"//tensorflow/lite/tools/optimize:testdata/fc_qat.bin",
|
||||
"//tensorflow/lite/tools/optimize:testdata/gather_nd.bin",
|
||||
"//tensorflow/lite/tools/optimize:testdata/lstm_calibrated.bin",
|
||||
"//tensorflow/lite/tools/optimize:testdata/lstm_calibrated2.bin",
|
||||
"//tensorflow/lite/tools/optimize:testdata/lstm_quantized.bin",
|
||||
|
@ -208,6 +208,12 @@ OperatorProperty GetOperatorProperty(OpVariant op_variant) {
|
||||
property.quantize_input_as_activations = true;
|
||||
property.version = 2;
|
||||
break;
|
||||
case BuiltinOperator_GATHER_ND:
|
||||
property.inputs = {{0, {}}};
|
||||
property.outputs = {{0, {}}};
|
||||
property.restrict_same_input_output_scale = true;
|
||||
property.version = 3;
|
||||
break;
|
||||
case BuiltinOperator_HARD_SWISH: {
|
||||
property.inputs = {{0, {}}};
|
||||
property.outputs = {{0, {}}};
|
||||
|
@ -1780,6 +1780,71 @@ TEST_P(QuantizeBroadcastToModelTest, VerifyBroadcastToQuantization) {
|
||||
EXPECT_EQ(model_.operator_codes[0]->version, 3);
|
||||
}
|
||||
|
||||
class QuantizeGatherNDModelTest
|
||||
: public QuantizeModelTest,
|
||||
public testing::WithParamInterface<TensorType> {
|
||||
protected:
|
||||
QuantizeGatherNDModelTest() {
|
||||
tensor_type_ = GetParam();
|
||||
input_model_ = ReadModel(internal::kModelWithGatherNDOp);
|
||||
readonly_model_ = input_model_->GetModel();
|
||||
readonly_model_->UnPackTo(&model_);
|
||||
}
|
||||
|
||||
TensorType tensor_type_;
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(QuantizeGatherNDModelTestInst,
|
||||
QuantizeGatherNDModelTest,
|
||||
testing::ValuesIn({TensorType_INT8,
|
||||
TensorType_INT16}));
|
||||
|
||||
TEST_P(QuantizeGatherNDModelTest, QuantizeGatherND) {
|
||||
auto status =
|
||||
QuantizeModelAllOperators(&builder_, &model_, tensor_type_, tensor_type_,
|
||||
false, tensor_type_, &error_reporter_);
|
||||
EXPECT_EQ(status, kTfLiteOk);
|
||||
|
||||
// There is only one subgraph.
|
||||
const int32_t subgraph_idx = 0;
|
||||
const auto& subgraph = model_.subgraphs[subgraph_idx];
|
||||
const auto& readonly_subgraph =
|
||||
readonly_model_->subgraphs()->Get(subgraph_idx);
|
||||
|
||||
// There should be a single gather_nd op.
|
||||
EXPECT_EQ(readonly_subgraph->operators()->size(), 1);
|
||||
EXPECT_EQ(subgraph->operators.size(), 1);
|
||||
const auto& gather_nd = subgraph->operators[0];
|
||||
EXPECT_EQ(model_.operator_codes[gather_nd->opcode_index]->builtin_code,
|
||||
BuiltinOperator_GATHER_ND);
|
||||
|
||||
// There should be 3 tensors: input, output, and indices.
|
||||
EXPECT_EQ(subgraph->tensors.size(), 3);
|
||||
|
||||
// Input Tensor
|
||||
EXPECT_EQ(subgraph->tensors[0]->type, tensor_type_);
|
||||
EXPECT_EQ(subgraph->tensors[0]->name, "input");
|
||||
EXPECT_EQ(subgraph->tensors[0]->quantization->scale.size(), 1);
|
||||
EXPECT_EQ(subgraph->tensors[0]->quantization->zero_point.size(), 1);
|
||||
|
||||
// Output Tensor
|
||||
EXPECT_EQ(subgraph->tensors[2]->type, tensor_type_);
|
||||
EXPECT_EQ(subgraph->tensors[2]->name, "output");
|
||||
EXPECT_EQ(subgraph->tensors[2]->quantization->scale.size(), 1);
|
||||
EXPECT_EQ(subgraph->tensors[2]->quantization->zero_point.size(), 1);
|
||||
|
||||
// The gather indices are of type INT32 and should not be quantized
|
||||
EXPECT_EQ(subgraph->tensors[1]->type, TensorType_INT32);
|
||||
EXPECT_EQ(subgraph->tensors[1]->name, "indices");
|
||||
EXPECT_EQ(subgraph->tensors[1]->quantization->scale.size(), 0);
|
||||
EXPECT_EQ(subgraph->tensors[1]->quantization->zero_point.size(), 0);
|
||||
|
||||
// Check op and versioning.
|
||||
EXPECT_EQ(model_.operator_codes.size(), 1);
|
||||
EXPECT_EQ(model_.operator_codes[0]->builtin_code, BuiltinOperator_GATHER_ND);
|
||||
EXPECT_EQ(model_.operator_codes[0]->version, 3);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace optimize
|
||||
} // namespace tflite
|
||||
|
@ -49,6 +49,8 @@ const char* kModelWithArgMaxOp = "argmax.bin";
|
||||
|
||||
const char* kModelWithFCOp = "fc.bin";
|
||||
|
||||
const char* kModelWithGatherNDOp = "gather_nd.bin";
|
||||
|
||||
const char* kModelMixed = "mixed.bin";
|
||||
const char* kModelMixed16x8 = "mixed16x8.bin";
|
||||
|
||||
|
@ -75,6 +75,9 @@ extern const char* kModelWithArgMaxOp;
|
||||
// Test model with a argmax op.
|
||||
extern const char* kModelWithFCOp;
|
||||
|
||||
// Test model with a gather_nd op.
|
||||
extern const char* kModelWithGatherNDOp;
|
||||
|
||||
// Test model with mixed quantizable and un-quantizable ops.
|
||||
// reshape->custom->custom->squeeze.
|
||||
extern const char* kModelMixed;
|
||||
|
BIN
tensorflow/lite/tools/optimize/testdata/gather_nd.bin
vendored
Normal file
BIN
tensorflow/lite/tools/optimize/testdata/gather_nd.bin
vendored
Normal file
Binary file not shown.
@ -513,6 +513,10 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) {
|
||||
return 1;
|
||||
|
||||
case BuiltinOperator_GATHER_ND:
|
||||
if (!op_sig.input_types.empty() &&
|
||||
(op_sig.input_types.at(0) == TensorType_INT16)) {
|
||||
return 3;
|
||||
}
|
||||
if (!op_sig.input_types.empty() &&
|
||||
op_sig.input_types.at(0) == TensorType_STRING) {
|
||||
return 2;
|
||||
|
@ -682,6 +682,13 @@ TEST(OpVersionTest, VersioningGatherNdOperatorTest) {
|
||||
std::vector<TensorType>{TensorType_STRING, TensorType_INT32},
|
||||
};
|
||||
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2);
|
||||
|
||||
fake_op_sig = {
|
||||
.op = BuiltinOperator_GATHER_ND,
|
||||
.input_types =
|
||||
std::vector<TensorType>{TensorType_INT16, TensorType_INT32},
|
||||
};
|
||||
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3);
|
||||
}
|
||||
TEST(OpVersionTest, VersioningDivTest) {
|
||||
OpSignature fake_op_sig = {
|
||||
|
@ -124,6 +124,7 @@ std::string FindMinimumRuntimeVersionForOp(tflite::BuiltinOperator op_code,
|
||||
{{BuiltinOperator_GATHER, 4}, "2.4.0"},
|
||||
{{BuiltinOperator_GATHER_ND, 1}, "1.14.0"},
|
||||
{{BuiltinOperator_GATHER_ND, 2}, "2.3.0"},
|
||||
{{BuiltinOperator_GATHER_ND, 3}, kPendingReleaseVersion},
|
||||
{{BuiltinOperator_HASHTABLE_LOOKUP, 1}, "1.5.0"},
|
||||
{{BuiltinOperator_SVDF, 1}, "1.5.0"},
|
||||
{{BuiltinOperator_SVDF, 2}, "1.14.0"},
|
||||
|
Loading…
x
Reference in New Issue
Block a user