Port the elementwise kernel to the new TfLiteEvalTensor API.
PiperOrigin-RevId: 323017688 Change-Id: I4362d0e3d70dd4449a9e45f0bf8289b0f8824235
This commit is contained in:
parent
6f57dbfffa
commit
06eb028030
@ -133,6 +133,7 @@ tflite_micro_cc_test(
|
||||
name = "elementwise_test",
|
||||
srcs = ["elementwise_test.cc"],
|
||||
deps = [
|
||||
":kernel_runner",
|
||||
"//tensorflow/lite/c:common",
|
||||
"//tensorflow/lite/micro:debug_log",
|
||||
"//tensorflow/lite/micro:op_resolvers",
|
||||
|
@ -18,6 +18,8 @@ limitations under the License.
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
|
||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||
#include "tensorflow/lite/micro/kernels/kernel_util.h"
|
||||
#include "tensorflow/lite/micro/micro_utils.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace ops {
|
||||
@ -52,13 +54,13 @@ TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
template <typename T>
|
||||
inline TfLiteStatus EvalImpl(TfLiteContext* context, TfLiteNode* node,
|
||||
T func(T), TfLiteType expected_type) {
|
||||
const TfLiteTensor* input = GetInput(context, node, 0);
|
||||
TfLiteTensor* output = GetOutput(context, node, 0);
|
||||
const TfLiteEvalTensor* input = tflite::micro::GetEvalInput(context, node, 0);
|
||||
TfLiteEvalTensor* output = tflite::micro::GetEvalOutput(context, node, 0);
|
||||
TF_LITE_ENSURE_TYPES_EQ(context, input->type, expected_type);
|
||||
const int64_t num_elements = NumElements(input);
|
||||
const T* in_data = GetTensorData<T>(input);
|
||||
T* out_data = GetTensorData<T>(output);
|
||||
for (int64_t i = 0; i < num_elements; ++i) {
|
||||
const size_t num_elements = ElementCount(*input->dims);
|
||||
const T* in_data = tflite::micro::GetTensorData<T>(input);
|
||||
T* out_data = tflite::micro::GetTensorData<T>(output);
|
||||
for (size_t i = 0; i < num_elements; ++i) {
|
||||
out_data[i] = func(in_data[i]);
|
||||
}
|
||||
return kTfLiteOk;
|
||||
@ -106,7 +108,6 @@ TfLiteStatus LogicalNotEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
return EvalLogical(context, node, [](bool v) { return !v; });
|
||||
}
|
||||
|
||||
|
||||
} // namespace
|
||||
} // namespace elementwise
|
||||
|
||||
|
@ -16,13 +16,14 @@ limitations under the License.
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/micro/all_ops_resolver.h"
|
||||
#include "tensorflow/lite/micro/debug_log.h"
|
||||
#include "tensorflow/lite/micro/kernels/kernel_runner.h"
|
||||
#include "tensorflow/lite/micro/testing/micro_test.h"
|
||||
#include "tensorflow/lite/micro/testing/test_utils.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace testing {
|
||||
|
||||
void TestElementwiseFloat(tflite::BuiltinOperator op,
|
||||
void TestElementwiseFloat(const TfLiteRegistration& registration,
|
||||
const int* input_dims_data, const float* input_data,
|
||||
const int* output_dims_data,
|
||||
const float* expected_output_data,
|
||||
@ -43,45 +44,26 @@ void TestElementwiseFloat(tflite::BuiltinOperator op,
|
||||
output_data[i] = 23;
|
||||
}
|
||||
|
||||
TfLiteContext context;
|
||||
PopulateContext(tensors, tensors_size, micro_test::reporter, &context);
|
||||
tflite::AllOpsResolver resolver;
|
||||
const TfLiteRegistration* registration = resolver.FindOp(op);
|
||||
TF_LITE_MICRO_EXPECT_NE(nullptr, registration);
|
||||
|
||||
void* user_data = nullptr;
|
||||
if (registration->init) {
|
||||
user_data = registration->init(&context, nullptr, 0);
|
||||
}
|
||||
static int inputs_array_data[] = {1, 0};
|
||||
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
|
||||
static int outputs_array_data[] = {1, 1};
|
||||
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
|
||||
|
||||
TfLiteNode node;
|
||||
node.inputs = inputs_array;
|
||||
node.outputs = outputs_array;
|
||||
node.user_data = user_data;
|
||||
node.builtin_data = nullptr;
|
||||
node.custom_initial_data = nullptr;
|
||||
node.custom_initial_data_size = 0;
|
||||
micro::KernelRunner runner(registration, tensors, tensors_size, inputs_array,
|
||||
outputs_array,
|
||||
/*builtin_data=*/nullptr, micro_test::reporter);
|
||||
|
||||
if (registration->prepare) {
|
||||
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
|
||||
}
|
||||
TF_LITE_MICRO_EXPECT_NE(nullptr, registration->invoke);
|
||||
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.InitAndPrepare());
|
||||
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.Invoke());
|
||||
|
||||
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->invoke(&context, &node));
|
||||
if (registration->free) {
|
||||
registration->free(&context, user_data);
|
||||
}
|
||||
for (int i = 0; i < output_dims_count; ++i) {
|
||||
TF_LITE_MICRO_EXPECT_NEAR(expected_output_data[i], output_data[i], 1e-5f);
|
||||
}
|
||||
}
|
||||
|
||||
void TestElementwiseBool(tflite::BuiltinOperator op, const int* input_dims_data,
|
||||
const bool* input_data, const int* output_dims_data,
|
||||
void TestElementwiseBool(const TfLiteRegistration& registration,
|
||||
const int* input_dims_data, const bool* input_data,
|
||||
const int* output_dims_data,
|
||||
const bool* expected_output_data, bool* output_data) {
|
||||
TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data);
|
||||
TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_data);
|
||||
@ -99,39 +81,18 @@ void TestElementwiseBool(tflite::BuiltinOperator op, const int* input_dims_data,
|
||||
output_data[i] = false;
|
||||
}
|
||||
|
||||
TfLiteContext context;
|
||||
PopulateContext(tensors, tensors_size, micro_test::reporter, &context);
|
||||
tflite::AllOpsResolver resolver;
|
||||
const TfLiteRegistration* registration = resolver.FindOp(op);
|
||||
TF_LITE_MICRO_EXPECT_NE(nullptr, registration);
|
||||
|
||||
void* user_data = nullptr;
|
||||
if (registration->init) {
|
||||
user_data = registration->init(&context, nullptr, 0);
|
||||
}
|
||||
|
||||
const int inputs_array_data[] = {1, 0};
|
||||
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
|
||||
const int outputs_array_data[] = {1, 1};
|
||||
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
|
||||
|
||||
TfLiteNode node;
|
||||
node.inputs = inputs_array;
|
||||
node.outputs = outputs_array;
|
||||
node.user_data = user_data;
|
||||
node.builtin_data = nullptr;
|
||||
node.custom_initial_data = nullptr;
|
||||
node.custom_initial_data_size = 0;
|
||||
micro::KernelRunner runner(registration, tensors, tensors_size, inputs_array,
|
||||
outputs_array,
|
||||
/*builtin_data=*/nullptr, micro_test::reporter);
|
||||
|
||||
if (registration->prepare) {
|
||||
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
|
||||
}
|
||||
TF_LITE_MICRO_EXPECT_NE(nullptr, registration->invoke);
|
||||
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.InitAndPrepare());
|
||||
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.Invoke());
|
||||
|
||||
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->invoke(&context, &node));
|
||||
if (registration->free) {
|
||||
registration->free(&context, user_data);
|
||||
}
|
||||
for (int i = 0; i < output_dims_count; ++i) {
|
||||
TF_LITE_MICRO_EXPECT_EQ(expected_output_data[i], output_data[i]);
|
||||
}
|
||||
@ -148,8 +109,9 @@ TF_LITE_MICRO_TEST(Abs) {
|
||||
const float input[] = {0.01, -0.01, 10, -10};
|
||||
const float golden[] = {0.01, 0.01, 10, 10};
|
||||
float output_data[output_dims_count];
|
||||
tflite::testing::TestElementwiseFloat(tflite::BuiltinOperator_ABS, shape,
|
||||
input, shape, golden, output_data);
|
||||
tflite::testing::TestElementwiseFloat(tflite::ops::micro::Register_ABS(),
|
||||
shape, input, shape, golden,
|
||||
output_data);
|
||||
}
|
||||
|
||||
TF_LITE_MICRO_TEST(Sin) {
|
||||
@ -158,8 +120,9 @@ TF_LITE_MICRO_TEST(Sin) {
|
||||
const float input[] = {0, 3.1415926, -3.1415926, 1};
|
||||
const float golden[] = {0, 0, 0, 0.84147};
|
||||
float output_data[output_dims_count];
|
||||
tflite::testing::TestElementwiseFloat(tflite::BuiltinOperator_SIN, shape,
|
||||
input, shape, golden, output_data);
|
||||
tflite::testing::TestElementwiseFloat(tflite::ops::micro::Register_SIN(),
|
||||
shape, input, shape, golden,
|
||||
output_data);
|
||||
}
|
||||
|
||||
TF_LITE_MICRO_TEST(Cos) {
|
||||
@ -168,8 +131,9 @@ TF_LITE_MICRO_TEST(Cos) {
|
||||
const float input[] = {0, 3.1415926, -3.1415926, 1};
|
||||
const float golden[] = {1, -1, -1, 0.54030};
|
||||
float output_data[output_dims_count];
|
||||
tflite::testing::TestElementwiseFloat(tflite::BuiltinOperator_COS, shape,
|
||||
input, shape, golden, output_data);
|
||||
tflite::testing::TestElementwiseFloat(tflite::ops::micro::Register_COS(),
|
||||
shape, input, shape, golden,
|
||||
output_data);
|
||||
}
|
||||
|
||||
TF_LITE_MICRO_TEST(Log) {
|
||||
@ -178,8 +142,9 @@ TF_LITE_MICRO_TEST(Log) {
|
||||
const float input[] = {1, 2.7182818, 0.5, 2};
|
||||
const float golden[] = {0, 1, -0.6931472, 0.6931472};
|
||||
float output_data[output_dims_count];
|
||||
tflite::testing::TestElementwiseFloat(tflite::BuiltinOperator_LOG, shape,
|
||||
input, shape, golden, output_data);
|
||||
tflite::testing::TestElementwiseFloat(tflite::ops::micro::Register_LOG(),
|
||||
shape, input, shape, golden,
|
||||
output_data);
|
||||
}
|
||||
|
||||
TF_LITE_MICRO_TEST(Sqrt) {
|
||||
@ -188,8 +153,9 @@ TF_LITE_MICRO_TEST(Sqrt) {
|
||||
const float input[] = {0, 1, 2, 4};
|
||||
const float golden[] = {0, 1, 1.41421, 2};
|
||||
float output_data[output_dims_count];
|
||||
tflite::testing::TestElementwiseFloat(tflite::BuiltinOperator_SQRT, shape,
|
||||
input, shape, golden, output_data);
|
||||
tflite::testing::TestElementwiseFloat(tflite::ops::micro::Register_SQRT(),
|
||||
shape, input, shape, golden,
|
||||
output_data);
|
||||
}
|
||||
|
||||
TF_LITE_MICRO_TEST(Rsqrt) {
|
||||
@ -198,8 +164,9 @@ TF_LITE_MICRO_TEST(Rsqrt) {
|
||||
const float input[] = {1, 2, 4, 9};
|
||||
const float golden[] = {1, 0.7071, 0.5, 0.33333};
|
||||
float output_data[output_dims_count];
|
||||
tflite::testing::TestElementwiseFloat(tflite::BuiltinOperator_RSQRT, shape,
|
||||
input, shape, golden, output_data);
|
||||
tflite::testing::TestElementwiseFloat(tflite::ops::micro::Register_RSQRT(),
|
||||
shape, input, shape, golden,
|
||||
output_data);
|
||||
}
|
||||
|
||||
TF_LITE_MICRO_TEST(Square) {
|
||||
@ -208,8 +175,9 @@ TF_LITE_MICRO_TEST(Square) {
|
||||
const float input[] = {1, 2, 0.5, -3.0};
|
||||
const float golden[] = {1, 4.0, 0.25, 9.0};
|
||||
float output_data[output_dims_count];
|
||||
tflite::testing::TestElementwiseFloat(tflite::BuiltinOperator_SQUARE, shape,
|
||||
input, shape, golden, output_data);
|
||||
tflite::testing::TestElementwiseFloat(tflite::ops::micro::Register_SQUARE(),
|
||||
shape, input, shape, golden,
|
||||
output_data);
|
||||
}
|
||||
|
||||
TF_LITE_MICRO_TEST(LogicalNot) {
|
||||
@ -218,9 +186,9 @@ TF_LITE_MICRO_TEST(LogicalNot) {
|
||||
const bool input[] = {true, false, false, true};
|
||||
const bool golden[] = {false, true, true, false};
|
||||
bool output_data[output_dims_count];
|
||||
tflite::testing::TestElementwiseBool(tflite::BuiltinOperator_LOGICAL_NOT,
|
||||
shape, input, shape, golden,
|
||||
output_data);
|
||||
tflite::testing::TestElementwiseBool(
|
||||
tflite::ops::micro::Register_LOGICAL_NOT(), shape, input, shape, golden,
|
||||
output_data);
|
||||
}
|
||||
|
||||
TF_LITE_MICRO_TESTS_END
|
||||
|
Loading…
Reference in New Issue
Block a user