diff --git a/tensorflow/lite/delegates/nnapi/nnapi_delegate.cc b/tensorflow/lite/delegates/nnapi/nnapi_delegate.cc index 07289a2d181..33ac862ed54 100644 --- a/tensorflow/lite/delegates/nnapi/nnapi_delegate.cc +++ b/tensorflow/lite/delegates/nnapi/nnapi_delegate.cc @@ -277,6 +277,7 @@ bool NeedInt8Conversion(const TfLiteContext* context, int builtin_code, case kTfLiteBuiltinGreaterEqual: case kTfLiteBuiltinHardSwish: case kTfLiteBuiltinL2Normalization: + case kTfLiteBuiltinLeakyRelu: case kTfLiteBuiltinLess: case kTfLiteBuiltinLessEqual: case kTfLiteBuiltinLogistic: @@ -288,6 +289,7 @@ bool NeedInt8Conversion(const TfLiteContext* context, int builtin_code, case kTfLiteBuiltinNotEqual: case kTfLiteBuiltinPad: case kTfLiteBuiltinPadv2: + case kTfLiteBuiltinPrelu: case kTfLiteBuiltinReduceMax: case kTfLiteBuiltinReduceMin: case kTfLiteBuiltinRelu: @@ -1874,8 +1876,14 @@ bool NNAPIDelegateKernel::Validate( } break; case kTfLiteBuiltinSoftmax: { ExpectOpVersion(version, 2, &val_ctx); - const auto& input = context->tensors[node->outputs->data[0]]; ExpectIsFloatOrQuant8Operator(context, node, &val_ctx); + const auto& output = context->tensors[node->outputs->data[0]]; + ExpectTypeIn(output.type, {kTfLiteFloat32, kTfLiteUInt8, kTfLiteInt8}, + NNAPIValidationFailureType::kUnsupportedOutputType, + "Output type should be one of kTfLiteFloat32, kTfLiteUInt8, " + "kTfLiteInt8.", + &val_ctx); + const auto& input = context->tensors[node->inputs->data[0]]; const int input_rank = input.dims->size; Expect(input_rank <= 4, NNAPIValidationFailureType::kUnsupportedOperandRank, @@ -2695,6 +2703,10 @@ bool NNAPIDelegateKernel::Validate( ExpectOpVersion(version, 1, &val_ctx); ExpectMinAndroidSdkVersion(android_sdk_version, kMinSdkVersionForNNAPI13, &val_ctx); + const auto input_type = context->tensors[node->inputs->data[0]].type; + Expect(input_type == kTfLiteFloat32, + NNAPIValidationFailureType::kUnsupportedInputType, + "NNAPI only supports floating point input.", &val_ctx); } break; case kTfLiteBuiltinFill: { ExpectOpVersion(version, 1, &val_ctx); @@ -3423,6 +3435,14 @@ TfLiteStatus NNAPIDelegateKernel::Map( mapping_args.builder->AddNewInputConstantTensor( ANEURALNETWORKS_TENSOR_FLOAT32, kTfLiteFloat32, alpha_tensor.dims, alpha_value, alpha_tensor.params, &new_tensor_index); + } else if (input_type == kTfLiteInt8 && + android_sdk_version >= kMinSdkVersionForNNAPI13) { + alpha_tensor.params.scale = builtin->alpha; + std::vector alpha_value = {1}; + mapping_args.builder->AddNewInputConstantTensor( + ANEURALNETWORKS_TENSOR_QUANT8_ASYMM_SIGNED, kTfLiteInt8, + alpha_tensor.dims, alpha_value, alpha_tensor.params, + &new_tensor_index); } else { alpha_tensor.params.scale = builtin->alpha; std::vector alpha_value = {1}; @@ -4321,8 +4341,10 @@ TfLiteStatus NNAPIDelegateKernel::AddOpsAndTensors( reg->builtin_code == kTfLiteBuiltinConcatenation || reg->builtin_code == kTfLiteBuiltinMaximum || reg->builtin_code == kTfLiteBuiltinMinimum || + reg->builtin_code == kTfLiteBuiltinLeakyRelu || reg->builtin_code == kTfLiteBuiltinLess || reg->builtin_code == kTfLiteBuiltinLessEqual || + reg->builtin_code == kTfLiteBuiltinPrelu || reg->builtin_code == kTfLiteBuiltinGreater || reg->builtin_code == kTfLiteBuiltinGreaterEqual || reg->builtin_code == kTfLiteBuiltinEqual ||