Add bool support for Gather.

PiperOrigin-RevId: 261222445
This commit is contained in:
Haoliang Zhang 2019-08-01 16:34:09 -07:00 committed by TensorFlower Gardener
parent 21488b7bca
commit a5167ca38b
4 changed files with 12 additions and 2 deletions

View File

@ -60,6 +60,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
case kTfLiteInt8:
case kTfLiteInt64:
case kTfLiteInt32:
case kTfLiteBool:
break;
case kTfLiteString: {
// Only 1D input is supported.
@ -142,6 +143,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
return Gather<int32_t, int32_t>(*params, input, positions, output);
case kTfLiteInt64:
return Gather<int64_t, int32_t>(*params, input, positions, output);
case kTfLiteBool:
return Gather<bool, int32_t>(*params, input, positions, output);
case kTfLiteString:
return GatherStrings<int32_t>(context, input, positions, output);
default:
@ -162,6 +165,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
return Gather<int32_t, int64_t>(*params, input, positions, output);
case kTfLiteInt64:
return Gather<int64_t, int64_t>(*params, input, positions, output);
case kTfLiteBool:
return Gather<bool, int64_t>(*params, input, positions, output);
case kTfLiteString:
return GatherStrings<int64_t>(context, input, positions, output);
default:

View File

@ -251,7 +251,7 @@ BuiltinOpResolver::BuiltinOpResolver() {
/* max_version */ 2);
AddBuiltin(BuiltinOperator_GATHER, Register_GATHER(),
/* min_version */ 1,
/* max_version */ 2);
/* max_version */ 3);
AddBuiltin(BuiltinOperator_TRANSPOSE, Register_TRANSPOSE(),
/* min_version */ 1,
/* max_version */ 3);

View File

@ -63,6 +63,7 @@ string GetMinimumRuntimeVersionForModel(const Model& model) {
{{OperatorType::kFullyConnected, 4}, "1.14.0"},
{{OperatorType::kGather, 1}, "1.6.0"},
{{OperatorType::kGather, 2}, "1.14.0"},
{{OperatorType::kGather, 3}, kPendingReleaseOpVersion},
{{OperatorType::kGatherNd, 1}, "1.14.0"},
{{OperatorType::kSvdf, 1}, "1.5.0"},
{{OperatorType::kSvdf, 2}, "1.14.0"},

View File

@ -540,7 +540,11 @@ class Gather : public BuiltinOperator<GatherOperator, ::tflite::GatherOptions,
int GetVersion(const OperatorSignature& op_signature) const override {
const string& input_name = op_signature.op->inputs[0];
const Array& input_array = op_signature.model->GetArray(input_name);
// If the op take int8 input, it is version 2.
// If the op takes bool input, it is version 3.
if (input_array.data_type == ArrayDataType::kBool) {
return 3;
}
// If the op takes int8 input, it is version 2.
if (input_array.data_type == ArrayDataType::kInt8) {
return 2;
}