Add int8 and int16x8 support for GATHER_ND operator

This commit is contained in:
Mohamed Nour Abouelseoud 2020-12-24 14:53:42 +00:00
parent 38a922c59f
commit e8d3a48d09
14 changed files with 118 additions and 4 deletions

View File

@ -1068,12 +1068,12 @@ def TFL_GatherNdOp : TFL_Op<"gather_nd", [
}];
let arguments = (ins
TFL_TensorOf<[F32, I8, I64, I32, UI8, TFL_Str]>:$params,
TFL_TensorOf<[F32, I8, I16, I64, I32, UI8, TFL_Str]>:$params,
TFL_I32OrI64Tensor:$indices
);
let results = (outs
TFL_TensorOf<[F32, I8, I64, I32, UI8, TFL_Str]>:$output
TFL_TensorOf<[F32, I8, I16, I64, I32, UI8, TFL_Str]>:$output
);
}

View File

@ -45,6 +45,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
case kTfLiteFloat32:
case kTfLiteUInt8:
case kTfLiteInt8:
case kTfLiteInt16:
case kTfLiteInt64:
case kTfLiteInt32:
case kTfLiteString:
@ -129,6 +130,8 @@ TfLiteStatus EvalGatherNd(TfLiteContext* context, const TfLiteTensor* params,
return GatherNd<uint8_t, IndicesT>(params, indices, output);
case kTfLiteInt8:
return GatherNd<int8_t, IndicesT>(params, indices, output);
case kTfLiteInt16:
return GatherNd<int16_t, IndicesT>(params, indices, output);
case kTfLiteInt32:
return GatherNd<int32_t, IndicesT>(params, indices, output);
case kTfLiteInt64:

View File

@ -294,6 +294,28 @@ TEST(GatherNdOpTest, Int8Int64) {
EXPECT_THAT(m.GetOutput<int8_t>(), ElementsAreArray({-2, 2, 2, 3, 3, -3}));
}
TEST(GatherNdOpTest, Int16Int32) {
GatherNdOpModel m({TensorType_INT16, {3, 2, 3}}, {TensorType_INT32, {2, 2}});
m.SetInput<int16_t>({1, -1, 1, -2, 2, 2, //
3, 3, -3, -4, -4, 4, //
5, -5, 5, 6, -6, 6});
m.SetPositions<int32_t>({0, 1, 1, 0});
m.Invoke();
EXPECT_THAT(m.GetOutput<int16_t>(), ElementsAreArray({-2, 2, 2, 3, 3, -3}));
}
TEST(GatherNdOpTest, Int16Int64) {
GatherNdOpModel m({TensorType_INT16, {3, 2, 3}}, {TensorType_INT64, {2, 2}});
m.SetInput<int16_t>({1, -1, 1, -2, 2, 2, //
3, 3, -3, -4, -4, 4, //
5, -5, 5, 6, -6, 6});
m.SetPositions<int64_t>({0LL, 1LL, 1LL, 0LL});
m.Invoke();
EXPECT_THAT(m.GetOutput<int16_t>(), ElementsAreArray({-2, 2, 2, 3, 3, -3}));
}
TEST(GatherNdOpTest, Int64Int32) {
GatherNdOpModel m({TensorType_INT64, {3, 2, 3}}, {TensorType_INT32, {2, 2}});
m.SetInput<int64_t>({1LL, -1LL, 1LL, -2LL, 2LL, 2LL, //

View File

@ -281,7 +281,7 @@ BuiltinOpResolver::BuiltinOpResolver() {
AddBuiltin(BuiltinOperator_ADD_N, Register_ADD_N());
AddBuiltin(BuiltinOperator_GATHER_ND, Register_GATHER_ND(),
/* min_version = */ 1,
/* max_version = */ 2);
/* max_version = */ 3);
AddBuiltin(BuiltinOperator_WHERE, Register_WHERE());
AddBuiltin(BuiltinOperator_ELU, Register_ELU());
AddBuiltin(BuiltinOperator_REVERSE_SEQUENCE, Register_REVERSE_SEQUENCE());

View File

@ -443,7 +443,7 @@ BuiltinRefOpResolver::BuiltinRefOpResolver() {
AddBuiltin(BuiltinOperator_ADD_N, Register_ADD_N());
AddBuiltin(BuiltinOperator_GATHER_ND, Register_GATHER_ND(),
/* min_version = */ 1,
/* max_version = */ 2);
/* max_version = */ 3);
AddBuiltin(BuiltinOperator_WHERE, Register_WHERE());
AddBuiltin(BuiltinOperator_REVERSE_SEQUENCE, Register_REVERSE_SEQUENCE());
AddBuiltin(BuiltinOperator_MATRIX_DIAG, Register_MATRIX_DIAG());

View File

@ -322,6 +322,7 @@ tf_cc_test(
"//tensorflow/lite/tools/optimize:testdata/concat.bin",
"//tensorflow/lite/tools/optimize:testdata/fc.bin",
"//tensorflow/lite/tools/optimize:testdata/fc_qat.bin",
"//tensorflow/lite/tools/optimize:testdata/gather_nd.bin",
"//tensorflow/lite/tools/optimize:testdata/lstm_calibrated.bin",
"//tensorflow/lite/tools/optimize:testdata/lstm_calibrated2.bin",
"//tensorflow/lite/tools/optimize:testdata/lstm_quantized.bin",

View File

@ -208,6 +208,12 @@ OperatorProperty GetOperatorProperty(OpVariant op_variant) {
property.quantize_input_as_activations = true;
property.version = 2;
break;
case BuiltinOperator_GATHER_ND:
property.inputs = {{0, {}}};
property.outputs = {{0, {}}};
property.restrict_same_input_output_scale = true;
property.version = 3;
break;
case BuiltinOperator_HARD_SWISH: {
property.inputs = {{0, {}}};
property.outputs = {{0, {}}};

View File

@ -1780,6 +1780,71 @@ TEST_P(QuantizeBroadcastToModelTest, VerifyBroadcastToQuantization) {
EXPECT_EQ(model_.operator_codes[0]->version, 3);
}
class QuantizeGatherNDModelTest
: public QuantizeModelTest,
public testing::WithParamInterface<TensorType> {
protected:
QuantizeGatherNDModelTest() {
tensor_type_ = GetParam();
input_model_ = ReadModel(internal::kModelWithGatherNDOp);
readonly_model_ = input_model_->GetModel();
readonly_model_->UnPackTo(&model_);
}
TensorType tensor_type_;
};
INSTANTIATE_TEST_SUITE_P(QuantizeGatherNDModelTestInst,
QuantizeGatherNDModelTest,
testing::ValuesIn({TensorType_INT8,
TensorType_INT16}));
TEST_P(QuantizeGatherNDModelTest, QuantizeGatherND) {
auto status =
QuantizeModelAllOperators(&builder_, &model_, tensor_type_, tensor_type_,
false, tensor_type_, &error_reporter_);
EXPECT_EQ(status, kTfLiteOk);
// There is only one subgraph.
const int32_t subgraph_idx = 0;
const auto& subgraph = model_.subgraphs[subgraph_idx];
const auto& readonly_subgraph =
readonly_model_->subgraphs()->Get(subgraph_idx);
// There should be a single gather_nd op.
EXPECT_EQ(readonly_subgraph->operators()->size(), 1);
EXPECT_EQ(subgraph->operators.size(), 1);
const auto& gather_nd = subgraph->operators[0];
EXPECT_EQ(model_.operator_codes[gather_nd->opcode_index]->builtin_code,
BuiltinOperator_GATHER_ND);
// There should be 3 tensors: input, output, and indices.
EXPECT_EQ(subgraph->tensors.size(), 3);
// Input Tensor
EXPECT_EQ(subgraph->tensors[0]->type, tensor_type_);
EXPECT_EQ(subgraph->tensors[0]->name, "input");
EXPECT_EQ(subgraph->tensors[0]->quantization->scale.size(), 1);
EXPECT_EQ(subgraph->tensors[0]->quantization->zero_point.size(), 1);
// Output Tensor
EXPECT_EQ(subgraph->tensors[2]->type, tensor_type_);
EXPECT_EQ(subgraph->tensors[2]->name, "output");
EXPECT_EQ(subgraph->tensors[2]->quantization->scale.size(), 1);
EXPECT_EQ(subgraph->tensors[2]->quantization->zero_point.size(), 1);
// The gather indices are of type INT32 and should not be quantized
EXPECT_EQ(subgraph->tensors[1]->type, TensorType_INT32);
EXPECT_EQ(subgraph->tensors[1]->name, "indices");
EXPECT_EQ(subgraph->tensors[1]->quantization->scale.size(), 0);
EXPECT_EQ(subgraph->tensors[1]->quantization->zero_point.size(), 0);
// Check op and versioning.
EXPECT_EQ(model_.operator_codes.size(), 1);
EXPECT_EQ(model_.operator_codes[0]->builtin_code, BuiltinOperator_GATHER_ND);
EXPECT_EQ(model_.operator_codes[0]->version, 3);
}
} // namespace
} // namespace optimize
} // namespace tflite

View File

@ -49,6 +49,8 @@ const char* kModelWithArgMaxOp = "argmax.bin";
const char* kModelWithFCOp = "fc.bin";
const char* kModelWithGatherNDOp = "gather_nd.bin";
const char* kModelMixed = "mixed.bin";
const char* kModelMixed16x8 = "mixed16x8.bin";

View File

@ -75,6 +75,9 @@ extern const char* kModelWithArgMaxOp;
// Test model with a argmax op.
extern const char* kModelWithFCOp;
// Test model with a gather_nd op.
extern const char* kModelWithGatherNDOp;
// Test model with mixed quantizable and un-quantizable ops.
// reshape->custom->custom->squeeze.
extern const char* kModelMixed;

Binary file not shown.

View File

@ -513,6 +513,10 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) {
return 1;
case BuiltinOperator_GATHER_ND:
if (!op_sig.input_types.empty() &&
(op_sig.input_types.at(0) == TensorType_INT16)) {
return 3;
}
if (!op_sig.input_types.empty() &&
op_sig.input_types.at(0) == TensorType_STRING) {
return 2;

View File

@ -682,6 +682,13 @@ TEST(OpVersionTest, VersioningGatherNdOperatorTest) {
std::vector<TensorType>{TensorType_STRING, TensorType_INT32},
};
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2);
fake_op_sig = {
.op = BuiltinOperator_GATHER_ND,
.input_types =
std::vector<TensorType>{TensorType_INT16, TensorType_INT32},
};
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3);
}
TEST(OpVersionTest, VersioningDivTest) {
OpSignature fake_op_sig = {

View File

@ -124,6 +124,7 @@ std::string FindMinimumRuntimeVersionForOp(tflite::BuiltinOperator op_code,
{{BuiltinOperator_GATHER, 4}, "2.4.0"},
{{BuiltinOperator_GATHER_ND, 1}, "1.14.0"},
{{BuiltinOperator_GATHER_ND, 2}, "2.3.0"},
{{BuiltinOperator_GATHER_ND, 3}, kPendingReleaseVersion},
{{BuiltinOperator_HASHTABLE_LOOKUP, 1}, "1.5.0"},
{{BuiltinOperator_SVDF, 1}, "1.5.0"},
{{BuiltinOperator_SVDF, 2}, "1.14.0"},