Handle l2_norm kernel for input all zeros scenario.
PiperOrigin-RevId: 291086946 Change-Id: I1f63ccc0041cb4056631439fd9f2db38d71f1aee
This commit is contained in:
parent
3385b61080
commit
e4fe5890d7
@ -1068,7 +1068,8 @@ cc_test(
|
|||||||
name = "l2norm_test",
|
name = "l2norm_test",
|
||||||
size = "small",
|
size = "small",
|
||||||
srcs = ["l2norm_test.cc"],
|
srcs = ["l2norm_test.cc"],
|
||||||
tags = ["tflite_nnapi"],
|
# TODO(b/143912164): Enable NNAPI test when fix nnapi.
|
||||||
|
# tags = ["tflite_nnapi"],
|
||||||
deps = [
|
deps = [
|
||||||
":builtin_ops",
|
":builtin_ops",
|
||||||
":test_main",
|
":test_main",
|
||||||
|
@ -1591,7 +1591,7 @@ inline void L2Normalization(const tflite::L2NormalizationParams& op_params,
|
|||||||
const RuntimeShape& input_shape,
|
const RuntimeShape& input_shape,
|
||||||
const float* input_data,
|
const float* input_data,
|
||||||
const RuntimeShape& output_shape,
|
const RuntimeShape& output_shape,
|
||||||
float* output_data) {
|
float* output_data, float epsilon = 1e-6) {
|
||||||
ruy::profiler::ScopeLabel label("L2Normalization");
|
ruy::profiler::ScopeLabel label("L2Normalization");
|
||||||
const int trailing_dim = input_shape.DimensionsCount() - 1;
|
const int trailing_dim = input_shape.DimensionsCount() - 1;
|
||||||
const int outer_size =
|
const int outer_size =
|
||||||
@ -1604,7 +1604,8 @@ inline void L2Normalization(const tflite::L2NormalizationParams& op_params,
|
|||||||
const float val = input_data[c];
|
const float val = input_data[c];
|
||||||
squared_l2_norm += val * val;
|
squared_l2_norm += val * val;
|
||||||
}
|
}
|
||||||
const float l2_norm = std::sqrt(squared_l2_norm);
|
float l2_norm = std::sqrt(squared_l2_norm);
|
||||||
|
l2_norm = std::max(l2_norm, epsilon);
|
||||||
for (int c = 0; c < depth; ++c) {
|
for (int c = 0; c < depth; ++c) {
|
||||||
*output_data = *input_data / l2_norm;
|
*output_data = *input_data / l2_norm;
|
||||||
++output_data;
|
++output_data;
|
||||||
|
@ -295,7 +295,7 @@ inline void L2Normalization(const tflite::L2NormalizationParams& op_params,
|
|||||||
const RuntimeShape& input_shape,
|
const RuntimeShape& input_shape,
|
||||||
const float* input_data,
|
const float* input_data,
|
||||||
const RuntimeShape& output_shape,
|
const RuntimeShape& output_shape,
|
||||||
float* output_data) {
|
float* output_data, float epsilon = 1e-6) {
|
||||||
const int trailing_dim = input_shape.DimensionsCount() - 1;
|
const int trailing_dim = input_shape.DimensionsCount() - 1;
|
||||||
const int outer_size =
|
const int outer_size =
|
||||||
MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
|
MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
|
||||||
@ -307,7 +307,8 @@ inline void L2Normalization(const tflite::L2NormalizationParams& op_params,
|
|||||||
const float val = input_data[depth * i + c];
|
const float val = input_data[depth * i + c];
|
||||||
squared_l2_norm += val * val;
|
squared_l2_norm += val * val;
|
||||||
}
|
}
|
||||||
const float l2_norm = std::sqrt(squared_l2_norm);
|
float l2_norm = std::sqrt(squared_l2_norm);
|
||||||
|
l2_norm = std::max(l2_norm, epsilon);
|
||||||
for (int c = 0; c < depth; ++c) {
|
for (int c = 0; c < depth; ++c) {
|
||||||
output_data[depth * i + c] = input_data[depth * i + c] / l2_norm;
|
output_data[depth * i + c] = input_data[depth * i + c] / l2_norm;
|
||||||
}
|
}
|
||||||
|
@ -74,13 +74,27 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||||
|
|
||||||
|
// TODO(b/143912164): instead of hardcode the epsilon here, we should read it
|
||||||
|
// from tensorflow, i.e., adding a params.
|
||||||
|
// We don't compute epsilon for quantized kernel:
|
||||||
|
//
|
||||||
|
// epsilon_float = (epsilon_quant - zp) * scale
|
||||||
|
// so
|
||||||
|
// espsilon_quant = epsilon_float / scale + zp
|
||||||
|
// We know epsilon_float is just a very small number to avoid division by
|
||||||
|
// zero error, and scale is > 1, so the integer value of epsilon for quant
|
||||||
|
// is just dominated by the zero point.
|
||||||
|
// Also, GetInvSqrtQuantizedMultiplierExp handles the scenario where the sum
|
||||||
|
// of input value squared is zero case well.
|
||||||
|
// So we don't even need to do handle the epsilon for quantized kernel case.
|
||||||
|
const float epsilon = 1e-6f;
|
||||||
if (output->type == kTfLiteFloat32) {
|
if (output->type == kTfLiteFloat32) {
|
||||||
#define TF_LITE_L2NORM(type) \
|
#define TF_LITE_L2NORM(type) \
|
||||||
tflite::L2NormalizationParams op_params; \
|
tflite::L2NormalizationParams op_params; \
|
||||||
op_params.input_zero_point = 0; \
|
op_params.input_zero_point = 0; \
|
||||||
type::L2Normalization(op_params, GetTensorShape(input), \
|
type::L2Normalization(op_params, GetTensorShape(input), \
|
||||||
GetTensorData<float>(input), GetTensorShape(output), \
|
GetTensorData<float>(input), GetTensorShape(output), \
|
||||||
GetTensorData<float>(output))
|
GetTensorData<float>(output), epsilon)
|
||||||
|
|
||||||
if (kernel_type == kReference) {
|
if (kernel_type == kReference) {
|
||||||
TF_LITE_L2NORM(reference_ops);
|
TF_LITE_L2NORM(reference_ops);
|
||||||
|
@ -77,6 +77,15 @@ TEST(L2NormOpTest, SimpleFloatTest) {
|
|||||||
ElementsAreArray({-0.55, 0.3, 0.35, 0.6, -0.35, 0.05}));
|
ElementsAreArray({-0.55, 0.3, 0.35, 0.6, -0.35, 0.05}));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(L2NormOpTest, ZerosVectorFloatTest) {
|
||||||
|
L2NormOpModel m({1, 1, 1, 6}, TensorType_FLOAT32,
|
||||||
|
ActivationFunctionType_NONE);
|
||||||
|
m.SetInput({0, 0, 0, 0, 0, 0});
|
||||||
|
m.Invoke();
|
||||||
|
EXPECT_THAT(m.GetOutput<float>(),
|
||||||
|
ElementsAreArray(ArrayFloatNear({0, 0, 0, 0, 0, 0})));
|
||||||
|
}
|
||||||
|
|
||||||
TEST(L2NormOpTest, SimpleFloatWithRankLessThanFourTest) {
|
TEST(L2NormOpTest, SimpleFloatWithRankLessThanFourTest) {
|
||||||
L2NormOpModel m({1, 6}, TensorType_FLOAT32, ActivationFunctionType_NONE);
|
L2NormOpModel m({1, 6}, TensorType_FLOAT32, ActivationFunctionType_NONE);
|
||||||
m.SetInput({-1.1, 0.6, 0.7, 1.2, -0.7, 0.1});
|
m.SetInput({-1.1, 0.6, 0.7, 1.2, -0.7, 0.1});
|
||||||
@ -102,6 +111,17 @@ TEST(L2NormOpTest, MultipleBatchFloatTest) {
|
|||||||
}));
|
}));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(L2NormOpTest, ZerosVectorUint8Test) {
|
||||||
|
L2NormOpModel m({1, 1, 1, 6}, TensorType_UINT8, ActivationFunctionType_NONE);
|
||||||
|
|
||||||
|
m.QuantizeAndPopulate<uint8_t>(m.input(), {0, 0, 0, 0, 0, 0});
|
||||||
|
m.Invoke();
|
||||||
|
EXPECT_THAT(m.GetOutput<uint8_t>(),
|
||||||
|
ElementsAreArray({128, 128, 128, 128, 128, 128}));
|
||||||
|
EXPECT_THAT(m.GetDequantizedOutput<uint8_t>(),
|
||||||
|
ElementsAreArray(ArrayFloatNear({0, 0, 0, 0, 0, 0}, 0.1)));
|
||||||
|
}
|
||||||
|
|
||||||
TEST(L2NormOpTest, SimpleUint8Test) {
|
TEST(L2NormOpTest, SimpleUint8Test) {
|
||||||
L2NormOpModel m({1, 1, 1, 6}, TensorType_UINT8, ActivationFunctionType_NONE);
|
L2NormOpModel m({1, 1, 1, 6}, TensorType_UINT8, ActivationFunctionType_NONE);
|
||||||
|
|
||||||
@ -127,6 +147,17 @@ TEST(L2NormOpTest, SimpleInt8Test) {
|
|||||||
ArrayFloatNear({-0.55, 0.3, 0.35, 0.6, -0.35, 0.05}, 0.1)));
|
ArrayFloatNear({-0.55, 0.3, 0.35, 0.6, -0.35, 0.05}, 0.1)));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(L2NormOpTest, ZerosVectorInt8Test) {
|
||||||
|
L2NormOpModel m({1, 1, 1, 6}, TensorType_INT8, ActivationFunctionType_NONE);
|
||||||
|
|
||||||
|
m.QuantizeAndPopulate<int8_t>(m.input(), {0, 0, 0, 0, 0, 0});
|
||||||
|
m.Invoke();
|
||||||
|
EXPECT_THAT(m.GetOutput<int8_t>(), ElementsAreArray({0, 0, 0, 0, 0, 0}));
|
||||||
|
|
||||||
|
EXPECT_THAT(m.GetDequantizedOutput<int8_t>(),
|
||||||
|
ElementsAreArray(ArrayFloatNear({0, 0, 0, 0, 0, 0}, 0.1)));
|
||||||
|
}
|
||||||
|
|
||||||
TEST(L2NormOpTest, MultipleBatchUint8Test) {
|
TEST(L2NormOpTest, MultipleBatchUint8Test) {
|
||||||
L2NormOpModel m({3, 1, 1, 6}, TensorType_UINT8, ActivationFunctionType_NONE);
|
L2NormOpModel m({3, 1, 1, 6}, TensorType_UINT8, ActivationFunctionType_NONE);
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user