Add NNAPI delegate support for GREATER/GREATER_EQUAL/LESS/LESS_EQUAL/EQUAL/NOT_EQUAL.
PiperOrigin-RevId: 252106583
This commit is contained in:
parent
a2ab778f39
commit
17a2326611
@ -93,6 +93,12 @@ bool IsScalarInputSupported(int builtin_code) {
|
||||
case kTfLiteBuiltinMul:
|
||||
case kTfLiteBuiltinSub:
|
||||
case kTfLiteBuiltinDiv:
|
||||
case kTfLiteBuiltinEqual:
|
||||
case kTfLiteBuiltinNotEqual:
|
||||
case kTfLiteBuiltinGreater:
|
||||
case kTfLiteBuiltinGreaterEqual:
|
||||
case kTfLiteBuiltinLess:
|
||||
case kTfLiteBuiltinLessEqual:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
@ -1459,6 +1465,54 @@ class NNAPIDelegateKernel {
|
||||
return BasicMappingFn<ANEURALNETWORKS_LOGICAL_NOT>;
|
||||
}
|
||||
} break;
|
||||
case kTfLiteBuiltinLess: {
|
||||
const auto input_type = context->tensors[node->inputs->data[0]].type;
|
||||
if (version == 1 && android_sdk_version >= kMinSdkVersionForNNAPI12 &&
|
||||
(input_type == kTfLiteFloat32 || input_type == kTfLiteUInt8 ||
|
||||
input_type == kTfLiteBool || input_type == kTfLiteInt32)) {
|
||||
return BasicMappingFn<ANEURALNETWORKS_LESS>;
|
||||
}
|
||||
} break;
|
||||
case kTfLiteBuiltinLessEqual: {
|
||||
const auto input_type = context->tensors[node->inputs->data[0]].type;
|
||||
if (version == 1 && android_sdk_version >= kMinSdkVersionForNNAPI12 &&
|
||||
(input_type == kTfLiteFloat32 || input_type == kTfLiteUInt8 ||
|
||||
input_type == kTfLiteBool || input_type == kTfLiteInt32)) {
|
||||
return BasicMappingFn<ANEURALNETWORKS_LESS_EQUAL>;
|
||||
}
|
||||
} break;
|
||||
case kTfLiteBuiltinGreater: {
|
||||
const auto input_type = context->tensors[node->inputs->data[0]].type;
|
||||
if (version == 1 && android_sdk_version >= kMinSdkVersionForNNAPI12 &&
|
||||
(input_type == kTfLiteFloat32 || input_type == kTfLiteUInt8 ||
|
||||
input_type == kTfLiteBool || input_type == kTfLiteInt32)) {
|
||||
return BasicMappingFn<ANEURALNETWORKS_GREATER>;
|
||||
}
|
||||
} break;
|
||||
case kTfLiteBuiltinGreaterEqual: {
|
||||
const auto input_type = context->tensors[node->inputs->data[0]].type;
|
||||
if (version == 1 && android_sdk_version >= kMinSdkVersionForNNAPI12 &&
|
||||
(input_type == kTfLiteFloat32 || input_type == kTfLiteUInt8 ||
|
||||
input_type == kTfLiteBool || input_type == kTfLiteInt32)) {
|
||||
return BasicMappingFn<ANEURALNETWORKS_GREATER_EQUAL>;
|
||||
}
|
||||
} break;
|
||||
case kTfLiteBuiltinEqual: {
|
||||
const auto input_type = context->tensors[node->inputs->data[0]].type;
|
||||
if (version == 1 && android_sdk_version >= kMinSdkVersionForNNAPI12 &&
|
||||
(input_type == kTfLiteFloat32 || input_type == kTfLiteUInt8 ||
|
||||
input_type == kTfLiteBool || input_type == kTfLiteInt32)) {
|
||||
return BasicMappingFn<ANEURALNETWORKS_EQUAL>;
|
||||
}
|
||||
} break;
|
||||
case kTfLiteBuiltinNotEqual: {
|
||||
const auto input_type = context->tensors[node->inputs->data[0]].type;
|
||||
if (version == 1 && android_sdk_version >= kMinSdkVersionForNNAPI12 &&
|
||||
(input_type == kTfLiteFloat32 || input_type == kTfLiteUInt8 ||
|
||||
input_type == kTfLiteBool || input_type == kTfLiteInt32)) {
|
||||
return BasicMappingFn<ANEURALNETWORKS_NOT_EQUAL>;
|
||||
}
|
||||
} break;
|
||||
default:
|
||||
// All other operators are not mapped.
|
||||
return nullptr;
|
||||
|
@ -1268,6 +1268,7 @@ cc_test(
|
||||
srcs = [
|
||||
"comparisons_test.cc",
|
||||
],
|
||||
tags = ["tflite_nnapi"],
|
||||
deps = [
|
||||
":builtin_ops",
|
||||
":test_main",
|
||||
|
@ -89,13 +89,19 @@ enum {
|
||||
ANEURALNETWORKS_SUB = 36,
|
||||
ANEURALNETWORKS_TRANSPOSE = 37,
|
||||
ANEURALNETWORKS_ABS = 38,
|
||||
ANEURALNETWORKS_EQUAL = 48,
|
||||
ANEURALNETWORKS_EXP = 49,
|
||||
ANEURALNETWORKS_GREATER = 53,
|
||||
ANEURALNETWORKS_GREATER_EQUAL = 54,
|
||||
ANEURALNETWORKS_LESS = 58,
|
||||
ANEURALNETWORKS_LESS_EQUAL = 59,
|
||||
ANEURALNETWORKS_LOG = 60,
|
||||
ANEURALNETWORKS_LOGICAL_AND = 61,
|
||||
ANEURALNETWORKS_LOGICAL_NOT = 62,
|
||||
ANEURALNETWORKS_LOGICAL_OR = 63,
|
||||
ANEURALNETWORKS_MAXIMUM = 65,
|
||||
ANEURALNETWORKS_MINIMUM = 66,
|
||||
ANEURALNETWORKS_NOT_EQUAL = 68,
|
||||
ANEURALNETWORKS_PAD_V2 = 69,
|
||||
ANEURALNETWORKS_PRELU = 71,
|
||||
ANEURALNETWORKS_RSQRT = 83,
|
||||
|
Loading…
Reference in New Issue
Block a user