add L2 Normalization op to micro
PiperOrigin-RevId: 307643391 Change-Id: Ib6497a74e1199d53a82c73d56346706dfbb6bcbd
This commit is contained in:
parent
d515127fdd
commit
e2bb4b2acd
@ -15,6 +15,7 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_L2NORMALIZATION_H_
|
||||
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_L2NORMALIZATION_H_
|
||||
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
@ -76,7 +77,9 @@ inline void L2Normalization(const tflite::L2NormalizationParams& op_params,
|
||||
int32 rescaled_diff = MultiplyByQuantizedMultiplierSmallerThanOneExp(
|
||||
128 * diff, inv_l2norm_multiplier, inv_l2norm_shift);
|
||||
int32 unclamped_output_val = 128 + rescaled_diff;
|
||||
int32 output_val = std::min(255, std::max(0, unclamped_output_val));
|
||||
int32 output_val =
|
||||
std::min(static_cast<int32>(255),
|
||||
std::max(static_cast<int32>(0), unclamped_output_val));
|
||||
output_data[depth * i + c] = static_cast<uint8>(output_val);
|
||||
}
|
||||
}
|
||||
|
@ -34,6 +34,7 @@ cc_library(
|
||||
"dequantize.cc",
|
||||
"elementwise.cc",
|
||||
"floor.cc",
|
||||
"l2norm.cc",
|
||||
"logical.cc",
|
||||
"logistic.cc",
|
||||
"maximum_minimum.cc",
|
||||
@ -132,6 +133,7 @@ cc_library(
|
||||
"elementwise.cc",
|
||||
"floor.cc",
|
||||
"fully_connected.cc",
|
||||
"l2norm.cc",
|
||||
"logical.cc",
|
||||
"logistic.cc",
|
||||
"maximum_minimum.cc",
|
||||
@ -669,3 +671,16 @@ tflite_micro_cc_test(
|
||||
"//tensorflow/lite/micro/testing:micro_test",
|
||||
],
|
||||
)
|
||||
|
||||
tflite_micro_cc_test(
|
||||
name = "l2norm_test",
|
||||
srcs = [
|
||||
"l2norm_test.cc",
|
||||
],
|
||||
deps = [
|
||||
":all_ops_resolver",
|
||||
":micro_ops",
|
||||
"//tensorflow/lite/c:common",
|
||||
"//tensorflow/lite/micro/testing:micro_test",
|
||||
],
|
||||
)
|
||||
|
@ -75,6 +75,7 @@ AllOpsResolver::AllOpsResolver() {
|
||||
Register_RESIZE_NEAREST_NEIGHBOR(),
|
||||
/* min_version = */ 1,
|
||||
/* max_version = */ 2);
|
||||
AddBuiltin(BuiltinOperator_L2_NORMALIZATION, Register_L2_NORMALIZATION());
|
||||
}
|
||||
|
||||
} // namespace micro
|
||||
|
150
tensorflow/lite/micro/kernels/l2norm.cc
Normal file
150
tensorflow/lite/micro/kernels/l2norm.cc
Normal file
@ -0,0 +1,150 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/kernels/internal/reference/integer_ops/l2normalization.h"
|
||||
#include "tensorflow/lite/kernels/internal/reference/l2normalization.h"
|
||||
#include "tensorflow/lite/kernels/internal/tensor.h"
|
||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace ops {
|
||||
namespace micro {
|
||||
namespace l2norm {
|
||||
|
||||
// This file has two implementation of L2Norm.
|
||||
enum KernelType {
|
||||
kReference,
|
||||
kGenericOptimized,
|
||||
};
|
||||
|
||||
constexpr int kInputTensor = 0;
|
||||
constexpr int kOutputTensor = 0;
|
||||
|
||||
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
#if defined(DEBUG)
|
||||
auto* params = reinterpret_cast<TfLiteL2NormParams*>(node->builtin_data);
|
||||
|
||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
|
||||
TF_LITE_ENSURE(context, NumDimensions(input) <= 4);
|
||||
|
||||
TF_LITE_ENSURE(context, output->type == kTfLiteFloat32 ||
|
||||
output->type == kTfLiteUInt8 ||
|
||||
output->type == kTfLiteInt8);
|
||||
TF_LITE_ENSURE_EQ(context, input->type, output->type);
|
||||
|
||||
if (output->type == kTfLiteUInt8 || output->type == kTfLiteInt8) {
|
||||
TF_LITE_ENSURE_EQ(context, output->params.scale, (1. / 128.));
|
||||
if (output->type == kTfLiteUInt8) {
|
||||
TF_LITE_ENSURE_EQ(context, output->params.zero_point, 128);
|
||||
}
|
||||
if (output->type == kTfLiteInt8) {
|
||||
TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0);
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(ahentz): For some reason our implementations don't support
|
||||
// activations.
|
||||
TF_LITE_ENSURE_EQ(context, params->activation, kTfLiteActNone);
|
||||
#endif
|
||||
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
|
||||
// TODO(b/143912164): instead of hardcode the epsilon here, we should read it
|
||||
// from tensorflow, i.e., adding a params.
|
||||
// We don't compute epsilon for quantized kernel:
|
||||
//
|
||||
// epsilon_float = (epsilon_quant - zp) * scale
|
||||
// so
|
||||
// espsilon_quant = epsilon_float / scale + zp
|
||||
// We know epsilon_float is just a very small number to avoid division by
|
||||
// zero error, and scale is > 1, so the integer value of epsilon for quant
|
||||
// is just dominated by the zero point.
|
||||
// Also, GetInvSqrtQuantizedMultiplierExp handles the scenario where the sum
|
||||
// of input value squared is zero case well.
|
||||
// So we don't even need to do handle the epsilon for quantized kernel case.
|
||||
const float epsilon = 1e-6f;
|
||||
if (output->type == kTfLiteFloat32) {
|
||||
#define TF_LITE_L2NORM(type) \
|
||||
tflite::L2NormalizationParams op_params; \
|
||||
op_params.input_zero_point = 0; \
|
||||
type::L2Normalization(op_params, GetTensorShape(input), \
|
||||
GetTensorData<float>(input), GetTensorShape(output), \
|
||||
GetTensorData<float>(output), epsilon)
|
||||
|
||||
TF_LITE_L2NORM(reference_ops);
|
||||
#undef TF_LITE_L2NORM
|
||||
} else if (output->type == kTfLiteUInt8) {
|
||||
#define TF_LITE_L2NORM(type) \
|
||||
tflite::L2NormalizationParams op_params; \
|
||||
op_params.input_zero_point = input->params.zero_point; \
|
||||
type::L2Normalization(op_params, GetTensorShape(input), \
|
||||
GetTensorData<uint8>(input), GetTensorShape(output), \
|
||||
GetTensorData<uint8>(output))
|
||||
|
||||
TF_LITE_L2NORM(reference_ops);
|
||||
#undef TF_LITE_L2NORM
|
||||
} else if (output->type == kTfLiteInt8) {
|
||||
const auto input_shape = GetTensorShape(input);
|
||||
const auto output_shape = GetTensorShape(output);
|
||||
const int trailing_dim = input_shape.DimensionsCount() - 1;
|
||||
const int depth =
|
||||
MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
|
||||
const int outer_size =
|
||||
MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
|
||||
reference_integer_ops::L2Normalization(input->params.zero_point, outer_size,
|
||||
depth, GetTensorData<int8>(input),
|
||||
GetTensorData<int8>(output));
|
||||
} else {
|
||||
TF_LITE_KERNEL_LOG(context, "Output type is %d, requires float.",
|
||||
output->type);
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
} // namespace l2norm
|
||||
|
||||
TfLiteRegistration* Register_L2NORM_REF() {
|
||||
static TfLiteRegistration r = {/*init=*/nullptr,
|
||||
/*free=*/nullptr,
|
||||
/*prepare=*/l2norm::Prepare,
|
||||
/*invoke=*/l2norm::Eval,
|
||||
/*profiling_string=*/nullptr,
|
||||
/*builtin_code=*/0,
|
||||
/*custom_name=*/nullptr,
|
||||
/*version=*/0};
|
||||
|
||||
return &r;
|
||||
}
|
||||
|
||||
TfLiteRegistration* Register_L2_NORMALIZATION() {
|
||||
return Register_L2NORM_REF();
|
||||
}
|
||||
|
||||
} // namespace micro
|
||||
} // namespace ops
|
||||
} // namespace tflite
|
332
tensorflow/lite/micro/kernels/l2norm_test.cc
Normal file
332
tensorflow/lite/micro/kernels/l2norm_test.cc
Normal file
@ -0,0 +1,332 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/micro/kernels/all_ops_resolver.h"
|
||||
#include "tensorflow/lite/micro/testing/micro_test.h"
|
||||
#include "tensorflow/lite/micro/testing/test_utils.h"
|
||||
|
||||
|
||||
namespace tflite {
|
||||
namespace testing {
|
||||
namespace {
|
||||
|
||||
// used to set the quantization parameters for the int8 and uint8 tests
|
||||
constexpr float kInputMin = -2.0;
|
||||
constexpr float kInputMax = 2.0;
|
||||
constexpr float kOutputMin = -1.0;
|
||||
constexpr float kOutputMax = 127.0 / 128.0;
|
||||
|
||||
|
||||
void QuantizeInputData(const float input_data[], int length,
|
||||
uint8_t* quantized_data) {
|
||||
for (int i=0; i < 6; i++) {
|
||||
quantized_data[i] = tflite::testing::F2Q(input_data[i],
|
||||
tflite::testing::kInputMin,
|
||||
tflite::testing::kInputMax);
|
||||
}
|
||||
}
|
||||
|
||||
void QuantizeInputData(const float input_data[], int length,
|
||||
int8_t* quantized_data) {
|
||||
for (int i=0; i < 6; i++) {
|
||||
quantized_data[i] = tflite::testing::F2QS(input_data[i],
|
||||
tflite::testing::kInputMin,
|
||||
tflite::testing::kInputMax);
|
||||
}
|
||||
}
|
||||
|
||||
TfLiteTensor CreateL2NormTensor(const float* data, TfLiteIntArray* dims,
|
||||
const char* name, bool is_input) {
|
||||
return CreateFloatTensor(data, dims, name);
|
||||
}
|
||||
|
||||
TfLiteTensor CreateL2NormTensor(const uint8* data, TfLiteIntArray* dims,
|
||||
const char* name, bool is_input) {
|
||||
TfLiteTensor tensor;
|
||||
|
||||
if (is_input) {
|
||||
tensor = CreateQuantizedTensor(data, dims, name, kInputMin, kInputMax);
|
||||
} else {
|
||||
tensor = CreateQuantizedTensor(data, dims, name, kOutputMin, kOutputMax);
|
||||
}
|
||||
|
||||
tensor.quantization.type = kTfLiteAffineQuantization;
|
||||
return tensor;
|
||||
}
|
||||
|
||||
TfLiteTensor CreateL2NormTensor(const int8* data, TfLiteIntArray* dims,
|
||||
const char* name, bool is_input) {
|
||||
TfLiteTensor tensor;
|
||||
|
||||
if (is_input) {
|
||||
tensor = CreateQuantizedTensor(data, dims, name, kInputMin, kInputMax);
|
||||
} else {
|
||||
tensor = CreateQuantizedTensor(data, dims, name, kOutputMin, kOutputMax);
|
||||
}
|
||||
|
||||
tensor.quantization.type = kTfLiteAffineQuantization;
|
||||
return tensor;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline float Dequantize(const T data, float scale, int32_t zero_point) {
|
||||
return scale * (data - zero_point);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void TestL2Normalization(const int* input_dims_data,
|
||||
const T* input_data,
|
||||
const float* expected_output_data,
|
||||
T* output_data, float variance) {
|
||||
TfLiteIntArray* dims = IntArrayFromInts(input_dims_data);
|
||||
|
||||
const int output_dims_count = ElementCount(*dims);
|
||||
|
||||
constexpr int tensors_size = 2;
|
||||
TfLiteTensor tensors[tensors_size] = {
|
||||
CreateL2NormTensor(input_data, dims, "input_tensor", true),
|
||||
CreateL2NormTensor(output_data, dims, "output_tensor", false),
|
||||
};
|
||||
|
||||
TfLiteContext context;
|
||||
PopulateContext(tensors, tensors_size, micro_test::reporter, &context);
|
||||
::tflite::ops::micro::AllOpsResolver resolver;
|
||||
const TfLiteRegistration* registration =
|
||||
resolver.FindOp(tflite::BuiltinOperator_L2_NORMALIZATION, 1);
|
||||
TF_LITE_MICRO_EXPECT_NE(nullptr, registration);
|
||||
|
||||
TfLiteL2NormParams builtin_data = {
|
||||
.activation = kTfLiteActNone,
|
||||
};
|
||||
|
||||
int inputs_array_data[] = {1, 0};
|
||||
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
|
||||
int outputs_array_data[] = {1, 1};
|
||||
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
|
||||
int temporaries_array_data[] = {0};
|
||||
TfLiteIntArray* temporaries_array = IntArrayFromInts(temporaries_array_data);
|
||||
|
||||
TfLiteNode node;
|
||||
node.inputs = inputs_array;
|
||||
node.outputs = outputs_array;
|
||||
node.temporaries = temporaries_array;
|
||||
node.user_data = nullptr;
|
||||
node.builtin_data = reinterpret_cast<void*>(&builtin_data);
|
||||
node.custom_initial_data = nullptr;
|
||||
node.custom_initial_data_size = 0;
|
||||
node.delegate = nullptr;
|
||||
TF_LITE_MICRO_EXPECT_NE(nullptr, registration->invoke);
|
||||
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->invoke(&context, &node));
|
||||
|
||||
// Compare the results from dequantization and expected outputs, and make
|
||||
// sure the difference is within a threshold.
|
||||
if (tensors[1].quantization.type != kTfLiteNoQuantization) {
|
||||
TfLiteTensor* output_tensor = &tensors[1];
|
||||
int32_t zero_point = output_tensor->params.zero_point;
|
||||
float scale = output_tensor->params.scale;
|
||||
|
||||
for (int i = 0; i < output_dims_count; ++i) {
|
||||
float output_val = Dequantize(output_data[i], scale, zero_point);
|
||||
|
||||
TF_LITE_MICRO_EXPECT_LE(expected_output_data[i] - variance, output_val);
|
||||
TF_LITE_MICRO_EXPECT_GE(expected_output_data[i] + variance, output_val);
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < output_dims_count; ++i) {
|
||||
float output_val = static_cast<float>(output_data[i]);
|
||||
TF_LITE_MICRO_EXPECT_LE(expected_output_data[i] - variance, output_val);
|
||||
TF_LITE_MICRO_EXPECT_GE(expected_output_data[i] + variance, output_val);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace testing
|
||||
} // namespace tflite
|
||||
|
||||
|
||||
TF_LITE_MICRO_TESTS_BEGIN
|
||||
|
||||
TF_LITE_MICRO_TEST(SimpleFloatTest) {
|
||||
const int input_dims[] = {4, 1, 1, 1, 6};
|
||||
constexpr int data_length = 6;
|
||||
const float input_data[data_length] = {
|
||||
-1.1, 0.6, 0.7, 1.2, -0.7, 0.1
|
||||
};
|
||||
const float expected_output_data[data_length] = {
|
||||
-0.55, 0.3, 0.35, 0.6, -0.35, 0.05
|
||||
};
|
||||
float output_data[data_length];
|
||||
|
||||
tflite::testing::TestL2Normalization<float>(input_dims, input_data,
|
||||
expected_output_data, output_data, 0);
|
||||
}
|
||||
|
||||
TF_LITE_MICRO_TEST(ZerosVectorFloatTest) {
|
||||
const int input_dims[] = {4, 1, 1, 1, 6};
|
||||
constexpr int data_length = 6;
|
||||
const float input_data[data_length] = {0, 0, 0, 0, 0, 0};
|
||||
const float expected_output_data[data_length] = {0, 0, 0, 0, 0, 0};
|
||||
float output_data[data_length];
|
||||
|
||||
tflite::testing::TestL2Normalization<float>(input_dims, input_data,
|
||||
expected_output_data, output_data, 0);
|
||||
}
|
||||
|
||||
TF_LITE_MICRO_TEST(SimpleFloatWithRankLessThanFourTest) {
|
||||
const int input_dims[] = {4, 1, 1, 1, 6};
|
||||
constexpr int data_length = 6;
|
||||
const float input_data[data_length] = {
|
||||
-1.1, 0.6, 0.7, 1.2, -0.7, 0.1
|
||||
};
|
||||
const float expected_output_data[data_length] = {
|
||||
-0.55, 0.3, 0.35, 0.6, -0.35, 0.05
|
||||
};
|
||||
float output_data[data_length];
|
||||
|
||||
tflite::testing::TestL2Normalization<float>(input_dims, input_data,
|
||||
expected_output_data, output_data, 0);
|
||||
}
|
||||
|
||||
TF_LITE_MICRO_TEST(MultipleBatchFloatTest) {
|
||||
const int input_dims[] = {4, 3, 1, 1, 6};
|
||||
constexpr int data_length = 18;
|
||||
const float input_data[data_length] = {
|
||||
-1.1, 0.6, 0.7, 1.2, -0.7, 0.1, // batch 1
|
||||
-1.1, 0.6, 0.7, 1.2, -0.7, 0.1, // batch 2
|
||||
-1.1, 0.6, 0.7, 1.2, -0.7, 0.1, // batch 3
|
||||
};
|
||||
const float expected_output_data[data_length] = {
|
||||
-0.55, 0.3, 0.35, 0.6, -0.35, 0.05, // batch 1
|
||||
-0.55, 0.3, 0.35, 0.6, -0.35, 0.05, // batch 2
|
||||
-0.55, 0.3, 0.35, 0.6, -0.35, 0.05, // batch 3
|
||||
};
|
||||
float output_data[data_length];
|
||||
|
||||
tflite::testing::TestL2Normalization<float>(input_dims, input_data,
|
||||
expected_output_data, output_data, 0);
|
||||
}
|
||||
|
||||
TF_LITE_MICRO_TEST(ZerosVectorUint8Test) {
|
||||
const int input_dims[] = {4, 1, 1, 1, 6};
|
||||
constexpr int data_length = 6;
|
||||
const float input_data[data_length] = {0};
|
||||
const float expected_output_data[data_length] = {0};
|
||||
uint8_t quantized_input[data_length];
|
||||
uint8_t output_data[data_length];
|
||||
|
||||
tflite::testing::QuantizeInputData(input_data, data_length, quantized_input);
|
||||
|
||||
tflite::testing::TestL2Normalization<uint8_t>(input_dims, quantized_input,
|
||||
expected_output_data, output_data, .1);
|
||||
}
|
||||
|
||||
TF_LITE_MICRO_TEST(SimpleUint8Test) {
|
||||
const int input_dims[] = {4, 1, 1, 1, 6};
|
||||
constexpr int data_length = 6;
|
||||
float input_data[data_length] = {
|
||||
-1.1, 0.6, 0.7, 1.2, -0.7, 0.1
|
||||
};
|
||||
float expected_output[data_length] = {
|
||||
-0.55, 0.3, 0.35, 0.6, -0.35, 0.05
|
||||
};
|
||||
uint8_t quantized_input[data_length];
|
||||
uint8_t output_data[data_length];
|
||||
|
||||
tflite::testing::QuantizeInputData(input_data, data_length, quantized_input);
|
||||
|
||||
tflite::testing::TestL2Normalization<uint8_t>(input_dims, quantized_input,
|
||||
expected_output, output_data, .1);
|
||||
}
|
||||
|
||||
TF_LITE_MICRO_TEST(SimpleInt8Test) {
|
||||
const int input_dims[] = {4, 1, 1, 1, 6};
|
||||
constexpr int data_length = 6;
|
||||
float input_data[data_length] = {
|
||||
-1.1, 0.6, 0.7, 1.2, -0.7, 0.1
|
||||
};
|
||||
float expected_output[data_length] = {
|
||||
-0.55, 0.3, 0.35, 0.6, -0.35, 0.05
|
||||
};
|
||||
int8_t quantized_input[data_length];
|
||||
int8_t output_data[data_length];
|
||||
|
||||
tflite::testing::QuantizeInputData(input_data, data_length, quantized_input);
|
||||
|
||||
tflite::testing::TestL2Normalization<int8_t>(input_dims, quantized_input,
|
||||
expected_output, output_data, .1);
|
||||
}
|
||||
|
||||
TF_LITE_MICRO_TEST(ZerosVectorInt8Test) {
|
||||
const int input_dims[] = {4, 1, 1, 1, 6};
|
||||
constexpr int data_length = 6;
|
||||
const float input_data[data_length] = {0};
|
||||
const float expected_output_data[data_length] = {0};
|
||||
int8_t quantized_input[data_length];
|
||||
int8_t output_data[data_length];
|
||||
|
||||
tflite::testing::QuantizeInputData(input_data, data_length, quantized_input);
|
||||
|
||||
tflite::testing::TestL2Normalization<int8_t>(input_dims, quantized_input,
|
||||
expected_output_data, output_data, .1);
|
||||
}
|
||||
|
||||
TF_LITE_MICRO_TEST(MultipleBatchUint8Test) {
|
||||
const int input_dims[] = {4, 1, 1, 1, 6};
|
||||
constexpr int data_length = 18;
|
||||
float input_data[data_length] = {
|
||||
-1.1, 0.6, 0.7, 1.2, -0.7, 0.1, // batch 1
|
||||
-1.1, 0.6, 0.7, 1.2, -0.7, 0.1, // batch 2
|
||||
-1.1, 0.6, 0.7, 1.2, -0.7, 0.1, // batch 3
|
||||
};
|
||||
float expected_output[data_length] = {
|
||||
-0.55, 0.3, 0.35, 0.6, -0.35, 0.05, // batch 1
|
||||
-0.55, 0.3, 0.35, 0.6, -0.35, 0.05, // batch 2
|
||||
-0.55, 0.3, 0.35, 0.6, -0.35, 0.05, // batch 3
|
||||
};
|
||||
uint8_t quantized_input[data_length];
|
||||
uint8_t output_data[data_length];
|
||||
|
||||
tflite::testing::QuantizeInputData(input_data, data_length, quantized_input);
|
||||
|
||||
tflite::testing::TestL2Normalization<uint8_t>(input_dims, quantized_input,
|
||||
expected_output, output_data, .1);
|
||||
}
|
||||
|
||||
TF_LITE_MICRO_TEST(MultipleBatchInt8Test) {
|
||||
const int input_dims[] = {4, 1, 1, 1, 6};
|
||||
constexpr int data_length = 18;
|
||||
float input_data[data_length] = {
|
||||
-1.1, 0.6, 0.7, 1.2, -0.7, 0.1, // batch 1
|
||||
-1.1, 0.6, 0.7, 1.2, -0.7, 0.1, // batch 2
|
||||
-1.1, 0.6, 0.7, 1.2, -0.7, 0.1, // batch 3
|
||||
};
|
||||
float expected_output[data_length] = {
|
||||
-0.55, 0.3, 0.35, 0.6, -0.35, 0.05, // batch 1
|
||||
-0.55, 0.3, 0.35, 0.6, -0.35, 0.05, // batch 2
|
||||
-0.55, 0.3, 0.35, 0.6, -0.35, 0.05, // batch 3
|
||||
};
|
||||
int8_t quantized_input[data_length];
|
||||
int8_t output_data[data_length];
|
||||
|
||||
tflite::testing::QuantizeInputData(input_data, data_length, quantized_input);
|
||||
|
||||
tflite::testing::TestL2Normalization<int8_t>(input_dims, quantized_input,
|
||||
expected_output, output_data, .1);
|
||||
}
|
||||
|
||||
TF_LITE_MICRO_TESTS_END
|
@ -80,6 +80,7 @@ TfLiteRegistration* Register_STRIDED_SLICE();
|
||||
TfLiteRegistration* Register_SUB();
|
||||
TfLiteRegistration* Register_SVDF();
|
||||
TfLiteRegistration* Register_UNPACK();
|
||||
TfLiteRegistration* Register_L2_NORMALIZATION();
|
||||
|
||||
} // namespace micro
|
||||
} // namespace ops
|
||||
|
@ -159,7 +159,9 @@ tensorflow/lite/kernels/internal/reference/integer_ops/conv.h \
|
||||
tensorflow/lite/kernels/internal/reference/integer_ops/depthwise_conv.h \
|
||||
tensorflow/lite/kernels/internal/reference/integer_ops/fully_connected.h \
|
||||
tensorflow/lite/kernels/internal/reference/integer_ops/logistic.h \
|
||||
tensorflow/lite/kernels/internal/reference/integer_ops/l2normalization.h \
|
||||
tensorflow/lite/kernels/internal/reference/integer_ops/mul.h \
|
||||
tensorflow/lite/kernels/internal/reference/l2normalization.h \
|
||||
tensorflow/lite/kernels/internal/reference/maximum_minimum.h \
|
||||
tensorflow/lite/kernels/internal/reference/mul.h \
|
||||
tensorflow/lite/kernels/internal/reference/neg.h \
|
||||
|
Loading…
Reference in New Issue
Block a user