Add guard to avoid acceleration of L2 Normalization with input rank != 4
PiperOrigin-RevId: 258114185
This commit is contained in:
parent
37eafe0e74
commit
e48577945d
@ -1263,8 +1263,9 @@ class NNAPIDelegateKernel {
|
||||
break;
|
||||
case kTfLiteBuiltinL2Normalization: {
|
||||
if (version == 1) {
|
||||
const auto& input = context->tensors[node->inputs->data[0]];
|
||||
if (android_sdk_version < kMinSdkVersionForNNAPI12 &&
|
||||
!IsFloatOperator(context, node)) {
|
||||
(!IsFloatOperator(context, node) || input.dims->size != 4)) {
|
||||
return nullptr;
|
||||
}
|
||||
auto builtin =
|
||||
|
@ -1966,20 +1966,6 @@ TEST(NNAPIDelegate, Relu6) {
|
||||
}));
|
||||
}
|
||||
|
||||
TEST(NNAPIDelegate, Tanh) {
|
||||
FloatActivationsOpModel m(BuiltinOperator_TANH,
|
||||
/*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}});
|
||||
m.SetInput({
|
||||
0, -6, 2, 4, //
|
||||
3, -2, 10, 1, //
|
||||
});
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({
|
||||
0, -0.9999877, 0.9640275, 0.999329, //
|
||||
0.99505475, -0.9640275, 1, 0.7615941, //
|
||||
})));
|
||||
}
|
||||
|
||||
TEST(NNAPIDelegate, LogisticFloat) {
|
||||
FloatActivationsOpModel m(BuiltinOperator_LOGISTIC,
|
||||
/*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}});
|
||||
|
@ -77,6 +77,14 @@ TEST(L2NormOpTest, SimpleFloatTest) {
|
||||
ElementsAreArray({-0.55, 0.3, 0.35, 0.6, -0.35, 0.05}));
|
||||
}
|
||||
|
||||
TEST(L2NormOpTest, SimpleFloatWithRankLessThanFourTest) {
|
||||
L2NormOpModel m({1, 6}, TensorType_FLOAT32, ActivationFunctionType_NONE);
|
||||
m.SetInput({-1.1, 0.6, 0.7, 1.2, -0.7, 0.1});
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.GetOutput<float>(),
|
||||
ElementsAreArray({-0.55, 0.3, 0.35, 0.6, -0.35, 0.05}));
|
||||
}
|
||||
|
||||
TEST(L2NormOpTest, MultipleBatchFloatTest) {
|
||||
L2NormOpModel m({3, 1, 1, 6}, TensorType_FLOAT32,
|
||||
ActivationFunctionType_NONE);
|
||||
|
Loading…
Reference in New Issue
Block a user