Port the split, strided_slice and sub kernels to the new TfLiteEvalTensor API along with removing initializer lists from their tests.

PiperOrigin-RevId: 323453073
Change-Id: I3ecbaa4716870618c2f052826f4bde178fe38366
This commit is contained in:
Nat Jeffries 2020-07-27 15:24:16 -07:00 committed by TensorFlower Gardener
parent a9b6a48489
commit 20cd718248
7 changed files with 1187 additions and 1352 deletions

View File

@ -317,6 +317,7 @@ tflite_micro_cc_test(
"sub_test.cc", "sub_test.cc",
], ],
deps = [ deps = [
":kernel_runner",
"//tensorflow/lite/c:common", "//tensorflow/lite/c:common",
"//tensorflow/lite/micro:op_resolvers", "//tensorflow/lite/micro:op_resolvers",
"//tensorflow/lite/micro/testing:micro_test", "//tensorflow/lite/micro/testing:micro_test",
@ -380,6 +381,7 @@ tflite_micro_cc_test(
"strided_slice_test.cc", "strided_slice_test.cc",
], ],
deps = [ deps = [
":kernel_runner",
"//tensorflow/lite/c:common", "//tensorflow/lite/c:common",
"//tensorflow/lite/micro:op_resolvers", "//tensorflow/lite/micro:op_resolvers",
"//tensorflow/lite/micro/testing:micro_test", "//tensorflow/lite/micro/testing:micro_test",
@ -420,6 +422,7 @@ tflite_micro_cc_test(
"split_test.cc", "split_test.cc",
], ],
deps = [ deps = [
":kernel_runner",
"//tensorflow/lite/c:common", "//tensorflow/lite/c:common",
"//tensorflow/lite/micro:debug_log", "//tensorflow/lite/micro:debug_log",
"//tensorflow/lite/micro:op_resolvers", "//tensorflow/lite/micro:op_resolvers",

View File

@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/lite/c/common.h" #include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h" #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/kernel_util.h" #include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/micro/kernels/kernel_util.h"
namespace tflite { namespace tflite {
namespace ops { namespace ops {
@ -25,10 +26,11 @@ namespace split {
template <typename T> template <typename T>
TfLiteStatus SplitImpl(TfLiteContext* context, TfLiteNode* node, TfLiteStatus SplitImpl(TfLiteContext* context, TfLiteNode* node,
const TfLiteTensor* input, int axis_value) { const TfLiteEvalTensor* input, int axis_value) {
const int output_count = NumOutputs(node); const int output_count = NumOutputs(node);
const TfLiteIntArray* input_dims = input->dims; const TfLiteIntArray* input_dims = input->dims;
const TfLiteTensor* output0 = GetOutput(context, node, 0); const TfLiteEvalTensor* output0 =
tflite::micro::GetEvalOutput(context, node, 0);
const TfLiteIntArray* output_dims = output0->dims; const TfLiteIntArray* output_dims = output0->dims;
const int split_dimensions = input_dims->size; const int split_dimensions = input_dims->size;
@ -50,11 +52,11 @@ TfLiteStatus SplitImpl(TfLiteContext* context, TfLiteNode* node,
base_inner_size *= input_dims->data[i]; base_inner_size *= input_dims->data[i];
} }
const T* input_ptr = GetTensorData<T>(input); const T* input_ptr = tflite::micro::GetTensorData<T>(input);
for (int k = 0; k < outer_size; ++k) { for (int k = 0; k < outer_size; ++k) {
for (int i = 0; i < output_count; ++i) { for (int i = 0; i < output_count; ++i) {
TfLiteTensor* t = GetOutput(context, node, i); TfLiteEvalTensor* t = tflite::micro::GetEvalOutput(context, node, i);
T* output_data = GetTensorData<T>(t); T* output_data = tflite::micro::GetTensorData<T>(t);
const int copy_size = output_dims->data[axis] * base_inner_size; const int copy_size = output_dims->data[axis] * base_inner_size;
T* output_ptr = output_data + k * copy_size; T* output_ptr = output_data + k * copy_size;
for (int j = 0; j < copy_size; ++j) output_ptr[j] = input_ptr[j]; for (int j = 0; j < copy_size; ++j) output_ptr[j] = input_ptr[j];
@ -65,23 +67,28 @@ TfLiteStatus SplitImpl(TfLiteContext* context, TfLiteNode* node,
return kTfLiteOk; return kTfLiteOk;
} }
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* axis = GetInput(context, node, 0); const TfLiteTensor* axis = GetInput(context, node, 0);
const TfLiteTensor* input = GetInput(context, node, 1);
// Dynamic output tensors are needed if axis tensor is not constant. // Dynamic output tensors are needed if axis tensor is not constant.
// But Micro doesn't support dynamic memory allocation, so we only support // But Micro doesn't support dynamic memory allocation, so we only support
// constant axis tensor for now. // constant axis tensor for now.
TF_LITE_ENSURE_MSG(context, IsConstantTensor(axis), TF_LITE_ENSURE_MSG(context, IsConstantTensor(axis),
"Non constant axis tensor not supported"); "Non constant axis tensor not supported");
return kTfLiteOk;
}
int axis_value = GetTensorData<int32_t>(axis)[0]; TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteEvalTensor* axis = tflite::micro::GetEvalInput(context, node, 0);
const TfLiteEvalTensor* input = tflite::micro::GetEvalInput(context, node, 1);
int axis_value = tflite::micro::GetTensorData<int32_t>(axis)[0];
if (axis_value < 0) { if (axis_value < 0) {
axis_value += NumDimensions(input); axis_value += input->dims->size;
} }
TF_LITE_ENSURE(context, axis_value >= 0); TF_LITE_ENSURE(context, axis_value >= 0);
TF_LITE_ENSURE(context, axis_value < NumDimensions(input)); TF_LITE_ENSURE(context, axis_value < input->dims->size);
switch (input->type) { switch (input->type) {
case kTfLiteFloat32: { case kTfLiteFloat32: {
@ -114,7 +121,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TfLiteRegistration Register_SPLIT() { TfLiteRegistration Register_SPLIT() {
return {/*init=*/nullptr, return {/*init=*/nullptr,
/*free=*/nullptr, /*free=*/nullptr,
/*prepare=*/nullptr, /*prepare=*/split::Prepare,
/*invoke=*/split::Eval, /*invoke=*/split::Eval,
/*profiling_string=*/nullptr, /*profiling_string=*/nullptr,
/*builtin_code=*/0, /*builtin_code=*/0,

View File

@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/lite/c/common.h" #include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/micro/all_ops_resolver.h" #include "tensorflow/lite/micro/all_ops_resolver.h"
#include "tensorflow/lite/micro/debug_log.h" #include "tensorflow/lite/micro/debug_log.h"
#include "tensorflow/lite/micro/kernels/kernel_runner.h"
#include "tensorflow/lite/micro/testing/micro_test.h" #include "tensorflow/lite/micro/testing/micro_test.h"
#include "tensorflow/lite/micro/testing/test_utils.h" #include "tensorflow/lite/micro/testing/test_utils.h"
@ -24,19 +25,15 @@ namespace tflite {
namespace testing { namespace testing {
void TestSplitTwoOutputsFloat( void TestSplitTwoOutputsFloat(
std::initializer_list<int> input_dims_data, const int* input_dims_data, const float* input_data,
std::initializer_list<float> input_data, const int* axis_dims_data, const int32_t* axis_data,
std::initializer_list<int> axis_dims_data, const int* output1_dims_data, const float* expected_output1_data,
std::initializer_list<int32_t> axis_data, const int* output2_dims_data, const float* expected_output2_data,
std::initializer_list<int> output1_dims_data, float* output1_data, float* output2_data) {
std::initializer_list<float> expected_output1_data, TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data);
std::initializer_list<int> output2_dims_data, TfLiteIntArray* axis_dims = IntArrayFromInts(axis_dims_data);
std::initializer_list<float> expected_output2_data, float* output1_data, TfLiteIntArray* output1_dims = IntArrayFromInts(output1_dims_data);
float* output2_data) { TfLiteIntArray* output2_dims = IntArrayFromInts(output2_dims_data);
TfLiteIntArray* input_dims = IntArrayFromInitializer(input_dims_data);
TfLiteIntArray* axis_dims = IntArrayFromInitializer(axis_dims_data);
TfLiteIntArray* output1_dims = IntArrayFromInitializer(output1_dims_data);
TfLiteIntArray* output2_dims = IntArrayFromInitializer(output2_dims_data);
const int output1_dims_count = ElementCount(*output1_dims); const int output1_dims_count = ElementCount(*output1_dims);
const int output2_dims_count = ElementCount(*output2_dims); const int output2_dims_count = ElementCount(*output2_dims);
@ -61,76 +58,42 @@ void TestSplitTwoOutputsFloat(
output2_data[i] = 23; output2_data[i] = 23;
} }
TfLiteContext context;
PopulateContext(tensors, tensors_size, micro_test::reporter, &context);
tflite::AllOpsResolver resolver;
const TfLiteRegistration* registration =
resolver.FindOp(tflite::BuiltinOperator_SPLIT);
TF_LITE_MICRO_EXPECT_NE(nullptr, registration);
TfLiteSplitParams builtin_data = {
.num_splits = 2,
};
void* user_data = nullptr;
if (registration->init) {
user_data = registration->init(&context, nullptr, 0);
}
int inputs_array_data[] = {2, 0, 1}; int inputs_array_data[] = {2, 0, 1};
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data); TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
int outputs_array_data[] = {2, 2, 3}; int outputs_array_data[] = {2, 2, 3};
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data); TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
TfLiteNode node; const TfLiteRegistration registration = tflite::ops::micro::Register_SPLIT();
node.inputs = inputs_array; micro::KernelRunner runner(registration, tensors, tensors_size, inputs_array,
node.outputs = outputs_array; outputs_array, nullptr, micro_test::reporter);
node.user_data = user_data;
node.builtin_data = reinterpret_cast<void*>(&builtin_data);
node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0;
if (registration->prepare) { TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.InitAndPrepare());
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node)); TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.Invoke());
}
TF_LITE_MICRO_EXPECT_NE(nullptr, registration->invoke);
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->invoke(&context, &node));
if (registration->free) {
registration->free(&context, user_data);
}
for (int i = 0; i < output1_dims_count; ++i) { for (int i = 0; i < output1_dims_count; ++i) {
TF_LITE_MICRO_EXPECT_NEAR(expected_output1_data.begin()[i], output1_data[i], TF_LITE_MICRO_EXPECT_NEAR(expected_output1_data[i], output1_data[i], 1e-5f);
1e-5f);
} }
for (int i = 0; i < output2_dims_count; ++i) { for (int i = 0; i < output2_dims_count; ++i) {
TF_LITE_MICRO_EXPECT_NEAR(expected_output2_data.begin()[i], output2_data[i], TF_LITE_MICRO_EXPECT_NEAR(expected_output2_data[i], output2_data[i], 1e-5f);
1e-5f);
} }
} }
void TestSplitFourOutputsFloat( void TestSplitFourOutputsFloat(
std::initializer_list<int> input_dims_data, const int* input_dims_data, const float* input_data,
std::initializer_list<float> input_data, const int* axis_dims_data, const int32_t* axis_data,
std::initializer_list<int> axis_dims_data, const int* output1_dims_data, const float* expected_output1_data,
std::initializer_list<int32_t> axis_data, const int* output2_dims_data, const float* expected_output2_data,
std::initializer_list<int> output1_dims_data, const int* output3_dims_data, const float* expected_output3_data,
std::initializer_list<float> expected_output1_data, const int* output4_dims_data, const float* expected_output4_data,
std::initializer_list<int> output2_dims_data, float* output1_data, float* output2_data, float* output3_data,
std::initializer_list<float> expected_output2_data, float* output4_data) {
std::initializer_list<int> output3_dims_data, TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data);
std::initializer_list<float> expected_output3_data, TfLiteIntArray* axis_dims = IntArrayFromInts(axis_dims_data);
std::initializer_list<int> output4_dims_data, TfLiteIntArray* output1_dims = IntArrayFromInts(output1_dims_data);
std::initializer_list<float> expected_output4_data, float* output1_data, TfLiteIntArray* output2_dims = IntArrayFromInts(output2_dims_data);
float* output2_data, float* output3_data, float* output4_data) { TfLiteIntArray* output3_dims = IntArrayFromInts(output3_dims_data);
TfLiteIntArray* input_dims = IntArrayFromInitializer(input_dims_data); TfLiteIntArray* output4_dims = IntArrayFromInts(output4_dims_data);
TfLiteIntArray* axis_dims = IntArrayFromInitializer(axis_dims_data);
TfLiteIntArray* output1_dims = IntArrayFromInitializer(output1_dims_data);
TfLiteIntArray* output2_dims = IntArrayFromInitializer(output2_dims_data);
TfLiteIntArray* output3_dims = IntArrayFromInitializer(output3_dims_data);
TfLiteIntArray* output4_dims = IntArrayFromInitializer(output4_dims_data);
const int output1_dims_count = ElementCount(*output1_dims); const int output1_dims_count = ElementCount(*output1_dims);
const int output2_dims_count = ElementCount(*output2_dims); const int output2_dims_count = ElementCount(*output2_dims);
const int output3_dims_count = ElementCount(*output3_dims); const int output3_dims_count = ElementCount(*output3_dims);
@ -164,77 +127,42 @@ void TestSplitFourOutputsFloat(
output4_data[i] = 23; output4_data[i] = 23;
} }
TfLiteContext context;
PopulateContext(tensors, tensors_size, micro_test::reporter, &context);
tflite::AllOpsResolver resolver;
const TfLiteRegistration* registration =
resolver.FindOp(tflite::BuiltinOperator_SPLIT);
TF_LITE_MICRO_EXPECT_NE(nullptr, registration);
TfLiteSplitParams builtin_data = {
.num_splits = 4,
};
void* user_data = nullptr;
if (registration->init) {
user_data = registration->init(&context, nullptr, 0);
}
int inputs_array_data[] = {2, 0, 1}; int inputs_array_data[] = {2, 0, 1};
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data); TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
int outputs_array_data[] = {4, 2, 3, 4, 5}; int outputs_array_data[] = {4, 2, 3, 4, 5};
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data); TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
TfLiteNode node; const TfLiteRegistration registration = tflite::ops::micro::Register_SPLIT();
node.inputs = inputs_array; micro::KernelRunner runner(registration, tensors, tensors_size, inputs_array,
node.outputs = outputs_array; outputs_array, nullptr, micro_test::reporter);
node.user_data = user_data;
node.builtin_data = reinterpret_cast<void*>(&builtin_data);
node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0;
if (registration->prepare) { TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.InitAndPrepare());
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node)); TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.Invoke());
}
TF_LITE_MICRO_EXPECT_NE(nullptr, registration->invoke);
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->invoke(&context, &node));
if (registration->free) {
registration->free(&context, user_data);
}
for (int i = 0; i < output1_dims_count; ++i) { for (int i = 0; i < output1_dims_count; ++i) {
TF_LITE_MICRO_EXPECT_NEAR(expected_output1_data.begin()[i], output1_data[i], TF_LITE_MICRO_EXPECT_NEAR(expected_output1_data[i], output1_data[i], 1e-5f);
1e-5f);
} }
for (int i = 0; i < output2_dims_count; ++i) { for (int i = 0; i < output2_dims_count; ++i) {
TF_LITE_MICRO_EXPECT_NEAR(expected_output2_data.begin()[i], output2_data[i], TF_LITE_MICRO_EXPECT_NEAR(expected_output2_data[i], output2_data[i], 1e-5f);
1e-5f);
} }
for (int i = 0; i < output3_dims_count; ++i) { for (int i = 0; i < output3_dims_count; ++i) {
TF_LITE_MICRO_EXPECT_NEAR(expected_output3_data.begin()[i], output3_data[i], TF_LITE_MICRO_EXPECT_NEAR(expected_output3_data[i], output3_data[i], 1e-5f);
1e-5f);
} }
for (int i = 0; i < output4_dims_count; ++i) { for (int i = 0; i < output4_dims_count; ++i) {
TF_LITE_MICRO_EXPECT_NEAR(expected_output4_data.begin()[i], output4_data[i], TF_LITE_MICRO_EXPECT_NEAR(expected_output4_data[i], output4_data[i], 1e-5f);
1e-5f);
} }
} }
void TestSplitTwoOutputsQuantized( void TestSplitTwoOutputsQuantized(
std::initializer_list<int> input_dims_data, const int* input_dims_data, const uint8_t* input_data,
std::initializer_list<uint8_t> input_data, const int* axis_dims_data, const int32_t* axis_data,
std::initializer_list<int> axis_dims_data, const int* output1_dims_data, const uint8_t* expected_output1_data,
std::initializer_list<int32_t> axis_data, const int* output2_dims_data, const uint8_t* expected_output2_data,
std::initializer_list<int> output1_dims_data, uint8_t* output1_data, uint8_t* output2_data) {
std::initializer_list<uint8_t> expected_output1_data, TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data);
std::initializer_list<int> output2_dims_data, TfLiteIntArray* axis_dims = IntArrayFromInts(axis_dims_data);
std::initializer_list<uint8_t> expected_output2_data, uint8_t* output1_data, TfLiteIntArray* output1_dims = IntArrayFromInts(output1_dims_data);
uint8_t* output2_data) { TfLiteIntArray* output2_dims = IntArrayFromInts(output2_dims_data);
TfLiteIntArray* input_dims = IntArrayFromInitializer(input_dims_data);
TfLiteIntArray* axis_dims = IntArrayFromInitializer(axis_dims_data);
TfLiteIntArray* output1_dims = IntArrayFromInitializer(output1_dims_data);
TfLiteIntArray* output2_dims = IntArrayFromInitializer(output2_dims_data);
const int output1_dims_count = ElementCount(*output1_dims); const int output1_dims_count = ElementCount(*output1_dims);
const int output2_dims_count = ElementCount(*output2_dims); const int output2_dims_count = ElementCount(*output2_dims);
@ -260,68 +188,37 @@ void TestSplitTwoOutputsQuantized(
output2_data[i] = 23; output2_data[i] = 23;
} }
TfLiteContext context;
PopulateContext(tensors, tensors_size, micro_test::reporter, &context);
tflite::AllOpsResolver resolver;
const TfLiteRegistration* registration =
resolver.FindOp(tflite::BuiltinOperator_SPLIT);
TF_LITE_MICRO_EXPECT_NE(nullptr, registration);
TfLiteSplitParams builtin_data = {
.num_splits = 2,
};
void* user_data = nullptr;
if (registration->init) {
user_data = registration->init(&context, nullptr, 0);
}
int inputs_array_data[] = {2, 0, 1}; int inputs_array_data[] = {2, 0, 1};
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data); TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
int outputs_array_data[] = {2, 2, 3}; int outputs_array_data[] = {2, 2, 3};
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data); TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
TfLiteNode node; const TfLiteRegistration registration = tflite::ops::micro::Register_SPLIT();
node.inputs = inputs_array; micro::KernelRunner runner(registration, tensors, tensors_size, inputs_array,
node.outputs = outputs_array; outputs_array, nullptr, micro_test::reporter);
node.user_data = user_data;
node.builtin_data = reinterpret_cast<void*>(&builtin_data);
node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0;
if (registration->prepare) { TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.InitAndPrepare());
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node)); TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.Invoke());
}
TF_LITE_MICRO_EXPECT_NE(nullptr, registration->invoke);
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->invoke(&context, &node));
if (registration->free) {
registration->free(&context, user_data);
}
for (int i = 0; i < output1_dims_count; ++i) { for (int i = 0; i < output1_dims_count; ++i) {
TF_LITE_MICRO_EXPECT_EQ(expected_output1_data.begin()[i], output1_data[i]); TF_LITE_MICRO_EXPECT_EQ(expected_output1_data[i], output1_data[i]);
} }
for (int i = 0; i < output2_dims_count; ++i) { for (int i = 0; i < output2_dims_count; ++i) {
TF_LITE_MICRO_EXPECT_EQ(expected_output2_data.begin()[i], output2_data[i]); TF_LITE_MICRO_EXPECT_EQ(expected_output2_data[i], output2_data[i]);
} }
} }
void TestSplitTwoOutputsQuantized32( void TestSplitTwoOutputsQuantized32(
std::initializer_list<int> input_dims_data, const int* input_dims_data, const int32_t* input_data,
std::initializer_list<int32_t> input_data, const int* axis_dims_data, const int32_t* axis_data,
std::initializer_list<int> axis_dims_data, const int* output1_dims_data, const int32_t* expected_output1_data,
std::initializer_list<int32_t> axis_data, const int* output2_dims_data, const int32_t* expected_output2_data,
std::initializer_list<int> output1_dims_data, int32_t* output1_data, int32_t* output2_data) {
std::initializer_list<int32_t> expected_output1_data, TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data);
std::initializer_list<int> output2_dims_data, TfLiteIntArray* axis_dims = IntArrayFromInts(axis_dims_data);
std::initializer_list<int32_t> expected_output2_data, int32_t* output1_data, TfLiteIntArray* output1_dims = IntArrayFromInts(output1_dims_data);
int32_t* output2_data) { TfLiteIntArray* output2_dims = IntArrayFromInts(output2_dims_data);
TfLiteIntArray* input_dims = IntArrayFromInitializer(input_dims_data);
TfLiteIntArray* axis_dims = IntArrayFromInitializer(axis_dims_data);
TfLiteIntArray* output1_dims = IntArrayFromInitializer(output1_dims_data);
TfLiteIntArray* output2_dims = IntArrayFromInitializer(output2_dims_data);
const int output1_dims_count = ElementCount(*output1_dims); const int output1_dims_count = ElementCount(*output1_dims);
const int output2_dims_count = ElementCount(*output2_dims); const int output2_dims_count = ElementCount(*output2_dims);
@ -347,51 +244,24 @@ void TestSplitTwoOutputsQuantized32(
output2_data[i] = 23; output2_data[i] = 23;
} }
TfLiteContext context;
PopulateContext(tensors, tensors_size, micro_test::reporter, &context);
tflite::AllOpsResolver resolver;
const TfLiteRegistration* registration =
resolver.FindOp(tflite::BuiltinOperator_SPLIT);
TF_LITE_MICRO_EXPECT_NE(nullptr, registration);
TfLiteSplitParams builtin_data = {
.num_splits = 2,
};
void* user_data = nullptr;
if (registration->init) {
user_data = registration->init(&context, nullptr, 0);
}
int inputs_array_data[] = {2, 0, 1}; int inputs_array_data[] = {2, 0, 1};
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data); TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
int outputs_array_data[] = {2, 2, 3}; int outputs_array_data[] = {2, 2, 3};
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data); TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
TfLiteNode node; const TfLiteRegistration registration = tflite::ops::micro::Register_SPLIT();
node.inputs = inputs_array; micro::KernelRunner runner(registration, tensors, tensors_size, inputs_array,
node.outputs = outputs_array; outputs_array, nullptr, micro_test::reporter);
node.user_data = user_data;
node.builtin_data = reinterpret_cast<void*>(&builtin_data);
node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0;
if (registration->prepare) { TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.InitAndPrepare());
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node)); TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.Invoke());
}
TF_LITE_MICRO_EXPECT_NE(nullptr, registration->invoke);
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->invoke(&context, &node));
if (registration->free) {
registration->free(&context, user_data);
}
for (int i = 0; i < output1_dims_count; ++i) { for (int i = 0; i < output1_dims_count; ++i) {
TF_LITE_MICRO_EXPECT_EQ(expected_output1_data.begin()[i], output1_data[i]); TF_LITE_MICRO_EXPECT_EQ(expected_output1_data[i], output1_data[i]);
} }
for (int i = 0; i < output2_dims_count; ++i) { for (int i = 0; i < output2_dims_count; ++i) {
TF_LITE_MICRO_EXPECT_EQ(expected_output2_data.begin()[i], output2_data[i]); TF_LITE_MICRO_EXPECT_EQ(expected_output2_data[i], output2_data[i]);
} }
} }
@ -401,91 +271,119 @@ void TestSplitTwoOutputsQuantized32(
TF_LITE_MICRO_TESTS_BEGIN TF_LITE_MICRO_TESTS_BEGIN
TF_LITE_MICRO_TEST(TwoSplitFourDimensionalAxisZero) { TF_LITE_MICRO_TEST(TwoSplitFourDimensionalAxisZero) {
const int input_shape[] = {4, 2, 2, 2, 2};
const float input_data[] = {1, 2, 3, 4, 5, 6, 7, 8,
9, 10, 11, 12, 13, 14, 15, 16};
const int axis_shape[] = {1, 1};
const int32_t axis_data[] = {0};
const int output1_shape[] = {4, 1, 2, 2, 2};
const float golden1[] = {1, 2, 3, 4, 5, 6, 7, 8};
const int output2_shape[] = {4, 1, 2, 2, 2};
const float golden2[] = {9, 10, 11, 12, 13, 14, 15, 16};
constexpr int output1_dims_count = 8; constexpr int output1_dims_count = 8;
constexpr int output2_dims_count = 8; constexpr int output2_dims_count = 8;
float output1_data[output1_dims_count]; float output1_data[output1_dims_count];
float output2_data[output2_dims_count]; float output2_data[output2_dims_count];
tflite::testing::TestSplitTwoOutputsFloat( tflite::testing::TestSplitTwoOutputsFloat(
{4, 2, 2, 2, 2}, // Input shape input_shape, input_data, axis_shape, axis_data, output1_shape, golden1,
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, // Input values output2_shape, golden2, output1_data, output2_data);
{1, 1}, // Axis shape
{0}, // Axis value
{4, 1, 2, 2, 2}, // Output1 shape
{1, 2, 3, 4, 5, 6, 7, 8}, // Output1 values
{4, 1, 2, 2, 2}, // Output2 shape
{9, 10, 11, 12, 13, 14, 15, 16}, // Output2 values
output1_data, output2_data);
} }
TF_LITE_MICRO_TEST(TwoSplitFourDimensionalAxisOne) { TF_LITE_MICRO_TEST(TwoSplitFourDimensionalAxisOne) {
const int input_shape[] = {4, 2, 2, 2, 2};
const float input_data[] = {1, 2, 3, 4, 5, 6, 7, 8,
9, 10, 11, 12, 13, 14, 15, 16};
const int axis_shape[] = {1, 1};
const int32_t axis_data[] = {1};
const int output1_shape[] = {4, 2, 1, 2, 2};
const float golden1[] = {1, 2, 3, 4, 9, 10, 11, 12};
const int output2_shape[] = {4, 2, 1, 2, 2};
const float golden2[] = {5, 6, 7, 8, 13, 14, 15, 16};
constexpr int output1_dims_count = 8; constexpr int output1_dims_count = 8;
constexpr int output2_dims_count = 8; constexpr int output2_dims_count = 8;
float output1_data[output1_dims_count]; float output1_data[output1_dims_count];
float output2_data[output2_dims_count]; float output2_data[output2_dims_count];
tflite::testing::TestSplitTwoOutputsFloat( tflite::testing::TestSplitTwoOutputsFloat(
{4, 2, 2, 2, 2}, // Input shape input_shape, input_data, axis_shape, axis_data, output1_shape, golden1,
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, // Input values output2_shape, golden2, output1_data, output2_data);
{1, 1}, // Axis shape
{1}, // Axis value
{4, 2, 1, 2, 2}, // Output1 shape
{1, 2, 3, 4, 9, 10, 11, 12}, // Output1 values
{4, 2, 1, 2, 2}, // Output2 shape
{5, 6, 7, 8, 13, 14, 15, 16}, // Output2 values
output1_data, output2_data);
} }
TF_LITE_MICRO_TEST(TwoSplitFourDimensionalAxisTwo) { TF_LITE_MICRO_TEST(TwoSplitFourDimensionalAxisTwo) {
const int input_shape[] = {4, 2, 2, 2, 2};
const float input_data[] = {1, 2, 3, 4, 5, 6, 7, 8,
9, 10, 11, 12, 13, 14, 15, 16};
const int axis_shape[] = {1, 1};
const int32_t axis_data[] = {2};
const int output1_shape[] = {4, 2, 2, 1, 2};
const float golden1[] = {1, 2, 5, 6, 9, 10, 13, 14};
const int output2_shape[] = {4, 2, 2, 1, 2};
const float golden2[] = {3, 4, 7, 8, 11, 12, 15, 16};
constexpr int output1_dims_count = 8; constexpr int output1_dims_count = 8;
constexpr int output2_dims_count = 8; constexpr int output2_dims_count = 8;
float output1_data[output1_dims_count]; float output1_data[output1_dims_count];
float output2_data[output2_dims_count]; float output2_data[output2_dims_count];
tflite::testing::TestSplitTwoOutputsFloat( tflite::testing::TestSplitTwoOutputsFloat(
{4, 2, 2, 2, 2}, // Input shape input_shape, input_data, axis_shape, axis_data, output1_shape, golden1,
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, // Input values output2_shape, golden2, output1_data, output2_data);
{1, 1}, // Axis shape
{2}, // Axis value
{4, 2, 2, 1, 2}, // Output1 shape
{1, 2, 5, 6, 9, 10, 13, 14}, // Output1 values
{4, 2, 2, 1, 2}, // Output2 shape
{3, 4, 7, 8, 11, 12, 15, 16}, // Output2 values
output1_data, output2_data);
} }
TF_LITE_MICRO_TEST(TwoSplitFourDimensionalAxisThree) { TF_LITE_MICRO_TEST(TwoSplitFourDimensionalAxisThree) {
const int input_shape[] = {4, 2, 2, 2, 2};
const float input_data[] = {1, 2, 3, 4, 5, 6, 7, 8,
9, 10, 11, 12, 13, 14, 15, 16};
const int axis_shape[] = {1, 1};
const int32_t axis_data[] = {3};
const int output1_shape[] = {4, 2, 2, 2, 1};
const float golden1[] = {1, 3, 5, 7, 9, 11, 13, 15};
const int output2_shape[] = {4, 2, 2, 2, 1};
const float golden2[] = {2, 4, 6, 8, 10, 12, 14, 16};
constexpr int output1_dims_count = 8; constexpr int output1_dims_count = 8;
constexpr int output2_dims_count = 8; constexpr int output2_dims_count = 8;
float output1_data[output1_dims_count]; float output1_data[output1_dims_count];
float output2_data[output2_dims_count]; float output2_data[output2_dims_count];
tflite::testing::TestSplitTwoOutputsFloat( tflite::testing::TestSplitTwoOutputsFloat(
{4, 2, 2, 2, 2}, // Input shape input_shape, input_data, axis_shape, axis_data, output1_shape, golden1,
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, // Input values output2_shape, golden2, output1_data, output2_data);
{1, 1}, // Axis shape
{3}, // Axis value
{4, 2, 2, 2, 1}, // Output1 shape
{1, 3, 5, 7, 9, 11, 13, 15}, // Output1 values
{4, 2, 2, 2, 1}, // Output2 shape
{2, 4, 6, 8, 10, 12, 14, 16}, // Output2 values
output1_data, output2_data);
} }
TF_LITE_MICRO_TEST(TwoSplitFourDimensionalNegativeAxis) { TF_LITE_MICRO_TEST(TwoSplitFourDimensionalNegativeAxis) {
const int input_shape[] = {4, 2, 2, 2, 2};
const float input_data[] = {1, 2, 3, 4, 5, 6, 7, 8,
9, 10, 11, 12, 13, 14, 15, 16};
const int axis_shape[] = {1, 1};
const int32_t axis_data[] = {-4};
const int output1_shape[] = {4, 1, 2, 2, 2};
const float golden1[] = {1, 2, 3, 4, 5, 6, 7, 8};
const int output2_shape[] = {4, 1, 2, 2, 2};
const float golden2[] = {9, 10, 11, 12, 13, 14, 15, 16};
constexpr int output1_dims_count = 8; constexpr int output1_dims_count = 8;
constexpr int output2_dims_count = 8; constexpr int output2_dims_count = 8;
float output1_data[output1_dims_count]; float output1_data[output1_dims_count];
float output2_data[output2_dims_count]; float output2_data[output2_dims_count];
tflite::testing::TestSplitTwoOutputsFloat( tflite::testing::TestSplitTwoOutputsFloat(
{4, 2, 2, 2, 2}, // Input shape input_shape, input_data, axis_shape, axis_data, output1_shape, golden1,
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, // Input values output2_shape, golden2, output1_data, output2_data);
{1, 1}, // Axis shape
{-4}, // Axis value
{4, 1, 2, 2, 2}, // Output1 shape
{1, 2, 3, 4, 5, 6, 7, 8}, // Output1 values
{4, 1, 2, 2, 2}, // Output2 shape
{9, 10, 11, 12, 13, 14, 15, 16}, // Output2 values
output1_data, output2_data);
} }
TF_LITE_MICRO_TEST(FourSplit) { TF_LITE_MICRO_TEST(FourSplit) {
const int input_shape[] = {1, 4};
const float input_data[] = {1, 2, 3, 4};
const int axis_shape[] = {1, 1};
const int32_t axis_data[] = {0};
const int output1_shape[] = {1, 1};
const float golden1[] = {1};
const int output2_shape[] = {1, 1};
const float golden2[] = {2};
const int output3_shape[] = {1, 1};
const float golden3[] = {3};
const int output4_shape[] = {1, 1};
const float golden4[] = {4};
constexpr int output1_dims_count = 1; constexpr int output1_dims_count = 1;
constexpr int output2_dims_count = 1; constexpr int output2_dims_count = 1;
constexpr int output3_dims_count = 1; constexpr int output3_dims_count = 1;
@ -494,70 +392,69 @@ TF_LITE_MICRO_TEST(FourSplit) {
float output2_data[output2_dims_count]; float output2_data[output2_dims_count];
float output3_data[output3_dims_count]; float output3_data[output3_dims_count];
float output4_data[output4_dims_count]; float output4_data[output4_dims_count];
tflite::testing::TestSplitFourOutputsFloat({1, 4}, // Input shape tflite::testing::TestSplitFourOutputsFloat(
{1, 2, 3, 4}, // Input values input_shape, input_data, axis_shape, axis_data, output1_shape, golden1,
{1, 1}, // Axis shape output2_shape, golden2, output3_shape, golden3, output4_shape, golden4,
{0}, // Axis value output1_data, output2_data, output3_data, output4_data);
{1, 1}, // Output1 shape
{1}, // Output1 values
{1, 1}, // Output2 shape
{2}, // Output2 values
{1, 1}, // Output3 shape
{3}, // Output3 values
{1, 1}, // Output4 shape
{4}, // Output4 values
output1_data, output2_data,
output3_data, output4_data);
} }
TF_LITE_MICRO_TEST(TwoSplitOneDimensional) { TF_LITE_MICRO_TEST(TwoSplitOneDimensional) {
const int input_shape[] = {1, 2};
const float input_data[] = {1, 2};
const int axis_shape[] = {1, 1};
const int32_t axis_data[] = {0};
const int output1_shape[] = {1, 1};
const float golden1[] = {1};
const int output2_shape[] = {1, 1};
const float golden2[] = {2};
constexpr int output1_dims_count = 8; constexpr int output1_dims_count = 8;
constexpr int output2_dims_count = 8; constexpr int output2_dims_count = 8;
float output1_data[output1_dims_count]; float output1_data[output1_dims_count];
float output2_data[output2_dims_count]; float output2_data[output2_dims_count];
tflite::testing::TestSplitTwoOutputsFloat({1, 2}, // Input shape tflite::testing::TestSplitTwoOutputsFloat(
{1, 2}, // Input values input_shape, input_data, axis_shape, axis_data, output1_shape, golden1,
{1, 1}, // Axis shape output2_shape, golden2, output1_data, output2_data);
{0}, // Axis value
{1, 1}, // Output1 shape
{1}, // Output1 values
{1, 1}, // Output2 shape
{2}, // Output2 values
output1_data, output2_data);
} }
TF_LITE_MICRO_TEST(TwoSplitFourDimensionalQuantized) { TF_LITE_MICRO_TEST(TwoSplitFourDimensionalQuantized) {
const int input_shape[] = {4, 2, 2, 2, 2};
const uint8_t input_data[] = {1, 2, 3, 4, 5, 6, 7, 8,
9, 10, 11, 12, 13, 14, 15, 16};
const int axis_shape[] = {1, 1};
const int32_t axis_data[] = {1};
const int output1_shape[] = {4, 2, 1, 2, 2};
const uint8_t golden1[] = {1, 2, 3, 4, 9, 10, 11, 12};
const int output2_shape[] = {4, 2, 1, 2, 2};
const uint8_t golden2[] = {5, 6, 7, 8, 13, 14, 15, 16};
constexpr int output1_dims_count = 8; constexpr int output1_dims_count = 8;
constexpr int output2_dims_count = 8; constexpr int output2_dims_count = 8;
uint8_t output1_data[output1_dims_count]; uint8_t output1_data[output1_dims_count];
uint8_t output2_data[output2_dims_count]; uint8_t output2_data[output2_dims_count];
tflite::testing::TestSplitTwoOutputsQuantized( tflite::testing::TestSplitTwoOutputsQuantized(
{4, 2, 2, 2, 2}, // Input shape input_shape, input_data, axis_shape, axis_data, output1_shape, golden1,
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, // Input values output2_shape, golden2, output1_data, output2_data);
{1, 1}, // Axis shape
{0}, // Axis value
{4, 1, 2, 2, 2}, // Output1 shape
{1, 2, 3, 4, 5, 6, 7, 8}, // Output1 values
{4, 1, 2, 2, 2}, // Output2 shape
{9, 10, 11, 12, 13, 14, 15, 16}, // Output2 values
output1_data, output2_data);
} }
TF_LITE_MICRO_TEST(TwoSplitFourDimensionalQuantized32) { TF_LITE_MICRO_TEST(TwoSplitFourDimensionalQuantized32) {
const int input_shape[] = {4, 2, 2, 2, 2};
const int32_t input_data[] = {1, 2, 3, 4, 5, 6, 7, 8,
9, 10, 11, 12, 13, 14, 15, 16};
const int axis_shape[] = {1, 1};
const int32_t axis_data[] = {1};
const int output1_shape[] = {4, 2, 1, 2, 2};
const int32_t golden1[] = {1, 2, 3, 4, 9, 10, 11, 12};
const int output2_shape[] = {4, 2, 1, 2, 2};
const int32_t golden2[] = {5, 6, 7, 8, 13, 14, 15, 16};
constexpr int output1_dims_count = 8; constexpr int output1_dims_count = 8;
constexpr int output2_dims_count = 8; constexpr int output2_dims_count = 8;
int32_t output1_data[output1_dims_count]; int32_t output1_data[output1_dims_count];
int32_t output2_data[output2_dims_count]; int32_t output2_data[output2_dims_count];
tflite::testing::TestSplitTwoOutputsQuantized32( tflite::testing::TestSplitTwoOutputsQuantized32(
{4, 2, 2, 2, 2}, // Input shape input_shape, input_data, axis_shape, axis_data, output1_shape, golden1,
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, // Input values output2_shape, golden2, output1_data, output2_data);
{1, 1}, // Axis shape
{0}, // Axis value
{4, 1, 2, 2, 2}, // Output1 shape
{1, 2, 3, 4, 5, 6, 7, 8}, // Output1 values
{4, 1, 2, 2, 2}, // Output2 shape
{9, 10, 11, 12, 13, 14, 15, 16}, // Output2 values
output1_data, output2_data);
} }
TF_LITE_MICRO_TESTS_END TF_LITE_MICRO_TESTS_END

View File

@ -15,23 +15,20 @@ limitations under the License.
#include "tensorflow/lite/kernels/internal/reference/strided_slice.h" #include "tensorflow/lite/kernels/internal/reference/strided_slice.h"
#include <cmath> #include <cmath>
#include <cstring>
#include "tensorflow/lite/c/builtin_op_data.h" #include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/common.h" #include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h" #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/kernel_util.h" #include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/kernels/op_macros.h" #include "tensorflow/lite/kernels/op_macros.h"
#include "tensorflow/lite/micro/kernels/kernel_util.h"
namespace tflite { namespace tflite {
namespace ops { namespace ops {
namespace micro { namespace micro {
namespace strided_slice { namespace strided_slice {
enum KernelType {
kReference,
// TODO(soroosh): add kGenericOptimized
};
constexpr int kInputTensor = 0; constexpr int kInputTensor = 0;
constexpr int kBeginTensor = 1; constexpr int kBeginTensor = 1;
constexpr int kEndTensor = 2; constexpr int kEndTensor = 2;
@ -120,58 +117,70 @@ TfLiteStatus CheckOutputSize(TfLiteContext* context,
return kTfLiteOk; return kTfLiteOk;
} }
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
return context->AllocatePersistentBuffer(context, sizeof(StridedSliceParams));
}
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TFLITE_DCHECK(node->user_data != nullptr);
StridedSliceParams* op_params =
static_cast<StridedSliceParams*>(node->user_data);
TF_LITE_ENSURE_EQ(context, NumInputs(node), 4); TF_LITE_ENSURE_EQ(context, NumInputs(node), 4);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
StridedSliceContext op_context(context, node); StridedSliceContext op_context(context, node);
TF_LITE_ENSURE_MSG(context, op_context.dims <= kMaxDim, TF_LITE_ENSURE_MSG(context, op_context.dims <= kMaxDim,
"input dim should not exceed 4"); "input dim should not exceed 4");
auto params = BuildStridedSliceParams(&op_context);
memcpy(op_params, &params, sizeof(StridedSliceParams));
return CheckOutputSize(context, &op_context); return CheckOutputSize(context, &op_context);
} }
template <KernelType kernel_type>
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
StridedSliceContext op_context(context, node); TFLITE_DCHECK(node->user_data != nullptr);
auto op_params = BuildStridedSliceParams(&op_context); const StridedSliceParams& op_params =
*(static_cast<const StridedSliceParams*>(node->user_data));
#define TF_LITE_STRIDED_SLICE(kernel_type, data_type) \ const TfLiteEvalTensor* input =
kernel_type::StridedSlice(op_params, GetTensorShape(op_context.input), \ tflite::micro::GetEvalInput(context, node, kInputTensor);
GetTensorData<data_type>(op_context.input), \ TfLiteEvalTensor* output =
GetTensorShape(op_context.output), \ tflite::micro::GetEvalOutput(context, node, kOutputTensor);
GetTensorData<data_type>(op_context.output)) switch (output->type) {
switch (op_context.input->type) {
case kTfLiteFloat32: case kTfLiteFloat32:
if (kernel_type == kReference) { reference_ops::StridedSlice(op_params,
TF_LITE_STRIDED_SLICE(reference_ops, float); tflite::micro::GetTensorShape(input),
} tflite::micro::GetTensorData<float>(input),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<float>(output));
break; break;
case kTfLiteUInt8: case kTfLiteUInt8:
if (kernel_type == kReference) { reference_ops::StridedSlice(
TF_LITE_STRIDED_SLICE(reference_ops, uint8_t); op_params, tflite::micro::GetTensorShape(input),
} tflite::micro::GetTensorData<uint8_t>(input),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<uint8_t>(output));
break; break;
case kTfLiteInt8: case kTfLiteInt8:
if (kernel_type == kReference) { reference_ops::StridedSlice(op_params,
TF_LITE_STRIDED_SLICE(reference_ops, int8_t); tflite::micro::GetTensorShape(input),
} tflite::micro::GetTensorData<int8_t>(input),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<int8_t>(output));
break; break;
default: default:
TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.", TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
TfLiteTypeGetName(op_context.input->type), TfLiteTypeGetName(input->type), input->type);
op_context.input->type);
return kTfLiteError; return kTfLiteError;
} }
#undef TF_LITE_STRIDED_SLICE
return kTfLiteOk; return kTfLiteOk;
} }
} // namespace strided_slice } // namespace strided_slice
TfLiteRegistration Register_STRIDED_SLICE() { TfLiteRegistration Register_STRIDED_SLICE() {
return {/*init=*/nullptr, return {/*init=*/strided_slice::Init,
/*free=*/nullptr, /*free=*/nullptr,
/*prepare=*/strided_slice::Prepare, /*prepare=*/strided_slice::Prepare,
/*invoke=*/strided_slice::Eval<strided_slice::kReference>, /*invoke=*/strided_slice::Eval,
/*profiling_string=*/nullptr, /*profiling_string=*/nullptr,
/*builtin_code=*/0, /*builtin_code=*/0,
/*custom_name=*/nullptr, /*custom_name=*/nullptr,

File diff suppressed because it is too large Load Diff

View File

@ -21,8 +21,10 @@ limitations under the License.
#include "tensorflow/lite/kernels/internal/quantization_util.h" #include "tensorflow/lite/kernels/internal/quantization_util.h"
#include "tensorflow/lite/kernels/internal/reference/process_broadcast_shapes.h" #include "tensorflow/lite/kernels/internal/reference/process_broadcast_shapes.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h" #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/internal/types.h"
#include "tensorflow/lite/kernels/kernel_util.h" #include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/kernels/op_macros.h" #include "tensorflow/lite/kernels/op_macros.h"
#include "tensorflow/lite/micro/kernels/kernel_util.h"
namespace tflite { namespace tflite {
namespace ops { namespace ops {
@ -93,31 +95,59 @@ TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteSubParams* params,
return kTfLiteOk; return kTfLiteOk;
} }
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
return context->AllocatePersistentBuffer(context, sizeof(OpData));
}
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TFLITE_DCHECK(node->user_data != nullptr);
TFLITE_DCHECK(node->builtin_data != nullptr);
OpData* data = static_cast<OpData*>(node->user_data);
auto* params = reinterpret_cast<TfLiteSubParams*>(node->builtin_data);
const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TF_LITE_ENSURE_STATUS(
CalculateOpData(context, params, input1, input2, output, data));
return kTfLiteOk;
}
void EvalSub(TfLiteContext* context, TfLiteNode* node, TfLiteSubParams* params, void EvalSub(TfLiteContext* context, TfLiteNode* node, TfLiteSubParams* params,
const OpData* data, const TfLiteTensor* input1, const OpData* data, const TfLiteEvalTensor* input1,
const TfLiteTensor* input2, TfLiteTensor* output) { const TfLiteEvalTensor* input2, TfLiteEvalTensor* output) {
float output_activation_min, output_activation_max; float output_activation_min, output_activation_max;
CalculateActivationRange(params->activation, &output_activation_min, CalculateActivationRange(params->activation, &output_activation_min,
&output_activation_max); &output_activation_max);
tflite::ArithmeticParams op_params; tflite::ArithmeticParams op_params;
SetActivationParams(output_activation_min, output_activation_max, &op_params); SetActivationParams(output_activation_min, output_activation_max, &op_params);
#define TF_LITE_SUB(opname) \
opname(op_params, GetTensorShape(input1), GetTensorData<float>(input1), \
GetTensorShape(input2), GetTensorData<float>(input2), \
GetTensorShape(output), GetTensorData<float>(output))
if (data->requires_broadcast) { if (data->requires_broadcast) {
TF_LITE_SUB(tflite::reference_ops::BroadcastSubSlow); tflite::reference_ops::BroadcastSubSlow(
op_params, tflite::micro::GetTensorShape(input1),
tflite::micro::GetTensorData<float>(input1),
tflite::micro::GetTensorShape(input2),
tflite::micro::GetTensorData<float>(input2),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<float>(output));
} else { } else {
TF_LITE_SUB(tflite::reference_ops::SubWithActivation); tflite::reference_ops::SubWithActivation(
op_params, tflite::micro::GetTensorShape(input1),
tflite::micro::GetTensorData<float>(input1),
tflite::micro::GetTensorShape(input2),
tflite::micro::GetTensorData<float>(input2),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<float>(output));
} }
#undef TF_LITE_SUB
} }
TfLiteStatus EvalSubQuantized(TfLiteContext* context, TfLiteNode* node, TfLiteStatus EvalSubQuantized(TfLiteContext* context, TfLiteNode* node,
TfLiteSubParams* params, const OpData* data, TfLiteSubParams* params, const OpData* data,
const TfLiteTensor* input1, const TfLiteEvalTensor* input1,
const TfLiteTensor* input2, const TfLiteEvalTensor* input2,
TfLiteTensor* output) { TfLiteEvalTensor* output) {
if (output->type == kTfLiteUInt8 || output->type == kTfLiteInt8) { if (output->type == kTfLiteUInt8 || output->type == kTfLiteInt8) {
tflite::ArithmeticParams op_params; tflite::ArithmeticParams op_params;
op_params.left_shift = data->left_shift; op_params.left_shift = data->left_shift;
@ -133,25 +163,46 @@ TfLiteStatus EvalSubQuantized(TfLiteContext* context, TfLiteNode* node,
SetActivationParams(data->output_activation_min, SetActivationParams(data->output_activation_min,
data->output_activation_max, &op_params); data->output_activation_max, &op_params);
bool need_broadcast = reference_ops::ProcessBroadcastShapes( bool need_broadcast = reference_ops::ProcessBroadcastShapes(
GetTensorShape(input1), GetTensorShape(input2), &op_params); tflite::micro::GetTensorShape(input1),
#define TF_LITE_SUB(opname, dtype) \ tflite::micro::GetTensorShape(input2), &op_params);
opname(op_params, GetTensorShape(input1), GetTensorData<dtype>(input1), \
GetTensorShape(input2), GetTensorData<dtype>(input2), \
GetTensorShape(output), GetTensorData<dtype>(output));
if (output->type == kTfLiteInt8) { if (output->type == kTfLiteInt8) {
if (need_broadcast) { if (need_broadcast) {
TF_LITE_SUB(tflite::reference_ops::BroadcastSubSlow, int8_t); tflite::reference_ops::BroadcastSubSlow(
op_params, tflite::micro::GetTensorShape(input1),
tflite::micro::GetTensorData<int8_t>(input1),
tflite::micro::GetTensorShape(input2),
tflite::micro::GetTensorData<int8_t>(input2),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<int8_t>(output));
} else { } else {
TF_LITE_SUB(tflite::reference_ops::Sub, int8_t); tflite::reference_ops::Sub(
op_params, tflite::micro::GetTensorShape(input1),
tflite::micro::GetTensorData<int8_t>(input1),
tflite::micro::GetTensorShape(input2),
tflite::micro::GetTensorData<int8_t>(input2),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<int8_t>(output));
} }
} else { } else {
if (need_broadcast) { if (need_broadcast) {
TF_LITE_SUB(tflite::reference_ops::BroadcastSubSlow, uint8_t); tflite::reference_ops::BroadcastSubSlow(
op_params, tflite::micro::GetTensorShape(input1),
tflite::micro::GetTensorData<uint8_t>(input1),
tflite::micro::GetTensorShape(input2),
tflite::micro::GetTensorData<uint8_t>(input2),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<uint8_t>(output));
} else { } else {
TF_LITE_SUB(tflite::reference_ops::Sub, uint8_t); tflite::reference_ops::Sub(
op_params, tflite::micro::GetTensorShape(input1),
tflite::micro::GetTensorData<uint8_t>(input1),
tflite::micro::GetTensorShape(input2),
tflite::micro::GetTensorData<uint8_t>(input2),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<uint8_t>(output));
} }
} }
#undef TF_LITE_SUB
} }
return kTfLiteOk; return kTfLiteOk;
@ -160,13 +211,15 @@ TfLiteStatus EvalSubQuantized(TfLiteContext* context, TfLiteNode* node,
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
auto* params = reinterpret_cast<TfLiteSubParams*>(node->builtin_data); auto* params = reinterpret_cast<TfLiteSubParams*>(node->builtin_data);
const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); const TfLiteEvalTensor* input1 =
const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); tflite::micro::GetEvalInput(context, node, kInputTensor1);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor); const TfLiteEvalTensor* input2 =
tflite::micro::GetEvalInput(context, node, kInputTensor2);
TfLiteEvalTensor* output =
tflite::micro::GetEvalOutput(context, node, kOutputTensor);
OpData data; TFLITE_DCHECK(node->user_data != nullptr);
TF_LITE_ENSURE_STATUS( const OpData& data = *(static_cast<const OpData*>(node->user_data));
CalculateOpData(context, params, input1, input2, output, &data));
if (output->type == kTfLiteFloat32) { if (output->type == kTfLiteFloat32) {
EvalSub(context, node, params, &data, input1, input2, output); EvalSub(context, node, params, &data, input1, input2, output);
@ -185,9 +238,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} // namespace sub } // namespace sub
TfLiteRegistration Register_SUB() { TfLiteRegistration Register_SUB() {
return {/*init=*/nullptr, return {/*init=*/sub::Init,
/*free=*/nullptr, /*free=*/nullptr,
/*prepare=*/nullptr, /*prepare=*/sub::Prepare,
/*invoke=*/sub::Eval, /*invoke=*/sub::Eval,
/*profiling_string=*/nullptr, /*profiling_string=*/nullptr,
/*builtin_code=*/0, /*builtin_code=*/0,

View File

@ -17,7 +17,7 @@ limitations under the License.
#include "tensorflow/lite/c/builtin_op_data.h" #include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/common.h" #include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/micro/all_ops_resolver.h" #include "tensorflow/lite/micro/kernels/kernel_runner.h"
#include "tensorflow/lite/micro/testing/micro_test.h" #include "tensorflow/lite/micro/testing/micro_test.h"
#include "tensorflow/lite/micro/testing/test_utils.h" #include "tensorflow/lite/micro/testing/test_utils.h"
@ -66,47 +66,21 @@ void ValidateSubGoldens(TfLiteTensor* tensors, int tensors_size,
const T* golden, T* output, int output_size, const T* golden, T* output, int output_size,
TfLiteFusedActivation activation, TfLiteFusedActivation activation,
float tolerance = 1e-5) { float tolerance = 1e-5) {
TfLiteContext context;
PopulateContext(tensors, tensors_size, micro_test::reporter, &context);
::tflite::AllOpsResolver resolver;
const TfLiteRegistration* registration =
resolver.FindOp(::tflite::BuiltinOperator_SUB);
TF_LITE_MICRO_EXPECT_NE(nullptr, registration);
TfLiteSubParams builtin_data; TfLiteSubParams builtin_data;
builtin_data.activation = activation; builtin_data.activation = activation;
const char* init_data = reinterpret_cast<const char*>(&builtin_data);
const size_t init_data_size = 0;
void* user_data = nullptr;
if (registration->init) {
user_data = registration->init(&context, init_data, init_data_size);
}
int inputs_array_data[] = {2, 0, 1}; int inputs_array_data[] = {2, 0, 1};
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data); TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
int outputs_array_data[] = {1, 2}; int outputs_array_data[] = {1, 2};
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data); TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
TfLiteNode node; const TfLiteRegistration registration = tflite::ops::micro::Register_SUB();
node.inputs = inputs_array; micro::KernelRunner runner(registration, tensors, tensors_size, inputs_array,
node.outputs = outputs_array; outputs_array, &builtin_data,
node.user_data = user_data; micro_test::reporter);
node.builtin_data = reinterpret_cast<void*>(&builtin_data);
node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0;
if (registration->prepare) { TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.InitAndPrepare());
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node)); TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.Invoke());
}
TF_LITE_MICRO_EXPECT_NE(nullptr, registration->invoke);
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->invoke(&context, &node));
if (registration->free) {
registration->free(&context, user_data);
}
for (int i = 0; i < output_size; ++i) { for (int i = 0; i < output_size; ++i) {
TF_LITE_MICRO_EXPECT_NEAR(golden[i], output[i], tolerance); TF_LITE_MICRO_EXPECT_NEAR(golden[i], output[i], tolerance);