diff --git a/tensorflow/lite/micro/kernels/BUILD b/tensorflow/lite/micro/kernels/BUILD index 8d7e5ff7354..7b9ec5dd8bb 100644 --- a/tensorflow/lite/micro/kernels/BUILD +++ b/tensorflow/lite/micro/kernels/BUILD @@ -619,6 +619,7 @@ tflite_micro_cc_test( "l2norm_test.cc", ], deps = [ + ":kernel_runner", "//tensorflow/lite/c:common", "//tensorflow/lite/micro:op_resolvers", "//tensorflow/lite/micro/testing:micro_test", diff --git a/tensorflow/lite/micro/kernels/l2norm.cc b/tensorflow/lite/micro/kernels/l2norm.cc index ab4067058a4..f864efa271c 100644 --- a/tensorflow/lite/micro/kernels/l2norm.cc +++ b/tensorflow/lite/micro/kernels/l2norm.cc @@ -18,12 +18,15 @@ limitations under the License. #include "tensorflow/lite/kernels/internal/reference/l2normalization.h" #include "tensorflow/lite/kernels/internal/tensor.h" #include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/micro/kernels/kernel_util.h" namespace tflite { namespace ops { namespace micro { namespace l2norm { +namespace { + // This file has two implementation of L2Norm. enum KernelType { kReference, @@ -33,9 +36,15 @@ enum KernelType { constexpr int kInputTensor = 0; constexpr int kOutputTensor = 0; +} // namespace + TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { -#if defined(DEBUG) + TFLITE_DCHECK(node->user_data != nullptr); + TFLITE_DCHECK(node->builtin_data != nullptr); + auto* params = reinterpret_cast(node->builtin_data); + L2NormalizationParams* data = + static_cast(node->user_data); TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); @@ -51,26 +60,33 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type); if (output->type == kTfLiteUInt8 || output->type == kTfLiteInt8) { - TF_LITE_ENSURE_EQ(context, output->params.scale, (1. / 128.)); - if (output->type == kTfLiteUInt8) { - TF_LITE_ENSURE_EQ(context, output->params.zero_point, 128); - } - if (output->type == kTfLiteInt8) { - TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0); - } + data->input_zero_point = input->params.zero_point; + } else if (output->type == kTfLiteFloat32) { + data->input_zero_point = 0; } // TODO(ahentz): For some reason our implementations don't support // activations. TF_LITE_ENSURE_EQ(context, params->activation, kTfLiteActNone); -#endif return kTfLiteOk; } +void* Init(TfLiteContext* context, const char* buffer, size_t length) { + TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr); + return context->AllocatePersistentBuffer(context, + sizeof(L2NormalizationParams)); +} + TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { - const TfLiteTensor* input = GetInput(context, node, kInputTensor); - TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + TFLITE_DCHECK(node->user_data != nullptr); + const L2NormalizationParams& data = + *(static_cast(node->user_data)); + + const TfLiteEvalTensor* input = + tflite::micro::GetEvalInput(context, node, kInputTensor); + TfLiteEvalTensor* output = + tflite::micro::GetEvalOutput(context, node, kOutputTensor); // TODO(b/143912164): instead of hardcode the epsilon here, we should read it // from tensorflow, i.e., adding a params. @@ -87,36 +103,29 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { // So we don't even need to do handle the epsilon for quantized kernel case. const float epsilon = 1e-6f; if (output->type == kTfLiteFloat32) { -#define TF_LITE_L2NORM(type) \ - tflite::L2NormalizationParams op_params; \ - op_params.input_zero_point = 0; \ - type::L2Normalization(op_params, GetTensorShape(input), \ - GetTensorData(input), GetTensorShape(output), \ - GetTensorData(output), epsilon) - - TF_LITE_L2NORM(reference_ops); -#undef TF_LITE_L2NORM + reference_ops::L2Normalization(data, tflite::micro::GetTensorShape(input), + tflite::micro::GetTensorData(input), + tflite::micro::GetTensorShape(output), + tflite::micro::GetTensorData(output), + epsilon); } else if (output->type == kTfLiteUInt8) { -#define TF_LITE_L2NORM(type) \ - tflite::L2NormalizationParams op_params; \ - op_params.input_zero_point = input->params.zero_point; \ - type::L2Normalization(op_params, GetTensorShape(input), \ - GetTensorData(input), GetTensorShape(output), \ - GetTensorData(output)) - - TF_LITE_L2NORM(reference_ops); -#undef TF_LITE_L2NORM + reference_ops::L2Normalization( + data, tflite::micro::GetTensorShape(input), + tflite::micro::GetTensorData(input), + tflite::micro::GetTensorShape(output), + tflite::micro::GetTensorData(output)); } else if (output->type == kTfLiteInt8) { - const auto input_shape = GetTensorShape(input); - const auto output_shape = GetTensorShape(output); + const auto input_shape = tflite::micro::GetTensorShape(input); + const auto output_shape = tflite::micro::GetTensorShape(output); const int trailing_dim = input_shape.DimensionsCount() - 1; const int depth = MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim); const int outer_size = MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape); - reference_integer_ops::L2Normalization(input->params.zero_point, outer_size, - depth, GetTensorData(input), - GetTensorData(output)); + reference_integer_ops::L2Normalization( + data.input_zero_point, outer_size, depth, + tflite::micro::GetTensorData(input), + tflite::micro::GetTensorData(output)); } else { TF_LITE_KERNEL_LOG(context, "Output type is %s, requires float.", TfLiteTypeGetName(output->type)); @@ -129,7 +138,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { } // namespace l2norm TfLiteRegistration Register_L2NORM_REF() { - return {/*init=*/nullptr, + return {/*init=*/l2norm::Init, /*free=*/nullptr, /*prepare=*/l2norm::Prepare, /*invoke=*/l2norm::Eval, diff --git a/tensorflow/lite/micro/kernels/l2norm_test.cc b/tensorflow/lite/micro/kernels/l2norm_test.cc index 89029bb260a..791f9036c56 100644 --- a/tensorflow/lite/micro/kernels/l2norm_test.cc +++ b/tensorflow/lite/micro/kernels/l2norm_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/lite/c/builtin_op_data.h" #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/micro/all_ops_resolver.h" +#include "tensorflow/lite/micro/kernels/kernel_runner.h" #include "tensorflow/lite/micro/testing/micro_test.h" #include "tensorflow/lite/micro/testing/test_utils.h" @@ -97,31 +98,23 @@ void TestL2Normalization(const int* input_dims_data, const T* input_data, CreateL2NormTensor(output_data, dims, false), }; - TfLiteContext context; - PopulateContext(tensors, tensors_size, micro_test::reporter, &context); - ::tflite::AllOpsResolver resolver; - const TfLiteRegistration* registration = - resolver.FindOp(tflite::BuiltinOperator_L2_NORMALIZATION); - TF_LITE_MICRO_EXPECT_NE(nullptr, registration); - - TfLiteL2NormParams builtin_data = { - .activation = kTfLiteActNone, - }; - int inputs_array_data[] = {1, 0}; TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data); int outputs_array_data[] = {1, 1}; TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data); - TfLiteNode node; - node.inputs = inputs_array; - node.outputs = outputs_array; - node.user_data = nullptr; - node.builtin_data = reinterpret_cast(&builtin_data); - node.custom_initial_data = nullptr; - node.custom_initial_data_size = 0; - TF_LITE_MICRO_EXPECT_NE(nullptr, registration->invoke); - TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->invoke(&context, &node)); + TfLiteL2NormParams builtin_data = { + .activation = kTfLiteActNone, + }; + + const TfLiteRegistration registration = + ops::micro::Register_L2_NORMALIZATION(); + micro::KernelRunner runner( + registration, tensors, tensors_size, inputs_array, outputs_array, + reinterpret_cast(&builtin_data), micro_test::reporter); + + TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.InitAndPrepare()); + TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.Invoke()); // Compare the results from dequantization and expected outputs, and make // sure the difference is within a threshold.