Support bool input type for StridedSlice.
PiperOrigin-RevId: 274651604
This commit is contained in:
parent
936e6cde7c
commit
1ad2fcb66a
@ -2543,7 +2543,7 @@ def TFL_StridedSliceOp: TFL_Op<"strided_slice",
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TensorOf<[F32, I32, I64, I8, QI8, QUI8]>:$input,
|
||||
TensorOf<[F32, I32, I64, I8, QI8, QUI8, I1]>:$input,
|
||||
TensorOf<[I32]>:$begin,
|
||||
TensorOf<[I32]>:$end,
|
||||
TensorOf<[I32]>:$strides,
|
||||
@ -2556,7 +2556,7 @@ def TFL_StridedSliceOp: TFL_Op<"strided_slice",
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TensorOf<[F32, I32, I64, I8, QI8, QUI8]>:$output
|
||||
TensorOf<[F32, I32, I64, I8, QI8, QUI8, I1]>:$output
|
||||
);
|
||||
|
||||
let hasOptions = 1;
|
||||
|
@ -155,7 +155,7 @@ BuiltinOpResolver::BuiltinOpResolver() {
|
||||
AddBuiltin(BuiltinOperator_SQUEEZE, Register_SQUEEZE());
|
||||
AddBuiltin(BuiltinOperator_STRIDED_SLICE, Register_STRIDED_SLICE(),
|
||||
/* min_version */ 1,
|
||||
/* max_version */ 2);
|
||||
/* max_version */ 3);
|
||||
AddBuiltin(BuiltinOperator_EXP, Register_EXP());
|
||||
AddBuiltin(BuiltinOperator_TOPK_V2, Register_TOPK_V2(),
|
||||
/* min_version */ 1,
|
||||
|
@ -207,6 +207,11 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_STRIDED_SLICE(reference_ops, int8_t);
|
||||
}
|
||||
break;
|
||||
case kTfLiteBool:
|
||||
if (kernel_type == kReference) {
|
||||
TF_LITE_STRIDED_SLICE(reference_ops, bool);
|
||||
}
|
||||
break;
|
||||
default:
|
||||
context->ReportError(context,
|
||||
"Type %d is currently not supported "
|
||||
|
@ -3232,7 +3232,7 @@ def make_strided_slice_tests(options):
|
||||
test_parameters = [
|
||||
# 4-D (basic cases with const/non-const indices).
|
||||
{
|
||||
"dtype": [tf.float32, tf.int32, tf.int64],
|
||||
"dtype": [tf.float32, tf.int32, tf.int64, tf.bool],
|
||||
"index_type": [tf.int32],
|
||||
"input_shape": [[12, 2, 2, 5]],
|
||||
"strides": [None, [2, 1, 3, 1]],
|
||||
|
@ -145,6 +145,7 @@ string GetMinimumRuntimeVersionForModel(const Model& model) {
|
||||
{{OperatorType::kSplitV, 1}, "1.13.1"},
|
||||
{{OperatorType::kStridedSlice, 1}, "1.6.0"},
|
||||
{{OperatorType::kStridedSlice, 2}, "1.14.0"},
|
||||
{{OperatorType::kStridedSlice, 3}, kPendingReleaseOpVersion},
|
||||
{{OperatorType::kTopK_V2, 1}, "1.7.0"},
|
||||
{{OperatorType::kTopK_V2, 2}, "1.14.0"},
|
||||
{{OperatorType::kArgMax, 1}, "1.9.0"},
|
||||
|
@ -243,6 +243,15 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) {
|
||||
return 2;
|
||||
}
|
||||
return 1;
|
||||
case BuiltinOperator_STRIDED_SLICE:
|
||||
// If the op takes bool input, it is version 3.
|
||||
if (op_sig.input_types.at(0) == TensorType_BOOL) {
|
||||
return 3;
|
||||
}
|
||||
if (op_sig.input_types.at(0) == TensorType_INT8) {
|
||||
return 2;
|
||||
}
|
||||
return 1;
|
||||
|
||||
case BuiltinOperator_AVERAGE_POOL_2D:
|
||||
case BuiltinOperator_ADD:
|
||||
@ -268,7 +277,6 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) {
|
||||
case BuiltinOperator_TANH:
|
||||
case BuiltinOperator_LOGISTIC:
|
||||
case BuiltinOperator_LOG_SOFTMAX:
|
||||
case BuiltinOperator_STRIDED_SLICE:
|
||||
case BuiltinOperator_TOPK_V2:
|
||||
case BuiltinOperator_ARG_MAX:
|
||||
case BuiltinOperator_ARG_MIN:
|
||||
|
Loading…
Reference in New Issue
Block a user