Update NNAPI delegate support for SUB, TANH
PiperOrigin-RevId: 257515288
This commit is contained in:
parent
08370f31ad
commit
daac7c7101
@ -123,13 +123,12 @@ bool IsFloatOrUint8Operator(const TfLiteContext* context,
|
||||
// Check if the operation requires explict conversion from int8 to uint8 values.
|
||||
bool NeedInt8Conversion(const TfLiteContext* context, int builtin_code,
|
||||
const TfLiteNode* node) {
|
||||
const int input_id = node->inputs->data[0];
|
||||
const TfLiteType input_type = context->tensors[input_id].type;
|
||||
switch (builtin_code) {
|
||||
case kTfLiteBuiltinConv2d:
|
||||
case kTfLiteBuiltinDepthwiseConv2d:
|
||||
case kTfLiteBuiltinFullyConnected:
|
||||
case kTfLiteBuiltinL2Normalization: {
|
||||
const int input_id = node->inputs->data[0];
|
||||
const TfLiteType input_type = context->tensors[input_id].type;
|
||||
case kTfLiteBuiltinFullyConnected: {
|
||||
if (input_type == kTfLiteInt8) {
|
||||
const int weights_id = node->inputs->data[1];
|
||||
const auto& weights_tensor = context->tensors[weights_id];
|
||||
@ -141,6 +140,11 @@ bool NeedInt8Conversion(const TfLiteContext* context, int builtin_code,
|
||||
}
|
||||
return false;
|
||||
}
|
||||
case kTfLiteBuiltinL2Normalization:
|
||||
case kTfLiteBuiltinSub:
|
||||
case kTfLiteBuiltinTanh: {
|
||||
return input_type == kTfLiteInt8;
|
||||
}
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
@ -1379,23 +1383,34 @@ class NNAPIDelegateKernel {
|
||||
break;
|
||||
case kTfLiteBuiltinTanh:
|
||||
// TODO(miaowang): add additional checks for the parameters.
|
||||
if (version == 1 &&
|
||||
context->tensors[node->inputs->data[0]].type == kTfLiteFloat32) {
|
||||
// NNAPI only support float tanh.
|
||||
return BasicMappingFn<ANEURALNETWORKS_TANH>;
|
||||
if (version == 1) {
|
||||
const TfLiteType input_type =
|
||||
context->tensors[node->inputs->data[0]].type;
|
||||
if (IsFloat(input_type) ||
|
||||
(IsQuantized(input_type) &&
|
||||
android_sdk_version >= kMinSdkVersionForNNAPI12)) {
|
||||
// NNAPI only support float tanh.
|
||||
return BasicMappingFn<ANEURALNETWORKS_TANH>;
|
||||
}
|
||||
}
|
||||
break;
|
||||
case kTfLiteBuiltinSub:
|
||||
if (version == 1 && android_sdk_version >= kMinSdkVersionForNNAPI11 &&
|
||||
context->tensors[node->inputs->data[0]].type == kTfLiteFloat32) {
|
||||
// NNAPI only support float sub.
|
||||
return [](const NNAPIOpMappingArgs& mapping_args)
|
||||
-> ANeuralNetworksOperationType {
|
||||
auto builtin = reinterpret_cast<TfLiteSubParams*>(
|
||||
mapping_args.node->builtin_data);
|
||||
mapping_args.builder->AddScalarInt32Operand(builtin->activation);
|
||||
return ANEURALNETWORKS_SUB;
|
||||
};
|
||||
if (version == 1) {
|
||||
const TfLiteType input_type =
|
||||
context->tensors[node->inputs->data[0]].type;
|
||||
if ((android_sdk_version >= kMinSdkVersionForNNAPI11 &&
|
||||
IsFloat(input_type)) ||
|
||||
(android_sdk_version >= kMinSdkVersionForNNAPI12 &&
|
||||
IsQuantized(input_type))) {
|
||||
// NNAPI only support float sub.
|
||||
return [](const NNAPIOpMappingArgs& mapping_args)
|
||||
-> ANeuralNetworksOperationType {
|
||||
auto builtin = reinterpret_cast<TfLiteSubParams*>(
|
||||
mapping_args.node->builtin_data);
|
||||
mapping_args.builder->AddScalarInt32Operand(builtin->activation);
|
||||
return ANEURALNETWORKS_SUB;
|
||||
};
|
||||
}
|
||||
}
|
||||
break;
|
||||
case kTfLiteBuiltinDiv:
|
||||
@ -2355,7 +2370,8 @@ class NNAPIDelegateKernel {
|
||||
const auto input_index = node->inputs->data[input_pos];
|
||||
if (need_int8_conversion &&
|
||||
(input_pos == 0 ||
|
||||
reg->builtin_code == kTfLiteBuiltinFullyConnected)) {
|
||||
reg->builtin_code == kTfLiteBuiltinFullyConnected ||
|
||||
reg->builtin_code == kTfLiteBuiltinSub)) {
|
||||
// Only selected inputs require int8 conversion.
|
||||
TF_LITE_ENSURE_STATUS(builder.AddTensorInput(
|
||||
input_index, hybrid_op,
|
||||
|
Loading…
x
Reference in New Issue
Block a user