From f6c97840e2e87d02906b7cbbf808febedc50a027 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 24 Jul 2019 18:14:52 -0700 Subject: [PATCH] Add delegate support for BATCH_TO_SPACE_ND PiperOrigin-RevId: 259858930 --- .../lite/delegates/nnapi/nnapi_delegate.cc | 18 ++++++++++++++++++ tensorflow/lite/kernels/BUILD | 1 + 2 files changed, 19 insertions(+) diff --git a/tensorflow/lite/delegates/nnapi/nnapi_delegate.cc b/tensorflow/lite/delegates/nnapi/nnapi_delegate.cc index 837ae62f2bd..87c89dde4fc 100644 --- a/tensorflow/lite/delegates/nnapi/nnapi_delegate.cc +++ b/tensorflow/lite/delegates/nnapi/nnapi_delegate.cc @@ -141,6 +141,7 @@ bool NeedInt8Conversion(const TfLiteContext* context, int builtin_code, } return false; } + case kTfLiteBuiltinBatchToSpaceNd: case kTfLiteBuiltinL2Normalization: case kTfLiteBuiltinSub: case kTfLiteBuiltinTanh: @@ -1501,6 +1502,18 @@ class NNAPIDelegateKernel { return BasicMappingFn; } break; + case kTfLiteBuiltinBatchToSpaceNd: + if (version == 1 && android_sdk_version >= kMinSdkVersionForNNAPI11) { + auto crops = context->tensors[node->inputs->data[2]]; + auto crops_data = crops.data.i32; + // Check if all crops are 0. + if (!crops_data || crops.bytes != 16 || crops_data[0] != 0 || + crops_data[1] != 0 || crops_data[2] != 0 || crops_data[3] != 0) { + return nullptr; + } + return BasicMappingFn; + } + break; case kTfLiteBuiltinStridedSlice: if (version == 1 && android_sdk_version >= kMinSdkVersionForNNAPI11) { return [](const NNAPIOpMappingArgs& mapping_args) @@ -2636,6 +2649,11 @@ class NNAPIDelegateKernel { input_pos == 1) { // The axis param is added during Map continue; + } else if (reg->builtin_code == kTfLiteBuiltinBatchToSpaceNd && + input_pos == 2) { + // NNAPI does not support crops. + // The Map fucntion will check if all crops are zero. + continue; } else if (reg->builtin_code == kTfLiteBuiltinArgMin || reg->builtin_code == kTfLiteBuiltinArgMax) { // The first input tensor is added as is. The second one, specifying diff --git a/tensorflow/lite/kernels/BUILD b/tensorflow/lite/kernels/BUILD index 9afe0c8a4e6..bca715a8ce5 100644 --- a/tensorflow/lite/kernels/BUILD +++ b/tensorflow/lite/kernels/BUILD @@ -708,6 +708,7 @@ cc_test( name = "batch_to_space_nd_test", size = "small", srcs = ["batch_to_space_nd_test.cc"], + tags = ["tflite_nnapi"], deps = [ ":builtin_ops", ":test_main",