Port the split, strided_slice and sub kernels to the new TfLiteEvalTensor API along with removing initializer lists from their tests.

PiperOrigin-RevId: 323453073
Change-Id: I3ecbaa4716870618c2f052826f4bde178fe38366
This commit is contained in:
Nat Jeffries 2020-07-27 15:24:16 -07:00 committed by TensorFlower Gardener
parent a9b6a48489
commit 20cd718248
7 changed files with 1187 additions and 1352 deletions

View File

@ -317,6 +317,7 @@ tflite_micro_cc_test(
"sub_test.cc",
],
deps = [
":kernel_runner",
"//tensorflow/lite/c:common",
"//tensorflow/lite/micro:op_resolvers",
"//tensorflow/lite/micro/testing:micro_test",
@ -380,6 +381,7 @@ tflite_micro_cc_test(
"strided_slice_test.cc",
],
deps = [
":kernel_runner",
"//tensorflow/lite/c:common",
"//tensorflow/lite/micro:op_resolvers",
"//tensorflow/lite/micro/testing:micro_test",
@ -420,6 +422,7 @@ tflite_micro_cc_test(
"split_test.cc",
],
deps = [
":kernel_runner",
"//tensorflow/lite/c:common",
"//tensorflow/lite/micro:debug_log",
"//tensorflow/lite/micro:op_resolvers",

View File

@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/micro/kernels/kernel_util.h"
namespace tflite {
namespace ops {
@ -25,10 +26,11 @@ namespace split {
template <typename T>
TfLiteStatus SplitImpl(TfLiteContext* context, TfLiteNode* node,
const TfLiteTensor* input, int axis_value) {
const TfLiteEvalTensor* input, int axis_value) {
const int output_count = NumOutputs(node);
const TfLiteIntArray* input_dims = input->dims;
const TfLiteTensor* output0 = GetOutput(context, node, 0);
const TfLiteEvalTensor* output0 =
tflite::micro::GetEvalOutput(context, node, 0);
const TfLiteIntArray* output_dims = output0->dims;
const int split_dimensions = input_dims->size;
@ -50,11 +52,11 @@ TfLiteStatus SplitImpl(TfLiteContext* context, TfLiteNode* node,
base_inner_size *= input_dims->data[i];
}
const T* input_ptr = GetTensorData<T>(input);
const T* input_ptr = tflite::micro::GetTensorData<T>(input);
for (int k = 0; k < outer_size; ++k) {
for (int i = 0; i < output_count; ++i) {
TfLiteTensor* t = GetOutput(context, node, i);
T* output_data = GetTensorData<T>(t);
TfLiteEvalTensor* t = tflite::micro::GetEvalOutput(context, node, i);
T* output_data = tflite::micro::GetTensorData<T>(t);
const int copy_size = output_dims->data[axis] * base_inner_size;
T* output_ptr = output_data + k * copy_size;
for (int j = 0; j < copy_size; ++j) output_ptr[j] = input_ptr[j];
@ -65,23 +67,28 @@ TfLiteStatus SplitImpl(TfLiteContext* context, TfLiteNode* node,
return kTfLiteOk;
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* axis = GetInput(context, node, 0);
const TfLiteTensor* input = GetInput(context, node, 1);
// Dynamic output tensors are needed if axis tensor is not constant.
// But Micro doesn't support dynamic memory allocation, so we only support
// constant axis tensor for now.
TF_LITE_ENSURE_MSG(context, IsConstantTensor(axis),
"Non constant axis tensor not supported");
return kTfLiteOk;
}
int axis_value = GetTensorData<int32_t>(axis)[0];
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteEvalTensor* axis = tflite::micro::GetEvalInput(context, node, 0);
const TfLiteEvalTensor* input = tflite::micro::GetEvalInput(context, node, 1);
int axis_value = tflite::micro::GetTensorData<int32_t>(axis)[0];
if (axis_value < 0) {
axis_value += NumDimensions(input);
axis_value += input->dims->size;
}
TF_LITE_ENSURE(context, axis_value >= 0);
TF_LITE_ENSURE(context, axis_value < NumDimensions(input));
TF_LITE_ENSURE(context, axis_value < input->dims->size);
switch (input->type) {
case kTfLiteFloat32: {
@ -114,7 +121,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TfLiteRegistration Register_SPLIT() {
return {/*init=*/nullptr,
/*free=*/nullptr,
/*prepare=*/nullptr,
/*prepare=*/split::Prepare,
/*invoke=*/split::Eval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,

View File

@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/micro/all_ops_resolver.h"
#include "tensorflow/lite/micro/debug_log.h"
#include "tensorflow/lite/micro/kernels/kernel_runner.h"
#include "tensorflow/lite/micro/testing/micro_test.h"
#include "tensorflow/lite/micro/testing/test_utils.h"
@ -24,19 +25,15 @@ namespace tflite {
namespace testing {
void TestSplitTwoOutputsFloat(
std::initializer_list<int> input_dims_data,
std::initializer_list<float> input_data,
std::initializer_list<int> axis_dims_data,
std::initializer_list<int32_t> axis_data,
std::initializer_list<int> output1_dims_data,
std::initializer_list<float> expected_output1_data,
std::initializer_list<int> output2_dims_data,
std::initializer_list<float> expected_output2_data, float* output1_data,
float* output2_data) {
TfLiteIntArray* input_dims = IntArrayFromInitializer(input_dims_data);
TfLiteIntArray* axis_dims = IntArrayFromInitializer(axis_dims_data);
TfLiteIntArray* output1_dims = IntArrayFromInitializer(output1_dims_data);
TfLiteIntArray* output2_dims = IntArrayFromInitializer(output2_dims_data);
const int* input_dims_data, const float* input_data,
const int* axis_dims_data, const int32_t* axis_data,
const int* output1_dims_data, const float* expected_output1_data,
const int* output2_dims_data, const float* expected_output2_data,
float* output1_data, float* output2_data) {
TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data);
TfLiteIntArray* axis_dims = IntArrayFromInts(axis_dims_data);
TfLiteIntArray* output1_dims = IntArrayFromInts(output1_dims_data);
TfLiteIntArray* output2_dims = IntArrayFromInts(output2_dims_data);
const int output1_dims_count = ElementCount(*output1_dims);
const int output2_dims_count = ElementCount(*output2_dims);
@ -61,76 +58,42 @@ void TestSplitTwoOutputsFloat(
output2_data[i] = 23;
}
TfLiteContext context;
PopulateContext(tensors, tensors_size, micro_test::reporter, &context);
tflite::AllOpsResolver resolver;
const TfLiteRegistration* registration =
resolver.FindOp(tflite::BuiltinOperator_SPLIT);
TF_LITE_MICRO_EXPECT_NE(nullptr, registration);
TfLiteSplitParams builtin_data = {
.num_splits = 2,
};
void* user_data = nullptr;
if (registration->init) {
user_data = registration->init(&context, nullptr, 0);
}
int inputs_array_data[] = {2, 0, 1};
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
int outputs_array_data[] = {2, 2, 3};
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
TfLiteNode node;
node.inputs = inputs_array;
node.outputs = outputs_array;
node.user_data = user_data;
node.builtin_data = reinterpret_cast<void*>(&builtin_data);
node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0;
const TfLiteRegistration registration = tflite::ops::micro::Register_SPLIT();
micro::KernelRunner runner(registration, tensors, tensors_size, inputs_array,
outputs_array, nullptr, micro_test::reporter);
if (registration->prepare) {
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
}
TF_LITE_MICRO_EXPECT_NE(nullptr, registration->invoke);
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->invoke(&context, &node));
if (registration->free) {
registration->free(&context, user_data);
}
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.InitAndPrepare());
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.Invoke());
for (int i = 0; i < output1_dims_count; ++i) {
TF_LITE_MICRO_EXPECT_NEAR(expected_output1_data.begin()[i], output1_data[i],
1e-5f);
TF_LITE_MICRO_EXPECT_NEAR(expected_output1_data[i], output1_data[i], 1e-5f);
}
for (int i = 0; i < output2_dims_count; ++i) {
TF_LITE_MICRO_EXPECT_NEAR(expected_output2_data.begin()[i], output2_data[i],
1e-5f);
TF_LITE_MICRO_EXPECT_NEAR(expected_output2_data[i], output2_data[i], 1e-5f);
}
}
void TestSplitFourOutputsFloat(
std::initializer_list<int> input_dims_data,
std::initializer_list<float> input_data,
std::initializer_list<int> axis_dims_data,
std::initializer_list<int32_t> axis_data,
std::initializer_list<int> output1_dims_data,
std::initializer_list<float> expected_output1_data,
std::initializer_list<int> output2_dims_data,
std::initializer_list<float> expected_output2_data,
std::initializer_list<int> output3_dims_data,
std::initializer_list<float> expected_output3_data,
std::initializer_list<int> output4_dims_data,
std::initializer_list<float> expected_output4_data, float* output1_data,
float* output2_data, float* output3_data, float* output4_data) {
TfLiteIntArray* input_dims = IntArrayFromInitializer(input_dims_data);
TfLiteIntArray* axis_dims = IntArrayFromInitializer(axis_dims_data);
TfLiteIntArray* output1_dims = IntArrayFromInitializer(output1_dims_data);
TfLiteIntArray* output2_dims = IntArrayFromInitializer(output2_dims_data);
TfLiteIntArray* output3_dims = IntArrayFromInitializer(output3_dims_data);
TfLiteIntArray* output4_dims = IntArrayFromInitializer(output4_dims_data);
const int* input_dims_data, const float* input_data,
const int* axis_dims_data, const int32_t* axis_data,
const int* output1_dims_data, const float* expected_output1_data,
const int* output2_dims_data, const float* expected_output2_data,
const int* output3_dims_data, const float* expected_output3_data,
const int* output4_dims_data, const float* expected_output4_data,
float* output1_data, float* output2_data, float* output3_data,
float* output4_data) {
TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data);
TfLiteIntArray* axis_dims = IntArrayFromInts(axis_dims_data);
TfLiteIntArray* output1_dims = IntArrayFromInts(output1_dims_data);
TfLiteIntArray* output2_dims = IntArrayFromInts(output2_dims_data);
TfLiteIntArray* output3_dims = IntArrayFromInts(output3_dims_data);
TfLiteIntArray* output4_dims = IntArrayFromInts(output4_dims_data);
const int output1_dims_count = ElementCount(*output1_dims);
const int output2_dims_count = ElementCount(*output2_dims);
const int output3_dims_count = ElementCount(*output3_dims);
@ -164,77 +127,42 @@ void TestSplitFourOutputsFloat(
output4_data[i] = 23;
}
TfLiteContext context;
PopulateContext(tensors, tensors_size, micro_test::reporter, &context);
tflite::AllOpsResolver resolver;
const TfLiteRegistration* registration =
resolver.FindOp(tflite::BuiltinOperator_SPLIT);
TF_LITE_MICRO_EXPECT_NE(nullptr, registration);
TfLiteSplitParams builtin_data = {
.num_splits = 4,
};
void* user_data = nullptr;
if (registration->init) {
user_data = registration->init(&context, nullptr, 0);
}
int inputs_array_data[] = {2, 0, 1};
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
int outputs_array_data[] = {4, 2, 3, 4, 5};
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
TfLiteNode node;
node.inputs = inputs_array;
node.outputs = outputs_array;
node.user_data = user_data;
node.builtin_data = reinterpret_cast<void*>(&builtin_data);
node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0;
const TfLiteRegistration registration = tflite::ops::micro::Register_SPLIT();
micro::KernelRunner runner(registration, tensors, tensors_size, inputs_array,
outputs_array, nullptr, micro_test::reporter);
if (registration->prepare) {
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
}
TF_LITE_MICRO_EXPECT_NE(nullptr, registration->invoke);
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->invoke(&context, &node));
if (registration->free) {
registration->free(&context, user_data);
}
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.InitAndPrepare());
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.Invoke());
for (int i = 0; i < output1_dims_count; ++i) {
TF_LITE_MICRO_EXPECT_NEAR(expected_output1_data.begin()[i], output1_data[i],
1e-5f);
TF_LITE_MICRO_EXPECT_NEAR(expected_output1_data[i], output1_data[i], 1e-5f);
}
for (int i = 0; i < output2_dims_count; ++i) {
TF_LITE_MICRO_EXPECT_NEAR(expected_output2_data.begin()[i], output2_data[i],
1e-5f);
TF_LITE_MICRO_EXPECT_NEAR(expected_output2_data[i], output2_data[i], 1e-5f);
}
for (int i = 0; i < output3_dims_count; ++i) {
TF_LITE_MICRO_EXPECT_NEAR(expected_output3_data.begin()[i], output3_data[i],
1e-5f);
TF_LITE_MICRO_EXPECT_NEAR(expected_output3_data[i], output3_data[i], 1e-5f);
}
for (int i = 0; i < output4_dims_count; ++i) {
TF_LITE_MICRO_EXPECT_NEAR(expected_output4_data.begin()[i], output4_data[i],
1e-5f);
TF_LITE_MICRO_EXPECT_NEAR(expected_output4_data[i], output4_data[i], 1e-5f);
}
}
void TestSplitTwoOutputsQuantized(
std::initializer_list<int> input_dims_data,
std::initializer_list<uint8_t> input_data,
std::initializer_list<int> axis_dims_data,
std::initializer_list<int32_t> axis_data,
std::initializer_list<int> output1_dims_data,
std::initializer_list<uint8_t> expected_output1_data,
std::initializer_list<int> output2_dims_data,
std::initializer_list<uint8_t> expected_output2_data, uint8_t* output1_data,
uint8_t* output2_data) {
TfLiteIntArray* input_dims = IntArrayFromInitializer(input_dims_data);
TfLiteIntArray* axis_dims = IntArrayFromInitializer(axis_dims_data);
TfLiteIntArray* output1_dims = IntArrayFromInitializer(output1_dims_data);
TfLiteIntArray* output2_dims = IntArrayFromInitializer(output2_dims_data);
const int* input_dims_data, const uint8_t* input_data,
const int* axis_dims_data, const int32_t* axis_data,
const int* output1_dims_data, const uint8_t* expected_output1_data,
const int* output2_dims_data, const uint8_t* expected_output2_data,
uint8_t* output1_data, uint8_t* output2_data) {
TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data);
TfLiteIntArray* axis_dims = IntArrayFromInts(axis_dims_data);
TfLiteIntArray* output1_dims = IntArrayFromInts(output1_dims_data);
TfLiteIntArray* output2_dims = IntArrayFromInts(output2_dims_data);
const int output1_dims_count = ElementCount(*output1_dims);
const int output2_dims_count = ElementCount(*output2_dims);
@ -260,68 +188,37 @@ void TestSplitTwoOutputsQuantized(
output2_data[i] = 23;
}
TfLiteContext context;
PopulateContext(tensors, tensors_size, micro_test::reporter, &context);
tflite::AllOpsResolver resolver;
const TfLiteRegistration* registration =
resolver.FindOp(tflite::BuiltinOperator_SPLIT);
TF_LITE_MICRO_EXPECT_NE(nullptr, registration);
TfLiteSplitParams builtin_data = {
.num_splits = 2,
};
void* user_data = nullptr;
if (registration->init) {
user_data = registration->init(&context, nullptr, 0);
}
int inputs_array_data[] = {2, 0, 1};
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
int outputs_array_data[] = {2, 2, 3};
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
TfLiteNode node;
node.inputs = inputs_array;
node.outputs = outputs_array;
node.user_data = user_data;
node.builtin_data = reinterpret_cast<void*>(&builtin_data);
node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0;
const TfLiteRegistration registration = tflite::ops::micro::Register_SPLIT();
micro::KernelRunner runner(registration, tensors, tensors_size, inputs_array,
outputs_array, nullptr, micro_test::reporter);
if (registration->prepare) {
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
}
TF_LITE_MICRO_EXPECT_NE(nullptr, registration->invoke);
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->invoke(&context, &node));
if (registration->free) {
registration->free(&context, user_data);
}
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.InitAndPrepare());
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.Invoke());
for (int i = 0; i < output1_dims_count; ++i) {
TF_LITE_MICRO_EXPECT_EQ(expected_output1_data.begin()[i], output1_data[i]);
TF_LITE_MICRO_EXPECT_EQ(expected_output1_data[i], output1_data[i]);
}
for (int i = 0; i < output2_dims_count; ++i) {
TF_LITE_MICRO_EXPECT_EQ(expected_output2_data.begin()[i], output2_data[i]);
TF_LITE_MICRO_EXPECT_EQ(expected_output2_data[i], output2_data[i]);
}
}
void TestSplitTwoOutputsQuantized32(
std::initializer_list<int> input_dims_data,
std::initializer_list<int32_t> input_data,
std::initializer_list<int> axis_dims_data,
std::initializer_list<int32_t> axis_data,
std::initializer_list<int> output1_dims_data,
std::initializer_list<int32_t> expected_output1_data,
std::initializer_list<int> output2_dims_data,
std::initializer_list<int32_t> expected_output2_data, int32_t* output1_data,
int32_t* output2_data) {
TfLiteIntArray* input_dims = IntArrayFromInitializer(input_dims_data);
TfLiteIntArray* axis_dims = IntArrayFromInitializer(axis_dims_data);
TfLiteIntArray* output1_dims = IntArrayFromInitializer(output1_dims_data);
TfLiteIntArray* output2_dims = IntArrayFromInitializer(output2_dims_data);
const int* input_dims_data, const int32_t* input_data,
const int* axis_dims_data, const int32_t* axis_data,
const int* output1_dims_data, const int32_t* expected_output1_data,
const int* output2_dims_data, const int32_t* expected_output2_data,
int32_t* output1_data, int32_t* output2_data) {
TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data);
TfLiteIntArray* axis_dims = IntArrayFromInts(axis_dims_data);
TfLiteIntArray* output1_dims = IntArrayFromInts(output1_dims_data);
TfLiteIntArray* output2_dims = IntArrayFromInts(output2_dims_data);
const int output1_dims_count = ElementCount(*output1_dims);
const int output2_dims_count = ElementCount(*output2_dims);
@ -347,51 +244,24 @@ void TestSplitTwoOutputsQuantized32(
output2_data[i] = 23;
}
TfLiteContext context;
PopulateContext(tensors, tensors_size, micro_test::reporter, &context);
tflite::AllOpsResolver resolver;
const TfLiteRegistration* registration =
resolver.FindOp(tflite::BuiltinOperator_SPLIT);
TF_LITE_MICRO_EXPECT_NE(nullptr, registration);
TfLiteSplitParams builtin_data = {
.num_splits = 2,
};
void* user_data = nullptr;
if (registration->init) {
user_data = registration->init(&context, nullptr, 0);
}
int inputs_array_data[] = {2, 0, 1};
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
int outputs_array_data[] = {2, 2, 3};
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
TfLiteNode node;
node.inputs = inputs_array;
node.outputs = outputs_array;
node.user_data = user_data;
node.builtin_data = reinterpret_cast<void*>(&builtin_data);
node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0;
const TfLiteRegistration registration = tflite::ops::micro::Register_SPLIT();
micro::KernelRunner runner(registration, tensors, tensors_size, inputs_array,
outputs_array, nullptr, micro_test::reporter);
if (registration->prepare) {
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
}
TF_LITE_MICRO_EXPECT_NE(nullptr, registration->invoke);
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->invoke(&context, &node));
if (registration->free) {
registration->free(&context, user_data);
}
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.InitAndPrepare());
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.Invoke());
for (int i = 0; i < output1_dims_count; ++i) {
TF_LITE_MICRO_EXPECT_EQ(expected_output1_data.begin()[i], output1_data[i]);
TF_LITE_MICRO_EXPECT_EQ(expected_output1_data[i], output1_data[i]);
}
for (int i = 0; i < output2_dims_count; ++i) {
TF_LITE_MICRO_EXPECT_EQ(expected_output2_data.begin()[i], output2_data[i]);
TF_LITE_MICRO_EXPECT_EQ(expected_output2_data[i], output2_data[i]);
}
}
@ -401,91 +271,119 @@ void TestSplitTwoOutputsQuantized32(
TF_LITE_MICRO_TESTS_BEGIN
TF_LITE_MICRO_TEST(TwoSplitFourDimensionalAxisZero) {
const int input_shape[] = {4, 2, 2, 2, 2};
const float input_data[] = {1, 2, 3, 4, 5, 6, 7, 8,
9, 10, 11, 12, 13, 14, 15, 16};
const int axis_shape[] = {1, 1};
const int32_t axis_data[] = {0};
const int output1_shape[] = {4, 1, 2, 2, 2};
const float golden1[] = {1, 2, 3, 4, 5, 6, 7, 8};
const int output2_shape[] = {4, 1, 2, 2, 2};
const float golden2[] = {9, 10, 11, 12, 13, 14, 15, 16};
constexpr int output1_dims_count = 8;
constexpr int output2_dims_count = 8;
float output1_data[output1_dims_count];
float output2_data[output2_dims_count];
tflite::testing::TestSplitTwoOutputsFloat(
{4, 2, 2, 2, 2}, // Input shape
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, // Input values
{1, 1}, // Axis shape
{0}, // Axis value
{4, 1, 2, 2, 2}, // Output1 shape
{1, 2, 3, 4, 5, 6, 7, 8}, // Output1 values
{4, 1, 2, 2, 2}, // Output2 shape
{9, 10, 11, 12, 13, 14, 15, 16}, // Output2 values
output1_data, output2_data);
input_shape, input_data, axis_shape, axis_data, output1_shape, golden1,
output2_shape, golden2, output1_data, output2_data);
}
TF_LITE_MICRO_TEST(TwoSplitFourDimensionalAxisOne) {
const int input_shape[] = {4, 2, 2, 2, 2};
const float input_data[] = {1, 2, 3, 4, 5, 6, 7, 8,
9, 10, 11, 12, 13, 14, 15, 16};
const int axis_shape[] = {1, 1};
const int32_t axis_data[] = {1};
const int output1_shape[] = {4, 2, 1, 2, 2};
const float golden1[] = {1, 2, 3, 4, 9, 10, 11, 12};
const int output2_shape[] = {4, 2, 1, 2, 2};
const float golden2[] = {5, 6, 7, 8, 13, 14, 15, 16};
constexpr int output1_dims_count = 8;
constexpr int output2_dims_count = 8;
float output1_data[output1_dims_count];
float output2_data[output2_dims_count];
tflite::testing::TestSplitTwoOutputsFloat(
{4, 2, 2, 2, 2}, // Input shape
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, // Input values
{1, 1}, // Axis shape
{1}, // Axis value
{4, 2, 1, 2, 2}, // Output1 shape
{1, 2, 3, 4, 9, 10, 11, 12}, // Output1 values
{4, 2, 1, 2, 2}, // Output2 shape
{5, 6, 7, 8, 13, 14, 15, 16}, // Output2 values
output1_data, output2_data);
input_shape, input_data, axis_shape, axis_data, output1_shape, golden1,
output2_shape, golden2, output1_data, output2_data);
}
TF_LITE_MICRO_TEST(TwoSplitFourDimensionalAxisTwo) {
const int input_shape[] = {4, 2, 2, 2, 2};
const float input_data[] = {1, 2, 3, 4, 5, 6, 7, 8,
9, 10, 11, 12, 13, 14, 15, 16};
const int axis_shape[] = {1, 1};
const int32_t axis_data[] = {2};
const int output1_shape[] = {4, 2, 2, 1, 2};
const float golden1[] = {1, 2, 5, 6, 9, 10, 13, 14};
const int output2_shape[] = {4, 2, 2, 1, 2};
const float golden2[] = {3, 4, 7, 8, 11, 12, 15, 16};
constexpr int output1_dims_count = 8;
constexpr int output2_dims_count = 8;
float output1_data[output1_dims_count];
float output2_data[output2_dims_count];
tflite::testing::TestSplitTwoOutputsFloat(
{4, 2, 2, 2, 2}, // Input shape
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, // Input values
{1, 1}, // Axis shape
{2}, // Axis value
{4, 2, 2, 1, 2}, // Output1 shape
{1, 2, 5, 6, 9, 10, 13, 14}, // Output1 values
{4, 2, 2, 1, 2}, // Output2 shape
{3, 4, 7, 8, 11, 12, 15, 16}, // Output2 values
output1_data, output2_data);
input_shape, input_data, axis_shape, axis_data, output1_shape, golden1,
output2_shape, golden2, output1_data, output2_data);
}
TF_LITE_MICRO_TEST(TwoSplitFourDimensionalAxisThree) {
const int input_shape[] = {4, 2, 2, 2, 2};
const float input_data[] = {1, 2, 3, 4, 5, 6, 7, 8,
9, 10, 11, 12, 13, 14, 15, 16};
const int axis_shape[] = {1, 1};
const int32_t axis_data[] = {3};
const int output1_shape[] = {4, 2, 2, 2, 1};
const float golden1[] = {1, 3, 5, 7, 9, 11, 13, 15};
const int output2_shape[] = {4, 2, 2, 2, 1};
const float golden2[] = {2, 4, 6, 8, 10, 12, 14, 16};
constexpr int output1_dims_count = 8;
constexpr int output2_dims_count = 8;
float output1_data[output1_dims_count];
float output2_data[output2_dims_count];
tflite::testing::TestSplitTwoOutputsFloat(
{4, 2, 2, 2, 2}, // Input shape
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, // Input values
{1, 1}, // Axis shape
{3}, // Axis value
{4, 2, 2, 2, 1}, // Output1 shape
{1, 3, 5, 7, 9, 11, 13, 15}, // Output1 values
{4, 2, 2, 2, 1}, // Output2 shape
{2, 4, 6, 8, 10, 12, 14, 16}, // Output2 values
output1_data, output2_data);
input_shape, input_data, axis_shape, axis_data, output1_shape, golden1,
output2_shape, golden2, output1_data, output2_data);
}
TF_LITE_MICRO_TEST(TwoSplitFourDimensionalNegativeAxis) {
const int input_shape[] = {4, 2, 2, 2, 2};
const float input_data[] = {1, 2, 3, 4, 5, 6, 7, 8,
9, 10, 11, 12, 13, 14, 15, 16};
const int axis_shape[] = {1, 1};
const int32_t axis_data[] = {-4};
const int output1_shape[] = {4, 1, 2, 2, 2};
const float golden1[] = {1, 2, 3, 4, 5, 6, 7, 8};
const int output2_shape[] = {4, 1, 2, 2, 2};
const float golden2[] = {9, 10, 11, 12, 13, 14, 15, 16};
constexpr int output1_dims_count = 8;
constexpr int output2_dims_count = 8;
float output1_data[output1_dims_count];
float output2_data[output2_dims_count];
tflite::testing::TestSplitTwoOutputsFloat(
{4, 2, 2, 2, 2}, // Input shape
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, // Input values
{1, 1}, // Axis shape
{-4}, // Axis value
{4, 1, 2, 2, 2}, // Output1 shape
{1, 2, 3, 4, 5, 6, 7, 8}, // Output1 values
{4, 1, 2, 2, 2}, // Output2 shape
{9, 10, 11, 12, 13, 14, 15, 16}, // Output2 values
output1_data, output2_data);
input_shape, input_data, axis_shape, axis_data, output1_shape, golden1,
output2_shape, golden2, output1_data, output2_data);
}
TF_LITE_MICRO_TEST(FourSplit) {
const int input_shape[] = {1, 4};
const float input_data[] = {1, 2, 3, 4};
const int axis_shape[] = {1, 1};
const int32_t axis_data[] = {0};
const int output1_shape[] = {1, 1};
const float golden1[] = {1};
const int output2_shape[] = {1, 1};
const float golden2[] = {2};
const int output3_shape[] = {1, 1};
const float golden3[] = {3};
const int output4_shape[] = {1, 1};
const float golden4[] = {4};
constexpr int output1_dims_count = 1;
constexpr int output2_dims_count = 1;
constexpr int output3_dims_count = 1;
@ -494,70 +392,69 @@ TF_LITE_MICRO_TEST(FourSplit) {
float output2_data[output2_dims_count];
float output3_data[output3_dims_count];
float output4_data[output4_dims_count];
tflite::testing::TestSplitFourOutputsFloat({1, 4}, // Input shape
{1, 2, 3, 4}, // Input values
{1, 1}, // Axis shape
{0}, // Axis value
{1, 1}, // Output1 shape
{1}, // Output1 values
{1, 1}, // Output2 shape
{2}, // Output2 values
{1, 1}, // Output3 shape
{3}, // Output3 values
{1, 1}, // Output4 shape
{4}, // Output4 values
output1_data, output2_data,
output3_data, output4_data);
tflite::testing::TestSplitFourOutputsFloat(
input_shape, input_data, axis_shape, axis_data, output1_shape, golden1,
output2_shape, golden2, output3_shape, golden3, output4_shape, golden4,
output1_data, output2_data, output3_data, output4_data);
}
TF_LITE_MICRO_TEST(TwoSplitOneDimensional) {
const int input_shape[] = {1, 2};
const float input_data[] = {1, 2};
const int axis_shape[] = {1, 1};
const int32_t axis_data[] = {0};
const int output1_shape[] = {1, 1};
const float golden1[] = {1};
const int output2_shape[] = {1, 1};
const float golden2[] = {2};
constexpr int output1_dims_count = 8;
constexpr int output2_dims_count = 8;
float output1_data[output1_dims_count];
float output2_data[output2_dims_count];
tflite::testing::TestSplitTwoOutputsFloat({1, 2}, // Input shape
{1, 2}, // Input values
{1, 1}, // Axis shape
{0}, // Axis value
{1, 1}, // Output1 shape
{1}, // Output1 values
{1, 1}, // Output2 shape
{2}, // Output2 values
output1_data, output2_data);
tflite::testing::TestSplitTwoOutputsFloat(
input_shape, input_data, axis_shape, axis_data, output1_shape, golden1,
output2_shape, golden2, output1_data, output2_data);
}
TF_LITE_MICRO_TEST(TwoSplitFourDimensionalQuantized) {
const int input_shape[] = {4, 2, 2, 2, 2};
const uint8_t input_data[] = {1, 2, 3, 4, 5, 6, 7, 8,
9, 10, 11, 12, 13, 14, 15, 16};
const int axis_shape[] = {1, 1};
const int32_t axis_data[] = {1};
const int output1_shape[] = {4, 2, 1, 2, 2};
const uint8_t golden1[] = {1, 2, 3, 4, 9, 10, 11, 12};
const int output2_shape[] = {4, 2, 1, 2, 2};
const uint8_t golden2[] = {5, 6, 7, 8, 13, 14, 15, 16};
constexpr int output1_dims_count = 8;
constexpr int output2_dims_count = 8;
uint8_t output1_data[output1_dims_count];
uint8_t output2_data[output2_dims_count];
tflite::testing::TestSplitTwoOutputsQuantized(
{4, 2, 2, 2, 2}, // Input shape
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, // Input values
{1, 1}, // Axis shape
{0}, // Axis value
{4, 1, 2, 2, 2}, // Output1 shape
{1, 2, 3, 4, 5, 6, 7, 8}, // Output1 values
{4, 1, 2, 2, 2}, // Output2 shape
{9, 10, 11, 12, 13, 14, 15, 16}, // Output2 values
output1_data, output2_data);
input_shape, input_data, axis_shape, axis_data, output1_shape, golden1,
output2_shape, golden2, output1_data, output2_data);
}
TF_LITE_MICRO_TEST(TwoSplitFourDimensionalQuantized32) {
const int input_shape[] = {4, 2, 2, 2, 2};
const int32_t input_data[] = {1, 2, 3, 4, 5, 6, 7, 8,
9, 10, 11, 12, 13, 14, 15, 16};
const int axis_shape[] = {1, 1};
const int32_t axis_data[] = {1};
const int output1_shape[] = {4, 2, 1, 2, 2};
const int32_t golden1[] = {1, 2, 3, 4, 9, 10, 11, 12};
const int output2_shape[] = {4, 2, 1, 2, 2};
const int32_t golden2[] = {5, 6, 7, 8, 13, 14, 15, 16};
constexpr int output1_dims_count = 8;
constexpr int output2_dims_count = 8;
int32_t output1_data[output1_dims_count];
int32_t output2_data[output2_dims_count];
tflite::testing::TestSplitTwoOutputsQuantized32(
{4, 2, 2, 2, 2}, // Input shape
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, // Input values
{1, 1}, // Axis shape
{0}, // Axis value
{4, 1, 2, 2, 2}, // Output1 shape
{1, 2, 3, 4, 5, 6, 7, 8}, // Output1 values
{4, 1, 2, 2, 2}, // Output2 shape
{9, 10, 11, 12, 13, 14, 15, 16}, // Output2 values
output1_data, output2_data);
input_shape, input_data, axis_shape, axis_data, output1_shape, golden1,
output2_shape, golden2, output1_data, output2_data);
}
TF_LITE_MICRO_TESTS_END

