Add guard to avoid acceleration of L2 Normalization with input rank != 4

PiperOrigin-RevId: 258114185
This commit is contained in:
A. Unique TensorFlower 2019-07-15 01:33:56 -07:00 committed by TensorFlower Gardener
parent 37eafe0e74
commit e48577945d
3 changed files with 10 additions and 15 deletions

View File

@ -1263,8 +1263,9 @@ class NNAPIDelegateKernel {
break;
case kTfLiteBuiltinL2Normalization: {
if (version == 1) {
const auto& input = context->tensors[node->inputs->data[0]];
if (android_sdk_version < kMinSdkVersionForNNAPI12 &&
!IsFloatOperator(context, node)) {
(!IsFloatOperator(context, node) || input.dims->size != 4)) {
return nullptr;
}
auto builtin =

View File

@ -1966,20 +1966,6 @@ TEST(NNAPIDelegate, Relu6) {
}));
}
TEST(NNAPIDelegate, Tanh) {
FloatActivationsOpModel m(BuiltinOperator_TANH,
/*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}});
m.SetInput({
0, -6, 2, 4, //
3, -2, 10, 1, //
});
m.Invoke();
EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({
0, -0.9999877, 0.9640275, 0.999329, //
0.99505475, -0.9640275, 1, 0.7615941, //
})));
}
TEST(NNAPIDelegate, LogisticFloat) {
FloatActivationsOpModel m(BuiltinOperator_LOGISTIC,
/*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}});

View File

@ -77,6 +77,14 @@ TEST(L2NormOpTest, SimpleFloatTest) {
ElementsAreArray({-0.55, 0.3, 0.35, 0.6, -0.35, 0.05}));
}
TEST(L2NormOpTest, SimpleFloatWithRankLessThanFourTest) {
L2NormOpModel m({1, 6}, TensorType_FLOAT32, ActivationFunctionType_NONE);
m.SetInput({-1.1, 0.6, 0.7, 1.2, -0.7, 0.1});
m.Invoke();
EXPECT_THAT(m.GetOutput<float>(),
ElementsAreArray({-0.55, 0.3, 0.35, 0.6, -0.35, 0.05}));
}
TEST(L2NormOpTest, MultipleBatchFloatTest) {
L2NormOpModel m({3, 1, 1, 6}, TensorType_FLOAT32,
ActivationFunctionType_NONE);