Add split support to the NNAPI delegate
PiperOrigin-RevId: 256448757
This commit is contained in:
parent
1a216e8946
commit
a18d0fc601
@ -1920,6 +1920,28 @@ class NNAPIDelegateKernel {
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
|
case kTfLiteBuiltinSplit: {
|
||||||
|
// Tensor indices: split_dim: 0, value: 1
|
||||||
|
const TfLiteTensor& axis = context->tensors[node->inputs->data[0]];
|
||||||
|
const TfLiteTensor& input = context->tensors[node->inputs->data[1]];
|
||||||
|
if (version == 1 && android_sdk_version >= kMinSdkVersionForNNAPI12 &&
|
||||||
|
(input.type == kTfLiteFloat32 || input.type == kTfLiteUInt8 ||
|
||||||
|
input.type == kTfLiteInt32) &&
|
||||||
|
(axis.type == kTfLiteInt32 &&
|
||||||
|
axis.allocation_type == kTfLiteMmapRo)) {
|
||||||
|
return [](const NNAPIOpMappingArgs& mapping_args)
|
||||||
|
-> ANeuralNetworksOperationType {
|
||||||
|
const TfLiteTensor& axis =
|
||||||
|
mapping_args.context
|
||||||
|
->tensors[mapping_args.node->inputs->data[0]];
|
||||||
|
auto builtin = reinterpret_cast<TfLiteSplitParams*>(
|
||||||
|
mapping_args.node->builtin_data);
|
||||||
|
mapping_args.builder->AddScalarInt32Operand(*axis.data.i32);
|
||||||
|
mapping_args.builder->AddScalarInt32Operand(builtin->num_splits);
|
||||||
|
return ANEURALNETWORKS_SPLIT;
|
||||||
|
};
|
||||||
|
}
|
||||||
|
} break;
|
||||||
default:
|
default:
|
||||||
// All other operators are not mapped.
|
// All other operators are not mapped.
|
||||||
return nullptr;
|
return nullptr;
|
||||||
@ -2335,6 +2357,14 @@ class NNAPIDelegateKernel {
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if ((reg->builtin_code == kTfLiteBuiltinSplit) &&
|
||||||
|
(input_index == node->inputs->data[0])) {
|
||||||
|
// Skip the axis input tensor; it will be added as a scalar operand
|
||||||
|
// by the Map() mapping.
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
// Pad and Padv2 have an optional parameter for a pad value which has
|
// Pad and Padv2 have an optional parameter for a pad value which has
|
||||||
// to be converted to a scalar type in NN API.
|
// to be converted to a scalar type in NN API.
|
||||||
if ((reg->builtin_code == kTfLiteBuiltinPadv2 ||
|
if ((reg->builtin_code == kTfLiteBuiltinPadv2 ||
|
||||||
|
@ -1231,6 +1231,7 @@ cc_test(
|
|||||||
name = "split_test",
|
name = "split_test",
|
||||||
size = "small",
|
size = "small",
|
||||||
srcs = ["split_test.cc"],
|
srcs = ["split_test.cc"],
|
||||||
|
tags = ["tflite_nnapi"],
|
||||||
deps = [
|
deps = [
|
||||||
":builtin_ops",
|
":builtin_ops",
|
||||||
":test_main",
|
":test_main",
|
||||||
|
@ -115,6 +115,7 @@ enum {
|
|||||||
ANEURALNETWORKS_RSQRT = 83,
|
ANEURALNETWORKS_RSQRT = 83,
|
||||||
ANEURALNETWORKS_SELECT = 84,
|
ANEURALNETWORKS_SELECT = 84,
|
||||||
ANEURALNETWORKS_SIN = 85,
|
ANEURALNETWORKS_SIN = 85,
|
||||||
|
ANEURALNETWORKS_SPLIT = 87,
|
||||||
ANEURALNETWORKS_SQRT = 88,
|
ANEURALNETWORKS_SQRT = 88,
|
||||||
ANEURALNETWORKS_TILE = 89,
|
ANEURALNETWORKS_TILE = 89,
|
||||||
ANEURALNETWORKS_TOPK_V2 = 90,
|
ANEURALNETWORKS_TOPK_V2 = 90,
|
||||||
|
Loading…
Reference in New Issue
Block a user