From e48577945d527ca07e61bc8e66428be64b62a540 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 15 Jul 2019 01:33:56 -0700 Subject: [PATCH] Add guard to avoid acceleration of L2 Normalization with input rank != 4 PiperOrigin-RevId: 258114185 --- tensorflow/lite/delegates/nnapi/nnapi_delegate.cc | 3 ++- .../lite/delegates/nnapi/nnapi_delegate_test.cc | 14 -------------- tensorflow/lite/kernels/l2norm_test.cc | 8 ++++++++ 3 files changed, 10 insertions(+), 15 deletions(-) diff --git a/tensorflow/lite/delegates/nnapi/nnapi_delegate.cc b/tensorflow/lite/delegates/nnapi/nnapi_delegate.cc index 2388b75efa5..609809ac68c 100644 --- a/tensorflow/lite/delegates/nnapi/nnapi_delegate.cc +++ b/tensorflow/lite/delegates/nnapi/nnapi_delegate.cc @@ -1263,8 +1263,9 @@ class NNAPIDelegateKernel { break; case kTfLiteBuiltinL2Normalization: { if (version == 1) { + const auto& input = context->tensors[node->inputs->data[0]]; if (android_sdk_version < kMinSdkVersionForNNAPI12 && - !IsFloatOperator(context, node)) { + (!IsFloatOperator(context, node) || input.dims->size != 4)) { return nullptr; } auto builtin = diff --git a/tensorflow/lite/delegates/nnapi/nnapi_delegate_test.cc b/tensorflow/lite/delegates/nnapi/nnapi_delegate_test.cc index 957deed274b..c8e9e00d86a 100644 --- a/tensorflow/lite/delegates/nnapi/nnapi_delegate_test.cc +++ b/tensorflow/lite/delegates/nnapi/nnapi_delegate_test.cc @@ -1966,20 +1966,6 @@ TEST(NNAPIDelegate, Relu6) { })); } -TEST(NNAPIDelegate, Tanh) { - FloatActivationsOpModel m(BuiltinOperator_TANH, - /*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}}); - m.SetInput({ - 0, -6, 2, 4, // - 3, -2, 10, 1, // - }); - m.Invoke(); - EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({ - 0, -0.9999877, 0.9640275, 0.999329, // - 0.99505475, -0.9640275, 1, 0.7615941, // - }))); -} - TEST(NNAPIDelegate, LogisticFloat) { FloatActivationsOpModel m(BuiltinOperator_LOGISTIC, /*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}}); diff --git a/tensorflow/lite/kernels/l2norm_test.cc b/tensorflow/lite/kernels/l2norm_test.cc index 19f1053db30..bd259e42f33 100644 --- a/tensorflow/lite/kernels/l2norm_test.cc +++ b/tensorflow/lite/kernels/l2norm_test.cc @@ -77,6 +77,14 @@ TEST(L2NormOpTest, SimpleFloatTest) { ElementsAreArray({-0.55, 0.3, 0.35, 0.6, -0.35, 0.05})); } +TEST(L2NormOpTest, SimpleFloatWithRankLessThanFourTest) { + L2NormOpModel m({1, 6}, TensorType_FLOAT32, ActivationFunctionType_NONE); + m.SetInput({-1.1, 0.6, 0.7, 1.2, -0.7, 0.1}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray({-0.55, 0.3, 0.35, 0.6, -0.35, 0.05})); +} + TEST(L2NormOpTest, MultipleBatchFloatTest) { L2NormOpModel m({3, 1, 1, 6}, TensorType_FLOAT32, ActivationFunctionType_NONE);