Add NNAPI delegate support for GREATER/GREATER_EQUAL/LESS/LESS_EQUAL/EQUAL/NOT_EQUAL.

PiperOrigin-RevId: 252106583
This commit is contained in:
Haoliang Zhang 2019-06-07 13:20:21 -07:00 committed by TensorFlower Gardener
parent a2ab778f39
commit 17a2326611
3 changed files with 61 additions and 0 deletions

View File

@ -93,6 +93,12 @@ bool IsScalarInputSupported(int builtin_code) {
case kTfLiteBuiltinMul:
case kTfLiteBuiltinSub:
case kTfLiteBuiltinDiv:
case kTfLiteBuiltinEqual:
case kTfLiteBuiltinNotEqual:
case kTfLiteBuiltinGreater:
case kTfLiteBuiltinGreaterEqual:
case kTfLiteBuiltinLess:
case kTfLiteBuiltinLessEqual:
return true;
default:
return false;
@ -1459,6 +1465,54 @@ class NNAPIDelegateKernel {
return BasicMappingFn<ANEURALNETWORKS_LOGICAL_NOT>;
}
} break;
case kTfLiteBuiltinLess: {
const auto input_type = context->tensors[node->inputs->data[0]].type;
if (version == 1 && android_sdk_version >= kMinSdkVersionForNNAPI12 &&
(input_type == kTfLiteFloat32 || input_type == kTfLiteUInt8 ||
input_type == kTfLiteBool || input_type == kTfLiteInt32)) {
return BasicMappingFn<ANEURALNETWORKS_LESS>;
}
} break;
case kTfLiteBuiltinLessEqual: {
const auto input_type = context->tensors[node->inputs->data[0]].type;
if (version == 1 && android_sdk_version >= kMinSdkVersionForNNAPI12 &&
(input_type == kTfLiteFloat32 || input_type == kTfLiteUInt8 ||
input_type == kTfLiteBool || input_type == kTfLiteInt32)) {
return BasicMappingFn<ANEURALNETWORKS_LESS_EQUAL>;
}
} break;
case kTfLiteBuiltinGreater: {
const auto input_type = context->tensors[node->inputs->data[0]].type;
if (version == 1 && android_sdk_version >= kMinSdkVersionForNNAPI12 &&
(input_type == kTfLiteFloat32 || input_type == kTfLiteUInt8 ||
input_type == kTfLiteBool || input_type == kTfLiteInt32)) {
return BasicMappingFn<ANEURALNETWORKS_GREATER>;
}
} break;
case kTfLiteBuiltinGreaterEqual: {
const auto input_type = context->tensors[node->inputs->data[0]].type;
if (version == 1 && android_sdk_version >= kMinSdkVersionForNNAPI12 &&
(input_type == kTfLiteFloat32 || input_type == kTfLiteUInt8 ||
input_type == kTfLiteBool || input_type == kTfLiteInt32)) {
return BasicMappingFn<ANEURALNETWORKS_GREATER_EQUAL>;
}
} break;
case kTfLiteBuiltinEqual: {
const auto input_type = context->tensors[node->inputs->data[0]].type;
if (version == 1 && android_sdk_version >= kMinSdkVersionForNNAPI12 &&
(input_type == kTfLiteFloat32 || input_type == kTfLiteUInt8 ||
input_type == kTfLiteBool || input_type == kTfLiteInt32)) {
return BasicMappingFn<ANEURALNETWORKS_EQUAL>;
}
} break;
case kTfLiteBuiltinNotEqual: {
const auto input_type = context->tensors[node->inputs->data[0]].type;
if (version == 1 && android_sdk_version >= kMinSdkVersionForNNAPI12 &&
(input_type == kTfLiteFloat32 || input_type == kTfLiteUInt8 ||
input_type == kTfLiteBool || input_type == kTfLiteInt32)) {
return BasicMappingFn<ANEURALNETWORKS_NOT_EQUAL>;
}
} break;
default:
// All other operators are not mapped.
return nullptr;

View File

@ -1268,6 +1268,7 @@ cc_test(
srcs = [
"comparisons_test.cc",
],
tags = ["tflite_nnapi"],
deps = [
":builtin_ops",
":test_main",

View File

@ -89,13 +89,19 @@ enum {
ANEURALNETWORKS_SUB = 36,
ANEURALNETWORKS_TRANSPOSE = 37,
ANEURALNETWORKS_ABS = 38,
ANEURALNETWORKS_EQUAL = 48,
ANEURALNETWORKS_EXP = 49,
ANEURALNETWORKS_GREATER = 53,
ANEURALNETWORKS_GREATER_EQUAL = 54,
ANEURALNETWORKS_LESS = 58,
ANEURALNETWORKS_LESS_EQUAL = 59,
ANEURALNETWORKS_LOG = 60,
ANEURALNETWORKS_LOGICAL_AND = 61,
ANEURALNETWORKS_LOGICAL_NOT = 62,
ANEURALNETWORKS_LOGICAL_OR = 63,
ANEURALNETWORKS_MAXIMUM = 65,
ANEURALNETWORKS_MINIMUM = 66,
ANEURALNETWORKS_NOT_EQUAL = 68,
ANEURALNETWORKS_PAD_V2 = 69,
ANEURALNETWORKS_PRELU = 71,
ANEURALNETWORKS_RSQRT = 83,