Update NNAPI delegate support for SUB, TANH

PiperOrigin-RevId: 257515288
This commit is contained in:
A. Unique TensorFlower 2019-07-10 17:20:28 -07:00 committed by TensorFlower Gardener
parent 08370f31ad
commit daac7c7101

View File

@ -123,13 +123,12 @@ bool IsFloatOrUint8Operator(const TfLiteContext* context,
// Check if the operation requires explict conversion from int8 to uint8 values.
bool NeedInt8Conversion(const TfLiteContext* context, int builtin_code,
const TfLiteNode* node) {
const int input_id = node->inputs->data[0];
const TfLiteType input_type = context->tensors[input_id].type;
switch (builtin_code) {
case kTfLiteBuiltinConv2d:
case kTfLiteBuiltinDepthwiseConv2d:
case kTfLiteBuiltinFullyConnected:
case kTfLiteBuiltinL2Normalization: {
const int input_id = node->inputs->data[0];
const TfLiteType input_type = context->tensors[input_id].type;
case kTfLiteBuiltinFullyConnected: {
if (input_type == kTfLiteInt8) {
const int weights_id = node->inputs->data[1];
const auto& weights_tensor = context->tensors[weights_id];
@ -141,6 +140,11 @@ bool NeedInt8Conversion(const TfLiteContext* context, int builtin_code,
}
return false;
}
case kTfLiteBuiltinL2Normalization:
case kTfLiteBuiltinSub:
case kTfLiteBuiltinTanh: {
return input_type == kTfLiteInt8;
}
default:
return false;
}
@ -1379,23 +1383,34 @@ class NNAPIDelegateKernel {
break;
case kTfLiteBuiltinTanh:
// TODO(miaowang): add additional checks for the parameters.
if (version == 1 &&
context->tensors[node->inputs->data[0]].type == kTfLiteFloat32) {
// NNAPI only support float tanh.
return BasicMappingFn<ANEURALNETWORKS_TANH>;
if (version == 1) {
const TfLiteType input_type =
context->tensors[node->inputs->data[0]].type;
if (IsFloat(input_type) ||
(IsQuantized(input_type) &&
android_sdk_version >= kMinSdkVersionForNNAPI12)) {
// NNAPI only support float tanh.
return BasicMappingFn<ANEURALNETWORKS_TANH>;
}
}
break;
case kTfLiteBuiltinSub:
if (version == 1 && android_sdk_version >= kMinSdkVersionForNNAPI11 &&
context->tensors[node->inputs->data[0]].type == kTfLiteFloat32) {
// NNAPI only support float sub.
return [](const NNAPIOpMappingArgs& mapping_args)
-> ANeuralNetworksOperationType {
auto builtin = reinterpret_cast<TfLiteSubParams*>(
mapping_args.node->builtin_data);
mapping_args.builder->AddScalarInt32Operand(builtin->activation);
return ANEURALNETWORKS_SUB;
};
if (version == 1) {
const TfLiteType input_type =
context->tensors[node->inputs->data[0]].type;
if ((android_sdk_version >= kMinSdkVersionForNNAPI11 &&
IsFloat(input_type)) ||
(android_sdk_version >= kMinSdkVersionForNNAPI12 &&
IsQuantized(input_type))) {
// NNAPI only support float sub.
return [](const NNAPIOpMappingArgs& mapping_args)
-> ANeuralNetworksOperationType {
auto builtin = reinterpret_cast<TfLiteSubParams*>(
mapping_args.node->builtin_data);
mapping_args.builder->AddScalarInt32Operand(builtin->activation);
return ANEURALNETWORKS_SUB;
};
}
}
break;
case kTfLiteBuiltinDiv:
@ -2355,7 +2370,8 @@ class NNAPIDelegateKernel {
const auto input_index = node->inputs->data[input_pos];
if (need_int8_conversion &&
(input_pos == 0 ||
reg->builtin_code == kTfLiteBuiltinFullyConnected)) {
reg->builtin_code == kTfLiteBuiltinFullyConnected ||
reg->builtin_code == kTfLiteBuiltinSub)) {
// Only selected inputs require int8 conversion.
TF_LITE_ENSURE_STATUS(builder.AddTensorInput(
input_index, hybrid_op,