Port the split, strided_slice and sub kernels to the new TfLiteEvalTensor API along with removing initializer lists from their tests.
PiperOrigin-RevId: 323453073 Change-Id: I3ecbaa4716870618c2f052826f4bde178fe38366
This commit is contained in:
parent
a9b6a48489
commit
20cd718248
tensorflow/lite/micro/kernels
@ -317,6 +317,7 @@ tflite_micro_cc_test(
|
|||||||
"sub_test.cc",
|
"sub_test.cc",
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
|
":kernel_runner",
|
||||||
"//tensorflow/lite/c:common",
|
"//tensorflow/lite/c:common",
|
||||||
"//tensorflow/lite/micro:op_resolvers",
|
"//tensorflow/lite/micro:op_resolvers",
|
||||||
"//tensorflow/lite/micro/testing:micro_test",
|
"//tensorflow/lite/micro/testing:micro_test",
|
||||||
@ -380,6 +381,7 @@ tflite_micro_cc_test(
|
|||||||
"strided_slice_test.cc",
|
"strided_slice_test.cc",
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
|
":kernel_runner",
|
||||||
"//tensorflow/lite/c:common",
|
"//tensorflow/lite/c:common",
|
||||||
"//tensorflow/lite/micro:op_resolvers",
|
"//tensorflow/lite/micro:op_resolvers",
|
||||||
"//tensorflow/lite/micro/testing:micro_test",
|
"//tensorflow/lite/micro/testing:micro_test",
|
||||||
@ -420,6 +422,7 @@ tflite_micro_cc_test(
|
|||||||
"split_test.cc",
|
"split_test.cc",
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
|
":kernel_runner",
|
||||||
"//tensorflow/lite/c:common",
|
"//tensorflow/lite/c:common",
|
||||||
"//tensorflow/lite/micro:debug_log",
|
"//tensorflow/lite/micro:debug_log",
|
||||||
"//tensorflow/lite/micro:op_resolvers",
|
"//tensorflow/lite/micro:op_resolvers",
|
||||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/lite/c/common.h"
|
#include "tensorflow/lite/c/common.h"
|
||||||
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
|
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
|
||||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||||
|
#include "tensorflow/lite/micro/kernels/kernel_util.h"
|
||||||
|
|
||||||
namespace tflite {
|
namespace tflite {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
@ -25,10 +26,11 @@ namespace split {
|
|||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
TfLiteStatus SplitImpl(TfLiteContext* context, TfLiteNode* node,
|
TfLiteStatus SplitImpl(TfLiteContext* context, TfLiteNode* node,
|
||||||
const TfLiteTensor* input, int axis_value) {
|
const TfLiteEvalTensor* input, int axis_value) {
|
||||||
const int output_count = NumOutputs(node);
|
const int output_count = NumOutputs(node);
|
||||||
const TfLiteIntArray* input_dims = input->dims;
|
const TfLiteIntArray* input_dims = input->dims;
|
||||||
const TfLiteTensor* output0 = GetOutput(context, node, 0);
|
const TfLiteEvalTensor* output0 =
|
||||||
|
tflite::micro::GetEvalOutput(context, node, 0);
|
||||||
const TfLiteIntArray* output_dims = output0->dims;
|
const TfLiteIntArray* output_dims = output0->dims;
|
||||||
|
|
||||||
const int split_dimensions = input_dims->size;
|
const int split_dimensions = input_dims->size;
|
||||||
@ -50,11 +52,11 @@ TfLiteStatus SplitImpl(TfLiteContext* context, TfLiteNode* node,
|
|||||||
base_inner_size *= input_dims->data[i];
|
base_inner_size *= input_dims->data[i];
|
||||||
}
|
}
|
||||||
|
|
||||||
const T* input_ptr = GetTensorData<T>(input);
|
const T* input_ptr = tflite::micro::GetTensorData<T>(input);
|
||||||
for (int k = 0; k < outer_size; ++k) {
|
for (int k = 0; k < outer_size; ++k) {
|
||||||
for (int i = 0; i < output_count; ++i) {
|
for (int i = 0; i < output_count; ++i) {
|
||||||
TfLiteTensor* t = GetOutput(context, node, i);
|
TfLiteEvalTensor* t = tflite::micro::GetEvalOutput(context, node, i);
|
||||||
T* output_data = GetTensorData<T>(t);
|
T* output_data = tflite::micro::GetTensorData<T>(t);
|
||||||
const int copy_size = output_dims->data[axis] * base_inner_size;
|
const int copy_size = output_dims->data[axis] * base_inner_size;
|
||||||
T* output_ptr = output_data + k * copy_size;
|
T* output_ptr = output_data + k * copy_size;
|
||||||
for (int j = 0; j < copy_size; ++j) output_ptr[j] = input_ptr[j];
|
for (int j = 0; j < copy_size; ++j) output_ptr[j] = input_ptr[j];
|
||||||
@ -65,23 +67,28 @@ TfLiteStatus SplitImpl(TfLiteContext* context, TfLiteNode* node,
|
|||||||
return kTfLiteOk;
|
return kTfLiteOk;
|
||||||
}
|
}
|
||||||
|
|
||||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||||
const TfLiteTensor* axis = GetInput(context, node, 0);
|
const TfLiteTensor* axis = GetInput(context, node, 0);
|
||||||
const TfLiteTensor* input = GetInput(context, node, 1);
|
|
||||||
|
|
||||||
// Dynamic output tensors are needed if axis tensor is not constant.
|
// Dynamic output tensors are needed if axis tensor is not constant.
|
||||||
// But Micro doesn't support dynamic memory allocation, so we only support
|
// But Micro doesn't support dynamic memory allocation, so we only support
|
||||||
// constant axis tensor for now.
|
// constant axis tensor for now.
|
||||||
TF_LITE_ENSURE_MSG(context, IsConstantTensor(axis),
|
TF_LITE_ENSURE_MSG(context, IsConstantTensor(axis),
|
||||||
"Non constant axis tensor not supported");
|
"Non constant axis tensor not supported");
|
||||||
|
return kTfLiteOk;
|
||||||
|
}
|
||||||
|
|
||||||
int axis_value = GetTensorData<int32_t>(axis)[0];
|
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||||
|
const TfLiteEvalTensor* axis = tflite::micro::GetEvalInput(context, node, 0);
|
||||||
|
const TfLiteEvalTensor* input = tflite::micro::GetEvalInput(context, node, 1);
|
||||||
|
|
||||||
|
int axis_value = tflite::micro::GetTensorData<int32_t>(axis)[0];
|
||||||
if (axis_value < 0) {
|
if (axis_value < 0) {
|
||||||
axis_value += NumDimensions(input);
|
axis_value += input->dims->size;
|
||||||
}
|
}
|
||||||
|
|
||||||
TF_LITE_ENSURE(context, axis_value >= 0);
|
TF_LITE_ENSURE(context, axis_value >= 0);
|
||||||
TF_LITE_ENSURE(context, axis_value < NumDimensions(input));
|
TF_LITE_ENSURE(context, axis_value < input->dims->size);
|
||||||
|
|
||||||
switch (input->type) {
|
switch (input->type) {
|
||||||
case kTfLiteFloat32: {
|
case kTfLiteFloat32: {
|
||||||
@ -114,7 +121,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
TfLiteRegistration Register_SPLIT() {
|
TfLiteRegistration Register_SPLIT() {
|
||||||
return {/*init=*/nullptr,
|
return {/*init=*/nullptr,
|
||||||
/*free=*/nullptr,
|
/*free=*/nullptr,
|
||||||
/*prepare=*/nullptr,
|
/*prepare=*/split::Prepare,
|
||||||
/*invoke=*/split::Eval,
|
/*invoke=*/split::Eval,
|
||||||
/*profiling_string=*/nullptr,
|
/*profiling_string=*/nullptr,
|
||||||
/*builtin_code=*/0,
|
/*builtin_code=*/0,
|
||||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/lite/c/common.h"
|
#include "tensorflow/lite/c/common.h"
|
||||||
#include "tensorflow/lite/micro/all_ops_resolver.h"
|
#include "tensorflow/lite/micro/all_ops_resolver.h"
|
||||||
#include "tensorflow/lite/micro/debug_log.h"
|
#include "tensorflow/lite/micro/debug_log.h"
|
||||||
|
#include "tensorflow/lite/micro/kernels/kernel_runner.h"
|
||||||
#include "tensorflow/lite/micro/testing/micro_test.h"
|
#include "tensorflow/lite/micro/testing/micro_test.h"
|
||||||
#include "tensorflow/lite/micro/testing/test_utils.h"
|
#include "tensorflow/lite/micro/testing/test_utils.h"
|
||||||
|
|
||||||
@ -24,19 +25,15 @@ namespace tflite {
|
|||||||
namespace testing {
|
namespace testing {
|
||||||
|
|
||||||
void TestSplitTwoOutputsFloat(
|
void TestSplitTwoOutputsFloat(
|
||||||
std::initializer_list<int> input_dims_data,
|
const int* input_dims_data, const float* input_data,
|
||||||
std::initializer_list<float> input_data,
|
const int* axis_dims_data, const int32_t* axis_data,
|
||||||
std::initializer_list<int> axis_dims_data,
|
const int* output1_dims_data, const float* expected_output1_data,
|
||||||
std::initializer_list<int32_t> axis_data,
|
const int* output2_dims_data, const float* expected_output2_data,
|
||||||
std::initializer_list<int> output1_dims_data,
|
float* output1_data, float* output2_data) {
|
||||||
std::initializer_list<float> expected_output1_data,
|
TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data);
|
||||||
std::initializer_list<int> output2_dims_data,
|
TfLiteIntArray* axis_dims = IntArrayFromInts(axis_dims_data);
|
||||||
std::initializer_list<float> expected_output2_data, float* output1_data,
|
TfLiteIntArray* output1_dims = IntArrayFromInts(output1_dims_data);
|
||||||
float* output2_data) {
|
TfLiteIntArray* output2_dims = IntArrayFromInts(output2_dims_data);
|
||||||
TfLiteIntArray* input_dims = IntArrayFromInitializer(input_dims_data);
|
|
||||||
TfLiteIntArray* axis_dims = IntArrayFromInitializer(axis_dims_data);
|
|
||||||
TfLiteIntArray* output1_dims = IntArrayFromInitializer(output1_dims_data);
|
|
||||||
TfLiteIntArray* output2_dims = IntArrayFromInitializer(output2_dims_data);
|
|
||||||
const int output1_dims_count = ElementCount(*output1_dims);
|
const int output1_dims_count = ElementCount(*output1_dims);
|
||||||
const int output2_dims_count = ElementCount(*output2_dims);
|
const int output2_dims_count = ElementCount(*output2_dims);
|
||||||
|
|
||||||
@ -61,76 +58,42 @@ void TestSplitTwoOutputsFloat(
|
|||||||
output2_data[i] = 23;
|
output2_data[i] = 23;
|
||||||
}
|
}
|
||||||
|
|
||||||
TfLiteContext context;
|
|
||||||
PopulateContext(tensors, tensors_size, micro_test::reporter, &context);
|
|
||||||
tflite::AllOpsResolver resolver;
|
|
||||||
const TfLiteRegistration* registration =
|
|
||||||
resolver.FindOp(tflite::BuiltinOperator_SPLIT);
|
|
||||||
TF_LITE_MICRO_EXPECT_NE(nullptr, registration);
|
|
||||||
|
|
||||||
TfLiteSplitParams builtin_data = {
|
|
||||||
.num_splits = 2,
|
|
||||||
};
|
|
||||||
|
|
||||||
void* user_data = nullptr;
|
|
||||||
if (registration->init) {
|
|
||||||
user_data = registration->init(&context, nullptr, 0);
|
|
||||||
}
|
|
||||||
|
|
||||||
int inputs_array_data[] = {2, 0, 1};
|
int inputs_array_data[] = {2, 0, 1};
|
||||||
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
|
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
|
||||||
int outputs_array_data[] = {2, 2, 3};
|
int outputs_array_data[] = {2, 2, 3};
|
||||||
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
|
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
|
||||||
|
|
||||||
TfLiteNode node;
|
const TfLiteRegistration registration = tflite::ops::micro::Register_SPLIT();
|
||||||
node.inputs = inputs_array;
|
micro::KernelRunner runner(registration, tensors, tensors_size, inputs_array,
|
||||||
node.outputs = outputs_array;
|
outputs_array, nullptr, micro_test::reporter);
|
||||||
node.user_data = user_data;
|
|
||||||
node.builtin_data = reinterpret_cast<void*>(&builtin_data);
|
|
||||||
node.custom_initial_data = nullptr;
|
|
||||||
node.custom_initial_data_size = 0;
|
|
||||||
|
|
||||||
if (registration->prepare) {
|
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.InitAndPrepare());
|
||||||
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
|
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.Invoke());
|
||||||
}
|
|
||||||
TF_LITE_MICRO_EXPECT_NE(nullptr, registration->invoke);
|
|
||||||
|
|
||||||
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->invoke(&context, &node));
|
|
||||||
if (registration->free) {
|
|
||||||
registration->free(&context, user_data);
|
|
||||||
}
|
|
||||||
|
|
||||||
for (int i = 0; i < output1_dims_count; ++i) {
|
for (int i = 0; i < output1_dims_count; ++i) {
|
||||||
TF_LITE_MICRO_EXPECT_NEAR(expected_output1_data.begin()[i], output1_data[i],
|
TF_LITE_MICRO_EXPECT_NEAR(expected_output1_data[i], output1_data[i], 1e-5f);
|
||||||
1e-5f);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int i = 0; i < output2_dims_count; ++i) {
|
for (int i = 0; i < output2_dims_count; ++i) {
|
||||||
TF_LITE_MICRO_EXPECT_NEAR(expected_output2_data.begin()[i], output2_data[i],
|
TF_LITE_MICRO_EXPECT_NEAR(expected_output2_data[i], output2_data[i], 1e-5f);
|
||||||
1e-5f);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void TestSplitFourOutputsFloat(
|
void TestSplitFourOutputsFloat(
|
||||||
std::initializer_list<int> input_dims_data,
|
const int* input_dims_data, const float* input_data,
|
||||||
std::initializer_list<float> input_data,
|
const int* axis_dims_data, const int32_t* axis_data,
|
||||||
std::initializer_list<int> axis_dims_data,
|
const int* output1_dims_data, const float* expected_output1_data,
|
||||||
std::initializer_list<int32_t> axis_data,
|
const int* output2_dims_data, const float* expected_output2_data,
|
||||||
std::initializer_list<int> output1_dims_data,
|
const int* output3_dims_data, const float* expected_output3_data,
|
||||||
std::initializer_list<float> expected_output1_data,
|
const int* output4_dims_data, const float* expected_output4_data,
|
||||||
std::initializer_list<int> output2_dims_data,
|
float* output1_data, float* output2_data, float* output3_data,
|
||||||
std::initializer_list<float> expected_output2_data,
|
float* output4_data) {
|
||||||
std::initializer_list<int> output3_dims_data,
|
TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data);
|
||||||
std::initializer_list<float> expected_output3_data,
|
TfLiteIntArray* axis_dims = IntArrayFromInts(axis_dims_data);
|
||||||
std::initializer_list<int> output4_dims_data,
|
TfLiteIntArray* output1_dims = IntArrayFromInts(output1_dims_data);
|
||||||
std::initializer_list<float> expected_output4_data, float* output1_data,
|
TfLiteIntArray* output2_dims = IntArrayFromInts(output2_dims_data);
|
||||||
float* output2_data, float* output3_data, float* output4_data) {
|
TfLiteIntArray* output3_dims = IntArrayFromInts(output3_dims_data);
|
||||||
TfLiteIntArray* input_dims = IntArrayFromInitializer(input_dims_data);
|
TfLiteIntArray* output4_dims = IntArrayFromInts(output4_dims_data);
|
||||||
TfLiteIntArray* axis_dims = IntArrayFromInitializer(axis_dims_data);
|
|
||||||
TfLiteIntArray* output1_dims = IntArrayFromInitializer(output1_dims_data);
|
|
||||||
TfLiteIntArray* output2_dims = IntArrayFromInitializer(output2_dims_data);
|
|
||||||
TfLiteIntArray* output3_dims = IntArrayFromInitializer(output3_dims_data);
|
|
||||||
TfLiteIntArray* output4_dims = IntArrayFromInitializer(output4_dims_data);
|
|
||||||
const int output1_dims_count = ElementCount(*output1_dims);
|
const int output1_dims_count = ElementCount(*output1_dims);
|
||||||
const int output2_dims_count = ElementCount(*output2_dims);
|
const int output2_dims_count = ElementCount(*output2_dims);
|
||||||
const int output3_dims_count = ElementCount(*output3_dims);
|
const int output3_dims_count = ElementCount(*output3_dims);
|
||||||
@ -164,77 +127,42 @@ void TestSplitFourOutputsFloat(
|
|||||||
output4_data[i] = 23;
|
output4_data[i] = 23;
|
||||||
}
|
}
|
||||||
|
|
||||||
TfLiteContext context;
|
|
||||||
PopulateContext(tensors, tensors_size, micro_test::reporter, &context);
|
|
||||||
tflite::AllOpsResolver resolver;
|
|
||||||
const TfLiteRegistration* registration =
|
|
||||||
resolver.FindOp(tflite::BuiltinOperator_SPLIT);
|
|
||||||
TF_LITE_MICRO_EXPECT_NE(nullptr, registration);
|
|
||||||
|
|
||||||
TfLiteSplitParams builtin_data = {
|
|
||||||
.num_splits = 4,
|
|
||||||
};
|
|
||||||
|
|
||||||
void* user_data = nullptr;
|
|
||||||
if (registration->init) {
|
|
||||||
user_data = registration->init(&context, nullptr, 0);
|
|
||||||
}
|
|
||||||
|
|
||||||
int inputs_array_data[] = {2, 0, 1};
|
int inputs_array_data[] = {2, 0, 1};
|
||||||
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
|
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
|
||||||
int outputs_array_data[] = {4, 2, 3, 4, 5};
|
int outputs_array_data[] = {4, 2, 3, 4, 5};
|
||||||
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
|
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
|
||||||
|
|
||||||
TfLiteNode node;
|
const TfLiteRegistration registration = tflite::ops::micro::Register_SPLIT();
|
||||||
node.inputs = inputs_array;
|
micro::KernelRunner runner(registration, tensors, tensors_size, inputs_array,
|
||||||
node.outputs = outputs_array;
|
outputs_array, nullptr, micro_test::reporter);
|
||||||
node.user_data = user_data;
|
|
||||||
node.builtin_data = reinterpret_cast<void*>(&builtin_data);
|
|
||||||
node.custom_initial_data = nullptr;
|
|
||||||
node.custom_initial_data_size = 0;
|
|
||||||
|
|
||||||
if (registration->prepare) {
|
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.InitAndPrepare());
|
||||||
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
|
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.Invoke());
|
||||||
}
|
|
||||||
TF_LITE_MICRO_EXPECT_NE(nullptr, registration->invoke);
|
|
||||||
|
|
||||||
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->invoke(&context, &node));
|
|
||||||
if (registration->free) {
|
|
||||||
registration->free(&context, user_data);
|
|
||||||
}
|
|
||||||
|
|
||||||
for (int i = 0; i < output1_dims_count; ++i) {
|
for (int i = 0; i < output1_dims_count; ++i) {
|
||||||
TF_LITE_MICRO_EXPECT_NEAR(expected_output1_data.begin()[i], output1_data[i],
|
TF_LITE_MICRO_EXPECT_NEAR(expected_output1_data[i], output1_data[i], 1e-5f);
|
||||||
1e-5f);
|
|
||||||
}
|
}
|
||||||
for (int i = 0; i < output2_dims_count; ++i) {
|
for (int i = 0; i < output2_dims_count; ++i) {
|
||||||
TF_LITE_MICRO_EXPECT_NEAR(expected_output2_data.begin()[i], output2_data[i],
|
TF_LITE_MICRO_EXPECT_NEAR(expected_output2_data[i], output2_data[i], 1e-5f);
|
||||||
1e-5f);
|
|
||||||
}
|
}
|
||||||
for (int i = 0; i < output3_dims_count; ++i) {
|
for (int i = 0; i < output3_dims_count; ++i) {
|
||||||
TF_LITE_MICRO_EXPECT_NEAR(expected_output3_data.begin()[i], output3_data[i],
|
TF_LITE_MICRO_EXPECT_NEAR(expected_output3_data[i], output3_data[i], 1e-5f);
|
||||||
1e-5f);
|
|
||||||
}
|
}
|
||||||
for (int i = 0; i < output4_dims_count; ++i) {
|
for (int i = 0; i < output4_dims_count; ++i) {
|
||||||
TF_LITE_MICRO_EXPECT_NEAR(expected_output4_data.begin()[i], output4_data[i],
|
TF_LITE_MICRO_EXPECT_NEAR(expected_output4_data[i], output4_data[i], 1e-5f);
|
||||||
1e-5f);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void TestSplitTwoOutputsQuantized(
|
void TestSplitTwoOutputsQuantized(
|
||||||
std::initializer_list<int> input_dims_data,
|
const int* input_dims_data, const uint8_t* input_data,
|
||||||
std::initializer_list<uint8_t> input_data,
|
const int* axis_dims_data, const int32_t* axis_data,
|
||||||
std::initializer_list<int> axis_dims_data,
|
const int* output1_dims_data, const uint8_t* expected_output1_data,
|
||||||
std::initializer_list<int32_t> axis_data,
|
const int* output2_dims_data, const uint8_t* expected_output2_data,
|
||||||
std::initializer_list<int> output1_dims_data,
|
uint8_t* output1_data, uint8_t* output2_data) {
|
||||||
std::initializer_list<uint8_t> expected_output1_data,
|
TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data);
|
||||||
std::initializer_list<int> output2_dims_data,
|
TfLiteIntArray* axis_dims = IntArrayFromInts(axis_dims_data);
|
||||||
std::initializer_list<uint8_t> expected_output2_data, uint8_t* output1_data,
|
TfLiteIntArray* output1_dims = IntArrayFromInts(output1_dims_data);
|
||||||
uint8_t* output2_data) {
|
TfLiteIntArray* output2_dims = IntArrayFromInts(output2_dims_data);
|
||||||
TfLiteIntArray* input_dims = IntArrayFromInitializer(input_dims_data);
|
|
||||||
TfLiteIntArray* axis_dims = IntArrayFromInitializer(axis_dims_data);
|
|
||||||
TfLiteIntArray* output1_dims = IntArrayFromInitializer(output1_dims_data);
|
|
||||||
TfLiteIntArray* output2_dims = IntArrayFromInitializer(output2_dims_data);
|
|
||||||
const int output1_dims_count = ElementCount(*output1_dims);
|
const int output1_dims_count = ElementCount(*output1_dims);
|
||||||
const int output2_dims_count = ElementCount(*output2_dims);
|
const int output2_dims_count = ElementCount(*output2_dims);
|
||||||
|
|
||||||
@ -260,68 +188,37 @@ void TestSplitTwoOutputsQuantized(
|
|||||||
output2_data[i] = 23;
|
output2_data[i] = 23;
|
||||||
}
|
}
|
||||||
|
|
||||||
TfLiteContext context;
|
|
||||||
PopulateContext(tensors, tensors_size, micro_test::reporter, &context);
|
|
||||||
tflite::AllOpsResolver resolver;
|
|
||||||
const TfLiteRegistration* registration =
|
|
||||||
resolver.FindOp(tflite::BuiltinOperator_SPLIT);
|
|
||||||
TF_LITE_MICRO_EXPECT_NE(nullptr, registration);
|
|
||||||
|
|
||||||
TfLiteSplitParams builtin_data = {
|
|
||||||
.num_splits = 2,
|
|
||||||
};
|
|
||||||
|
|
||||||
void* user_data = nullptr;
|
|
||||||
if (registration->init) {
|
|
||||||
user_data = registration->init(&context, nullptr, 0);
|
|
||||||
}
|
|
||||||
|
|
||||||
int inputs_array_data[] = {2, 0, 1};
|
int inputs_array_data[] = {2, 0, 1};
|
||||||
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
|
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
|
||||||
int outputs_array_data[] = {2, 2, 3};
|
int outputs_array_data[] = {2, 2, 3};
|
||||||
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
|
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
|
||||||
|
|
||||||
TfLiteNode node;
|
const TfLiteRegistration registration = tflite::ops::micro::Register_SPLIT();
|
||||||
node.inputs = inputs_array;
|
micro::KernelRunner runner(registration, tensors, tensors_size, inputs_array,
|
||||||
node.outputs = outputs_array;
|
outputs_array, nullptr, micro_test::reporter);
|
||||||
node.user_data = user_data;
|
|
||||||
node.builtin_data = reinterpret_cast<void*>(&builtin_data);
|
|
||||||
node.custom_initial_data = nullptr;
|
|
||||||
node.custom_initial_data_size = 0;
|
|
||||||
|
|
||||||
if (registration->prepare) {
|
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.InitAndPrepare());
|
||||||
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
|
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.Invoke());
|
||||||
}
|
|
||||||
TF_LITE_MICRO_EXPECT_NE(nullptr, registration->invoke);
|
|
||||||
|
|
||||||
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->invoke(&context, &node));
|
|
||||||
if (registration->free) {
|
|
||||||
registration->free(&context, user_data);
|
|
||||||
}
|
|
||||||
|
|
||||||
for (int i = 0; i < output1_dims_count; ++i) {
|
for (int i = 0; i < output1_dims_count; ++i) {
|
||||||
TF_LITE_MICRO_EXPECT_EQ(expected_output1_data.begin()[i], output1_data[i]);
|
TF_LITE_MICRO_EXPECT_EQ(expected_output1_data[i], output1_data[i]);
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int i = 0; i < output2_dims_count; ++i) {
|
for (int i = 0; i < output2_dims_count; ++i) {
|
||||||
TF_LITE_MICRO_EXPECT_EQ(expected_output2_data.begin()[i], output2_data[i]);
|
TF_LITE_MICRO_EXPECT_EQ(expected_output2_data[i], output2_data[i]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void TestSplitTwoOutputsQuantized32(
|
void TestSplitTwoOutputsQuantized32(
|
||||||
std::initializer_list<int> input_dims_data,
|
const int* input_dims_data, const int32_t* input_data,
|
||||||
std::initializer_list<int32_t> input_data,
|
const int* axis_dims_data, const int32_t* axis_data,
|
||||||
std::initializer_list<int> axis_dims_data,
|
const int* output1_dims_data, const int32_t* expected_output1_data,
|
||||||
std::initializer_list<int32_t> axis_data,
|
const int* output2_dims_data, const int32_t* expected_output2_data,
|
||||||
std::initializer_list<int> output1_dims_data,
|
int32_t* output1_data, int32_t* output2_data) {
|
||||||
std::initializer_list<int32_t> expected_output1_data,
|
TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data);
|
||||||
std::initializer_list<int> output2_dims_data,
|
TfLiteIntArray* axis_dims = IntArrayFromInts(axis_dims_data);
|
||||||
std::initializer_list<int32_t> expected_output2_data, int32_t* output1_data,
|
TfLiteIntArray* output1_dims = IntArrayFromInts(output1_dims_data);
|
||||||
int32_t* output2_data) {
|
TfLiteIntArray* output2_dims = IntArrayFromInts(output2_dims_data);
|
||||||
TfLiteIntArray* input_dims = IntArrayFromInitializer(input_dims_data);
|
|
||||||
TfLiteIntArray* axis_dims = IntArrayFromInitializer(axis_dims_data);
|
|
||||||
TfLiteIntArray* output1_dims = IntArrayFromInitializer(output1_dims_data);
|
|
||||||
TfLiteIntArray* output2_dims = IntArrayFromInitializer(output2_dims_data);
|
|
||||||
const int output1_dims_count = ElementCount(*output1_dims);
|
const int output1_dims_count = ElementCount(*output1_dims);
|
||||||
const int output2_dims_count = ElementCount(*output2_dims);
|
const int output2_dims_count = ElementCount(*output2_dims);
|
||||||
|
|
||||||
@ -347,51 +244,24 @@ void TestSplitTwoOutputsQuantized32(
|
|||||||
output2_data[i] = 23;
|
output2_data[i] = 23;
|
||||||
}
|
}
|
||||||
|
|
||||||
TfLiteContext context;
|
|
||||||
PopulateContext(tensors, tensors_size, micro_test::reporter, &context);
|
|
||||||
tflite::AllOpsResolver resolver;
|
|
||||||
const TfLiteRegistration* registration =
|
|
||||||
resolver.FindOp(tflite::BuiltinOperator_SPLIT);
|
|
||||||
TF_LITE_MICRO_EXPECT_NE(nullptr, registration);
|
|
||||||
|
|
||||||
TfLiteSplitParams builtin_data = {
|
|
||||||
.num_splits = 2,
|
|
||||||
};
|
|
||||||
|
|
||||||
void* user_data = nullptr;
|
|
||||||
if (registration->init) {
|
|
||||||
user_data = registration->init(&context, nullptr, 0);
|
|
||||||
}
|
|
||||||
|
|
||||||
int inputs_array_data[] = {2, 0, 1};
|
int inputs_array_data[] = {2, 0, 1};
|
||||||
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
|
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
|
||||||
int outputs_array_data[] = {2, 2, 3};
|
int outputs_array_data[] = {2, 2, 3};
|
||||||
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
|
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
|
||||||
|
|
||||||
TfLiteNode node;
|
const TfLiteRegistration registration = tflite::ops::micro::Register_SPLIT();
|
||||||
node.inputs = inputs_array;
|
micro::KernelRunner runner(registration, tensors, tensors_size, inputs_array,
|
||||||
node.outputs = outputs_array;
|
outputs_array, nullptr, micro_test::reporter);
|
||||||
node.user_data = user_data;
|
|
||||||
node.builtin_data = reinterpret_cast<void*>(&builtin_data);
|
|
||||||
node.custom_initial_data = nullptr;
|
|
||||||
node.custom_initial_data_size = 0;
|
|
||||||
|
|
||||||
if (registration->prepare) {
|
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.InitAndPrepare());
|
||||||
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
|
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.Invoke());
|
||||||
}
|
|
||||||
TF_LITE_MICRO_EXPECT_NE(nullptr, registration->invoke);
|
|
||||||
|
|
||||||
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->invoke(&context, &node));
|
|
||||||
if (registration->free) {
|
|
||||||
registration->free(&context, user_data);
|
|
||||||
}
|
|
||||||
|
|
||||||
for (int i = 0; i < output1_dims_count; ++i) {
|
for (int i = 0; i < output1_dims_count; ++i) {
|
||||||
TF_LITE_MICRO_EXPECT_EQ(expected_output1_data.begin()[i], output1_data[i]);
|
TF_LITE_MICRO_EXPECT_EQ(expected_output1_data[i], output1_data[i]);
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int i = 0; i < output2_dims_count; ++i) {
|
for (int i = 0; i < output2_dims_count; ++i) {
|
||||||
TF_LITE_MICRO_EXPECT_EQ(expected_output2_data.begin()[i], output2_data[i]);
|
TF_LITE_MICRO_EXPECT_EQ(expected_output2_data[i], output2_data[i]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -401,91 +271,119 @@ void TestSplitTwoOutputsQuantized32(
|
|||||||
TF_LITE_MICRO_TESTS_BEGIN
|
TF_LITE_MICRO_TESTS_BEGIN
|
||||||
|
|
||||||
TF_LITE_MICRO_TEST(TwoSplitFourDimensionalAxisZero) {
|
TF_LITE_MICRO_TEST(TwoSplitFourDimensionalAxisZero) {
|
||||||
|
const int input_shape[] = {4, 2, 2, 2, 2};
|
||||||
|
const float input_data[] = {1, 2, 3, 4, 5, 6, 7, 8,
|
||||||
|
9, 10, 11, 12, 13, 14, 15, 16};
|
||||||
|
const int axis_shape[] = {1, 1};
|
||||||
|
const int32_t axis_data[] = {0};
|
||||||
|
const int output1_shape[] = {4, 1, 2, 2, 2};
|
||||||
|
const float golden1[] = {1, 2, 3, 4, 5, 6, 7, 8};
|
||||||
|
const int output2_shape[] = {4, 1, 2, 2, 2};
|
||||||
|
const float golden2[] = {9, 10, 11, 12, 13, 14, 15, 16};
|
||||||
|
|
||||||
constexpr int output1_dims_count = 8;
|
constexpr int output1_dims_count = 8;
|
||||||
constexpr int output2_dims_count = 8;
|
constexpr int output2_dims_count = 8;
|
||||||
float output1_data[output1_dims_count];
|
float output1_data[output1_dims_count];
|
||||||
float output2_data[output2_dims_count];
|
float output2_data[output2_dims_count];
|
||||||
tflite::testing::TestSplitTwoOutputsFloat(
|
tflite::testing::TestSplitTwoOutputsFloat(
|
||||||
{4, 2, 2, 2, 2}, // Input shape
|
input_shape, input_data, axis_shape, axis_data, output1_shape, golden1,
|
||||||
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, // Input values
|
output2_shape, golden2, output1_data, output2_data);
|
||||||
{1, 1}, // Axis shape
|
|
||||||
{0}, // Axis value
|
|
||||||
{4, 1, 2, 2, 2}, // Output1 shape
|
|
||||||
{1, 2, 3, 4, 5, 6, 7, 8}, // Output1 values
|
|
||||||
{4, 1, 2, 2, 2}, // Output2 shape
|
|
||||||
{9, 10, 11, 12, 13, 14, 15, 16}, // Output2 values
|
|
||||||
output1_data, output2_data);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TF_LITE_MICRO_TEST(TwoSplitFourDimensionalAxisOne) {
|
TF_LITE_MICRO_TEST(TwoSplitFourDimensionalAxisOne) {
|
||||||
|
const int input_shape[] = {4, 2, 2, 2, 2};
|
||||||
|
const float input_data[] = {1, 2, 3, 4, 5, 6, 7, 8,
|
||||||
|
9, 10, 11, 12, 13, 14, 15, 16};
|
||||||
|
const int axis_shape[] = {1, 1};
|
||||||
|
const int32_t axis_data[] = {1};
|
||||||
|
const int output1_shape[] = {4, 2, 1, 2, 2};
|
||||||
|
const float golden1[] = {1, 2, 3, 4, 9, 10, 11, 12};
|
||||||
|
const int output2_shape[] = {4, 2, 1, 2, 2};
|
||||||
|
const float golden2[] = {5, 6, 7, 8, 13, 14, 15, 16};
|
||||||
|
|
||||||
constexpr int output1_dims_count = 8;
|
constexpr int output1_dims_count = 8;
|
||||||
constexpr int output2_dims_count = 8;
|
constexpr int output2_dims_count = 8;
|
||||||
float output1_data[output1_dims_count];
|
float output1_data[output1_dims_count];
|
||||||
float output2_data[output2_dims_count];
|
float output2_data[output2_dims_count];
|
||||||
tflite::testing::TestSplitTwoOutputsFloat(
|
tflite::testing::TestSplitTwoOutputsFloat(
|
||||||
{4, 2, 2, 2, 2}, // Input shape
|
input_shape, input_data, axis_shape, axis_data, output1_shape, golden1,
|
||||||
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, // Input values
|
output2_shape, golden2, output1_data, output2_data);
|
||||||
{1, 1}, // Axis shape
|
|
||||||
{1}, // Axis value
|
|
||||||
{4, 2, 1, 2, 2}, // Output1 shape
|
|
||||||
{1, 2, 3, 4, 9, 10, 11, 12}, // Output1 values
|
|
||||||
{4, 2, 1, 2, 2}, // Output2 shape
|
|
||||||
{5, 6, 7, 8, 13, 14, 15, 16}, // Output2 values
|
|
||||||
output1_data, output2_data);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TF_LITE_MICRO_TEST(TwoSplitFourDimensionalAxisTwo) {
|
TF_LITE_MICRO_TEST(TwoSplitFourDimensionalAxisTwo) {
|
||||||
|
const int input_shape[] = {4, 2, 2, 2, 2};
|
||||||
|
const float input_data[] = {1, 2, 3, 4, 5, 6, 7, 8,
|
||||||
|
9, 10, 11, 12, 13, 14, 15, 16};
|
||||||
|
const int axis_shape[] = {1, 1};
|
||||||
|
const int32_t axis_data[] = {2};
|
||||||
|
const int output1_shape[] = {4, 2, 2, 1, 2};
|
||||||
|
const float golden1[] = {1, 2, 5, 6, 9, 10, 13, 14};
|
||||||
|
const int output2_shape[] = {4, 2, 2, 1, 2};
|
||||||
|
const float golden2[] = {3, 4, 7, 8, 11, 12, 15, 16};
|
||||||
|
|
||||||
constexpr int output1_dims_count = 8;
|
constexpr int output1_dims_count = 8;
|
||||||
constexpr int output2_dims_count = 8;
|
constexpr int output2_dims_count = 8;
|
||||||
float output1_data[output1_dims_count];
|
float output1_data[output1_dims_count];
|
||||||
float output2_data[output2_dims_count];
|
float output2_data[output2_dims_count];
|
||||||
tflite::testing::TestSplitTwoOutputsFloat(
|
tflite::testing::TestSplitTwoOutputsFloat(
|
||||||
{4, 2, 2, 2, 2}, // Input shape
|
input_shape, input_data, axis_shape, axis_data, output1_shape, golden1,
|
||||||
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, // Input values
|
output2_shape, golden2, output1_data, output2_data);
|
||||||
{1, 1}, // Axis shape
|
|
||||||
{2}, // Axis value
|
|
||||||
{4, 2, 2, 1, 2}, // Output1 shape
|
|
||||||
{1, 2, 5, 6, 9, 10, 13, 14}, // Output1 values
|
|
||||||
{4, 2, 2, 1, 2}, // Output2 shape
|
|
||||||
{3, 4, 7, 8, 11, 12, 15, 16}, // Output2 values
|
|
||||||
output1_data, output2_data);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TF_LITE_MICRO_TEST(TwoSplitFourDimensionalAxisThree) {
|
TF_LITE_MICRO_TEST(TwoSplitFourDimensionalAxisThree) {
|
||||||
|
const int input_shape[] = {4, 2, 2, 2, 2};
|
||||||
|
const float input_data[] = {1, 2, 3, 4, 5, 6, 7, 8,
|
||||||
|
9, 10, 11, 12, 13, 14, 15, 16};
|
||||||
|
const int axis_shape[] = {1, 1};
|
||||||
|
const int32_t axis_data[] = {3};
|
||||||
|
const int output1_shape[] = {4, 2, 2, 2, 1};
|
||||||
|
const float golden1[] = {1, 3, 5, 7, 9, 11, 13, 15};
|
||||||
|
const int output2_shape[] = {4, 2, 2, 2, 1};
|
||||||
|
const float golden2[] = {2, 4, 6, 8, 10, 12, 14, 16};
|
||||||
|
|
||||||
constexpr int output1_dims_count = 8;
|
constexpr int output1_dims_count = 8;
|
||||||
constexpr int output2_dims_count = 8;
|
constexpr int output2_dims_count = 8;
|
||||||
float output1_data[output1_dims_count];
|
float output1_data[output1_dims_count];
|
||||||
float output2_data[output2_dims_count];
|
float output2_data[output2_dims_count];
|
||||||
tflite::testing::TestSplitTwoOutputsFloat(
|
tflite::testing::TestSplitTwoOutputsFloat(
|
||||||
{4, 2, 2, 2, 2}, // Input shape
|
input_shape, input_data, axis_shape, axis_data, output1_shape, golden1,
|
||||||
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, // Input values
|
output2_shape, golden2, output1_data, output2_data);
|
||||||
{1, 1}, // Axis shape
|
|
||||||
{3}, // Axis value
|
|
||||||
{4, 2, 2, 2, 1}, // Output1 shape
|
|
||||||
{1, 3, 5, 7, 9, 11, 13, 15}, // Output1 values
|
|
||||||
{4, 2, 2, 2, 1}, // Output2 shape
|
|
||||||
{2, 4, 6, 8, 10, 12, 14, 16}, // Output2 values
|
|
||||||
output1_data, output2_data);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TF_LITE_MICRO_TEST(TwoSplitFourDimensionalNegativeAxis) {
|
TF_LITE_MICRO_TEST(TwoSplitFourDimensionalNegativeAxis) {
|
||||||
|
const int input_shape[] = {4, 2, 2, 2, 2};
|
||||||
|
const float input_data[] = {1, 2, 3, 4, 5, 6, 7, 8,
|
||||||
|
9, 10, 11, 12, 13, 14, 15, 16};
|
||||||
|
const int axis_shape[] = {1, 1};
|
||||||
|
const int32_t axis_data[] = {-4};
|
||||||
|
const int output1_shape[] = {4, 1, 2, 2, 2};
|
||||||
|
const float golden1[] = {1, 2, 3, 4, 5, 6, 7, 8};
|
||||||
|
const int output2_shape[] = {4, 1, 2, 2, 2};
|
||||||
|
const float golden2[] = {9, 10, 11, 12, 13, 14, 15, 16};
|
||||||
|
|
||||||
constexpr int output1_dims_count = 8;
|
constexpr int output1_dims_count = 8;
|
||||||
constexpr int output2_dims_count = 8;
|
constexpr int output2_dims_count = 8;
|
||||||
float output1_data[output1_dims_count];
|
float output1_data[output1_dims_count];
|
||||||
float output2_data[output2_dims_count];
|
float output2_data[output2_dims_count];
|
||||||
tflite::testing::TestSplitTwoOutputsFloat(
|
tflite::testing::TestSplitTwoOutputsFloat(
|
||||||
{4, 2, 2, 2, 2}, // Input shape
|
input_shape, input_data, axis_shape, axis_data, output1_shape, golden1,
|
||||||
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, // Input values
|
output2_shape, golden2, output1_data, output2_data);
|
||||||
{1, 1}, // Axis shape
|
|
||||||
{-4}, // Axis value
|
|
||||||
{4, 1, 2, 2, 2}, // Output1 shape
|
|
||||||
{1, 2, 3, 4, 5, 6, 7, 8}, // Output1 values
|
|
||||||
{4, 1, 2, 2, 2}, // Output2 shape
|
|
||||||
{9, 10, 11, 12, 13, 14, 15, 16}, // Output2 values
|
|
||||||
output1_data, output2_data);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TF_LITE_MICRO_TEST(FourSplit) {
|
TF_LITE_MICRO_TEST(FourSplit) {
|
||||||
|
const int input_shape[] = {1, 4};
|
||||||
|
const float input_data[] = {1, 2, 3, 4};
|
||||||
|
const int axis_shape[] = {1, 1};
|
||||||
|
const int32_t axis_data[] = {0};
|
||||||
|
const int output1_shape[] = {1, 1};
|
||||||
|
const float golden1[] = {1};
|
||||||
|
const int output2_shape[] = {1, 1};
|
||||||
|
const float golden2[] = {2};
|
||||||
|
const int output3_shape[] = {1, 1};
|
||||||
|
const float golden3[] = {3};
|
||||||
|
const int output4_shape[] = {1, 1};
|
||||||
|
const float golden4[] = {4};
|
||||||
|
|
||||||
constexpr int output1_dims_count = 1;
|
constexpr int output1_dims_count = 1;
|
||||||
constexpr int output2_dims_count = 1;
|
constexpr int output2_dims_count = 1;
|
||||||
constexpr int output3_dims_count = 1;
|
constexpr int output3_dims_count = 1;
|
||||||
@ -494,70 +392,69 @@ TF_LITE_MICRO_TEST(FourSplit) {
|
|||||||
float output2_data[output2_dims_count];
|
float output2_data[output2_dims_count];
|
||||||
float output3_data[output3_dims_count];
|
float output3_data[output3_dims_count];
|
||||||
float output4_data[output4_dims_count];
|
float output4_data[output4_dims_count];
|
||||||
tflite::testing::TestSplitFourOutputsFloat({1, 4}, // Input shape
|
tflite::testing::TestSplitFourOutputsFloat(
|
||||||
{1, 2, 3, 4}, // Input values
|
input_shape, input_data, axis_shape, axis_data, output1_shape, golden1,
|
||||||
{1, 1}, // Axis shape
|
output2_shape, golden2, output3_shape, golden3, output4_shape, golden4,
|
||||||
{0}, // Axis value
|
output1_data, output2_data, output3_data, output4_data);
|
||||||
{1, 1}, // Output1 shape
|
|
||||||
{1}, // Output1 values
|
|
||||||
{1, 1}, // Output2 shape
|
|
||||||
{2}, // Output2 values
|
|
||||||
{1, 1}, // Output3 shape
|
|
||||||
{3}, // Output3 values
|
|
||||||
{1, 1}, // Output4 shape
|
|
||||||
{4}, // Output4 values
|
|
||||||
output1_data, output2_data,
|
|
||||||
output3_data, output4_data);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TF_LITE_MICRO_TEST(TwoSplitOneDimensional) {
|
TF_LITE_MICRO_TEST(TwoSplitOneDimensional) {
|
||||||
|
const int input_shape[] = {1, 2};
|
||||||
|
const float input_data[] = {1, 2};
|
||||||
|
const int axis_shape[] = {1, 1};
|
||||||
|
const int32_t axis_data[] = {0};
|
||||||
|
const int output1_shape[] = {1, 1};
|
||||||
|
const float golden1[] = {1};
|
||||||
|
const int output2_shape[] = {1, 1};
|
||||||
|
const float golden2[] = {2};
|
||||||
|
|
||||||
constexpr int output1_dims_count = 8;
|
constexpr int output1_dims_count = 8;
|
||||||
constexpr int output2_dims_count = 8;
|
constexpr int output2_dims_count = 8;
|
||||||
float output1_data[output1_dims_count];
|
float output1_data[output1_dims_count];
|
||||||
float output2_data[output2_dims_count];
|
float output2_data[output2_dims_count];
|
||||||
tflite::testing::TestSplitTwoOutputsFloat({1, 2}, // Input shape
|
tflite::testing::TestSplitTwoOutputsFloat(
|
||||||
{1, 2}, // Input values
|
input_shape, input_data, axis_shape, axis_data, output1_shape, golden1,
|
||||||
{1, 1}, // Axis shape
|
output2_shape, golden2, output1_data, output2_data);
|
||||||
{0}, // Axis value
|
|
||||||
{1, 1}, // Output1 shape
|
|
||||||
{1}, // Output1 values
|
|
||||||
{1, 1}, // Output2 shape
|
|
||||||
{2}, // Output2 values
|
|
||||||
output1_data, output2_data);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TF_LITE_MICRO_TEST(TwoSplitFourDimensionalQuantized) {
|
TF_LITE_MICRO_TEST(TwoSplitFourDimensionalQuantized) {
|
||||||
|
const int input_shape[] = {4, 2, 2, 2, 2};
|
||||||
|
const uint8_t input_data[] = {1, 2, 3, 4, 5, 6, 7, 8,
|
||||||
|
9, 10, 11, 12, 13, 14, 15, 16};
|
||||||
|
const int axis_shape[] = {1, 1};
|
||||||
|
const int32_t axis_data[] = {1};
|
||||||
|
const int output1_shape[] = {4, 2, 1, 2, 2};
|
||||||
|
const uint8_t golden1[] = {1, 2, 3, 4, 9, 10, 11, 12};
|
||||||
|
const int output2_shape[] = {4, 2, 1, 2, 2};
|
||||||
|
const uint8_t golden2[] = {5, 6, 7, 8, 13, 14, 15, 16};
|
||||||
|
|
||||||
constexpr int output1_dims_count = 8;
|
constexpr int output1_dims_count = 8;
|
||||||
constexpr int output2_dims_count = 8;
|
constexpr int output2_dims_count = 8;
|
||||||
uint8_t output1_data[output1_dims_count];
|
uint8_t output1_data[output1_dims_count];
|
||||||
uint8_t output2_data[output2_dims_count];
|
uint8_t output2_data[output2_dims_count];
|
||||||
tflite::testing::TestSplitTwoOutputsQuantized(
|
tflite::testing::TestSplitTwoOutputsQuantized(
|
||||||
{4, 2, 2, 2, 2}, // Input shape
|
input_shape, input_data, axis_shape, axis_data, output1_shape, golden1,
|
||||||
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, // Input values
|
output2_shape, golden2, output1_data, output2_data);
|
||||||
{1, 1}, // Axis shape
|
|
||||||
{0}, // Axis value
|
|
||||||
{4, 1, 2, 2, 2}, // Output1 shape
|
|
||||||
{1, 2, 3, 4, 5, 6, 7, 8}, // Output1 values
|
|
||||||
{4, 1, 2, 2, 2}, // Output2 shape
|
|
||||||
{9, 10, 11, 12, 13, 14, 15, 16}, // Output2 values
|
|
||||||
output1_data, output2_data);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TF_LITE_MICRO_TEST(TwoSplitFourDimensionalQuantized32) {
|
TF_LITE_MICRO_TEST(TwoSplitFourDimensionalQuantized32) {
|
||||||
|
const int input_shape[] = {4, 2, 2, 2, 2};
|
||||||
|
const int32_t input_data[] = {1, 2, 3, 4, 5, 6, 7, 8,
|
||||||
|
9, 10, 11, 12, 13, 14, 15, 16};
|
||||||
|
const int axis_shape[] = {1, 1};
|
||||||
|
const int32_t axis_data[] = {1};
|
||||||
|
const int output1_shape[] = {4, 2, 1, 2, 2};
|
||||||
|
const int32_t golden1[] = {1, 2, 3, 4, 9, 10, 11, 12};
|
||||||
|
const int output2_shape[] = {4, 2, 1, 2, 2};
|
||||||
|
const int32_t golden2[] = {5, 6, 7, 8, 13, 14, 15, 16};
|
||||||
|
|
||||||
constexpr int output1_dims_count = 8;
|
constexpr int output1_dims_count = 8;
|
||||||
constexpr int output2_dims_count = 8;
|
constexpr int output2_dims_count = 8;
|
||||||
int32_t output1_data[output1_dims_count];
|
int32_t output1_data[output1_dims_count];
|
||||||
int32_t output2_data[output2_dims_count];
|
int32_t output2_data[output2_dims_count];
|
||||||
tflite::testing::TestSplitTwoOutputsQuantized32(
|
tflite::testing::TestSplitTwoOutputsQuantized32(
|
||||||
{4, 2, 2, 2, 2}, // Input shape
|
input_shape, input_data, axis_shape, axis_data, output1_shape, golden1,
|
||||||
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, // Input values
|
output2_shape, golden2, output1_data, output2_data);
|
||||||
{1, 1}, // Axis shape
|
|
||||||
{0}, // Axis value
|
|
||||||
{4, 1, 2, 2, 2}, // Output1 shape
|
|
||||||
{1, 2, 3, 4, 5, 6, 7, 8}, // Output1 values
|
|
||||||
{4, 1, 2, 2, 2}, // Output2 shape
|
|
||||||
{9, 10, 11, 12, 13, 14, 15, 16}, // Output2 values
|
|
||||||
output1_data, output2_data);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TF_LITE_MICRO_TESTS_END
|
TF_LITE_MICRO_TESTS_END
|
||||||
|
@ -15,23 +15,20 @@ limitations under the License.
|
|||||||
#include "tensorflow/lite/kernels/internal/reference/strided_slice.h"
|
#include "tensorflow/lite/kernels/internal/reference/strided_slice.h"
|
||||||
|
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
|
#include <cstring>
|
||||||
|
|
||||||
#include "tensorflow/lite/c/builtin_op_data.h"
|
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||||
#include "tensorflow/lite/c/common.h"
|
#include "tensorflow/lite/c/common.h"
|
||||||
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
|
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
|
||||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||||
#include "tensorflow/lite/kernels/op_macros.h"
|
#include "tensorflow/lite/kernels/op_macros.h"
|
||||||
|
#include "tensorflow/lite/micro/kernels/kernel_util.h"
|
||||||
|
|
||||||
namespace tflite {
|
namespace tflite {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
namespace micro {
|
namespace micro {
|
||||||
namespace strided_slice {
|
namespace strided_slice {
|
||||||
|
|
||||||
enum KernelType {
|
|
||||||
kReference,
|
|
||||||
// TODO(soroosh): add kGenericOptimized
|
|
||||||
};
|
|
||||||
|
|
||||||
constexpr int kInputTensor = 0;
|
constexpr int kInputTensor = 0;
|
||||||
constexpr int kBeginTensor = 1;
|
constexpr int kBeginTensor = 1;
|
||||||
constexpr int kEndTensor = 2;
|
constexpr int kEndTensor = 2;
|
||||||
@ -120,58 +117,70 @@ TfLiteStatus CheckOutputSize(TfLiteContext* context,
|
|||||||
return kTfLiteOk;
|
return kTfLiteOk;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
||||||
|
TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
|
||||||
|
return context->AllocatePersistentBuffer(context, sizeof(StridedSliceParams));
|
||||||
|
}
|
||||||
|
|
||||||
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||||
|
TFLITE_DCHECK(node->user_data != nullptr);
|
||||||
|
StridedSliceParams* op_params =
|
||||||
|
static_cast<StridedSliceParams*>(node->user_data);
|
||||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 4);
|
TF_LITE_ENSURE_EQ(context, NumInputs(node), 4);
|
||||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||||
StridedSliceContext op_context(context, node);
|
StridedSliceContext op_context(context, node);
|
||||||
TF_LITE_ENSURE_MSG(context, op_context.dims <= kMaxDim,
|
TF_LITE_ENSURE_MSG(context, op_context.dims <= kMaxDim,
|
||||||
"input dim should not exceed 4");
|
"input dim should not exceed 4");
|
||||||
|
auto params = BuildStridedSliceParams(&op_context);
|
||||||
|
memcpy(op_params, ¶ms, sizeof(StridedSliceParams));
|
||||||
return CheckOutputSize(context, &op_context);
|
return CheckOutputSize(context, &op_context);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <KernelType kernel_type>
|
|
||||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||||
StridedSliceContext op_context(context, node);
|
TFLITE_DCHECK(node->user_data != nullptr);
|
||||||
auto op_params = BuildStridedSliceParams(&op_context);
|
const StridedSliceParams& op_params =
|
||||||
|
*(static_cast<const StridedSliceParams*>(node->user_data));
|
||||||
|
|
||||||
#define TF_LITE_STRIDED_SLICE(kernel_type, data_type) \
|
const TfLiteEvalTensor* input =
|
||||||
kernel_type::StridedSlice(op_params, GetTensorShape(op_context.input), \
|
tflite::micro::GetEvalInput(context, node, kInputTensor);
|
||||||
GetTensorData<data_type>(op_context.input), \
|
TfLiteEvalTensor* output =
|
||||||
GetTensorShape(op_context.output), \
|
tflite::micro::GetEvalOutput(context, node, kOutputTensor);
|
||||||
GetTensorData<data_type>(op_context.output))
|
switch (output->type) {
|
||||||
|
|
||||||
switch (op_context.input->type) {
|
|
||||||
case kTfLiteFloat32:
|
case kTfLiteFloat32:
|
||||||
if (kernel_type == kReference) {
|
reference_ops::StridedSlice(op_params,
|
||||||
TF_LITE_STRIDED_SLICE(reference_ops, float);
|
tflite::micro::GetTensorShape(input),
|
||||||
}
|
tflite::micro::GetTensorData<float>(input),
|
||||||
|
tflite::micro::GetTensorShape(output),
|
||||||
|
tflite::micro::GetTensorData<float>(output));
|
||||||
break;
|
break;
|
||||||
case kTfLiteUInt8:
|
case kTfLiteUInt8:
|
||||||
if (kernel_type == kReference) {
|
reference_ops::StridedSlice(
|
||||||
TF_LITE_STRIDED_SLICE(reference_ops, uint8_t);
|
op_params, tflite::micro::GetTensorShape(input),
|
||||||
}
|
tflite::micro::GetTensorData<uint8_t>(input),
|
||||||
|
tflite::micro::GetTensorShape(output),
|
||||||
|
tflite::micro::GetTensorData<uint8_t>(output));
|
||||||
break;
|
break;
|
||||||
case kTfLiteInt8:
|
case kTfLiteInt8:
|
||||||
if (kernel_type == kReference) {
|
reference_ops::StridedSlice(op_params,
|
||||||
TF_LITE_STRIDED_SLICE(reference_ops, int8_t);
|
tflite::micro::GetTensorShape(input),
|
||||||
}
|
tflite::micro::GetTensorData<int8_t>(input),
|
||||||
|
tflite::micro::GetTensorShape(output),
|
||||||
|
tflite::micro::GetTensorData<int8_t>(output));
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
|
TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
|
||||||
TfLiteTypeGetName(op_context.input->type),
|
TfLiteTypeGetName(input->type), input->type);
|
||||||
op_context.input->type);
|
|
||||||
return kTfLiteError;
|
return kTfLiteError;
|
||||||
}
|
}
|
||||||
#undef TF_LITE_STRIDED_SLICE
|
|
||||||
return kTfLiteOk;
|
return kTfLiteOk;
|
||||||
}
|
}
|
||||||
} // namespace strided_slice
|
} // namespace strided_slice
|
||||||
|
|
||||||
TfLiteRegistration Register_STRIDED_SLICE() {
|
TfLiteRegistration Register_STRIDED_SLICE() {
|
||||||
return {/*init=*/nullptr,
|
return {/*init=*/strided_slice::Init,
|
||||||
/*free=*/nullptr,
|
/*free=*/nullptr,
|
||||||
/*prepare=*/strided_slice::Prepare,
|
/*prepare=*/strided_slice::Prepare,
|
||||||
/*invoke=*/strided_slice::Eval<strided_slice::kReference>,
|
/*invoke=*/strided_slice::Eval,
|
||||||
/*profiling_string=*/nullptr,
|
/*profiling_string=*/nullptr,
|
||||||
/*builtin_code=*/0,
|
/*builtin_code=*/0,
|
||||||
/*custom_name=*/nullptr,
|
/*custom_name=*/nullptr,
|
||||||
|
File diff suppressed because it is too large
Load Diff
@ -21,8 +21,10 @@ limitations under the License.
|
|||||||
#include "tensorflow/lite/kernels/internal/quantization_util.h"
|
#include "tensorflow/lite/kernels/internal/quantization_util.h"
|
||||||
#include "tensorflow/lite/kernels/internal/reference/process_broadcast_shapes.h"
|
#include "tensorflow/lite/kernels/internal/reference/process_broadcast_shapes.h"
|
||||||
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
|
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
|
||||||
|
#include "tensorflow/lite/kernels/internal/types.h"
|
||||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||||
#include "tensorflow/lite/kernels/op_macros.h"
|
#include "tensorflow/lite/kernels/op_macros.h"
|
||||||
|
#include "tensorflow/lite/micro/kernels/kernel_util.h"
|
||||||
|
|
||||||
namespace tflite {
|
namespace tflite {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
@ -93,31 +95,59 @@ TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteSubParams* params,
|
|||||||
return kTfLiteOk;
|
return kTfLiteOk;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
||||||
|
TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
|
||||||
|
return context->AllocatePersistentBuffer(context, sizeof(OpData));
|
||||||
|
}
|
||||||
|
|
||||||
|
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||||
|
TFLITE_DCHECK(node->user_data != nullptr);
|
||||||
|
TFLITE_DCHECK(node->builtin_data != nullptr);
|
||||||
|
|
||||||
|
OpData* data = static_cast<OpData*>(node->user_data);
|
||||||
|
auto* params = reinterpret_cast<TfLiteSubParams*>(node->builtin_data);
|
||||||
|
|
||||||
|
const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
|
||||||
|
const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
|
||||||
|
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||||
|
|
||||||
|
TF_LITE_ENSURE_STATUS(
|
||||||
|
CalculateOpData(context, params, input1, input2, output, data));
|
||||||
|
return kTfLiteOk;
|
||||||
|
}
|
||||||
|
|
||||||
void EvalSub(TfLiteContext* context, TfLiteNode* node, TfLiteSubParams* params,
|
void EvalSub(TfLiteContext* context, TfLiteNode* node, TfLiteSubParams* params,
|
||||||
const OpData* data, const TfLiteTensor* input1,
|
const OpData* data, const TfLiteEvalTensor* input1,
|
||||||
const TfLiteTensor* input2, TfLiteTensor* output) {
|
const TfLiteEvalTensor* input2, TfLiteEvalTensor* output) {
|
||||||
float output_activation_min, output_activation_max;
|
float output_activation_min, output_activation_max;
|
||||||
CalculateActivationRange(params->activation, &output_activation_min,
|
CalculateActivationRange(params->activation, &output_activation_min,
|
||||||
&output_activation_max);
|
&output_activation_max);
|
||||||
tflite::ArithmeticParams op_params;
|
tflite::ArithmeticParams op_params;
|
||||||
SetActivationParams(output_activation_min, output_activation_max, &op_params);
|
SetActivationParams(output_activation_min, output_activation_max, &op_params);
|
||||||
#define TF_LITE_SUB(opname) \
|
|
||||||
opname(op_params, GetTensorShape(input1), GetTensorData<float>(input1), \
|
|
||||||
GetTensorShape(input2), GetTensorData<float>(input2), \
|
|
||||||
GetTensorShape(output), GetTensorData<float>(output))
|
|
||||||
if (data->requires_broadcast) {
|
if (data->requires_broadcast) {
|
||||||
TF_LITE_SUB(tflite::reference_ops::BroadcastSubSlow);
|
tflite::reference_ops::BroadcastSubSlow(
|
||||||
|
op_params, tflite::micro::GetTensorShape(input1),
|
||||||
|
tflite::micro::GetTensorData<float>(input1),
|
||||||
|
tflite::micro::GetTensorShape(input2),
|
||||||
|
tflite::micro::GetTensorData<float>(input2),
|
||||||
|
tflite::micro::GetTensorShape(output),
|
||||||
|
tflite::micro::GetTensorData<float>(output));
|
||||||
} else {
|
} else {
|
||||||
TF_LITE_SUB(tflite::reference_ops::SubWithActivation);
|
tflite::reference_ops::SubWithActivation(
|
||||||
|
op_params, tflite::micro::GetTensorShape(input1),
|
||||||
|
tflite::micro::GetTensorData<float>(input1),
|
||||||
|
tflite::micro::GetTensorShape(input2),
|
||||||
|
tflite::micro::GetTensorData<float>(input2),
|
||||||
|
tflite::micro::GetTensorShape(output),
|
||||||
|
tflite::micro::GetTensorData<float>(output));
|
||||||
}
|
}
|
||||||
#undef TF_LITE_SUB
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TfLiteStatus EvalSubQuantized(TfLiteContext* context, TfLiteNode* node,
|
TfLiteStatus EvalSubQuantized(TfLiteContext* context, TfLiteNode* node,
|
||||||
TfLiteSubParams* params, const OpData* data,
|
TfLiteSubParams* params, const OpData* data,
|
||||||
const TfLiteTensor* input1,
|
const TfLiteEvalTensor* input1,
|
||||||
const TfLiteTensor* input2,
|
const TfLiteEvalTensor* input2,
|
||||||
TfLiteTensor* output) {
|
TfLiteEvalTensor* output) {
|
||||||
if (output->type == kTfLiteUInt8 || output->type == kTfLiteInt8) {
|
if (output->type == kTfLiteUInt8 || output->type == kTfLiteInt8) {
|
||||||
tflite::ArithmeticParams op_params;
|
tflite::ArithmeticParams op_params;
|
||||||
op_params.left_shift = data->left_shift;
|
op_params.left_shift = data->left_shift;
|
||||||
@ -133,25 +163,46 @@ TfLiteStatus EvalSubQuantized(TfLiteContext* context, TfLiteNode* node,
|
|||||||
SetActivationParams(data->output_activation_min,
|
SetActivationParams(data->output_activation_min,
|
||||||
data->output_activation_max, &op_params);
|
data->output_activation_max, &op_params);
|
||||||
bool need_broadcast = reference_ops::ProcessBroadcastShapes(
|
bool need_broadcast = reference_ops::ProcessBroadcastShapes(
|
||||||
GetTensorShape(input1), GetTensorShape(input2), &op_params);
|
tflite::micro::GetTensorShape(input1),
|
||||||
#define TF_LITE_SUB(opname, dtype) \
|
tflite::micro::GetTensorShape(input2), &op_params);
|
||||||
opname(op_params, GetTensorShape(input1), GetTensorData<dtype>(input1), \
|
|
||||||
GetTensorShape(input2), GetTensorData<dtype>(input2), \
|
|
||||||
GetTensorShape(output), GetTensorData<dtype>(output));
|
|
||||||
if (output->type == kTfLiteInt8) {
|
if (output->type == kTfLiteInt8) {
|
||||||
if (need_broadcast) {
|
if (need_broadcast) {
|
||||||
TF_LITE_SUB(tflite::reference_ops::BroadcastSubSlow, int8_t);
|
tflite::reference_ops::BroadcastSubSlow(
|
||||||
|
op_params, tflite::micro::GetTensorShape(input1),
|
||||||
|
tflite::micro::GetTensorData<int8_t>(input1),
|
||||||
|
tflite::micro::GetTensorShape(input2),
|
||||||
|
tflite::micro::GetTensorData<int8_t>(input2),
|
||||||
|
tflite::micro::GetTensorShape(output),
|
||||||
|
tflite::micro::GetTensorData<int8_t>(output));
|
||||||
} else {
|
} else {
|
||||||
TF_LITE_SUB(tflite::reference_ops::Sub, int8_t);
|
tflite::reference_ops::Sub(
|
||||||
|
op_params, tflite::micro::GetTensorShape(input1),
|
||||||
|
tflite::micro::GetTensorData<int8_t>(input1),
|
||||||
|
tflite::micro::GetTensorShape(input2),
|
||||||
|
tflite::micro::GetTensorData<int8_t>(input2),
|
||||||
|
tflite::micro::GetTensorShape(output),
|
||||||
|
tflite::micro::GetTensorData<int8_t>(output));
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if (need_broadcast) {
|
if (need_broadcast) {
|
||||||
TF_LITE_SUB(tflite::reference_ops::BroadcastSubSlow, uint8_t);
|
tflite::reference_ops::BroadcastSubSlow(
|
||||||
|
op_params, tflite::micro::GetTensorShape(input1),
|
||||||
|
tflite::micro::GetTensorData<uint8_t>(input1),
|
||||||
|
tflite::micro::GetTensorShape(input2),
|
||||||
|
tflite::micro::GetTensorData<uint8_t>(input2),
|
||||||
|
tflite::micro::GetTensorShape(output),
|
||||||
|
tflite::micro::GetTensorData<uint8_t>(output));
|
||||||
} else {
|
} else {
|
||||||
TF_LITE_SUB(tflite::reference_ops::Sub, uint8_t);
|
tflite::reference_ops::Sub(
|
||||||
|
op_params, tflite::micro::GetTensorShape(input1),
|
||||||
|
tflite::micro::GetTensorData<uint8_t>(input1),
|
||||||
|
tflite::micro::GetTensorShape(input2),
|
||||||
|
tflite::micro::GetTensorData<uint8_t>(input2),
|
||||||
|
tflite::micro::GetTensorShape(output),
|
||||||
|
tflite::micro::GetTensorData<uint8_t>(output));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#undef TF_LITE_SUB
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return kTfLiteOk;
|
return kTfLiteOk;
|
||||||
@ -160,13 +211,15 @@ TfLiteStatus EvalSubQuantized(TfLiteContext* context, TfLiteNode* node,
|
|||||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||||
auto* params = reinterpret_cast<TfLiteSubParams*>(node->builtin_data);
|
auto* params = reinterpret_cast<TfLiteSubParams*>(node->builtin_data);
|
||||||
|
|
||||||
const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
|
const TfLiteEvalTensor* input1 =
|
||||||
const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
|
tflite::micro::GetEvalInput(context, node, kInputTensor1);
|
||||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
const TfLiteEvalTensor* input2 =
|
||||||
|
tflite::micro::GetEvalInput(context, node, kInputTensor2);
|
||||||
|
TfLiteEvalTensor* output =
|
||||||
|
tflite::micro::GetEvalOutput(context, node, kOutputTensor);
|
||||||
|
|
||||||
OpData data;
|
TFLITE_DCHECK(node->user_data != nullptr);
|
||||||
TF_LITE_ENSURE_STATUS(
|
const OpData& data = *(static_cast<const OpData*>(node->user_data));
|
||||||
CalculateOpData(context, params, input1, input2, output, &data));
|
|
||||||
|
|
||||||
if (output->type == kTfLiteFloat32) {
|
if (output->type == kTfLiteFloat32) {
|
||||||
EvalSub(context, node, params, &data, input1, input2, output);
|
EvalSub(context, node, params, &data, input1, input2, output);
|
||||||
@ -185,9 +238,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
} // namespace sub
|
} // namespace sub
|
||||||
|
|
||||||
TfLiteRegistration Register_SUB() {
|
TfLiteRegistration Register_SUB() {
|
||||||
return {/*init=*/nullptr,
|
return {/*init=*/sub::Init,
|
||||||
/*free=*/nullptr,
|
/*free=*/nullptr,
|
||||||
/*prepare=*/nullptr,
|
/*prepare=*/sub::Prepare,
|
||||||
/*invoke=*/sub::Eval,
|
/*invoke=*/sub::Eval,
|
||||||
/*profiling_string=*/nullptr,
|
/*profiling_string=*/nullptr,
|
||||||
/*builtin_code=*/0,
|
/*builtin_code=*/0,
|
||||||
|
@ -17,7 +17,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/lite/c/builtin_op_data.h"
|
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||||
#include "tensorflow/lite/c/common.h"
|
#include "tensorflow/lite/c/common.h"
|
||||||
#include "tensorflow/lite/micro/all_ops_resolver.h"
|
#include "tensorflow/lite/micro/kernels/kernel_runner.h"
|
||||||
#include "tensorflow/lite/micro/testing/micro_test.h"
|
#include "tensorflow/lite/micro/testing/micro_test.h"
|
||||||
#include "tensorflow/lite/micro/testing/test_utils.h"
|
#include "tensorflow/lite/micro/testing/test_utils.h"
|
||||||
|
|
||||||
@ -66,47 +66,21 @@ void ValidateSubGoldens(TfLiteTensor* tensors, int tensors_size,
|
|||||||
const T* golden, T* output, int output_size,
|
const T* golden, T* output, int output_size,
|
||||||
TfLiteFusedActivation activation,
|
TfLiteFusedActivation activation,
|
||||||
float tolerance = 1e-5) {
|
float tolerance = 1e-5) {
|
||||||
TfLiteContext context;
|
|
||||||
PopulateContext(tensors, tensors_size, micro_test::reporter, &context);
|
|
||||||
|
|
||||||
::tflite::AllOpsResolver resolver;
|
|
||||||
const TfLiteRegistration* registration =
|
|
||||||
resolver.FindOp(::tflite::BuiltinOperator_SUB);
|
|
||||||
|
|
||||||
TF_LITE_MICRO_EXPECT_NE(nullptr, registration);
|
|
||||||
|
|
||||||
TfLiteSubParams builtin_data;
|
TfLiteSubParams builtin_data;
|
||||||
builtin_data.activation = activation;
|
builtin_data.activation = activation;
|
||||||
|
|
||||||
const char* init_data = reinterpret_cast<const char*>(&builtin_data);
|
|
||||||
const size_t init_data_size = 0;
|
|
||||||
void* user_data = nullptr;
|
|
||||||
if (registration->init) {
|
|
||||||
user_data = registration->init(&context, init_data, init_data_size);
|
|
||||||
}
|
|
||||||
|
|
||||||
int inputs_array_data[] = {2, 0, 1};
|
int inputs_array_data[] = {2, 0, 1};
|
||||||
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
|
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
|
||||||
int outputs_array_data[] = {1, 2};
|
int outputs_array_data[] = {1, 2};
|
||||||
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
|
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
|
||||||
|
|
||||||
TfLiteNode node;
|
const TfLiteRegistration registration = tflite::ops::micro::Register_SUB();
|
||||||
node.inputs = inputs_array;
|
micro::KernelRunner runner(registration, tensors, tensors_size, inputs_array,
|
||||||
node.outputs = outputs_array;
|
outputs_array, &builtin_data,
|
||||||
node.user_data = user_data;
|
micro_test::reporter);
|
||||||
node.builtin_data = reinterpret_cast<void*>(&builtin_data);
|
|
||||||
node.custom_initial_data = nullptr;
|
|
||||||
node.custom_initial_data_size = 0;
|
|
||||||
|
|
||||||
if (registration->prepare) {
|
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.InitAndPrepare());
|
||||||
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
|
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.Invoke());
|
||||||
}
|
|
||||||
TF_LITE_MICRO_EXPECT_NE(nullptr, registration->invoke);
|
|
||||||
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->invoke(&context, &node));
|
|
||||||
|
|
||||||
if (registration->free) {
|
|
||||||
registration->free(&context, user_data);
|
|
||||||
}
|
|
||||||
|
|
||||||
for (int i = 0; i < output_size; ++i) {
|
for (int i = 0; i < output_size; ++i) {
|
||||||
TF_LITE_MICRO_EXPECT_NEAR(golden[i], output[i], tolerance);
|
TF_LITE_MICRO_EXPECT_NEAR(golden[i], output[i], tolerance);
|
||||||
|
Loading…
Reference in New Issue
Block a user