Handle l2_norm kernel for input all zeros scenario.
PiperOrigin-RevId: 291086946 Change-Id: I1f63ccc0041cb4056631439fd9f2db38d71f1aee
This commit is contained in:
parent
3385b61080
commit
e4fe5890d7
@ -1068,7 +1068,8 @@ cc_test(
|
||||
name = "l2norm_test",
|
||||
size = "small",
|
||||
srcs = ["l2norm_test.cc"],
|
||||
tags = ["tflite_nnapi"],
|
||||
# TODO(b/143912164): Enable NNAPI test when fix nnapi.
|
||||
# tags = ["tflite_nnapi"],
|
||||
deps = [
|
||||
":builtin_ops",
|
||||
":test_main",
|
||||
|
@ -1591,7 +1591,7 @@ inline void L2Normalization(const tflite::L2NormalizationParams& op_params,
|
||||
const RuntimeShape& input_shape,
|
||||
const float* input_data,
|
||||
const RuntimeShape& output_shape,
|
||||
float* output_data) {
|
||||
float* output_data, float epsilon = 1e-6) {
|
||||
ruy::profiler::ScopeLabel label("L2Normalization");
|
||||
const int trailing_dim = input_shape.DimensionsCount() - 1;
|
||||
const int outer_size =
|
||||
@ -1604,7 +1604,8 @@ inline void L2Normalization(const tflite::L2NormalizationParams& op_params,
|
||||
const float val = input_data[c];
|
||||
squared_l2_norm += val * val;
|
||||
}
|
||||
const float l2_norm = std::sqrt(squared_l2_norm);
|
||||
float l2_norm = std::sqrt(squared_l2_norm);
|
||||
l2_norm = std::max(l2_norm, epsilon);
|
||||
for (int c = 0; c < depth; ++c) {
|
||||
*output_data = *input_data / l2_norm;
|
||||
++output_data;
|
||||
|
@ -295,7 +295,7 @@ inline void L2Normalization(const tflite::L2NormalizationParams& op_params,
|
||||
const RuntimeShape& input_shape,
|
||||
const float* input_data,
|
||||
const RuntimeShape& output_shape,
|
||||
float* output_data) {
|
||||
float* output_data, float epsilon = 1e-6) {
|
||||
const int trailing_dim = input_shape.DimensionsCount() - 1;
|
||||
const int outer_size =
|
||||
MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
|
||||
@ -307,7 +307,8 @@ inline void L2Normalization(const tflite::L2NormalizationParams& op_params,
|
||||
const float val = input_data[depth * i + c];
|
||||
squared_l2_norm += val * val;
|
||||
}
|
||||
const float l2_norm = std::sqrt(squared_l2_norm);
|
||||
float l2_norm = std::sqrt(squared_l2_norm);
|
||||
l2_norm = std::max(l2_norm, epsilon);
|
||||
for (int c = 0; c < depth; ++c) {
|
||||
output_data[depth * i + c] = input_data[depth * i + c] / l2_norm;
|
||||
}
|
||||
|
@ -74,13 +74,27 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
|
||||
// TODO(b/143912164): instead of hardcode the epsilon here, we should read it
|
||||
// from tensorflow, i.e., adding a params.
|
||||
// We don't compute epsilon for quantized kernel:
|
||||
//
|
||||
// epsilon_float = (epsilon_quant - zp) * scale
|
||||
// so
|
||||
// espsilon_quant = epsilon_float / scale + zp
|
||||
// We know epsilon_float is just a very small number to avoid division by
|
||||
// zero error, and scale is > 1, so the integer value of epsilon for quant
|
||||
// is just dominated by the zero point.
|
||||
// Also, GetInvSqrtQuantizedMultiplierExp handles the scenario where the sum
|
||||
// of input value squared is zero case well.
|
||||
// So we don't even need to do handle the epsilon for quantized kernel case.
|
||||
const float epsilon = 1e-6f;
|
||||
if (output->type == kTfLiteFloat32) {
|
||||
#define TF_LITE_L2NORM(type) \
|
||||
tflite::L2NormalizationParams op_params; \
|
||||
op_params.input_zero_point = 0; \
|
||||
type::L2Normalization(op_params, GetTensorShape(input), \
|
||||
GetTensorData<float>(input), GetTensorShape(output), \
|
||||
GetTensorData<float>(output))
|
||||
GetTensorData<float>(output), epsilon)
|
||||
|
||||
if (kernel_type == kReference) {
|
||||
TF_LITE_L2NORM(reference_ops);
|
||||
|
@ -77,6 +77,15 @@ TEST(L2NormOpTest, SimpleFloatTest) {
|
||||
ElementsAreArray({-0.55, 0.3, 0.35, 0.6, -0.35, 0.05}));
|
||||
}
|
||||
|
||||
TEST(L2NormOpTest, ZerosVectorFloatTest) {
|
||||
L2NormOpModel m({1, 1, 1, 6}, TensorType_FLOAT32,
|
||||
ActivationFunctionType_NONE);
|
||||
m.SetInput({0, 0, 0, 0, 0, 0});
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.GetOutput<float>(),
|
||||
ElementsAreArray(ArrayFloatNear({0, 0, 0, 0, 0, 0})));
|
||||
}
|
||||
|
||||
TEST(L2NormOpTest, SimpleFloatWithRankLessThanFourTest) {
|
||||
L2NormOpModel m({1, 6}, TensorType_FLOAT32, ActivationFunctionType_NONE);
|
||||
m.SetInput({-1.1, 0.6, 0.7, 1.2, -0.7, 0.1});
|
||||
@ -102,6 +111,17 @@ TEST(L2NormOpTest, MultipleBatchFloatTest) {
|
||||
}));
|
||||
}
|
||||
|
||||
TEST(L2NormOpTest, ZerosVectorUint8Test) {
|
||||
L2NormOpModel m({1, 1, 1, 6}, TensorType_UINT8, ActivationFunctionType_NONE);
|
||||
|
||||
m.QuantizeAndPopulate<uint8_t>(m.input(), {0, 0, 0, 0, 0, 0});
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.GetOutput<uint8_t>(),
|
||||
ElementsAreArray({128, 128, 128, 128, 128, 128}));
|
||||
EXPECT_THAT(m.GetDequantizedOutput<uint8_t>(),
|
||||
ElementsAreArray(ArrayFloatNear({0, 0, 0, 0, 0, 0}, 0.1)));
|
||||
}
|
||||
|
||||
TEST(L2NormOpTest, SimpleUint8Test) {
|
||||
L2NormOpModel m({1, 1, 1, 6}, TensorType_UINT8, ActivationFunctionType_NONE);
|
||||
|
||||
@ -127,6 +147,17 @@ TEST(L2NormOpTest, SimpleInt8Test) {
|
||||
ArrayFloatNear({-0.55, 0.3, 0.35, 0.6, -0.35, 0.05}, 0.1)));
|
||||
}
|
||||
|
||||
TEST(L2NormOpTest, ZerosVectorInt8Test) {
|
||||
L2NormOpModel m({1, 1, 1, 6}, TensorType_INT8, ActivationFunctionType_NONE);
|
||||
|
||||
m.QuantizeAndPopulate<int8_t>(m.input(), {0, 0, 0, 0, 0, 0});
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.GetOutput<int8_t>(), ElementsAreArray({0, 0, 0, 0, 0, 0}));
|
||||
|
||||
EXPECT_THAT(m.GetDequantizedOutput<int8_t>(),
|
||||
ElementsAreArray(ArrayFloatNear({0, 0, 0, 0, 0, 0}, 0.1)));
|
||||
}
|
||||
|
||||
TEST(L2NormOpTest, MultipleBatchUint8Test) {
|
||||
L2NormOpModel m({3, 1, 1, 6}, TensorType_UINT8, ActivationFunctionType_NONE);
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user