Add rank check to Sub op delegation to NNAPI

PiperOrigin-RevId: 307821863
Change-Id: Ib98448d67e9948576e6c9fb43a98d364ab434e37
This commit is contained in:
Lev Proleev 2020-04-22 08:28:48 -07:00 committed by TensorFlower Gardener
parent 254cf1cb8e
commit 0e3574d39c

View File

@ -1799,7 +1799,7 @@ bool NNAPIDelegateKernel::Validate(
" NNAPI only support float tanh.", &val_ctx); " NNAPI only support float tanh.", &val_ctx);
} break; } break;
case kTfLiteBuiltinSub: { case kTfLiteBuiltinSub: {
ExpectMaxOpVersion(version, 2, &val_ctx); ExpectMaxOpVersion(version, 3, &val_ctx);
const TfLiteType input_type = const TfLiteType input_type =
context->tensors[node->inputs->data[0]].type; context->tensors[node->inputs->data[0]].type;
Expect((android_sdk_version >= kMinSdkVersionForNNAPI11 && Expect((android_sdk_version >= kMinSdkVersionForNNAPI11 &&
@ -1808,6 +1808,13 @@ bool NNAPIDelegateKernel::Validate(
IsQuantized(input_type)), IsQuantized(input_type)),
NNAPIValidationFailureType::kUnsupportedInputType, NNAPIValidationFailureType::kUnsupportedInputType,
"NNAPI only support float sub.", &val_ctx); "NNAPI only support float sub.", &val_ctx);
const int input0_rank =
context->tensors[node->inputs->data[0]].dims->size;
const int input1_rank =
context->tensors[node->inputs->data[1]].dims->size;
Expect(input0_rank <= 4 && input1_rank <= 4,
NNAPIValidationFailureType::kUnsupportedOperandRank,
"Input rank must be <= 4", &val_ctx);
} break; } break;
case kTfLiteBuiltinDiv: { case kTfLiteBuiltinDiv: {
ExpectOpVersion(version, 1, &val_ctx); ExpectOpVersion(version, 1, &val_ctx);
@ -2327,7 +2334,7 @@ bool NNAPIDelegateKernel::Validate(
"Unsupported operation type.", &val_ctx); "Unsupported operation type.", &val_ctx);
} }
return val_ctx.is_valid; return val_ctx.is_valid;
} } // NOLINT(readability/fn_size)
TfLiteStatus NNAPIDelegateKernel::Map( TfLiteStatus NNAPIDelegateKernel::Map(
TfLiteContext* context, int builtin_code, int version, TfLiteContext* context, int builtin_code, int version,