Port the elementwise kernel to the new TfLiteEvalTensor API.

PiperOrigin-RevId: 323017688
Change-Id: I4362d0e3d70dd4449a9e45f0bf8289b0f8824235
This commit is contained in:
Nick Kreeger 2020-07-24 10:07:48 -07:00 committed by TensorFlower Gardener
parent 6f57dbfffa
commit 06eb028030
3 changed files with 48 additions and 78 deletions

View File

@ -133,6 +133,7 @@ tflite_micro_cc_test(
name = "elementwise_test",
srcs = ["elementwise_test.cc"],
deps = [
":kernel_runner",
"//tensorflow/lite/c:common",
"//tensorflow/lite/micro:debug_log",
"//tensorflow/lite/micro:op_resolvers",

View File

@ -18,6 +18,8 @@ limitations under the License.
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/micro/kernels/kernel_util.h"
#include "tensorflow/lite/micro/micro_utils.h"
namespace tflite {
namespace ops {
@ -52,13 +54,13 @@ TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteNode* node) {
template <typename T>
inline TfLiteStatus EvalImpl(TfLiteContext* context, TfLiteNode* node,
T func(T), TfLiteType expected_type) {
const TfLiteTensor* input = GetInput(context, node, 0);
TfLiteTensor* output = GetOutput(context, node, 0);
const TfLiteEvalTensor* input = tflite::micro::GetEvalInput(context, node, 0);
TfLiteEvalTensor* output = tflite::micro::GetEvalOutput(context, node, 0);
TF_LITE_ENSURE_TYPES_EQ(context, input->type, expected_type);
const int64_t num_elements = NumElements(input);
const T* in_data = GetTensorData<T>(input);
T* out_data = GetTensorData<T>(output);
for (int64_t i = 0; i < num_elements; ++i) {
const size_t num_elements = ElementCount(*input->dims);
const T* in_data = tflite::micro::GetTensorData<T>(input);
T* out_data = tflite::micro::GetTensorData<T>(output);
for (size_t i = 0; i < num_elements; ++i) {
out_data[i] = func(in_data[i]);
}
return kTfLiteOk;
@ -106,7 +108,6 @@ TfLiteStatus LogicalNotEval(TfLiteContext* context, TfLiteNode* node) {
return EvalLogical(context, node, [](bool v) { return !v; });
}
} // namespace
} // namespace elementwise

View File

@ -16,13 +16,14 @@ limitations under the License.
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/micro/all_ops_resolver.h"
#include "tensorflow/lite/micro/debug_log.h"
#include "tensorflow/lite/micro/kernels/kernel_runner.h"
#include "tensorflow/lite/micro/testing/micro_test.h"
#include "tensorflow/lite/micro/testing/test_utils.h"
namespace tflite {
namespace testing {
void TestElementwiseFloat(tflite::BuiltinOperator op,
void TestElementwiseFloat(const TfLiteRegistration& registration,
const int* input_dims_data, const float* input_data,
const int* output_dims_data,
const float* expected_output_data,
@ -43,45 +44,26 @@ void TestElementwiseFloat(tflite::BuiltinOperator op,
output_data[i] = 23;
}
TfLiteContext context;
PopulateContext(tensors, tensors_size, micro_test::reporter, &context);
tflite::AllOpsResolver resolver;
const TfLiteRegistration* registration = resolver.FindOp(op);
TF_LITE_MICRO_EXPECT_NE(nullptr, registration);
void* user_data = nullptr;
if (registration->init) {
user_data = registration->init(&context, nullptr, 0);
}
static int inputs_array_data[] = {1, 0};
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
static int outputs_array_data[] = {1, 1};
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
TfLiteNode node;
node.inputs = inputs_array;
node.outputs = outputs_array;
node.user_data = user_data;
node.builtin_data = nullptr;
node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0;
micro::KernelRunner runner(registration, tensors, tensors_size, inputs_array,
outputs_array,
/*builtin_data=*/nullptr, micro_test::reporter);
if (registration->prepare) {
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
}
TF_LITE_MICRO_EXPECT_NE(nullptr, registration->invoke);
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.InitAndPrepare());
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.Invoke());
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->invoke(&context, &node));
if (registration->free) {
registration->free(&context, user_data);
}
for (int i = 0; i < output_dims_count; ++i) {
TF_LITE_MICRO_EXPECT_NEAR(expected_output_data[i], output_data[i], 1e-5f);
}
}
void TestElementwiseBool(tflite::BuiltinOperator op, const int* input_dims_data,
const bool* input_data, const int* output_dims_data,
void TestElementwiseBool(const TfLiteRegistration& registration,
const int* input_dims_data, const bool* input_data,
const int* output_dims_data,
const bool* expected_output_data, bool* output_data) {
TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data);
TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_data);
@ -99,39 +81,18 @@ void TestElementwiseBool(tflite::BuiltinOperator op, const int* input_dims_data,
output_data[i] = false;
}
TfLiteContext context;
PopulateContext(tensors, tensors_size, micro_test::reporter, &context);
tflite::AllOpsResolver resolver;
const TfLiteRegistration* registration = resolver.FindOp(op);
TF_LITE_MICRO_EXPECT_NE(nullptr, registration);
void* user_data = nullptr;
if (registration->init) {
user_data = registration->init(&context, nullptr, 0);
}
const int inputs_array_data[] = {1, 0};
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
const int outputs_array_data[] = {1, 1};
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
TfLiteNode node;
node.inputs = inputs_array;
node.outputs = outputs_array;
node.user_data = user_data;
node.builtin_data = nullptr;
node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0;
micro::KernelRunner runner(registration, tensors, tensors_size, inputs_array,
outputs_array,
/*builtin_data=*/nullptr, micro_test::reporter);
if (registration->prepare) {
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
}
TF_LITE_MICRO_EXPECT_NE(nullptr, registration->invoke);
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.InitAndPrepare());
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.Invoke());
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->invoke(&context, &node));
if (registration->free) {
registration->free(&context, user_data);
}
for (int i = 0; i < output_dims_count; ++i) {
TF_LITE_MICRO_EXPECT_EQ(expected_output_data[i], output_data[i]);
}
@ -148,8 +109,9 @@ TF_LITE_MICRO_TEST(Abs) {
const float input[] = {0.01, -0.01, 10, -10};
const float golden[] = {0.01, 0.01, 10, 10};
float output_data[output_dims_count];
tflite::testing::TestElementwiseFloat(tflite::BuiltinOperator_ABS, shape,
input, shape, golden, output_data);
tflite::testing::TestElementwiseFloat(tflite::ops::micro::Register_ABS(),
shape, input, shape, golden,
output_data);
}
TF_LITE_MICRO_TEST(Sin) {
@ -158,8 +120,9 @@ TF_LITE_MICRO_TEST(Sin) {
const float input[] = {0, 3.1415926, -3.1415926, 1};
const float golden[] = {0, 0, 0, 0.84147};
float output_data[output_dims_count];
tflite::testing::TestElementwiseFloat(tflite::BuiltinOperator_SIN, shape,
input, shape, golden, output_data);
tflite::testing::TestElementwiseFloat(tflite::ops::micro::Register_SIN(),
shape, input, shape, golden,
output_data);
}
TF_LITE_MICRO_TEST(Cos) {
@ -168,8 +131,9 @@ TF_LITE_MICRO_TEST(Cos) {
const float input[] = {0, 3.1415926, -3.1415926, 1};
const float golden[] = {1, -1, -1, 0.54030};
float output_data[output_dims_count];
tflite::testing::TestElementwiseFloat(tflite::BuiltinOperator_COS, shape,
input, shape, golden, output_data);
tflite::testing::TestElementwiseFloat(tflite::ops::micro::Register_COS(),
shape, input, shape, golden,
output_data);
}
TF_LITE_MICRO_TEST(Log) {
@ -178,8 +142,9 @@ TF_LITE_MICRO_TEST(Log) {
const float input[] = {1, 2.7182818, 0.5, 2};
const float golden[] = {0, 1, -0.6931472, 0.6931472};
float output_data[output_dims_count];
tflite::testing::TestElementwiseFloat(tflite::BuiltinOperator_LOG, shape,
input, shape, golden, output_data);
tflite::testing::TestElementwiseFloat(tflite::ops::micro::Register_LOG(),
shape, input, shape, golden,
output_data);
}
TF_LITE_MICRO_TEST(Sqrt) {
@ -188,8 +153,9 @@ TF_LITE_MICRO_TEST(Sqrt) {
const float input[] = {0, 1, 2, 4};
const float golden[] = {0, 1, 1.41421, 2};
float output_data[output_dims_count];
tflite::testing::TestElementwiseFloat(tflite::BuiltinOperator_SQRT, shape,
input, shape, golden, output_data);
tflite::testing::TestElementwiseFloat(tflite::ops::micro::Register_SQRT(),
shape, input, shape, golden,
output_data);
}
TF_LITE_MICRO_TEST(Rsqrt) {
@ -198,8 +164,9 @@ TF_LITE_MICRO_TEST(Rsqrt) {
const float input[] = {1, 2, 4, 9};
const float golden[] = {1, 0.7071, 0.5, 0.33333};
float output_data[output_dims_count];
tflite::testing::TestElementwiseFloat(tflite::BuiltinOperator_RSQRT, shape,
input, shape, golden, output_data);
tflite::testing::TestElementwiseFloat(tflite::ops::micro::Register_RSQRT(),
shape, input, shape, golden,
output_data);
}
TF_LITE_MICRO_TEST(Square) {
@ -208,8 +175,9 @@ TF_LITE_MICRO_TEST(Square) {
const float input[] = {1, 2, 0.5, -3.0};
const float golden[] = {1, 4.0, 0.25, 9.0};
float output_data[output_dims_count];
tflite::testing::TestElementwiseFloat(tflite::BuiltinOperator_SQUARE, shape,
input, shape, golden, output_data);
tflite::testing::TestElementwiseFloat(tflite::ops::micro::Register_SQUARE(),
shape, input, shape, golden,
output_data);
}
TF_LITE_MICRO_TEST(LogicalNot) {
@ -218,9 +186,9 @@ TF_LITE_MICRO_TEST(LogicalNot) {
const bool input[] = {true, false, false, true};
const bool golden[] = {false, true, true, false};
bool output_data[output_dims_count];
tflite::testing::TestElementwiseBool(tflite::BuiltinOperator_LOGICAL_NOT,
shape, input, shape, golden,
output_data);
tflite::testing::TestElementwiseBool(
tflite::ops::micro::Register_LOGICAL_NOT(), shape, input, shape, golden,
output_data);
}
TF_LITE_MICRO_TESTS_END