Add bool support for Gather.
PiperOrigin-RevId: 261222445
This commit is contained in:
parent
21488b7bca
commit
a5167ca38b
@ -60,6 +60,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
case kTfLiteInt8:
|
case kTfLiteInt8:
|
||||||
case kTfLiteInt64:
|
case kTfLiteInt64:
|
||||||
case kTfLiteInt32:
|
case kTfLiteInt32:
|
||||||
|
case kTfLiteBool:
|
||||||
break;
|
break;
|
||||||
case kTfLiteString: {
|
case kTfLiteString: {
|
||||||
// Only 1D input is supported.
|
// Only 1D input is supported.
|
||||||
@ -142,6 +143,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
return Gather<int32_t, int32_t>(*params, input, positions, output);
|
return Gather<int32_t, int32_t>(*params, input, positions, output);
|
||||||
case kTfLiteInt64:
|
case kTfLiteInt64:
|
||||||
return Gather<int64_t, int32_t>(*params, input, positions, output);
|
return Gather<int64_t, int32_t>(*params, input, positions, output);
|
||||||
|
case kTfLiteBool:
|
||||||
|
return Gather<bool, int32_t>(*params, input, positions, output);
|
||||||
case kTfLiteString:
|
case kTfLiteString:
|
||||||
return GatherStrings<int32_t>(context, input, positions, output);
|
return GatherStrings<int32_t>(context, input, positions, output);
|
||||||
default:
|
default:
|
||||||
@ -162,6 +165,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
return Gather<int32_t, int64_t>(*params, input, positions, output);
|
return Gather<int32_t, int64_t>(*params, input, positions, output);
|
||||||
case kTfLiteInt64:
|
case kTfLiteInt64:
|
||||||
return Gather<int64_t, int64_t>(*params, input, positions, output);
|
return Gather<int64_t, int64_t>(*params, input, positions, output);
|
||||||
|
case kTfLiteBool:
|
||||||
|
return Gather<bool, int64_t>(*params, input, positions, output);
|
||||||
case kTfLiteString:
|
case kTfLiteString:
|
||||||
return GatherStrings<int64_t>(context, input, positions, output);
|
return GatherStrings<int64_t>(context, input, positions, output);
|
||||||
default:
|
default:
|
||||||
|
@ -251,7 +251,7 @@ BuiltinOpResolver::BuiltinOpResolver() {
|
|||||||
/* max_version */ 2);
|
/* max_version */ 2);
|
||||||
AddBuiltin(BuiltinOperator_GATHER, Register_GATHER(),
|
AddBuiltin(BuiltinOperator_GATHER, Register_GATHER(),
|
||||||
/* min_version */ 1,
|
/* min_version */ 1,
|
||||||
/* max_version */ 2);
|
/* max_version */ 3);
|
||||||
AddBuiltin(BuiltinOperator_TRANSPOSE, Register_TRANSPOSE(),
|
AddBuiltin(BuiltinOperator_TRANSPOSE, Register_TRANSPOSE(),
|
||||||
/* min_version */ 1,
|
/* min_version */ 1,
|
||||||
/* max_version */ 3);
|
/* max_version */ 3);
|
||||||
|
@ -63,6 +63,7 @@ string GetMinimumRuntimeVersionForModel(const Model& model) {
|
|||||||
{{OperatorType::kFullyConnected, 4}, "1.14.0"},
|
{{OperatorType::kFullyConnected, 4}, "1.14.0"},
|
||||||
{{OperatorType::kGather, 1}, "1.6.0"},
|
{{OperatorType::kGather, 1}, "1.6.0"},
|
||||||
{{OperatorType::kGather, 2}, "1.14.0"},
|
{{OperatorType::kGather, 2}, "1.14.0"},
|
||||||
|
{{OperatorType::kGather, 3}, kPendingReleaseOpVersion},
|
||||||
{{OperatorType::kGatherNd, 1}, "1.14.0"},
|
{{OperatorType::kGatherNd, 1}, "1.14.0"},
|
||||||
{{OperatorType::kSvdf, 1}, "1.5.0"},
|
{{OperatorType::kSvdf, 1}, "1.5.0"},
|
||||||
{{OperatorType::kSvdf, 2}, "1.14.0"},
|
{{OperatorType::kSvdf, 2}, "1.14.0"},
|
||||||
|
@ -540,7 +540,11 @@ class Gather : public BuiltinOperator<GatherOperator, ::tflite::GatherOptions,
|
|||||||
int GetVersion(const OperatorSignature& op_signature) const override {
|
int GetVersion(const OperatorSignature& op_signature) const override {
|
||||||
const string& input_name = op_signature.op->inputs[0];
|
const string& input_name = op_signature.op->inputs[0];
|
||||||
const Array& input_array = op_signature.model->GetArray(input_name);
|
const Array& input_array = op_signature.model->GetArray(input_name);
|
||||||
// If the op take int8 input, it is version 2.
|
// If the op takes bool input, it is version 3.
|
||||||
|
if (input_array.data_type == ArrayDataType::kBool) {
|
||||||
|
return 3;
|
||||||
|
}
|
||||||
|
// If the op takes int8 input, it is version 2.
|
||||||
if (input_array.data_type == ArrayDataType::kInt8) {
|
if (input_array.data_type == ArrayDataType::kInt8) {
|
||||||
return 2;
|
return 2;
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user