Add arg_min and arg_max op for micro
PiperOrigin-RevId: 261230626
This commit is contained in:
parent
c52b412821
commit
8a142d90de
@ -14,6 +14,7 @@ package(
|
||||
cc_library(
|
||||
name = "micro_ops",
|
||||
srcs = [
|
||||
"arg_min_max.cc",
|
||||
"conv.cc",
|
||||
"depthwise_conv.cc",
|
||||
"elementwise.cc",
|
||||
@ -28,6 +29,7 @@ cc_library(
|
||||
copts = tflite_copts(),
|
||||
deps = [
|
||||
"//tensorflow/lite/c:c_api_internal",
|
||||
"//tensorflow/lite/experimental/micro/kernels:micro_utils",
|
||||
"//tensorflow/lite/kernels:kernel_util",
|
||||
"//tensorflow/lite/kernels:op_macros",
|
||||
"//tensorflow/lite/kernels:padding",
|
||||
@ -56,6 +58,7 @@ cc_library(
|
||||
cc_library(
|
||||
name = "portable_optimized_micro_ops",
|
||||
srcs = [
|
||||
"arg_min_max.cc",
|
||||
"conv.cc",
|
||||
"elementwise.cc",
|
||||
"floor.cc",
|
||||
@ -70,6 +73,7 @@ cc_library(
|
||||
copts = tflite_copts(),
|
||||
deps = [
|
||||
"//tensorflow/lite/c:c_api_internal",
|
||||
"//tensorflow/lite/experimental/micro/kernels:micro_utils",
|
||||
"//tensorflow/lite/kernels:kernel_util",
|
||||
"//tensorflow/lite/kernels:op_macros",
|
||||
"//tensorflow/lite/kernels:padding",
|
||||
@ -209,3 +213,22 @@ tflite_micro_cc_test(
|
||||
"//tensorflow/lite/experimental/micro/testing:micro_test",
|
||||
],
|
||||
)
|
||||
|
||||
tflite_micro_cc_test(
|
||||
name = "arg_min_max_test",
|
||||
srcs = [
|
||||
"arg_min_max_test.cc",
|
||||
],
|
||||
deps = [
|
||||
":all_ops_resolver",
|
||||
"//tensorflow/lite/c:c_api_internal",
|
||||
"//tensorflow/lite/experimental/micro:micro_framework",
|
||||
"//tensorflow/lite/experimental/micro/kernels:micro_utils",
|
||||
"//tensorflow/lite/experimental/micro/testing:micro_test",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "micro_utils",
|
||||
hdrs = ["micro_utils.h"],
|
||||
)
|
||||
|
@ -25,6 +25,8 @@ TfLiteRegistration* Register_MAX_POOL_2D();
|
||||
TfLiteRegistration* Register_ABS();
|
||||
TfLiteRegistration* Register_PRELU();
|
||||
TfLiteRegistration* Register_FLOOR();
|
||||
TfLiteRegistration* Register_ARG_MAX();
|
||||
TfLiteRegistration* Register_ARG_MIN();
|
||||
|
||||
AllOpsResolver::AllOpsResolver() {
|
||||
AddBuiltin(BuiltinOperator_DEPTHWISE_CONV_2D, Register_DEPTHWISE_CONV_2D());
|
||||
@ -38,6 +40,8 @@ AllOpsResolver::AllOpsResolver() {
|
||||
AddBuiltin(BuiltinOperator_ABS, Register_ABS());
|
||||
AddBuiltin(BuiltinOperator_PRELU, Register_PRELU());
|
||||
AddBuiltin(BuiltinOperator_FLOOR, Register_FLOOR());
|
||||
AddBuiltin(BuiltinOperator_ARG_MAX, Register_ARG_MAX());
|
||||
AddBuiltin(BuiltinOperator_ARG_MIN, Register_ARG_MIN());
|
||||
}
|
||||
|
||||
} // namespace micro
|
||||
|
117
tensorflow/lite/experimental/micro/kernels/arg_min_max.cc
Normal file
117
tensorflow/lite/experimental/micro/kernels/arg_min_max.cc
Normal file
@ -0,0 +1,117 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/lite/kernels/internal/reference/arg_min_max.h"
|
||||
|
||||
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||
#include "tensorflow/lite/c/c_api_internal.h"
|
||||
#include "tensorflow/lite/experimental/micro/kernels/micro_utils.h"
|
||||
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
|
||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace ops {
|
||||
namespace micro {
|
||||
namespace arg_min_max {
|
||||
|
||||
constexpr int kInputTensor = 0;
|
||||
constexpr int kAxis = 1;
|
||||
constexpr int kOutputTensor = 0;
|
||||
|
||||
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
template <typename T1, typename T2, typename T3>
|
||||
inline void ArgMinMaxHelper(const RuntimeShape& input1_shape,
|
||||
const T1* input1_data, const T3* input2_data,
|
||||
const RuntimeShape& output_shape, T2* output_data,
|
||||
bool is_arg_max) {
|
||||
if (is_arg_max) {
|
||||
reference_ops::ArgMinMax(input1_shape, input1_data, input2_data,
|
||||
output_shape, output_data, micro::Greater());
|
||||
} else {
|
||||
reference_ops::ArgMinMax(input1_shape, input1_data, input2_data,
|
||||
output_shape, output_data, micro::Less());
|
||||
}
|
||||
}
|
||||
|
||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node, bool is_arg_max) {
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
const TfLiteTensor* axis = GetInput(context, node, kAxis);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
|
||||
#define TF_LITE_ARG_MIN_MAX(data_type, axis_type, output_type) \
|
||||
ArgMinMaxHelper(GetTensorShape(input), GetTensorData<data_type>(input), \
|
||||
GetTensorData<axis_type>(axis), GetTensorShape(output), \
|
||||
GetTensorData<output_type>(output), is_arg_max)
|
||||
if (axis->type == kTfLiteInt32) {
|
||||
if (output->type == kTfLiteInt32) {
|
||||
switch (input->type) {
|
||||
case kTfLiteFloat32:
|
||||
TF_LITE_ARG_MIN_MAX(float, int32_t, int32_t);
|
||||
break;
|
||||
case kTfLiteUInt8:
|
||||
TF_LITE_ARG_MIN_MAX(uint8_t, int32_t, int32_t);
|
||||
break;
|
||||
default:
|
||||
context->ReportError(context,
|
||||
"Only float32, uint8 are "
|
||||
"supported currently, got %s.",
|
||||
TfLiteTypeGetName(input->type));
|
||||
return kTfLiteError;
|
||||
}
|
||||
} else {
|
||||
context->ReportError(context,
|
||||
"Only int32 are supported currently, got %s.",
|
||||
TfLiteTypeGetName(output->type));
|
||||
return kTfLiteError;
|
||||
}
|
||||
} else {
|
||||
context->ReportError(context, "Only int32 are supported currently, got %s.",
|
||||
TfLiteTypeGetName(axis->type));
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
#undef TF_LITE_ARG_MIN_MAX
|
||||
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
TfLiteStatus ArgMinEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
return Eval(context, node, false);
|
||||
}
|
||||
|
||||
TfLiteStatus ArgMaxEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
return Eval(context, node, true);
|
||||
}
|
||||
|
||||
} // namespace arg_min_max
|
||||
|
||||
TfLiteRegistration* Register_ARG_MAX() {
|
||||
static TfLiteRegistration r = {nullptr, nullptr, arg_min_max::Prepare,
|
||||
arg_min_max::ArgMaxEval};
|
||||
return &r;
|
||||
}
|
||||
|
||||
TfLiteRegistration* Register_ARG_MIN() {
|
||||
static TfLiteRegistration r = {nullptr, nullptr, arg_min_max::Prepare,
|
||||
arg_min_max::ArgMinEval};
|
||||
return &r;
|
||||
}
|
||||
|
||||
} // namespace micro
|
||||
} // namespace ops
|
||||
} // namespace tflite
|
388
tensorflow/lite/experimental/micro/kernels/arg_min_max_test.cc
Normal file
388
tensorflow/lite/experimental/micro/kernels/arg_min_max_test.cc
Normal file
@ -0,0 +1,388 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||
#include "tensorflow/lite/c/c_api_internal.h"
|
||||
#include "tensorflow/lite/experimental/micro/kernels/all_ops_resolver.h"
|
||||
#include "tensorflow/lite/experimental/micro/simple_tensor_allocator.h"
|
||||
#include "tensorflow/lite/experimental/micro/testing/micro_test.h"
|
||||
#include "tensorflow/lite/experimental/micro/testing/test_utils.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace testing {
|
||||
namespace {
|
||||
|
||||
#define TFLMICRO_CREATE_TENSOR(type_name, type_t, tftype, field) \
|
||||
inline TfLiteTensor Create##type_name##Tensor( \
|
||||
const type_t* data, TfLiteIntArray* dims, const char* name) { \
|
||||
TfLiteTensor result; \
|
||||
result.type = tftype; \
|
||||
result.data.field = const_cast<type_t*>(data); \
|
||||
result.dims = dims; \
|
||||
result.params = {}; \
|
||||
result.allocation_type = kTfLiteMemNone; \
|
||||
result.bytes = ElementCount(*dims) * sizeof(type_t); \
|
||||
result.allocation = nullptr; \
|
||||
result.name = name; \
|
||||
return result; \
|
||||
} \
|
||||
inline TfLiteTensor Create##type_name##Tensor( \
|
||||
std::initializer_list<type_t> data, TfLiteIntArray* dims, \
|
||||
const char* name) { \
|
||||
return Create##type_name##Tensor(data.begin(), dims, name); \
|
||||
}
|
||||
|
||||
TFLMICRO_CREATE_TENSOR(Int32, int32_t, kTfLiteInt32, i32)
|
||||
TFLMICRO_CREATE_TENSOR(Int64, int64_t, kTfLiteInt64, i64)
|
||||
|
||||
#undef TFLMICRO_CREATE_TENSOR
|
||||
|
||||
// If expected output is empty, the test is expected to fail.
|
||||
void TestArgMinMax(TfLiteTensor* input_tensor, TfLiteTensor* axis_tensor,
|
||||
TfLiteTensor* output_tensor,
|
||||
std::initializer_list<int> expected_output_data,
|
||||
bool using_min = false) {
|
||||
const int output_dims_count = ElementCount(*output_tensor->dims);
|
||||
constexpr int inputs_size = 2;
|
||||
constexpr int outputs_size = 1;
|
||||
constexpr int tensors_size = inputs_size + outputs_size;
|
||||
TfLiteTensor tensors[tensors_size] = {
|
||||
*input_tensor,
|
||||
*axis_tensor,
|
||||
*output_tensor,
|
||||
};
|
||||
TfLiteContext context;
|
||||
PopulateContext(tensors, tensors_size, &context);
|
||||
::tflite::ops::micro::AllOpsResolver resolver;
|
||||
const TfLiteRegistration* registration;
|
||||
if (using_min) {
|
||||
registration = resolver.FindOp(tflite::BuiltinOperator_ARG_MIN, 1);
|
||||
} else {
|
||||
registration = resolver.FindOp(tflite::BuiltinOperator_ARG_MAX, 1);
|
||||
}
|
||||
TF_LITE_MICRO_EXPECT_NE(nullptr, registration);
|
||||
|
||||
size_t init_data_size = 0;
|
||||
void* user_data = nullptr;
|
||||
if (registration->init) {
|
||||
user_data = registration->init(&context, nullptr, init_data_size);
|
||||
}
|
||||
int inputs_array_data[] = {2, 0, 1};
|
||||
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
|
||||
int outputs_array_data[] = {1, 2};
|
||||
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
|
||||
TfLiteIntArray* temporaries_array = IntArrayFromInitializer({0});
|
||||
TfLiteNode node;
|
||||
node.inputs = inputs_array;
|
||||
node.outputs = outputs_array;
|
||||
node.temporaries = temporaries_array;
|
||||
node.user_data = user_data;
|
||||
node.builtin_data = nullptr;
|
||||
node.custom_initial_data = nullptr;
|
||||
node.custom_initial_data_size = 0;
|
||||
node.delegate = nullptr;
|
||||
if (registration->prepare) {
|
||||
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
|
||||
}
|
||||
TF_LITE_MICRO_EXPECT_NE(nullptr, registration->invoke);
|
||||
if (!expected_output_data.size()) {
|
||||
TF_LITE_MICRO_EXPECT_EQ(kTfLiteError,
|
||||
registration->invoke(&context, &node));
|
||||
return;
|
||||
}
|
||||
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->invoke(&context, &node));
|
||||
if (registration->free) {
|
||||
registration->free(&context, user_data);
|
||||
}
|
||||
for (int i = 0; i < output_dims_count; ++i) {
|
||||
TF_LITE_MICRO_EXPECT_NEAR(expected_output_data.begin()[i],
|
||||
output_tensor->data.i32[i], 1e-5f);
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
} // namespace testing
|
||||
} // namespace tflite
|
||||
|
||||
TF_LITE_MICRO_TESTS_BEGIN
|
||||
|
||||
TF_LITE_MICRO_TEST(GetMaxArgFloat) {
|
||||
int32_t output_data[1];
|
||||
TfLiteIntArray* input_dims =
|
||||
tflite::testing::IntArrayFromInitializer({4, 1, 1, 1, 4});
|
||||
auto input_tensor = tflite::testing::CreateFloatTensor(
|
||||
{0.1, 0.9, 0.7, 0.3}, input_dims, "input_tensor");
|
||||
auto axis_tensor = tflite::testing::CreateInt32Tensor(
|
||||
{3}, tflite::testing::IntArrayFromInitializer({3, 1, 1, 1}),
|
||||
"axis_tensor");
|
||||
auto output_tensor = tflite::testing::CreateInt32Tensor(
|
||||
output_data, tflite::testing::IntArrayFromInitializer({3, 1, 1, 1}),
|
||||
"output_tensor");
|
||||
tflite::testing::TestArgMinMax(&input_tensor, &axis_tensor, &output_tensor,
|
||||
{1});
|
||||
}
|
||||
|
||||
TF_LITE_MICRO_TEST(GetMaxArgUInt8) {
|
||||
using tflite::testing::F2Q;
|
||||
int32_t output_data[1];
|
||||
float input_min = 0;
|
||||
float input_max = 15.9375;
|
||||
TfLiteIntArray* input_dims =
|
||||
tflite::testing::IntArrayFromInitializer({4, 1, 1, 1, 4});
|
||||
auto input_data = {
|
||||
F2Q(1., input_min, input_max), F2Q(9., input_min, input_max),
|
||||
F2Q(7., input_min, input_max), F2Q(3., input_min, input_max)};
|
||||
auto input_tensor = tflite::testing::CreateQuantizedTensor(
|
||||
input_data, input_dims, "input_tensor", input_min, input_max);
|
||||
auto axis_tensor = tflite::testing::CreateInt32Tensor(
|
||||
{3}, tflite::testing::IntArrayFromInitializer({3, 1, 1, 1}),
|
||||
"axis_tensor");
|
||||
auto output_tensor = tflite::testing::CreateInt32Tensor(
|
||||
output_data, tflite::testing::IntArrayFromInitializer({3, 1, 1, 1}),
|
||||
"output_tensor");
|
||||
tflite::testing::TestArgMinMax(&input_tensor, &axis_tensor, &output_tensor,
|
||||
{1});
|
||||
}
|
||||
|
||||
TF_LITE_MICRO_TEST(GetMaxArgInt32) {
|
||||
using tflite::testing::F2Q32;
|
||||
int32_t output_data[1];
|
||||
float input_min = 0;
|
||||
float input_max = 31.9375;
|
||||
TfLiteIntArray* input_dims =
|
||||
tflite::testing::IntArrayFromInitializer({4, 1, 1, 1, 4});
|
||||
auto input_data = {
|
||||
F2Q32(1, input_min, input_max), F2Q32(9, input_min, input_max),
|
||||
F2Q32(7, input_min, input_max), F2Q32(3, input_min, input_max)};
|
||||
auto input_tensor = tflite::testing::CreateQuantized32Tensor(
|
||||
input_data, input_dims, "input_tensor", input_min, input_max);
|
||||
auto axis_tensor = tflite::testing::CreateInt32Tensor(
|
||||
{3}, tflite::testing::IntArrayFromInitializer({3, 1, 1, 1}),
|
||||
"axis_tensor");
|
||||
auto output_tensor = tflite::testing::CreateInt32Tensor(
|
||||
output_data, tflite::testing::IntArrayFromInitializer({3, 1, 1, 1}),
|
||||
"output_tensor");
|
||||
tflite::testing::TestArgMinMax(&input_tensor, &axis_tensor, &output_tensor,
|
||||
{}); // Expects {1} if supported.
|
||||
}
|
||||
|
||||
TF_LITE_MICRO_TEST(GetMaxArgMulDimensions) {
|
||||
using tflite::testing::F2Q;
|
||||
int32_t output_data[2];
|
||||
float input_min = 0;
|
||||
float input_max = 15.9375;
|
||||
TfLiteIntArray* input_dims =
|
||||
tflite::testing::IntArrayFromInitializer({4, 1, 1, 2, 4});
|
||||
auto input_data = {
|
||||
F2Q(1, input_min, input_max), F2Q(2, input_min, input_max),
|
||||
F2Q(7, input_min, input_max), F2Q(8, input_min, input_max),
|
||||
F2Q(1, input_min, input_max), F2Q(9, input_min, input_max),
|
||||
F2Q(7, input_min, input_max), F2Q(3, input_min, input_max)};
|
||||
auto input_tensor = tflite::testing::CreateQuantizedTensor(
|
||||
input_data, input_dims, "input_tensor", input_min, input_max);
|
||||
auto axis_tensor = tflite::testing::CreateInt32Tensor(
|
||||
{3}, tflite::testing::IntArrayFromInitializer({3, 1, 1, 1}),
|
||||
"axis_tensor");
|
||||
auto output_tensor = tflite::testing::CreateInt32Tensor(
|
||||
output_data, tflite::testing::IntArrayFromInitializer({3, 1, 1, 2}),
|
||||
"output_tensor");
|
||||
tflite::testing::TestArgMinMax(&input_tensor, &axis_tensor, &output_tensor,
|
||||
{3, 1});
|
||||
}
|
||||
|
||||
TF_LITE_MICRO_TEST(GetMaxArgNegativeAxis) {
|
||||
using tflite::testing::F2Q;
|
||||
int32_t output_data[4];
|
||||
float input_min = 0;
|
||||
float input_max = 15.9375;
|
||||
TfLiteIntArray* input_dims =
|
||||
tflite::testing::IntArrayFromInitializer({4, 1, 1, 2, 4});
|
||||
auto input_data = {
|
||||
F2Q(1, input_min, input_max), F2Q(2, input_min, input_max),
|
||||
F2Q(7, input_min, input_max), F2Q(8, input_min, input_max),
|
||||
F2Q(1, input_min, input_max), F2Q(9, input_min, input_max),
|
||||
F2Q(7, input_min, input_max), F2Q(3, input_min, input_max)};
|
||||
auto input_tensor = tflite::testing::CreateQuantizedTensor(
|
||||
input_data, input_dims, "input_tensor", input_min, input_max);
|
||||
auto axis_tensor = tflite::testing::CreateInt32Tensor(
|
||||
{-2}, tflite::testing::IntArrayFromInitializer({3, 1, 1, 1}),
|
||||
"axis_tensor");
|
||||
auto output_tensor = tflite::testing::CreateInt32Tensor(
|
||||
output_data, tflite::testing::IntArrayFromInitializer({3, 1, 1, 4}),
|
||||
"output_tensor");
|
||||
tflite::testing::TestArgMinMax(&input_tensor, &axis_tensor, &output_tensor,
|
||||
{0, 1, 0, 0});
|
||||
}
|
||||
|
||||
TF_LITE_MICRO_TEST(GetMaxArgOutput64) {
|
||||
using tflite::testing::F2Q;
|
||||
int64_t output_data[2];
|
||||
float input_min = 0;
|
||||
float input_max = 15.9375;
|
||||
TfLiteIntArray* input_dims =
|
||||
tflite::testing::IntArrayFromInitializer({4, 1, 1, 2, 4});
|
||||
auto input_data = {
|
||||
F2Q(10, input_min, input_max), F2Q(2, input_min, input_max),
|
||||
F2Q(7, input_min, input_max), F2Q(8, input_min, input_max),
|
||||
F2Q(1, input_min, input_max), F2Q(9, input_min, input_max),
|
||||
F2Q(7, input_min, input_max), F2Q(3, input_min, input_max)};
|
||||
auto input_tensor = tflite::testing::CreateQuantizedTensor(
|
||||
input_data, input_dims, "input_tensor", input_min, input_max);
|
||||
auto axis_tensor = tflite::testing::CreateInt32Tensor(
|
||||
{3}, tflite::testing::IntArrayFromInitializer({3, 1, 1, 1}),
|
||||
"axis_tensor");
|
||||
auto output_tensor = tflite::testing::CreateInt64Tensor(
|
||||
output_data, tflite::testing::IntArrayFromInitializer({3, 1, 1, 2}),
|
||||
"output_tensor");
|
||||
tflite::testing::TestArgMinMax(&input_tensor, &axis_tensor, &output_tensor,
|
||||
{}); // Expects {0, 1} if supported.
|
||||
}
|
||||
|
||||
TF_LITE_MICRO_TEST(GetMaxArgAxis64) {
|
||||
using tflite::testing::F2Q;
|
||||
int32_t output_data[2];
|
||||
float input_min = 0;
|
||||
float input_max = 15.9375;
|
||||
TfLiteIntArray* input_dims =
|
||||
tflite::testing::IntArrayFromInitializer({4, 1, 1, 2, 4});
|
||||
auto input_data = {
|
||||
F2Q(10, input_min, input_max), F2Q(2, input_min, input_max),
|
||||
F2Q(7, input_min, input_max), F2Q(8, input_min, input_max),
|
||||
F2Q(1, input_min, input_max), F2Q(9, input_min, input_max),
|
||||
F2Q(7, input_min, input_max), F2Q(3, input_min, input_max)};
|
||||
auto input_tensor = tflite::testing::CreateQuantizedTensor(
|
||||
input_data, input_dims, "input_tensor", input_min, input_max);
|
||||
auto axis_tensor = tflite::testing::CreateInt64Tensor(
|
||||
{3}, tflite::testing::IntArrayFromInitializer({3, 1, 1, 1}),
|
||||
"axis_tensor");
|
||||
auto output_tensor = tflite::testing::CreateInt32Tensor(
|
||||
output_data, tflite::testing::IntArrayFromInitializer({3, 1, 1, 2}),
|
||||
"output_tensor");
|
||||
tflite::testing::TestArgMinMax(&input_tensor, &axis_tensor, &output_tensor,
|
||||
{}); // Expects {0, 1} if supported.
|
||||
}
|
||||
|
||||
TF_LITE_MICRO_TEST(GetMinArgFloat) {
|
||||
int32_t output_data[1];
|
||||
TfLiteIntArray* input_dims =
|
||||
tflite::testing::IntArrayFromInitializer({4, 1, 1, 1, 4});
|
||||
auto input_tensor = tflite::testing::CreateFloatTensor(
|
||||
{0.1, 0.9, 0.7, 0.3}, input_dims, "input_tensor");
|
||||
auto axis_tensor = tflite::testing::CreateInt32Tensor(
|
||||
{3}, tflite::testing::IntArrayFromInitializer({3, 1, 1, 1}),
|
||||
"axis_tensor");
|
||||
auto output_tensor = tflite::testing::CreateInt32Tensor(
|
||||
output_data, tflite::testing::IntArrayFromInitializer({3, 1, 1, 1}),
|
||||
"output_tensor");
|
||||
tflite::testing::TestArgMinMax(&input_tensor, &axis_tensor, &output_tensor,
|
||||
{0}, true);
|
||||
}
|
||||
|
||||
TF_LITE_MICRO_TEST(GetMinArgUInt8) {
|
||||
using tflite::testing::F2Q;
|
||||
float input_min = 0;
|
||||
float input_max = 15.9375;
|
||||
int32_t output_data[1];
|
||||
TfLiteIntArray* input_dims =
|
||||
tflite::testing::IntArrayFromInitializer({4, 1, 1, 1, 4});
|
||||
// Getting weird error when defining input_data directly in
|
||||
// CreateQuantizedTensor. So I have to define it ahead.
|
||||
auto input_data = {
|
||||
F2Q(1.0, input_min, input_max), F2Q(9.0, input_min, input_max),
|
||||
F2Q(7.0, input_min, input_max), F2Q(3.0, input_min, input_max)};
|
||||
auto input_tensor = tflite::testing::CreateQuantizedTensor(
|
||||
input_data, input_dims, "input_tensor", input_min, input_max);
|
||||
auto axis_tensor = tflite::testing::CreateInt32Tensor(
|
||||
{3}, tflite::testing::IntArrayFromInitializer({3, 1, 1, 1}),
|
||||
"axis_tensor");
|
||||
auto output_tensor = tflite::testing::CreateInt32Tensor(
|
||||
output_data, tflite::testing::IntArrayFromInitializer({3, 1, 1, 1}),
|
||||
"output_tensor");
|
||||
tflite::testing::TestArgMinMax(&input_tensor, &axis_tensor, &output_tensor,
|
||||
{0}, true);
|
||||
}
|
||||
|
||||
TF_LITE_MICRO_TEST(GetMinArgMulDimensions) {
|
||||
using tflite::testing::F2Q;
|
||||
float input_min = 0;
|
||||
float input_max = 15.9375;
|
||||
int32_t output_data[1];
|
||||
TfLiteIntArray* input_dims =
|
||||
tflite::testing::IntArrayFromInitializer({4, 1, 1, 2, 4});
|
||||
auto input_data = {
|
||||
F2Q(1, input_min, input_max), F2Q(2, input_min, input_max),
|
||||
F2Q(7, input_min, input_max), F2Q(8, input_min, input_max),
|
||||
F2Q(1, input_min, input_max), F2Q(9, input_min, input_max),
|
||||
F2Q(7, input_min, input_max), F2Q(3, input_min, input_max)};
|
||||
auto input_tensor = tflite::testing::CreateQuantizedTensor(
|
||||
input_data, input_dims, "input_tensor", input_min, input_max);
|
||||
auto axis_tensor = tflite::testing::CreateInt32Tensor(
|
||||
{3}, tflite::testing::IntArrayFromInitializer({3, 1, 1, 1}),
|
||||
"axis_tensor");
|
||||
auto output_tensor = tflite::testing::CreateInt32Tensor(
|
||||
output_data, tflite::testing::IntArrayFromInitializer({3, 1, 1, 2}),
|
||||
"output_tensor");
|
||||
tflite::testing::TestArgMinMax(&input_tensor, &axis_tensor, &output_tensor,
|
||||
{0, 0}, true);
|
||||
}
|
||||
|
||||
TF_LITE_MICRO_TEST(GetMinArgOutput64) {
|
||||
using tflite::testing::F2Q;
|
||||
float input_min = 0;
|
||||
float input_max = 15.9375;
|
||||
int64_t output_data[1];
|
||||
TfLiteIntArray* input_dims =
|
||||
tflite::testing::IntArrayFromInitializer({4, 1, 1, 2, 4});
|
||||
auto input_data = {
|
||||
F2Q(10, input_min, input_max), F2Q(2, input_min, input_max),
|
||||
F2Q(7, input_min, input_max), F2Q(8, input_min, input_max),
|
||||
F2Q(1, input_min, input_max), F2Q(9, input_min, input_max),
|
||||
F2Q(7, input_min, input_max), F2Q(3, input_min, input_max)};
|
||||
auto input_tensor = tflite::testing::CreateQuantizedTensor(
|
||||
input_data, input_dims, "input_tensor", input_min, input_max);
|
||||
auto axis_tensor = tflite::testing::CreateInt32Tensor(
|
||||
{3}, tflite::testing::IntArrayFromInitializer({3, 1, 1, 1}),
|
||||
"axis_tensor");
|
||||
auto output_tensor = tflite::testing::CreateInt64Tensor(
|
||||
output_data, tflite::testing::IntArrayFromInitializer({3, 1, 1, 2}),
|
||||
"output_tensor");
|
||||
tflite::testing::TestArgMinMax(&input_tensor, &axis_tensor, &output_tensor,
|
||||
{}, true); // Expects {1, 0} if supported.
|
||||
}
|
||||
|
||||
TF_LITE_MICRO_TEST(GetMinArgAxis64) {
|
||||
using tflite::testing::F2Q;
|
||||
float input_min = 0;
|
||||
float input_max = 15.9375;
|
||||
int32_t output_data[1];
|
||||
TfLiteIntArray* input_dims =
|
||||
tflite::testing::IntArrayFromInitializer({4, 1, 1, 2, 4});
|
||||
auto input_data = {
|
||||
F2Q(10, input_min, input_max), F2Q(2, input_min, input_max),
|
||||
F2Q(7, input_min, input_max), F2Q(8, input_min, input_max),
|
||||
F2Q(1, input_min, input_max), F2Q(9, input_min, input_max),
|
||||
F2Q(7, input_min, input_max), F2Q(3, input_min, input_max)};
|
||||
auto input_tensor = tflite::testing::CreateQuantizedTensor(
|
||||
input_data, input_dims, "input_tensor", input_min, input_max);
|
||||
auto axis_tensor = tflite::testing::CreateInt64Tensor(
|
||||
{3}, tflite::testing::IntArrayFromInitializer({3, 1, 1, 1}),
|
||||
"axis_tensor");
|
||||
auto output_tensor = tflite::testing::CreateInt32Tensor(
|
||||
output_data, tflite::testing::IntArrayFromInitializer({3, 1, 1, 2}),
|
||||
"output_tensor");
|
||||
tflite::testing::TestArgMinMax(&input_tensor, &axis_tensor, &output_tensor,
|
||||
{}, true); // Expects {1, 0} if supported
|
||||
}
|
||||
|
||||
TF_LITE_MICRO_TESTS_END
|
37
tensorflow/lite/experimental/micro/kernels/micro_utils.h
Normal file
37
tensorflow/lite/experimental/micro/kernels/micro_utils.h
Normal file
@ -0,0 +1,37 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_LITE_EXPERIMENTAL_MICRO_KERNELS_MICRO_UTILS_H_
|
||||
#define TENSORFLOW_LITE_EXPERIMENTAL_MICRO_KERNELS_MICRO_UTILS_H_
|
||||
namespace tflite {
|
||||
namespace ops {
|
||||
namespace micro {
|
||||
|
||||
// Same as gtl::Greater but defined here to reduce dependencies and
|
||||
// binary size for micro environment.
|
||||
struct Greater {
|
||||
template <typename T>
|
||||
bool operator()(const T& x, const T& y) const {
|
||||
return x > y;
|
||||
}
|
||||
};
|
||||
|
||||
struct Less {
|
||||
template <typename T>
|
||||
bool operator()(const T& x, const T& y) const {
|
||||
return x < y;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace micro
|
||||
} // namespace ops
|
||||
} // namespace tflite
|
||||
#endif // TENSORFLOW_LITE_EXPERIMENTAL_MICRO_KERNELS_MICRO_UTILS_H_
|
@ -115,6 +115,7 @@ tensorflow/lite/kernels/internal/reference/fully_connected.h \
|
||||
tensorflow/lite/kernels/internal/reference/pooling.h \
|
||||
tensorflow/lite/kernels/internal/reference/prelu.h \
|
||||
tensorflow/lite/kernels/internal/reference/softmax.h \
|
||||
tensorflow/lite/kernels/internal/reference/arg_min_max.h \
|
||||
tensorflow/lite/kernels/internal/round.h \
|
||||
tensorflow/lite/kernels/internal/tensor_ctypes.h \
|
||||
tensorflow/lite/kernels/internal/types.h \
|
||||
|
@ -347,6 +347,7 @@ cc_library(
|
||||
name = "reference_base",
|
||||
srcs = [],
|
||||
hdrs = [
|
||||
"reference/arg_min_max.h",
|
||||
"reference/conv.h",
|
||||
"reference/depthwiseconv_float.h",
|
||||
"reference/depthwiseconv_uint8.h",
|
||||
@ -401,6 +402,7 @@ cc_library(
|
||||
name = "legacy_reference_base",
|
||||
srcs = [],
|
||||
hdrs = [
|
||||
"reference/arg_min_max.h",
|
||||
"reference/conv.h",
|
||||
"reference/depthwiseconv_float.h",
|
||||
"reference/depthwiseconv_uint8.h",
|
||||
|
68
tensorflow/lite/kernels/internal/reference/arg_min_max.h
Normal file
68
tensorflow/lite/kernels/internal/reference/arg_min_max.h
Normal file
@ -0,0 +1,68 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_ARG_MIN_MAX_H_
|
||||
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_ARG_MIN_MAX_H_
|
||||
|
||||
#include "tensorflow/lite/kernels/internal/types.h"
|
||||
|
||||
namespace tflite {
|
||||
|
||||
namespace reference_ops {
|
||||
|
||||
template <typename T1, typename T2, typename T3, typename Cmp>
|
||||
void ArgMinMax(const RuntimeShape& input1_shape, const T1* input1_data,
|
||||
const T3* input2_data, const RuntimeShape& output_shape,
|
||||
T2* output_data, const Cmp& cmp) {
|
||||
TFLITE_DCHECK_GT(input1_shape.DimensionsCount(), 0);
|
||||
TFLITE_DCHECK_EQ(input1_shape.DimensionsCount() - 1,
|
||||
output_shape.DimensionsCount());
|
||||
int axis = input2_data[0];
|
||||
if (axis < 0) {
|
||||
axis += input1_shape.DimensionsCount();
|
||||
}
|
||||
const int axis_size = input1_shape.Dims(axis);
|
||||
|
||||
int outer_size = 1;
|
||||
for (int i = 0; i < axis; ++i) {
|
||||
TFLITE_DCHECK_EQ(input1_shape.Dims(i), output_shape.Dims(i));
|
||||
outer_size *= input1_shape.Dims(i);
|
||||
}
|
||||
|
||||
int inner_size = 1;
|
||||
const int dims_count = input1_shape.DimensionsCount();
|
||||
for (int i = axis + 1; i < dims_count; ++i) {
|
||||
TFLITE_DCHECK_EQ(input1_shape.Dims(i), output_shape.Dims(i - 1));
|
||||
inner_size *= input1_shape.Dims(i);
|
||||
}
|
||||
for (int outer = 0; outer < outer_size; ++outer) {
|
||||
for (int inner = 0; inner < inner_size; ++inner) {
|
||||
auto min_max_value = input1_data[outer * axis_size * inner_size + inner];
|
||||
T2 min_max_index = 0;
|
||||
for (int i = 1; i < axis_size; ++i) {
|
||||
const auto& curr_value =
|
||||
input1_data[(outer * axis_size + i) * inner_size + inner];
|
||||
if (cmp(curr_value, min_max_value)) {
|
||||
min_max_value = curr_value;
|
||||
min_max_index = static_cast<T2>(i);
|
||||
}
|
||||
}
|
||||
output_data[outer * inner_size + inner] = min_max_index;
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace reference_ops
|
||||
} // namespace tflite
|
||||
|
||||
#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_ARG_MIN_MAX_H_
|
@ -32,6 +32,7 @@ limitations under the License.
|
||||
#include "tensorflow/lite/c/c_api_internal.h"
|
||||
#include "tensorflow/lite/kernels/internal/common.h"
|
||||
#include "tensorflow/lite/kernels/internal/quantization_util.h"
|
||||
#include "tensorflow/lite/kernels/internal/reference/arg_min_max.h"
|
||||
#include "tensorflow/lite/kernels/internal/reference/conv.h"
|
||||
#include "tensorflow/lite/kernels/internal/reference/floor.h"
|
||||
#include "tensorflow/lite/kernels/internal/reference/fully_connected.h"
|
||||
@ -3572,52 +3573,6 @@ void MaximumMinimumBroadcast4DSlow(const RuntimeShape& unextended_input1_shape,
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T1, typename T2, typename T3, typename Cmp>
|
||||
void ArgMinMax(const RuntimeShape& input1_shape, const T1* input1_data,
|
||||
const T3* input2_data, const RuntimeShape& output_shape,
|
||||
T2* output_data, const Cmp& cmp) {
|
||||
gemmlowp::ScopedProfilingLabel label("ArgMinMax");
|
||||
TFLITE_DCHECK_GT(input1_shape.DimensionsCount(), 0);
|
||||
TFLITE_DCHECK_EQ(input1_shape.DimensionsCount() - 1,
|
||||
output_shape.DimensionsCount());
|
||||
|
||||
int axis = input2_data[0];
|
||||
if (axis < 0) {
|
||||
axis += input1_shape.DimensionsCount();
|
||||
}
|
||||
|
||||
const int axis_size = input1_shape.Dims(axis);
|
||||
|
||||
int outer_size = 1;
|
||||
for (int i = 0; i < axis; ++i) {
|
||||
TFLITE_DCHECK_EQ(input1_shape.Dims(i), output_shape.Dims(i));
|
||||
outer_size *= input1_shape.Dims(i);
|
||||
}
|
||||
|
||||
int inner_size = 1;
|
||||
const int dims_count = input1_shape.DimensionsCount();
|
||||
for (int i = axis + 1; i < dims_count; ++i) {
|
||||
TFLITE_DCHECK_EQ(input1_shape.Dims(i), output_shape.Dims(i - 1));
|
||||
inner_size *= input1_shape.Dims(i);
|
||||
}
|
||||
|
||||
for (int outer = 0; outer < outer_size; ++outer) {
|
||||
for (int inner = 0; inner < inner_size; ++inner) {
|
||||
auto min_max_value = input1_data[outer * axis_size * inner_size + inner];
|
||||
int min_max_index = 0;
|
||||
for (int i = 1; i < axis_size; ++i) {
|
||||
const auto& curr_value =
|
||||
input1_data[(outer * axis_size + i) * inner_size + inner];
|
||||
if (cmp(curr_value, min_max_value)) {
|
||||
min_max_value = curr_value;
|
||||
min_max_index = i;
|
||||
}
|
||||
}
|
||||
output_data[outer * inner_size + inner] = min_max_index;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T1, typename T2, typename T3>
|
||||
void ArgMax(const RuntimeShape& input1_shape, const T1* input1_data,
|
||||
const T3* input2_data, const RuntimeShape& output_shape,
|
||||
|
Loading…
x
Reference in New Issue
Block a user