diff --git a/tensorflow/lite/delegates/nnapi/nnapi_delegate.cc b/tensorflow/lite/delegates/nnapi/nnapi_delegate.cc index 0cb4d9efcc5..012da4a1d9b 100644 --- a/tensorflow/lite/delegates/nnapi/nnapi_delegate.cc +++ b/tensorflow/lite/delegates/nnapi/nnapi_delegate.cc @@ -1799,7 +1799,7 @@ bool NNAPIDelegateKernel::Validate( " NNAPI only support float tanh.", &val_ctx); } break; case kTfLiteBuiltinSub: { - ExpectMaxOpVersion(version, 2, &val_ctx); + ExpectMaxOpVersion(version, 3, &val_ctx); const TfLiteType input_type = context->tensors[node->inputs->data[0]].type; Expect((android_sdk_version >= kMinSdkVersionForNNAPI11 && @@ -1808,6 +1808,13 @@ bool NNAPIDelegateKernel::Validate( IsQuantized(input_type)), NNAPIValidationFailureType::kUnsupportedInputType, "NNAPI only support float sub.", &val_ctx); + const int input0_rank = + context->tensors[node->inputs->data[0]].dims->size; + const int input1_rank = + context->tensors[node->inputs->data[1]].dims->size; + Expect(input0_rank <= 4 && input1_rank <= 4, + NNAPIValidationFailureType::kUnsupportedOperandRank, + "Input rank must be <= 4", &val_ctx); } break; case kTfLiteBuiltinDiv: { ExpectOpVersion(version, 1, &val_ctx); @@ -2327,7 +2334,7 @@ bool NNAPIDelegateKernel::Validate( "Unsupported operation type.", &val_ctx); } return val_ctx.is_valid; -} +} // NOLINT(readability/fn_size) TfLiteStatus NNAPIDelegateKernel::Map( TfLiteContext* context, int builtin_code, int version,