Add bool support for Gather.
PiperOrigin-RevId: 261222445
This commit is contained in:
parent
21488b7bca
commit
a5167ca38b
@ -60,6 +60,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
case kTfLiteInt8:
|
||||
case kTfLiteInt64:
|
||||
case kTfLiteInt32:
|
||||
case kTfLiteBool:
|
||||
break;
|
||||
case kTfLiteString: {
|
||||
// Only 1D input is supported.
|
||||
@ -142,6 +143,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
return Gather<int32_t, int32_t>(*params, input, positions, output);
|
||||
case kTfLiteInt64:
|
||||
return Gather<int64_t, int32_t>(*params, input, positions, output);
|
||||
case kTfLiteBool:
|
||||
return Gather<bool, int32_t>(*params, input, positions, output);
|
||||
case kTfLiteString:
|
||||
return GatherStrings<int32_t>(context, input, positions, output);
|
||||
default:
|
||||
@ -162,6 +165,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
return Gather<int32_t, int64_t>(*params, input, positions, output);
|
||||
case kTfLiteInt64:
|
||||
return Gather<int64_t, int64_t>(*params, input, positions, output);
|
||||
case kTfLiteBool:
|
||||
return Gather<bool, int64_t>(*params, input, positions, output);
|
||||
case kTfLiteString:
|
||||
return GatherStrings<int64_t>(context, input, positions, output);
|
||||
default:
|
||||
|
@ -251,7 +251,7 @@ BuiltinOpResolver::BuiltinOpResolver() {
|
||||
/* max_version */ 2);
|
||||
AddBuiltin(BuiltinOperator_GATHER, Register_GATHER(),
|
||||
/* min_version */ 1,
|
||||
/* max_version */ 2);
|
||||
/* max_version */ 3);
|
||||
AddBuiltin(BuiltinOperator_TRANSPOSE, Register_TRANSPOSE(),
|
||||
/* min_version */ 1,
|
||||
/* max_version */ 3);
|
||||
|
@ -63,6 +63,7 @@ string GetMinimumRuntimeVersionForModel(const Model& model) {
|
||||
{{OperatorType::kFullyConnected, 4}, "1.14.0"},
|
||||
{{OperatorType::kGather, 1}, "1.6.0"},
|
||||
{{OperatorType::kGather, 2}, "1.14.0"},
|
||||
{{OperatorType::kGather, 3}, kPendingReleaseOpVersion},
|
||||
{{OperatorType::kGatherNd, 1}, "1.14.0"},
|
||||
{{OperatorType::kSvdf, 1}, "1.5.0"},
|
||||
{{OperatorType::kSvdf, 2}, "1.14.0"},
|
||||
|
@ -540,7 +540,11 @@ class Gather : public BuiltinOperator<GatherOperator, ::tflite::GatherOptions,
|
||||
int GetVersion(const OperatorSignature& op_signature) const override {
|
||||
const string& input_name = op_signature.op->inputs[0];
|
||||
const Array& input_array = op_signature.model->GetArray(input_name);
|
||||
// If the op take int8 input, it is version 2.
|
||||
// If the op takes bool input, it is version 3.
|
||||
if (input_array.data_type == ArrayDataType::kBool) {
|
||||
return 3;
|
||||
}
|
||||
// If the op takes int8 input, it is version 2.
|
||||
if (input_array.data_type == ArrayDataType::kInt8) {
|
||||
return 2;
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user