diff --git a/tensorflow/lite/delegates/nnapi/nnapi_delegate.cc b/tensorflow/lite/delegates/nnapi/nnapi_delegate.cc index 28efbd19a2b..30f754156a4 100644 --- a/tensorflow/lite/delegates/nnapi/nnapi_delegate.cc +++ b/tensorflow/lite/delegates/nnapi/nnapi_delegate.cc @@ -123,13 +123,12 @@ bool IsFloatOrUint8Operator(const TfLiteContext* context, // Check if the operation requires explict conversion from int8 to uint8 values. bool NeedInt8Conversion(const TfLiteContext* context, int builtin_code, const TfLiteNode* node) { + const int input_id = node->inputs->data[0]; + const TfLiteType input_type = context->tensors[input_id].type; switch (builtin_code) { case kTfLiteBuiltinConv2d: case kTfLiteBuiltinDepthwiseConv2d: - case kTfLiteBuiltinFullyConnected: - case kTfLiteBuiltinL2Normalization: { - const int input_id = node->inputs->data[0]; - const TfLiteType input_type = context->tensors[input_id].type; + case kTfLiteBuiltinFullyConnected: { if (input_type == kTfLiteInt8) { const int weights_id = node->inputs->data[1]; const auto& weights_tensor = context->tensors[weights_id]; @@ -141,6 +140,11 @@ bool NeedInt8Conversion(const TfLiteContext* context, int builtin_code, } return false; } + case kTfLiteBuiltinL2Normalization: + case kTfLiteBuiltinSub: + case kTfLiteBuiltinTanh: { + return input_type == kTfLiteInt8; + } default: return false; } @@ -1379,23 +1383,34 @@ class NNAPIDelegateKernel { break; case kTfLiteBuiltinTanh: // TODO(miaowang): add additional checks for the parameters. - if (version == 1 && - context->tensors[node->inputs->data[0]].type == kTfLiteFloat32) { - // NNAPI only support float tanh. - return BasicMappingFn; + if (version == 1) { + const TfLiteType input_type = + context->tensors[node->inputs->data[0]].type; + if (IsFloat(input_type) || + (IsQuantized(input_type) && + android_sdk_version >= kMinSdkVersionForNNAPI12)) { + // NNAPI only support float tanh. + return BasicMappingFn; + } } break; case kTfLiteBuiltinSub: - if (version == 1 && android_sdk_version >= kMinSdkVersionForNNAPI11 && - context->tensors[node->inputs->data[0]].type == kTfLiteFloat32) { - // NNAPI only support float sub. - return [](const NNAPIOpMappingArgs& mapping_args) - -> ANeuralNetworksOperationType { - auto builtin = reinterpret_cast( - mapping_args.node->builtin_data); - mapping_args.builder->AddScalarInt32Operand(builtin->activation); - return ANEURALNETWORKS_SUB; - }; + if (version == 1) { + const TfLiteType input_type = + context->tensors[node->inputs->data[0]].type; + if ((android_sdk_version >= kMinSdkVersionForNNAPI11 && + IsFloat(input_type)) || + (android_sdk_version >= kMinSdkVersionForNNAPI12 && + IsQuantized(input_type))) { + // NNAPI only support float sub. + return [](const NNAPIOpMappingArgs& mapping_args) + -> ANeuralNetworksOperationType { + auto builtin = reinterpret_cast( + mapping_args.node->builtin_data); + mapping_args.builder->AddScalarInt32Operand(builtin->activation); + return ANEURALNETWORKS_SUB; + }; + } } break; case kTfLiteBuiltinDiv: @@ -2355,7 +2370,8 @@ class NNAPIDelegateKernel { const auto input_index = node->inputs->data[input_pos]; if (need_int8_conversion && (input_pos == 0 || - reg->builtin_code == kTfLiteBuiltinFullyConnected)) { + reg->builtin_code == kTfLiteBuiltinFullyConnected || + reg->builtin_code == kTfLiteBuiltinSub)) { // Only selected inputs require int8 conversion. TF_LITE_ENSURE_STATUS(builder.AddTensorInput( input_index, hybrid_op,