Port the l2norm kernel to the new TfLiteEvalTensor API.
PiperOrigin-RevId: 323039627 Change-Id: I1c64000d5e83b5aff8875f1b354439f352c13edd
This commit is contained in:
parent
0c1b3ee3a0
commit
3098c7a84d
@ -619,6 +619,7 @@ tflite_micro_cc_test(
|
||||
"l2norm_test.cc",
|
||||
],
|
||||
deps = [
|
||||
":kernel_runner",
|
||||
"//tensorflow/lite/c:common",
|
||||
"//tensorflow/lite/micro:op_resolvers",
|
||||
"//tensorflow/lite/micro/testing:micro_test",
|
||||
|
@ -18,12 +18,15 @@ limitations under the License.
|
||||
#include "tensorflow/lite/kernels/internal/reference/l2normalization.h"
|
||||
#include "tensorflow/lite/kernels/internal/tensor.h"
|
||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||
#include "tensorflow/lite/micro/kernels/kernel_util.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace ops {
|
||||
namespace micro {
|
||||
namespace l2norm {
|
||||
|
||||
namespace {
|
||||
|
||||
// This file has two implementation of L2Norm.
|
||||
enum KernelType {
|
||||
kReference,
|
||||
@ -33,9 +36,15 @@ enum KernelType {
|
||||
constexpr int kInputTensor = 0;
|
||||
constexpr int kOutputTensor = 0;
|
||||
|
||||
} // namespace
|
||||
|
||||
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
#if defined(DEBUG)
|
||||
TFLITE_DCHECK(node->user_data != nullptr);
|
||||
TFLITE_DCHECK(node->builtin_data != nullptr);
|
||||
|
||||
auto* params = reinterpret_cast<TfLiteL2NormParams*>(node->builtin_data);
|
||||
L2NormalizationParams* data =
|
||||
static_cast<L2NormalizationParams*>(node->user_data);
|
||||
|
||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||
@ -51,26 +60,33 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
|
||||
|
||||
if (output->type == kTfLiteUInt8 || output->type == kTfLiteInt8) {
|
||||
TF_LITE_ENSURE_EQ(context, output->params.scale, (1. / 128.));
|
||||
if (output->type == kTfLiteUInt8) {
|
||||
TF_LITE_ENSURE_EQ(context, output->params.zero_point, 128);
|
||||
}
|
||||
if (output->type == kTfLiteInt8) {
|
||||
TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0);
|
||||
}
|
||||
data->input_zero_point = input->params.zero_point;
|
||||
} else if (output->type == kTfLiteFloat32) {
|
||||
data->input_zero_point = 0;
|
||||
}
|
||||
|
||||
// TODO(ahentz): For some reason our implementations don't support
|
||||
// activations.
|
||||
TF_LITE_ENSURE_EQ(context, params->activation, kTfLiteActNone);
|
||||
#endif
|
||||
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
||||
TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
|
||||
return context->AllocatePersistentBuffer(context,
|
||||
sizeof(L2NormalizationParams));
|
||||
}
|
||||
|
||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
TFLITE_DCHECK(node->user_data != nullptr);
|
||||
const L2NormalizationParams& data =
|
||||
*(static_cast<const L2NormalizationParams*>(node->user_data));
|
||||
|
||||
const TfLiteEvalTensor* input =
|
||||
tflite::micro::GetEvalInput(context, node, kInputTensor);
|
||||
TfLiteEvalTensor* output =
|
||||
tflite::micro::GetEvalOutput(context, node, kOutputTensor);
|
||||
|
||||
// TODO(b/143912164): instead of hardcode the epsilon here, we should read it
|
||||
// from tensorflow, i.e., adding a params.
|
||||
@ -87,36 +103,29 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
// So we don't even need to do handle the epsilon for quantized kernel case.
|
||||
const float epsilon = 1e-6f;
|
||||
if (output->type == kTfLiteFloat32) {
|
||||
#define TF_LITE_L2NORM(type) \
|
||||
tflite::L2NormalizationParams op_params; \
|
||||
op_params.input_zero_point = 0; \
|
||||
type::L2Normalization(op_params, GetTensorShape(input), \
|
||||
GetTensorData<float>(input), GetTensorShape(output), \
|
||||
GetTensorData<float>(output), epsilon)
|
||||
|
||||
TF_LITE_L2NORM(reference_ops);
|
||||
#undef TF_LITE_L2NORM
|
||||
reference_ops::L2Normalization(data, tflite::micro::GetTensorShape(input),
|
||||
tflite::micro::GetTensorData<float>(input),
|
||||
tflite::micro::GetTensorShape(output),
|
||||
tflite::micro::GetTensorData<float>(output),
|
||||
epsilon);
|
||||
} else if (output->type == kTfLiteUInt8) {
|
||||
#define TF_LITE_L2NORM(type) \
|
||||
tflite::L2NormalizationParams op_params; \
|
||||
op_params.input_zero_point = input->params.zero_point; \
|
||||
type::L2Normalization(op_params, GetTensorShape(input), \
|
||||
GetTensorData<uint8_t>(input), GetTensorShape(output), \
|
||||
GetTensorData<uint8_t>(output))
|
||||
|
||||
TF_LITE_L2NORM(reference_ops);
|
||||
#undef TF_LITE_L2NORM
|
||||
reference_ops::L2Normalization(
|
||||
data, tflite::micro::GetTensorShape(input),
|
||||
tflite::micro::GetTensorData<uint8_t>(input),
|
||||
tflite::micro::GetTensorShape(output),
|
||||
tflite::micro::GetTensorData<uint8_t>(output));
|
||||
} else if (output->type == kTfLiteInt8) {
|
||||
const auto input_shape = GetTensorShape(input);
|
||||
const auto output_shape = GetTensorShape(output);
|
||||
const auto input_shape = tflite::micro::GetTensorShape(input);
|
||||
const auto output_shape = tflite::micro::GetTensorShape(output);
|
||||
const int trailing_dim = input_shape.DimensionsCount() - 1;
|
||||
const int depth =
|
||||
MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
|
||||
const int outer_size =
|
||||
MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
|
||||
reference_integer_ops::L2Normalization(input->params.zero_point, outer_size,
|
||||
depth, GetTensorData<int8_t>(input),
|
||||
GetTensorData<int8_t>(output));
|
||||
reference_integer_ops::L2Normalization(
|
||||
data.input_zero_point, outer_size, depth,
|
||||
tflite::micro::GetTensorData<int8_t>(input),
|
||||
tflite::micro::GetTensorData<int8_t>(output));
|
||||
} else {
|
||||
TF_LITE_KERNEL_LOG(context, "Output type is %s, requires float.",
|
||||
TfLiteTypeGetName(output->type));
|
||||
@ -129,7 +138,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
} // namespace l2norm
|
||||
|
||||
TfLiteRegistration Register_L2NORM_REF() {
|
||||
return {/*init=*/nullptr,
|
||||
return {/*init=*/l2norm::Init,
|
||||
/*free=*/nullptr,
|
||||
/*prepare=*/l2norm::Prepare,
|
||||
/*invoke=*/l2norm::Eval,
|
||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/micro/all_ops_resolver.h"
|
||||
#include "tensorflow/lite/micro/kernels/kernel_runner.h"
|
||||
#include "tensorflow/lite/micro/testing/micro_test.h"
|
||||
#include "tensorflow/lite/micro/testing/test_utils.h"
|
||||
|
||||
@ -97,31 +98,23 @@ void TestL2Normalization(const int* input_dims_data, const T* input_data,
|
||||
CreateL2NormTensor(output_data, dims, false),
|
||||
};
|
||||
|
||||
TfLiteContext context;
|
||||
PopulateContext(tensors, tensors_size, micro_test::reporter, &context);
|
||||
::tflite::AllOpsResolver resolver;
|
||||
const TfLiteRegistration* registration =
|
||||
resolver.FindOp(tflite::BuiltinOperator_L2_NORMALIZATION);
|
||||
TF_LITE_MICRO_EXPECT_NE(nullptr, registration);
|
||||
|
||||
TfLiteL2NormParams builtin_data = {
|
||||
.activation = kTfLiteActNone,
|
||||
};
|
||||
|
||||
int inputs_array_data[] = {1, 0};
|
||||
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
|
||||
int outputs_array_data[] = {1, 1};
|
||||
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
|
||||
|
||||
TfLiteNode node;
|
||||
node.inputs = inputs_array;
|
||||
node.outputs = outputs_array;
|
||||
node.user_data = nullptr;
|
||||
node.builtin_data = reinterpret_cast<void*>(&builtin_data);
|
||||
node.custom_initial_data = nullptr;
|
||||
node.custom_initial_data_size = 0;
|
||||
TF_LITE_MICRO_EXPECT_NE(nullptr, registration->invoke);
|
||||
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->invoke(&context, &node));
|
||||
TfLiteL2NormParams builtin_data = {
|
||||
.activation = kTfLiteActNone,
|
||||
};
|
||||
|
||||
const TfLiteRegistration registration =
|
||||
ops::micro::Register_L2_NORMALIZATION();
|
||||
micro::KernelRunner runner(
|
||||
registration, tensors, tensors_size, inputs_array, outputs_array,
|
||||
reinterpret_cast<void*>(&builtin_data), micro_test::reporter);
|
||||
|
||||
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.InitAndPrepare());
|
||||
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.Invoke());
|
||||
|
||||
// Compare the results from dequantization and expected outputs, and make
|
||||
// sure the difference is within a threshold.
|
||||
|
Loading…
x
Reference in New Issue
Block a user