Support bool input type for StridedSlice.

PiperOrigin-RevId: 274651604
This commit is contained in:
Haoliang Zhang 2019-10-14 13:47:26 -07:00 committed by TensorFlower Gardener
parent 936e6cde7c
commit 1ad2fcb66a
6 changed files with 19 additions and 5 deletions

View File

@ -2543,7 +2543,7 @@ def TFL_StridedSliceOp: TFL_Op<"strided_slice",
}];
let arguments = (ins
TensorOf<[F32, I32, I64, I8, QI8, QUI8]>:$input,
TensorOf<[F32, I32, I64, I8, QI8, QUI8, I1]>:$input,
TensorOf<[I32]>:$begin,
TensorOf<[I32]>:$end,
TensorOf<[I32]>:$strides,
@ -2556,7 +2556,7 @@ def TFL_StridedSliceOp: TFL_Op<"strided_slice",
);
let results = (outs
TensorOf<[F32, I32, I64, I8, QI8, QUI8]>:$output
TensorOf<[F32, I32, I64, I8, QI8, QUI8, I1]>:$output
);
let hasOptions = 1;

View File

@ -155,7 +155,7 @@ BuiltinOpResolver::BuiltinOpResolver() {
AddBuiltin(BuiltinOperator_SQUEEZE, Register_SQUEEZE());
AddBuiltin(BuiltinOperator_STRIDED_SLICE, Register_STRIDED_SLICE(),
/* min_version */ 1,
/* max_version */ 2);
/* max_version */ 3);
AddBuiltin(BuiltinOperator_EXP, Register_EXP());
AddBuiltin(BuiltinOperator_TOPK_V2, Register_TOPK_V2(),
/* min_version */ 1,

View File

@ -207,6 +207,11 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_STRIDED_SLICE(reference_ops, int8_t);
}
break;
case kTfLiteBool:
if (kernel_type == kReference) {
TF_LITE_STRIDED_SLICE(reference_ops, bool);
}
break;
default:
context->ReportError(context,
"Type %d is currently not supported "

View File

@ -3232,7 +3232,7 @@ def make_strided_slice_tests(options):
test_parameters = [
# 4-D (basic cases with const/non-const indices).
{
"dtype": [tf.float32, tf.int32, tf.int64],
"dtype": [tf.float32, tf.int32, tf.int64, tf.bool],
"index_type": [tf.int32],
"input_shape": [[12, 2, 2, 5]],
"strides": [None, [2, 1, 3, 1]],

View File

@ -145,6 +145,7 @@ string GetMinimumRuntimeVersionForModel(const Model& model) {
{{OperatorType::kSplitV, 1}, "1.13.1"},
{{OperatorType::kStridedSlice, 1}, "1.6.0"},
{{OperatorType::kStridedSlice, 2}, "1.14.0"},
{{OperatorType::kStridedSlice, 3}, kPendingReleaseOpVersion},
{{OperatorType::kTopK_V2, 1}, "1.7.0"},
{{OperatorType::kTopK_V2, 2}, "1.14.0"},
{{OperatorType::kArgMax, 1}, "1.9.0"},

View File

@ -243,6 +243,15 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) {
return 2;
}
return 1;
case BuiltinOperator_STRIDED_SLICE:
// If the op takes bool input, it is version 3.
if (op_sig.input_types.at(0) == TensorType_BOOL) {
return 3;
}
if (op_sig.input_types.at(0) == TensorType_INT8) {
return 2;
}
return 1;
case BuiltinOperator_AVERAGE_POOL_2D:
case BuiltinOperator_ADD:
@ -268,7 +277,6 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) {
case BuiltinOperator_TANH:
case BuiltinOperator_LOGISTIC:
case BuiltinOperator_LOG_SOFTMAX:
case BuiltinOperator_STRIDED_SLICE:
case BuiltinOperator_TOPK_V2:
case BuiltinOperator_ARG_MAX:
case BuiltinOperator_ARG_MIN: