Added SHAPE op and tests for micro
This commit is contained in:
parent
46b6537110
commit
0da7fd25e1
tensorflow/lite
@ -1514,6 +1514,20 @@ TfLiteStatus ParseRsqrt(const Operator*, ErrorReporter*, BuiltinDataAllocator*,
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
// We have this parse function instead of directly returning kTfLiteOk from the
|
||||
// switch-case in ParseOpData because this function is used as part of the
|
||||
// selective registration for the OpResolver implementation in micro.
|
||||
TfLiteStatus ParseShape(const Operator*, ErrorReporter* error_reporter,
|
||||
|
||||
SafeBuiltinDataAllocator safe_allocator(allocator);
|
||||
std::unique_ptr<TfLiteShapeParams,
|
||||
SafeBuiltinDataAllocator::BuiltinDataDeleter>
|
||||
params = safe_allocator.Allocate<TfLiteShapeParams>();
|
||||
TF_LITE_ENSURE(error_reporter, params != nullptr);
|
||||
*builtin_data = params.release();
|
||||
return kTfLiteOk;
|
||||
} // namespace tflite
|
||||
|
||||
// We have this parse function instead of directly returning kTfLiteOk from the
|
||||
// switch-case in ParseOpData because this function is used as part of the
|
||||
// selective registration for the OpResolver implementation in micro.
|
||||
|
@ -216,6 +216,9 @@ TfLiteStatus ParseRound(const Operator* op, ErrorReporter* error_reporter,
|
||||
TfLiteStatus ParseRsqrt(const Operator* op, ErrorReporter* error_reporter,
|
||||
BuiltinDataAllocator* allocator, void** builtin_data);
|
||||
|
||||
TfLiteStatus ParseShape(const Operator* op, ErrorReporter* error_reporter,
|
||||
BuiltinDataAllocator* allocator, void** builtin_data);
|
||||
|
||||
TfLiteStatus ParseSin(const Operator* op, ErrorReporter* error_reporter,
|
||||
BuiltinDataAllocator* allocator, void** builtin_data);
|
||||
|
||||
|
@ -70,6 +70,7 @@ AllOpsResolver::AllOpsResolver() {
|
||||
AddResizeNearestNeighbor();
|
||||
AddRound();
|
||||
AddRsqrt();
|
||||
AddShape();
|
||||
AddSin();
|
||||
AddSoftmax();
|
||||
AddSplit();
|
||||
|
@ -53,6 +53,7 @@ cc_library(
|
||||
"reshape.cc",
|
||||
"resize_nearest_neighbor.cc",
|
||||
"round.cc",
|
||||
"shape.cc",
|
||||
"split.cc",
|
||||
"split_v.cc",
|
||||
"strided_slice.cc",
|
||||
|
@ -74,6 +74,7 @@ TfLiteRegistration Register_RESHAPE();
|
||||
TfLiteRegistration Register_RESIZE_NEAREST_NEIGHBOR();
|
||||
TfLiteRegistration Register_ROUND();
|
||||
TfLiteRegistration Register_RSQRT();
|
||||
TfLiteRegistration Register_SHAPE();
|
||||
TfLiteRegistration Register_SIN();
|
||||
TfLiteRegistration Register_SOFTMAX();
|
||||
TfLiteRegistration Register_SPLIT();
|
||||
|
80
tensorflow/lite/micro/kernels/shape.cc
Executable file
80
tensorflow/lite/micro/kernels/shape.cc
Executable file
@ -0,0 +1,80 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
|
||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||
#include "tensorflow/lite/kernels/op_macros.h"
|
||||
#include "tensorflow/lite/micro/kernels/kernel_util.h"
|
||||
#include "tensorflow/lite/micro/memory_helpers.h"
|
||||
#include "tensorflow/lite/micro/micro_utils.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace ops {
|
||||
namespace micro {
|
||||
namespace shape {
|
||||
|
||||
constexpr int kInputTensor = 0;
|
||||
constexpr int kOutputTensor = 0;
|
||||
|
||||
void ExtractShape(const TfLiteEvalTensor* input, int32_t* output_data) {
|
||||
int numInputDims = input->dims->size;
|
||||
|
||||
for (int i = 0; i < numInputDims; ++i) {
|
||||
output_data[i] = input->dims->data[i];
|
||||
}
|
||||
}
|
||||
|
||||
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteEvalTensor* input =
|
||||
tflite::micro::GetEvalInput(context, node, kInputTensor);
|
||||
TfLiteEvalTensor* output =
|
||||
tflite::micro::GetEvalOutput(context, node, kOutputTensor);
|
||||
|
||||
switch (output->type) {
|
||||
case kTfLiteInt32:
|
||||
ExtractShape(input, tflite::micro::GetTensorData<int32_t>(output));
|
||||
break;
|
||||
default:
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
} // namespace shape
|
||||
|
||||
TfLiteRegistration Register_SHAPE() {
|
||||
return {/*init=*/nullptr,
|
||||
/*free=*/nullptr,
|
||||
/*prepare=*/shape::Prepare,
|
||||
/*invoke=*/shape::Eval,
|
||||
/*profiling_string=*/nullptr,
|
||||
/*builtin_code=*/0,
|
||||
/*custom_name=*/nullptr,
|
||||
/*version=*/0};
|
||||
}
|
||||
|
||||
} // namespace micro
|
||||
} // namespace ops
|
||||
} // namespace tflite
|
139
tensorflow/lite/micro/kernels/shape_test.cc
Executable file
139
tensorflow/lite/micro/kernels/shape_test.cc
Executable file
@ -0,0 +1,139 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/micro/all_ops_resolver.h"
|
||||
#include "tensorflow/lite/micro/kernels/kernel_runner.h"
|
||||
#include "tensorflow/lite/micro/test_helpers.h"
|
||||
#include "tensorflow/lite/micro/testing/micro_test.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace testing {
|
||||
namespace {
|
||||
|
||||
void ValidateShape(TfLiteTensor* tensors, const int tensor_count,
|
||||
int* output_data, const int* expected_output,
|
||||
int output_dims_count) {
|
||||
TfLiteShapeParams builtin_data;
|
||||
int inputs_array_data[] = {1, 0};
|
||||
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
|
||||
int outputs_array_data[] = {1, 1};
|
||||
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
|
||||
|
||||
const TfLiteRegistration registration = tflite::ops::micro::Register_SHAPE();
|
||||
micro::KernelRunner runner(registration, tensors, tensor_count, inputs_array,
|
||||
outputs_array, nullptr, micro_test::reporter);
|
||||
|
||||
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.InitAndPrepare());
|
||||
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.Invoke());
|
||||
|
||||
for (int i = 0; i < output_dims_count; ++i) {
|
||||
TF_LITE_MICRO_EXPECT_EQ(expected_output[i], output_data[i]);
|
||||
}
|
||||
}
|
||||
|
||||
void TestShape(const int* input_dims_data, const float* input_data,
|
||||
const int* output_dims_data, const int* expected_output_data,
|
||||
int* output_data) {
|
||||
TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data);
|
||||
TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_data);
|
||||
const int output_dims_count = ElementCount(*output_dims);
|
||||
|
||||
constexpr int inputs_size = 1;
|
||||
constexpr int outputs_size = 1;
|
||||
constexpr int tensors_size = inputs_size + outputs_size;
|
||||
TfLiteTensor tensors[tensors_size] = {
|
||||
CreateFloatTensor(input_data, input_dims),
|
||||
CreateInt32Tensor(output_data, output_dims, true),
|
||||
};
|
||||
|
||||
ValidateShape(tensors, tensors_size, output_data, expected_output_data,
|
||||
output_dims_count);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace testing
|
||||
} // namespace tflite
|
||||
|
||||
TF_LITE_MICRO_TESTS_BEGIN
|
||||
|
||||
TF_LITE_MICRO_TEST(TestShape0) {
|
||||
int input_shape[] = {1, 5};
|
||||
float input_values[] = {1, 3, 1, 3, 5};
|
||||
int output_dims[] = {1, 1}; // this is actually input_shapes shape
|
||||
int expected_output_data[] = {5};
|
||||
int output_data[1];
|
||||
|
||||
tflite::testing::TestShape(input_shape, input_values, output_dims,
|
||||
expected_output_data, output_data);
|
||||
}
|
||||
|
||||
TF_LITE_MICRO_TEST(TestShape1) {
|
||||
int input_shape[] = {2, 4, 3};
|
||||
float input_values[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12};
|
||||
int output_dims[] = {2, 1, 1};
|
||||
int expected_output_data[] = {4, 3};
|
||||
int output_data[2];
|
||||
|
||||
tflite::testing::TestShape(input_shape, input_values, output_dims,
|
||||
expected_output_data, output_data);
|
||||
}
|
||||
|
||||
TF_LITE_MICRO_TEST(TestShape2) {
|
||||
int input_shape[] = {2, 12, 1};
|
||||
float input_values[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12};
|
||||
int output_dims[] = {2, 1, 1};
|
||||
int expected_output_data[] = {12, 1};
|
||||
int output_data[2];
|
||||
|
||||
tflite::testing::TestShape(input_shape, input_values, output_dims,
|
||||
expected_output_data, output_data);
|
||||
}
|
||||
|
||||
TF_LITE_MICRO_TEST(TestShape3) {
|
||||
int input_shape[] = {2, 2, 6};
|
||||
float input_values[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12};
|
||||
int output_dims[] = {2, 1, 1};
|
||||
int expected_output_data[] = {2, 6};
|
||||
int output_data[2];
|
||||
|
||||
tflite::testing::TestShape(input_shape, input_values, output_dims,
|
||||
expected_output_data, output_data);
|
||||
}
|
||||
|
||||
TF_LITE_MICRO_TEST(TestShape4) {
|
||||
int input_shape[] = {2, 2, 2, 3};
|
||||
float input_values[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12};
|
||||
int output_dims[] = {3, 1, 1, 1};
|
||||
int expected_output_data[] = {2, 2, 3};
|
||||
int output_data[3];
|
||||
|
||||
tflite::testing::TestShape(input_shape, input_values, output_dims,
|
||||
expected_output_data, output_data);
|
||||
}
|
||||
|
||||
TF_LITE_MICRO_TEST(TestShape5) {
|
||||
int input_shape[] = {1, 1};
|
||||
float input_values[] = {1};
|
||||
int output_dims[] = {1, 1};
|
||||
int expected_output_data[] = {1};
|
||||
int output_data[1];
|
||||
|
||||
tflite::testing::TestShape(input_shape, input_values, output_dims,
|
||||
expected_output_data, output_data);
|
||||
}
|
||||
|
||||
TF_LITE_MICRO_TESTS_END
|
@ -345,6 +345,11 @@ class MicroMutableOpResolver : public MicroOpResolver {
|
||||
tflite::ops::micro::Register_RSQRT(), ParseRsqrt);
|
||||
}
|
||||
|
||||
TfLiteStatus AddShape() {
|
||||
return AddBuiltin(BuiltinOperator_SHAPE,
|
||||
tflite::ops::micro::Register_SHAPE(), ParseShape);
|
||||
}
|
||||
|
||||
TfLiteStatus AddSin() {
|
||||
return AddBuiltin(BuiltinOperator_SIN, tflite::ops::micro::Register_SIN(),
|
||||
ParseSin);
|
||||
|
Loading…
Reference in New Issue
Block a user