Implement lite/micro/kernels/zeros_like.cc and its test code
This commit is contained in:
parent
7128ad8140
commit
7108450567
@ -144,6 +144,7 @@ cc_library(
|
||||
"tanh.cc",
|
||||
"transpose_conv.cc",
|
||||
"unpack.cc",
|
||||
"zeros_like.cc",
|
||||
] + select({
|
||||
"//conditions:default": [
|
||||
"conv.cc",
|
||||
|
||||
@ -42,6 +42,7 @@ TfLiteRegistration Register_SOFTMAX();
|
||||
TfLiteRegistration Register_SPACE_TO_BATCH_ND();
|
||||
TfLiteRegistration Register_SVDF();
|
||||
TfLiteRegistration Register_TRANSPOSE_CONV_2D();
|
||||
TfLiteRegistration Register_ZEROS_LIKE();
|
||||
|
||||
namespace ops {
|
||||
namespace micro {
|
||||
|
||||
@ -13,18 +13,13 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <stdint.h>
|
||||
#include <string.h>
|
||||
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/kernels/internal/tensor.h"
|
||||
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
|
||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||
#include "tensorflow/lite/micro/kernels/kernel_util.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace ops {
|
||||
namespace micro {
|
||||
namespace zeros_like {
|
||||
namespace {
|
||||
|
||||
constexpr int kInputTensor = 0;
|
||||
constexpr int kOutputTensor = 0;
|
||||
@ -39,26 +34,32 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||
output->type = input->type;
|
||||
|
||||
return context->ResizeTensor(context, output,
|
||||
TfLiteIntArrayCopy(input->dims));
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void resetZeros(T* out, int num_elements) {
|
||||
for (int i = 0; i < num_elements; ++i) {
|
||||
out[i] = static_cast<T>(0);
|
||||
}
|
||||
}
|
||||
|
||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||
const int num_elements = NumElements(input);
|
||||
const TfLiteEvalTensor* input =
|
||||
tflite::micro::GetEvalInput(context, node, kInputTensor);
|
||||
TfLiteEvalTensor* output =
|
||||
tflite::micro::GetEvalOutput(context, node, kOutputTensor);
|
||||
int flat_size = MatchingFlatSize(tflite::micro::GetTensorShape(input),
|
||||
tflite::micro::GetTensorShape(output));
|
||||
switch (input->type) {
|
||||
case kTfLiteInt64:
|
||||
memset(GetTensorData<int64_t>(output), 0, num_elements * sizeof(int64_t));
|
||||
resetZeros(tflite::micro::GetTensorData<int64_t>(output), flat_size);
|
||||
break;
|
||||
case kTfLiteInt32:
|
||||
memset(GetTensorData<int32_t>(output), 0, num_elements * sizeof(int32_t));
|
||||
resetZeros(tflite::micro::GetTensorData<int32_t>(output), flat_size);
|
||||
break;
|
||||
case kTfLiteFloat32:
|
||||
memset(GetTensorData<float>(output), 0, num_elements * sizeof(float));
|
||||
resetZeros(tflite::micro::GetTensorData<float>(output), flat_size);
|
||||
break;
|
||||
default:
|
||||
TF_LITE_KERNEL_LOG(context,
|
||||
@ -69,15 +70,17 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
}
|
||||
return kTfLiteOk;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
} // namespace zeros_like
|
||||
|
||||
TfLiteRegistration* Register_ZEROS_LIKE() {
|
||||
static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr,
|
||||
zeros_like::Prepare, zeros_like::Eval};
|
||||
return &r;
|
||||
TfLiteRegistration Register_ZEROS_LIKE() {
|
||||
return {/*init=*/nullptr,
|
||||
/*free=*/nullptr,
|
||||
/*prepare=*/Prepare,
|
||||
/*invoke=*/Eval,
|
||||
/*profiling_string=*/nullptr,
|
||||
/*builtin_code=*/0,
|
||||
/*custom_name=*/nullptr,
|
||||
/*version=*/0};
|
||||
}
|
||||
|
||||
} // namespace micro
|
||||
} // namespace ops
|
||||
} // namespace tflite
|
||||
|
||||
156
tensorflow/lite/micro/kernels/zeros_like_test.cc
Normal file
156
tensorflow/lite/micro/kernels/zeros_like_test.cc
Normal file
@ -0,0 +1,156 @@
|
||||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/micro/all_ops_resolver.h"
|
||||
#include "tensorflow/lite/micro/kernels/kernel_runner.h"
|
||||
#include "tensorflow/lite/micro/test_helpers.h"
|
||||
#include "tensorflow/lite/micro/testing/micro_test.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace testing {
|
||||
namespace {
|
||||
|
||||
void TestZerosLikeFloat(const int* input_dims_data, const float* input_data,
|
||||
const float* expected_output_data,
|
||||
float* output_data) {
|
||||
TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data);
|
||||
TfLiteIntArray* output_dims = IntArrayFromInts(input_dims_data);
|
||||
const int output_dims_count = ElementCount(*output_dims);
|
||||
constexpr int inputs_size = 1;
|
||||
constexpr int outputs_size = 1;
|
||||
constexpr int tensors_size = inputs_size + outputs_size;
|
||||
TfLiteTensor tensors[tensors_size] = {
|
||||
CreateTensor(input_data, input_dims),
|
||||
CreateTensor(output_data, output_dims),
|
||||
};
|
||||
|
||||
int inputs_array_data[] = {1, 0};
|
||||
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
|
||||
int outputs_array_data[] = {1, 1};
|
||||
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
|
||||
|
||||
const TfLiteRegistration registration = Register_ZEROS_LIKE();
|
||||
micro::KernelRunner runner(registration, tensors, tensors_size, inputs_array,
|
||||
outputs_array,
|
||||
/*builtin_data=*/nullptr, micro_test::reporter);
|
||||
|
||||
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.InitAndPrepare());
|
||||
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.Invoke());
|
||||
|
||||
for (int i = 0; i < output_dims_count; ++i) {
|
||||
TF_LITE_MICRO_EXPECT_EQ(expected_output_data[i], output_data[i]);
|
||||
}
|
||||
}
|
||||
|
||||
void TestZerosLikeInt32(const int* input_dims_data, const int32_t* input_data,
|
||||
const int32_t* expected_output_data,
|
||||
int32_t* output_data) {
|
||||
TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data);
|
||||
TfLiteIntArray* output_dims = IntArrayFromInts(input_dims_data);
|
||||
const int output_dims_count = ElementCount(*output_dims);
|
||||
constexpr int inputs_size = 1;
|
||||
constexpr int outputs_size = 1;
|
||||
constexpr int tensors_size = inputs_size + outputs_size;
|
||||
TfLiteTensor tensors[tensors_size] = {
|
||||
CreateTensor(input_data, input_dims),
|
||||
CreateTensor(output_data, output_dims),
|
||||
};
|
||||
|
||||
int inputs_array_data[] = {1, 0};
|
||||
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
|
||||
int outputs_array_data[] = {1, 1};
|
||||
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
|
||||
|
||||
const TfLiteRegistration registration = Register_ZEROS_LIKE();
|
||||
micro::KernelRunner runner(registration, tensors, tensors_size, inputs_array,
|
||||
outputs_array,
|
||||
/*builtin_data=*/nullptr, micro_test::reporter);
|
||||
|
||||
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.InitAndPrepare());
|
||||
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.Invoke());
|
||||
|
||||
for (int i = 0; i < output_dims_count; ++i) {
|
||||
TF_LITE_MICRO_EXPECT_EQ(expected_output_data[i], output_data[i]);
|
||||
}
|
||||
}
|
||||
|
||||
void TestZerosLikeInt64(const int* input_dims_data, const int64_t* input_data,
|
||||
const int64_t* expected_output_data,
|
||||
int64_t* output_data) {
|
||||
TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data);
|
||||
TfLiteIntArray* output_dims = IntArrayFromInts(input_dims_data);
|
||||
const int output_dims_count = ElementCount(*output_dims);
|
||||
constexpr int inputs_size = 1;
|
||||
constexpr int outputs_size = 1;
|
||||
constexpr int tensors_size = inputs_size + outputs_size;
|
||||
TfLiteTensor tensors[tensors_size] = {
|
||||
CreateTensor(input_data, input_dims),
|
||||
CreateTensor(output_data, output_dims),
|
||||
};
|
||||
|
||||
int inputs_array_data[] = {1, 0};
|
||||
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
|
||||
int outputs_array_data[] = {1, 1};
|
||||
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
|
||||
|
||||
const TfLiteRegistration registration = Register_ZEROS_LIKE();
|
||||
micro::KernelRunner runner(registration, tensors, tensors_size, inputs_array,
|
||||
outputs_array,
|
||||
/*builtin_data=*/nullptr, micro_test::reporter);
|
||||
|
||||
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.InitAndPrepare());
|
||||
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.Invoke());
|
||||
|
||||
for (int i = 0; i < output_dims_count; ++i) {
|
||||
TF_LITE_MICRO_EXPECT_EQ(expected_output_data[i], output_data[i]);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace testing
|
||||
} // namespace tflite
|
||||
|
||||
TF_LITE_MICRO_TESTS_BEGIN
|
||||
|
||||
TF_LITE_MICRO_TEST(TestZerosLikeFloat) {
|
||||
float output_data[6];
|
||||
const int input_dims[] = {2, 3};
|
||||
const float input_values[] = {-2.0, -1.0, 0.0, 1.0, 2.0, 3.0};
|
||||
const float golden[] = {0.0, 0.0, 0.0, 0.0, 0.0, 0.0};
|
||||
tflite::testing::TestZerosLikeFloat(input_dims, input_values, golden,
|
||||
output_data);
|
||||
}
|
||||
|
||||
TF_LITE_MICRO_TEST(TestZerosLikeInt32) {
|
||||
int32_t output_data[4];
|
||||
const int input_dims[] = {1, 2, 2, 1};
|
||||
const int32_t input_values[] = {-2, -1, 0, 3};
|
||||
const int32_t golden[] = {0, 0, 0, 0};
|
||||
tflite::testing::TestZerosLikeInt32(input_dims, input_values, golden,
|
||||
output_data);
|
||||
}
|
||||
|
||||
TF_LITE_MICRO_TEST(TestZerosLikeInt64) {
|
||||
int64_t output_data[4];
|
||||
const int input_dims[] = {1, 2, 2, 1};
|
||||
const int64_t input_values[] = {-2, -1, 0, 3};
|
||||
const int64_t golden[] = {0, 0, 0, 0};
|
||||
tflite::testing::TestZerosLikeInt64(input_dims, input_values, golden,
|
||||
output_data);
|
||||
}
|
||||
|
||||
TF_LITE_MICRO_TESTS_END
|
||||
@ -438,6 +438,11 @@ class MicroMutableOpResolver : public MicroOpResolver {
|
||||
tflite::ops::micro::Register_UNPACK(), ParseUnpack);
|
||||
}
|
||||
|
||||
TfLiteStatus AddZerosLike() {
|
||||
return AddBuiltin(BuiltinOperator_ZEROS_LIKE,
|
||||
Register_ZEROS_LIKE(), ParseZerosLike);
|
||||
}
|
||||
|
||||
unsigned int GetRegistrationLength() { return registrations_len_; }
|
||||
|
||||
private:
|
||||
|
||||
@ -301,6 +301,7 @@ tensorflow/lite/micro/kernels/svdf_test.cc \
|
||||
tensorflow/lite/micro/kernels/tanh_test.cc \
|
||||
tensorflow/lite/micro/kernels/transpose_conv_test.cc \
|
||||
tensorflow/lite/micro/kernels/unpack_test.cc \
|
||||
tensorflow/lite/micro/kernels/zeros_like_test.cc \
|
||||
tensorflow/lite/micro/memory_planner/greedy_memory_planner_test.cc \
|
||||
tensorflow/lite/micro/memory_planner/linear_memory_planner_test.cc
|
||||
|
||||
@ -357,7 +358,8 @@ tensorflow/lite/micro/kernels/svdf.cc \
|
||||
tensorflow/lite/micro/kernels/svdf_common.cc \
|
||||
tensorflow/lite/micro/kernels/tanh.cc \
|
||||
tensorflow/lite/micro/kernels/transpose_conv.cc \
|
||||
tensorflow/lite/micro/kernels/unpack.cc
|
||||
tensorflow/lite/micro/kernels/unpack.cc \
|
||||
tensorflow/lite/micro/kernels/zeros_like.cc
|
||||
|
||||
MICROLITE_TEST_HDRS := \
|
||||
$(wildcard tensorflow/lite/micro/testing/*.h)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user