From 06eb028030a0a6f9c3fb0ec46a5ded4e3c0ed03e Mon Sep 17 00:00:00 2001 From: Nick Kreeger Date: Fri, 24 Jul 2020 10:07:48 -0700 Subject: [PATCH] Port the elementwise kernel to the new TfLiteEvalTensor API. PiperOrigin-RevId: 323017688 Change-Id: I4362d0e3d70dd4449a9e45f0bf8289b0f8824235 --- tensorflow/lite/micro/kernels/BUILD | 1 + tensorflow/lite/micro/kernels/elementwise.cc | 15 +-- .../lite/micro/kernels/elementwise_test.cc | 110 +++++++----------- 3 files changed, 48 insertions(+), 78 deletions(-) diff --git a/tensorflow/lite/micro/kernels/BUILD b/tensorflow/lite/micro/kernels/BUILD index 5ff1121fedb..5f79f7c0c62 100644 --- a/tensorflow/lite/micro/kernels/BUILD +++ b/tensorflow/lite/micro/kernels/BUILD @@ -133,6 +133,7 @@ tflite_micro_cc_test( name = "elementwise_test", srcs = ["elementwise_test.cc"], deps = [ + ":kernel_runner", "//tensorflow/lite/c:common", "//tensorflow/lite/micro:debug_log", "//tensorflow/lite/micro:op_resolvers", diff --git a/tensorflow/lite/micro/kernels/elementwise.cc b/tensorflow/lite/micro/kernels/elementwise.cc index cb1fd852812..64880344664 100644 --- a/tensorflow/lite/micro/kernels/elementwise.cc +++ b/tensorflow/lite/micro/kernels/elementwise.cc @@ -18,6 +18,8 @@ limitations under the License. #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" #include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/micro/kernels/kernel_util.h" +#include "tensorflow/lite/micro/micro_utils.h" namespace tflite { namespace ops { @@ -52,13 +54,13 @@ TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteNode* node) { template inline TfLiteStatus EvalImpl(TfLiteContext* context, TfLiteNode* node, T func(T), TfLiteType expected_type) { - const TfLiteTensor* input = GetInput(context, node, 0); - TfLiteTensor* output = GetOutput(context, node, 0); + const TfLiteEvalTensor* input = tflite::micro::GetEvalInput(context, node, 0); + TfLiteEvalTensor* output = tflite::micro::GetEvalOutput(context, node, 0); TF_LITE_ENSURE_TYPES_EQ(context, input->type, expected_type); - const int64_t num_elements = NumElements(input); - const T* in_data = GetTensorData(input); - T* out_data = GetTensorData(output); - for (int64_t i = 0; i < num_elements; ++i) { + const size_t num_elements = ElementCount(*input->dims); + const T* in_data = tflite::micro::GetTensorData(input); + T* out_data = tflite::micro::GetTensorData(output); + for (size_t i = 0; i < num_elements; ++i) { out_data[i] = func(in_data[i]); } return kTfLiteOk; @@ -106,7 +108,6 @@ TfLiteStatus LogicalNotEval(TfLiteContext* context, TfLiteNode* node) { return EvalLogical(context, node, [](bool v) { return !v; }); } - } // namespace } // namespace elementwise diff --git a/tensorflow/lite/micro/kernels/elementwise_test.cc b/tensorflow/lite/micro/kernels/elementwise_test.cc index 8f028b1f451..b7094cbd445 100644 --- a/tensorflow/lite/micro/kernels/elementwise_test.cc +++ b/tensorflow/lite/micro/kernels/elementwise_test.cc @@ -16,13 +16,14 @@ limitations under the License. #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/micro/all_ops_resolver.h" #include "tensorflow/lite/micro/debug_log.h" +#include "tensorflow/lite/micro/kernels/kernel_runner.h" #include "tensorflow/lite/micro/testing/micro_test.h" #include "tensorflow/lite/micro/testing/test_utils.h" namespace tflite { namespace testing { -void TestElementwiseFloat(tflite::BuiltinOperator op, +void TestElementwiseFloat(const TfLiteRegistration& registration, const int* input_dims_data, const float* input_data, const int* output_dims_data, const float* expected_output_data, @@ -43,45 +44,26 @@ void TestElementwiseFloat(tflite::BuiltinOperator op, output_data[i] = 23; } - TfLiteContext context; - PopulateContext(tensors, tensors_size, micro_test::reporter, &context); - tflite::AllOpsResolver resolver; - const TfLiteRegistration* registration = resolver.FindOp(op); - TF_LITE_MICRO_EXPECT_NE(nullptr, registration); - - void* user_data = nullptr; - if (registration->init) { - user_data = registration->init(&context, nullptr, 0); - } static int inputs_array_data[] = {1, 0}; TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data); static int outputs_array_data[] = {1, 1}; TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data); - TfLiteNode node; - node.inputs = inputs_array; - node.outputs = outputs_array; - node.user_data = user_data; - node.builtin_data = nullptr; - node.custom_initial_data = nullptr; - node.custom_initial_data_size = 0; + micro::KernelRunner runner(registration, tensors, tensors_size, inputs_array, + outputs_array, + /*builtin_data=*/nullptr, micro_test::reporter); - if (registration->prepare) { - TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node)); - } - TF_LITE_MICRO_EXPECT_NE(nullptr, registration->invoke); + TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.InitAndPrepare()); + TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.Invoke()); - TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->invoke(&context, &node)); - if (registration->free) { - registration->free(&context, user_data); - } for (int i = 0; i < output_dims_count; ++i) { TF_LITE_MICRO_EXPECT_NEAR(expected_output_data[i], output_data[i], 1e-5f); } } -void TestElementwiseBool(tflite::BuiltinOperator op, const int* input_dims_data, - const bool* input_data, const int* output_dims_data, +void TestElementwiseBool(const TfLiteRegistration& registration, + const int* input_dims_data, const bool* input_data, + const int* output_dims_data, const bool* expected_output_data, bool* output_data) { TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data); TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_data); @@ -99,39 +81,18 @@ void TestElementwiseBool(tflite::BuiltinOperator op, const int* input_dims_data, output_data[i] = false; } - TfLiteContext context; - PopulateContext(tensors, tensors_size, micro_test::reporter, &context); - tflite::AllOpsResolver resolver; - const TfLiteRegistration* registration = resolver.FindOp(op); - TF_LITE_MICRO_EXPECT_NE(nullptr, registration); - - void* user_data = nullptr; - if (registration->init) { - user_data = registration->init(&context, nullptr, 0); - } - const int inputs_array_data[] = {1, 0}; TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data); const int outputs_array_data[] = {1, 1}; TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data); - TfLiteNode node; - node.inputs = inputs_array; - node.outputs = outputs_array; - node.user_data = user_data; - node.builtin_data = nullptr; - node.custom_initial_data = nullptr; - node.custom_initial_data_size = 0; + micro::KernelRunner runner(registration, tensors, tensors_size, inputs_array, + outputs_array, + /*builtin_data=*/nullptr, micro_test::reporter); - if (registration->prepare) { - TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node)); - } - TF_LITE_MICRO_EXPECT_NE(nullptr, registration->invoke); + TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.InitAndPrepare()); + TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.Invoke()); - TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->invoke(&context, &node)); - if (registration->free) { - registration->free(&context, user_data); - } for (int i = 0; i < output_dims_count; ++i) { TF_LITE_MICRO_EXPECT_EQ(expected_output_data[i], output_data[i]); } @@ -148,8 +109,9 @@ TF_LITE_MICRO_TEST(Abs) { const float input[] = {0.01, -0.01, 10, -10}; const float golden[] = {0.01, 0.01, 10, 10}; float output_data[output_dims_count]; - tflite::testing::TestElementwiseFloat(tflite::BuiltinOperator_ABS, shape, - input, shape, golden, output_data); + tflite::testing::TestElementwiseFloat(tflite::ops::micro::Register_ABS(), + shape, input, shape, golden, + output_data); } TF_LITE_MICRO_TEST(Sin) { @@ -158,8 +120,9 @@ TF_LITE_MICRO_TEST(Sin) { const float input[] = {0, 3.1415926, -3.1415926, 1}; const float golden[] = {0, 0, 0, 0.84147}; float output_data[output_dims_count]; - tflite::testing::TestElementwiseFloat(tflite::BuiltinOperator_SIN, shape, - input, shape, golden, output_data); + tflite::testing::TestElementwiseFloat(tflite::ops::micro::Register_SIN(), + shape, input, shape, golden, + output_data); } TF_LITE_MICRO_TEST(Cos) { @@ -168,8 +131,9 @@ TF_LITE_MICRO_TEST(Cos) { const float input[] = {0, 3.1415926, -3.1415926, 1}; const float golden[] = {1, -1, -1, 0.54030}; float output_data[output_dims_count]; - tflite::testing::TestElementwiseFloat(tflite::BuiltinOperator_COS, shape, - input, shape, golden, output_data); + tflite::testing::TestElementwiseFloat(tflite::ops::micro::Register_COS(), + shape, input, shape, golden, + output_data); } TF_LITE_MICRO_TEST(Log) { @@ -178,8 +142,9 @@ TF_LITE_MICRO_TEST(Log) { const float input[] = {1, 2.7182818, 0.5, 2}; const float golden[] = {0, 1, -0.6931472, 0.6931472}; float output_data[output_dims_count]; - tflite::testing::TestElementwiseFloat(tflite::BuiltinOperator_LOG, shape, - input, shape, golden, output_data); + tflite::testing::TestElementwiseFloat(tflite::ops::micro::Register_LOG(), + shape, input, shape, golden, + output_data); } TF_LITE_MICRO_TEST(Sqrt) { @@ -188,8 +153,9 @@ TF_LITE_MICRO_TEST(Sqrt) { const float input[] = {0, 1, 2, 4}; const float golden[] = {0, 1, 1.41421, 2}; float output_data[output_dims_count]; - tflite::testing::TestElementwiseFloat(tflite::BuiltinOperator_SQRT, shape, - input, shape, golden, output_data); + tflite::testing::TestElementwiseFloat(tflite::ops::micro::Register_SQRT(), + shape, input, shape, golden, + output_data); } TF_LITE_MICRO_TEST(Rsqrt) { @@ -198,8 +164,9 @@ TF_LITE_MICRO_TEST(Rsqrt) { const float input[] = {1, 2, 4, 9}; const float golden[] = {1, 0.7071, 0.5, 0.33333}; float output_data[output_dims_count]; - tflite::testing::TestElementwiseFloat(tflite::BuiltinOperator_RSQRT, shape, - input, shape, golden, output_data); + tflite::testing::TestElementwiseFloat(tflite::ops::micro::Register_RSQRT(), + shape, input, shape, golden, + output_data); } TF_LITE_MICRO_TEST(Square) { @@ -208,8 +175,9 @@ TF_LITE_MICRO_TEST(Square) { const float input[] = {1, 2, 0.5, -3.0}; const float golden[] = {1, 4.0, 0.25, 9.0}; float output_data[output_dims_count]; - tflite::testing::TestElementwiseFloat(tflite::BuiltinOperator_SQUARE, shape, - input, shape, golden, output_data); + tflite::testing::TestElementwiseFloat(tflite::ops::micro::Register_SQUARE(), + shape, input, shape, golden, + output_data); } TF_LITE_MICRO_TEST(LogicalNot) { @@ -218,9 +186,9 @@ TF_LITE_MICRO_TEST(LogicalNot) { const bool input[] = {true, false, false, true}; const bool golden[] = {false, true, true, false}; bool output_data[output_dims_count]; - tflite::testing::TestElementwiseBool(tflite::BuiltinOperator_LOGICAL_NOT, - shape, input, shape, golden, - output_data); + tflite::testing::TestElementwiseBool( + tflite::ops::micro::Register_LOGICAL_NOT(), shape, input, shape, golden, + output_data); } TF_LITE_MICRO_TESTS_END