From 710845056711c8d0606da21d28fb1cb7adaac0e6 Mon Sep 17 00:00:00 2001 From: rsun Date: Mon, 1 Feb 2021 16:40:39 -0800 Subject: [PATCH] Implement lite/micro/kernels/zeros_like.cc and its test code --- tensorflow/lite/micro/kernels/BUILD | 1 + tensorflow/lite/micro/kernels/micro_ops.h | 1 + tensorflow/lite/micro/kernels/zeros_like.cc | 55 +++--- .../lite/micro/kernels/zeros_like_test.cc | 156 ++++++++++++++++++ .../lite/micro/micro_mutable_op_resolver.h | 5 + tensorflow/lite/micro/tools/make/Makefile | 4 +- 6 files changed, 195 insertions(+), 27 deletions(-) create mode 100644 tensorflow/lite/micro/kernels/zeros_like_test.cc diff --git a/tensorflow/lite/micro/kernels/BUILD b/tensorflow/lite/micro/kernels/BUILD index 85835c8fdb8..758fcbdb937 100644 --- a/tensorflow/lite/micro/kernels/BUILD +++ b/tensorflow/lite/micro/kernels/BUILD @@ -144,6 +144,7 @@ cc_library( "tanh.cc", "transpose_conv.cc", "unpack.cc", + "zeros_like.cc", ] + select({ "//conditions:default": [ "conv.cc", diff --git a/tensorflow/lite/micro/kernels/micro_ops.h b/tensorflow/lite/micro/kernels/micro_ops.h index 1537fca3973..6164b24c395 100644 --- a/tensorflow/lite/micro/kernels/micro_ops.h +++ b/tensorflow/lite/micro/kernels/micro_ops.h @@ -42,6 +42,7 @@ TfLiteRegistration Register_SOFTMAX(); TfLiteRegistration Register_SPACE_TO_BATCH_ND(); TfLiteRegistration Register_SVDF(); TfLiteRegistration Register_TRANSPOSE_CONV_2D(); +TfLiteRegistration Register_ZEROS_LIKE(); namespace ops { namespace micro { diff --git a/tensorflow/lite/micro/kernels/zeros_like.cc b/tensorflow/lite/micro/kernels/zeros_like.cc index a231104978c..752745bcf83 100644 --- a/tensorflow/lite/micro/kernels/zeros_like.cc +++ b/tensorflow/lite/micro/kernels/zeros_like.cc @@ -13,18 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include -#include - #include "tensorflow/lite/c/common.h" -#include "tensorflow/lite/kernels/internal/tensor.h" #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" #include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/micro/kernels/kernel_util.h" namespace tflite { -namespace ops { -namespace micro { -namespace zeros_like { +namespace { constexpr int kInputTensor = 0; constexpr int kOutputTensor = 0; @@ -39,26 +34,32 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { GetOutputSafe(context, node, kOutputTensor, &output)); output->type = input->type; - return context->ResizeTensor(context, output, - TfLiteIntArrayCopy(input->dims)); + return kTfLiteOk; +} + +template +void resetZeros(T* out, int num_elements) { + for (int i = 0; i < num_elements; ++i) { + out[i] = static_cast(0); + } } TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { - const TfLiteTensor* input; - TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input)); - TfLiteTensor* output; - TF_LITE_ENSURE_OK(context, - GetOutputSafe(context, node, kOutputTensor, &output)); - const int num_elements = NumElements(input); + const TfLiteEvalTensor* input = + tflite::micro::GetEvalInput(context, node, kInputTensor); + TfLiteEvalTensor* output = + tflite::micro::GetEvalOutput(context, node, kOutputTensor); + int flat_size = MatchingFlatSize(tflite::micro::GetTensorShape(input), + tflite::micro::GetTensorShape(output)); switch (input->type) { case kTfLiteInt64: - memset(GetTensorData(output), 0, num_elements * sizeof(int64_t)); + resetZeros(tflite::micro::GetTensorData(output), flat_size); break; case kTfLiteInt32: - memset(GetTensorData(output), 0, num_elements * sizeof(int32_t)); + resetZeros(tflite::micro::GetTensorData(output), flat_size); break; case kTfLiteFloat32: - memset(GetTensorData(output), 0, num_elements * sizeof(float)); + resetZeros(tflite::micro::GetTensorData(output), flat_size); break; default: TF_LITE_KERNEL_LOG(context, @@ -69,15 +70,17 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { } return kTfLiteOk; } +} // namespace -} // namespace zeros_like - -TfLiteRegistration* Register_ZEROS_LIKE() { - static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr, - zeros_like::Prepare, zeros_like::Eval}; - return &r; +TfLiteRegistration Register_ZEROS_LIKE() { + return {/*init=*/nullptr, + /*free=*/nullptr, + /*prepare=*/Prepare, + /*invoke=*/Eval, + /*profiling_string=*/nullptr, + /*builtin_code=*/0, + /*custom_name=*/nullptr, + /*version=*/0}; } -} // namespace micro -} // namespace ops } // namespace tflite diff --git a/tensorflow/lite/micro/kernels/zeros_like_test.cc b/tensorflow/lite/micro/kernels/zeros_like_test.cc new file mode 100644 index 00000000000..5e1a326c268 --- /dev/null +++ b/tensorflow/lite/micro/kernels/zeros_like_test.cc @@ -0,0 +1,156 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/micro/all_ops_resolver.h" +#include "tensorflow/lite/micro/kernels/kernel_runner.h" +#include "tensorflow/lite/micro/test_helpers.h" +#include "tensorflow/lite/micro/testing/micro_test.h" + +namespace tflite { +namespace testing { +namespace { + +void TestZerosLikeFloat(const int* input_dims_data, const float* input_data, + const float* expected_output_data, + float* output_data) { + TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data); + TfLiteIntArray* output_dims = IntArrayFromInts(input_dims_data); + const int output_dims_count = ElementCount(*output_dims); + constexpr int inputs_size = 1; + constexpr int outputs_size = 1; + constexpr int tensors_size = inputs_size + outputs_size; + TfLiteTensor tensors[tensors_size] = { + CreateTensor(input_data, input_dims), + CreateTensor(output_data, output_dims), + }; + + int inputs_array_data[] = {1, 0}; + TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data); + int outputs_array_data[] = {1, 1}; + TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data); + + const TfLiteRegistration registration = Register_ZEROS_LIKE(); + micro::KernelRunner runner(registration, tensors, tensors_size, inputs_array, + outputs_array, + /*builtin_data=*/nullptr, micro_test::reporter); + + TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.InitAndPrepare()); + TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.Invoke()); + + for (int i = 0; i < output_dims_count; ++i) { + TF_LITE_MICRO_EXPECT_EQ(expected_output_data[i], output_data[i]); + } +} + +void TestZerosLikeInt32(const int* input_dims_data, const int32_t* input_data, + const int32_t* expected_output_data, + int32_t* output_data) { + TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data); + TfLiteIntArray* output_dims = IntArrayFromInts(input_dims_data); + const int output_dims_count = ElementCount(*output_dims); + constexpr int inputs_size = 1; + constexpr int outputs_size = 1; + constexpr int tensors_size = inputs_size + outputs_size; + TfLiteTensor tensors[tensors_size] = { + CreateTensor(input_data, input_dims), + CreateTensor(output_data, output_dims), + }; + + int inputs_array_data[] = {1, 0}; + TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data); + int outputs_array_data[] = {1, 1}; + TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data); + + const TfLiteRegistration registration = Register_ZEROS_LIKE(); + micro::KernelRunner runner(registration, tensors, tensors_size, inputs_array, + outputs_array, + /*builtin_data=*/nullptr, micro_test::reporter); + + TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.InitAndPrepare()); + TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.Invoke()); + + for (int i = 0; i < output_dims_count; ++i) { + TF_LITE_MICRO_EXPECT_EQ(expected_output_data[i], output_data[i]); + } +} + +void TestZerosLikeInt64(const int* input_dims_data, const int64_t* input_data, + const int64_t* expected_output_data, + int64_t* output_data) { + TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data); + TfLiteIntArray* output_dims = IntArrayFromInts(input_dims_data); + const int output_dims_count = ElementCount(*output_dims); + constexpr int inputs_size = 1; + constexpr int outputs_size = 1; + constexpr int tensors_size = inputs_size + outputs_size; + TfLiteTensor tensors[tensors_size] = { + CreateTensor(input_data, input_dims), + CreateTensor(output_data, output_dims), + }; + + int inputs_array_data[] = {1, 0}; + TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data); + int outputs_array_data[] = {1, 1}; + TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data); + + const TfLiteRegistration registration = Register_ZEROS_LIKE(); + micro::KernelRunner runner(registration, tensors, tensors_size, inputs_array, + outputs_array, + /*builtin_data=*/nullptr, micro_test::reporter); + + TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.InitAndPrepare()); + TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.Invoke()); + + for (int i = 0; i < output_dims_count; ++i) { + TF_LITE_MICRO_EXPECT_EQ(expected_output_data[i], output_data[i]); + } +} + +} // namespace +} // namespace testing +} // namespace tflite + +TF_LITE_MICRO_TESTS_BEGIN + +TF_LITE_MICRO_TEST(TestZerosLikeFloat) { + float output_data[6]; + const int input_dims[] = {2, 3}; + const float input_values[] = {-2.0, -1.0, 0.0, 1.0, 2.0, 3.0}; + const float golden[] = {0.0, 0.0, 0.0, 0.0, 0.0, 0.0}; + tflite::testing::TestZerosLikeFloat(input_dims, input_values, golden, + output_data); +} + +TF_LITE_MICRO_TEST(TestZerosLikeInt32) { + int32_t output_data[4]; + const int input_dims[] = {1, 2, 2, 1}; + const int32_t input_values[] = {-2, -1, 0, 3}; + const int32_t golden[] = {0, 0, 0, 0}; + tflite::testing::TestZerosLikeInt32(input_dims, input_values, golden, + output_data); +} + +TF_LITE_MICRO_TEST(TestZerosLikeInt64) { + int64_t output_data[4]; + const int input_dims[] = {1, 2, 2, 1}; + const int64_t input_values[] = {-2, -1, 0, 3}; + const int64_t golden[] = {0, 0, 0, 0}; + tflite::testing::TestZerosLikeInt64(input_dims, input_values, golden, + output_data); +} + +TF_LITE_MICRO_TESTS_END diff --git a/tensorflow/lite/micro/micro_mutable_op_resolver.h b/tensorflow/lite/micro/micro_mutable_op_resolver.h index 30f6ffde44c..3d4ed575b83 100644 --- a/tensorflow/lite/micro/micro_mutable_op_resolver.h +++ b/tensorflow/lite/micro/micro_mutable_op_resolver.h @@ -438,6 +438,11 @@ class MicroMutableOpResolver : public MicroOpResolver { tflite::ops::micro::Register_UNPACK(), ParseUnpack); } + TfLiteStatus AddZerosLike() { + return AddBuiltin(BuiltinOperator_ZEROS_LIKE, + Register_ZEROS_LIKE(), ParseZerosLike); + } + unsigned int GetRegistrationLength() { return registrations_len_; } private: diff --git a/tensorflow/lite/micro/tools/make/Makefile b/tensorflow/lite/micro/tools/make/Makefile index 43bf9f1705e..a44895281df 100644 --- a/tensorflow/lite/micro/tools/make/Makefile +++ b/tensorflow/lite/micro/tools/make/Makefile @@ -301,6 +301,7 @@ tensorflow/lite/micro/kernels/svdf_test.cc \ tensorflow/lite/micro/kernels/tanh_test.cc \ tensorflow/lite/micro/kernels/transpose_conv_test.cc \ tensorflow/lite/micro/kernels/unpack_test.cc \ +tensorflow/lite/micro/kernels/zeros_like_test.cc \ tensorflow/lite/micro/memory_planner/greedy_memory_planner_test.cc \ tensorflow/lite/micro/memory_planner/linear_memory_planner_test.cc @@ -357,7 +358,8 @@ tensorflow/lite/micro/kernels/svdf.cc \ tensorflow/lite/micro/kernels/svdf_common.cc \ tensorflow/lite/micro/kernels/tanh.cc \ tensorflow/lite/micro/kernels/transpose_conv.cc \ -tensorflow/lite/micro/kernels/unpack.cc +tensorflow/lite/micro/kernels/unpack.cc \ +tensorflow/lite/micro/kernels/zeros_like.cc MICROLITE_TEST_HDRS := \ $(wildcard tensorflow/lite/micro/testing/*.h)