View File

@ -15,23 +15,20 @@ limitations under the License.
#include "tensorflow/lite/kernels/internal/reference/strided_slice.h"
#include <cmath>
#include <cstring>
#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/kernels/op_macros.h"
#include "tensorflow/lite/micro/kernels/kernel_util.h"
namespace tflite {
namespace ops {
namespace micro {
namespace strided_slice {
enum KernelType {
kReference,
// TODO(soroosh): add kGenericOptimized
};
constexpr int kInputTensor = 0;
constexpr int kBeginTensor = 1;
constexpr int kEndTensor = 2;
@ -120,58 +117,70 @@ TfLiteStatus CheckOutputSize(TfLiteContext* context,
return kTfLiteOk;
}
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
return context->AllocatePersistentBuffer(context, sizeof(StridedSliceParams));
}
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TFLITE_DCHECK(node->user_data != nullptr);
StridedSliceParams* op_params =
static_cast<StridedSliceParams*>(node->user_data);
TF_LITE_ENSURE_EQ(context, NumInputs(node), 4);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
StridedSliceContext op_context(context, node);
TF_LITE_ENSURE_MSG(context, op_context.dims <= kMaxDim,
"input dim should not exceed 4");
auto params = BuildStridedSliceParams(&op_context);
memcpy(op_params, &params, sizeof(StridedSliceParams));
return CheckOutputSize(context, &op_context);
}
template <KernelType kernel_type>
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
StridedSliceContext op_context(context, node);
auto op_params = BuildStridedSliceParams(&op_context);
TFLITE_DCHECK(node->user_data != nullptr);
const StridedSliceParams& op_params =
*(static_cast<const StridedSliceParams*>(node->user_data));
#define TF_LITE_STRIDED_SLICE(kernel_type, data_type) \
kernel_type::StridedSlice(op_params, GetTensorShape(op_context.input), \
GetTensorData<data_type>(op_context.input), \
GetTensorShape(op_context.output), \
GetTensorData<data_type>(op_context.output))
switch (op_context.input->type) {
const TfLiteEvalTensor* input =
tflite::micro::GetEvalInput(context, node, kInputTensor);
TfLiteEvalTensor* output =
tflite::micro::GetEvalOutput(context, node, kOutputTensor);
switch (output->type) {
case kTfLiteFloat32:
if (kernel_type == kReference) {
TF_LITE_STRIDED_SLICE(reference_ops, float);
}
reference_ops::StridedSlice(op_params,
tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorData<float>(input),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<float>(output));
break;
case kTfLiteUInt8:
if (kernel_type == kReference) {
TF_LITE_STRIDED_SLICE(reference_ops, uint8_t);
}
reference_ops::StridedSlice(
op_params, tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorData<uint8_t>(input),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<uint8_t>(output));
break;
case kTfLiteInt8:
if (kernel_type == kReference) {
TF_LITE_STRIDED_SLICE(reference_ops, int8_t);
}
reference_ops::StridedSlice(op_params,
tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorData<int8_t>(input),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<int8_t>(output));
break;
default:
TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
TfLiteTypeGetName(op_context.input->type),
op_context.input->type);
TfLiteTypeGetName(input->type), input->type);
return kTfLiteError;
}
#undef TF_LITE_STRIDED_SLICE
return kTfLiteOk;
}
} // namespace strided_slice
TfLiteRegistration Register_STRIDED_SLICE() {
return {/*init=*/nullptr,
return {/*init=*/strided_slice::Init,
/*free=*/nullptr,
/*prepare=*/strided_slice::Prepare,
/*invoke=*/strided_slice::Eval<strided_slice::kReference>,
/*invoke=*/strided_slice::Eval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,

File diff suppressed because it is too large Load Diff

View File

@ -21,8 +21,10 @@ limitations under the License.
#include "tensorflow/lite/kernels/internal/quantization_util.h"
#include "tensorflow/lite/kernels/internal/reference/process_broadcast_shapes.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/internal/types.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/kernels/op_macros.h"
#include "tensorflow/lite/micro/kernels/kernel_util.h"
namespace tflite {
namespace ops {
@ -93,31 +95,59 @@ TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteSubParams* params,
return kTfLiteOk;
}
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
return context->AllocatePersistentBuffer(context, sizeof(OpData));
}
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TFLITE_DCHECK(node->user_data != nullptr);
TFLITE_DCHECK(node->builtin_data != nullptr);
OpData* data = static_cast<OpData*>(node->user_data);
auto* params = reinterpret_cast<TfLiteSubParams*>(node->builtin_data);
const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TF_LITE_ENSURE_STATUS(
CalculateOpData(context, params, input1, input2, output, data));
return kTfLiteOk;
}
void EvalSub(TfLiteContext* context, TfLiteNode* node, TfLiteSubParams* params,
const OpData* data, const TfLiteTensor* input1,
const TfLiteTensor* input2, TfLiteTensor* output) {
const OpData* data, const TfLiteEvalTensor* input1,
const TfLiteEvalTensor* input2, TfLiteEvalTensor* output) {
float output_activation_min, output_activation_max;
CalculateActivationRange(params->activation, &output_activation_min,
&output_activation_max);
tflite::ArithmeticParams op_params;
SetActivationParams(output_activation_min, output_activation_max, &op_params);
#define TF_LITE_SUB(opname) \
opname(op_params, GetTensorShape(input1), GetTensorData<float>(input1), \
GetTensorShape(input2), GetTensorData<float>(input2), \
GetTensorShape(output), GetTensorData<float>(output))
if (data->requires_broadcast) {
TF_LITE_SUB(tflite::reference_ops::BroadcastSubSlow);
tflite::reference_ops::BroadcastSubSlow(
op_params, tflite::micro::GetTensorShape(input1),
tflite::micro::GetTensorData<float>(input1),
tflite::micro::GetTensorShape(input2),
tflite::micro::GetTensorData<float>(input2),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<float>(output));
} else {
TF_LITE_SUB(tflite::reference_ops::SubWithActivation);
tflite::reference_ops::SubWithActivation(
op_params, tflite::micro::GetTensorShape(input1),
tflite::micro::GetTensorData<float>(input1),
tflite::micro::GetTensorShape(input2),
tflite::micro::GetTensorData<float>(input2),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<float>(output));
}
#undef TF_LITE_SUB
}
TfLiteStatus EvalSubQuantized(TfLiteContext* context, TfLiteNode* node,
TfLiteSubParams* params, const OpData* data,
const TfLiteTensor* input1,
const TfLiteTensor* input2,
TfLiteTensor* output) {
const TfLiteEvalTensor* input1,
const TfLiteEvalTensor* input2,
TfLiteEvalTensor* output) {
if (output->type == kTfLiteUInt8 || output->type == kTfLiteInt8) {
tflite::ArithmeticParams op_params;
op_params.left_shift = data->left_shift;
@ -133,25 +163,46 @@ TfLiteStatus EvalSubQuantized(TfLiteContext* context, TfLiteNode* node,
SetActivationParams(data->output_activation_min,
data->output_activation_max, &op_params);
bool need_broadcast = reference_ops::ProcessBroadcastShapes(
GetTensorShape(input1), GetTensorShape(input2), &op_params);
#define TF_LITE_SUB(opname, dtype) \
opname(op_params, GetTensorShape(input1), GetTensorData<dtype>(input1), \
GetTensorShape(input2), GetTensorData<dtype>(input2), \
GetTensorShape(output), GetTensorData<dtype>(output));
tflite::micro::GetTensorShape(input1),
tflite::micro::GetTensorShape(input2), &op_params);
if (output->type == kTfLiteInt8) {
if (need_broadcast) {
TF_LITE_SUB(tflite::reference_ops::BroadcastSubSlow, int8_t);
tflite::reference_ops::BroadcastSubSlow(
op_params, tflite::micro::GetTensorShape(input1),
tflite::micro::GetTensorData<int8_t>(input1),
tflite::micro::GetTensorShape(input2),
tflite::micro::GetTensorData<int8_t>(input2),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<int8_t>(output));
} else {
TF_LITE_SUB(tflite::reference_ops::Sub, int8_t);
tflite::reference_ops::Sub(
op_params, tflite::micro::GetTensorShape(input1),
tflite::micro::GetTensorData<int8_t>(input1),
tflite::micro::GetTensorShape(input2),
tflite::micro::GetTensorData<int8_t>(input2),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<int8_t>(output));
}
} else {
if (need_broadcast) {
TF_LITE_SUB(tflite::reference_ops::BroadcastSubSlow, uint8_t);
tflite::reference_ops::BroadcastSubSlow(
op_params, tflite::micro::GetTensorShape(input1),
tflite::micro::GetTensorData<uint8_t>(input1),
tflite::micro::GetTensorShape(input2),
tflite::micro::GetTensorData<uint8_t>(input2),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<uint8_t>(output));
} else {
TF_LITE_SUB(tflite::reference_ops::Sub, uint8_t);
tflite::reference_ops::Sub(
op_params, tflite::micro::GetTensorShape(input1),
tflite::micro::GetTensorData<uint8_t>(input1),
tflite::micro::GetTensorShape(input2),
tflite::micro::GetTensorData<uint8_t>(input2),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<uint8_t>(output));
}
}
#undef TF_LITE_SUB
}
return kTfLiteOk;
@ -160,13 +211,15 @@ TfLiteStatus EvalSubQuantized(TfLiteContext* context, TfLiteNode* node,
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
auto* params = reinterpret_cast<TfLiteSubParams*>(node->builtin_data);
const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
const TfLiteEvalTensor* input1 =
tflite::micro::GetEvalInput(context, node, kInputTensor1);
const TfLiteEvalTensor* input2 =
tflite::micro::GetEvalInput(context, node, kInputTensor2);
TfLiteEvalTensor* output =
tflite::micro::GetEvalOutput(context, node, kOutputTensor);
OpData data;
TF_LITE_ENSURE_STATUS(
CalculateOpData(context, params, input1, input2, output, &data));
TFLITE_DCHECK(node->user_data != nullptr);
const OpData& data = *(static_cast<const OpData*>(node->user_data));
if (output->type == kTfLiteFloat32) {
EvalSub(context, node, params, &data, input1, input2, output);
@ -185,9 +238,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} // namespace sub
TfLiteRegistration Register_SUB() {
return {/*init=*/nullptr,
return {/*init=*/sub::Init,
/*free=*/nullptr,
/*prepare=*/nullptr,
/*prepare=*/sub::Prepare,
/*invoke=*/sub::Eval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,

View File

@ -17,7 +17,7 @@ limitations under the License.
#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/micro/all_ops_resolver.h"
#include "tensorflow/lite/micro/kernels/kernel_runner.h"
#include "tensorflow/lite/micro/testing/micro_test.h"
#include "tensorflow/lite/micro/testing/test_utils.h"
@ -66,47 +66,21 @@ void ValidateSubGoldens(TfLiteTensor* tensors, int tensors_size,
const T* golden, T* output, int output_size,
TfLiteFusedActivation activation,
float tolerance = 1e-5) {
TfLiteContext context;
PopulateContext(tensors, tensors_size, micro_test::reporter, &context);
::tflite::AllOpsResolver resolver;
const TfLiteRegistration* registration =
resolver.FindOp(::tflite::BuiltinOperator_SUB);
TF_LITE_MICRO_EXPECT_NE(nullptr, registration);
TfLiteSubParams builtin_data;
builtin_data.activation = activation;
const char* init_data = reinterpret_cast<const char*>(&builtin_data);
const size_t init_data_size = 0;
void* user_data = nullptr;
if (registration->init) {
user_data = registration->init(&context, init_data, init_data_size);
}
int inputs_array_data[] = {2, 0, 1};
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
int outputs_array_data[] = {1, 2};
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
TfLiteNode node;
node.inputs = inputs_array;
node.outputs = outputs_array;
node.user_data = user_data;
node.builtin_data = reinterpret_cast<void*>(&builtin_data);
node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0;
const TfLiteRegistration registration = tflite::ops::micro::Register_SUB();
micro::KernelRunner runner(registration, tensors, tensors_size, inputs_array,
outputs_array, &builtin_data,
micro_test::reporter);
if (registration->prepare) {
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
}
TF_LITE_MICRO_EXPECT_NE(nullptr, registration->invoke);
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->invoke(&context, &node));
if (registration->free) {
registration->free(&context, user_data);
}
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.InitAndPrepare());
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.Invoke());
for (int i = 0; i < output_size; ++i) {
TF_LITE_MICRO_EXPECT_NEAR(golden[i], output[i], tolerance);