Add LOGICAL_AND, LOGICAL_OR operator to TFL micro.
PiperOrigin-RevId: 261856585
This commit is contained in:
parent
cdfd0503f9
commit
e08fa9ed2b
@ -20,6 +20,7 @@ cc_library(
|
||||
"elementwise.cc",
|
||||
"floor.cc",
|
||||
"fully_connected.cc",
|
||||
"logical.cc",
|
||||
"maximum_minimum.cc",
|
||||
"pooling.cc",
|
||||
"prelu.cc",
|
||||
@ -64,6 +65,7 @@ cc_library(
|
||||
"elementwise.cc",
|
||||
"floor.cc",
|
||||
"fully_connected.cc",
|
||||
"logical.cc",
|
||||
"maximum_minimum.cc",
|
||||
"pooling.cc",
|
||||
"portable_optimized/depthwise_conv.cc",
|
||||
@ -216,6 +218,19 @@ tflite_micro_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
tflite_micro_cc_test(
|
||||
name = "logical_test",
|
||||
srcs = [
|
||||
"logical_test.cc",
|
||||
],
|
||||
deps = [
|
||||
":all_ops_resolver",
|
||||
"//tensorflow/lite/c:c_api_internal",
|
||||
"//tensorflow/lite/experimental/micro:micro_framework",
|
||||
"//tensorflow/lite/experimental/micro/testing:micro_test",
|
||||
],
|
||||
)
|
||||
|
||||
tflite_micro_cc_test(
|
||||
name = "maximum_minimum_test",
|
||||
srcs = [
|
||||
|
@ -29,6 +29,8 @@ TfLiteRegistration* Register_MAXIMUM();
|
||||
TfLiteRegistration* Register_MINIMUM();
|
||||
TfLiteRegistration* Register_ARG_MAX();
|
||||
TfLiteRegistration* Register_ARG_MIN();
|
||||
TfLiteRegistration* Register_LOGICAL_OR();
|
||||
TfLiteRegistration* Register_LOGICAL_AND();
|
||||
AllOpsResolver::AllOpsResolver() {
|
||||
AddBuiltin(BuiltinOperator_DEPTHWISE_CONV_2D, Register_DEPTHWISE_CONV_2D());
|
||||
AddBuiltin(BuiltinOperator_FULLY_CONNECTED, Register_FULLY_CONNECTED(),
|
||||
@ -45,6 +47,8 @@ AllOpsResolver::AllOpsResolver() {
|
||||
AddBuiltin(BuiltinOperator_MINIMUM, Register_MINIMUM());
|
||||
AddBuiltin(BuiltinOperator_ARG_MAX, Register_ARG_MAX());
|
||||
AddBuiltin(BuiltinOperator_ARG_MIN, Register_ARG_MIN());
|
||||
AddBuiltin(BuiltinOperator_LOGICAL_OR, Register_LOGICAL_OR());
|
||||
AddBuiltin(BuiltinOperator_LOGICAL_AND, Register_LOGICAL_AND());
|
||||
}
|
||||
|
||||
} // namespace micro
|
||||
|
87
tensorflow/lite/experimental/micro/kernels/logical.cc
Normal file
87
tensorflow/lite/experimental/micro/kernels/logical.cc
Normal file
@ -0,0 +1,87 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/lite/c/c_api_internal.h"
|
||||
#include "tensorflow/lite/kernels/internal/reference/binary_function.h"
|
||||
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
|
||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||
#include "tensorflow/lite/kernels/op_macros.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace ops {
|
||||
namespace micro {
|
||||
namespace logical {
|
||||
namespace {
|
||||
|
||||
// Input/output tensor index.
|
||||
constexpr int kInputTensor1 = 0;
|
||||
constexpr int kInputTensor2 = 1;
|
||||
constexpr int kOutputTensor = 0;
|
||||
|
||||
TfLiteStatus LogicalImpl(TfLiteContext* context, TfLiteNode* node,
|
||||
bool (*func)(bool, bool)) {
|
||||
const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
|
||||
const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
|
||||
if (HaveSameShapes(input1, input2)) {
|
||||
reference_ops::BinaryFunction<bool, bool, bool>(
|
||||
GetTensorShape(input1), GetTensorData<bool>(input1),
|
||||
GetTensorShape(input2), GetTensorData<bool>(input2),
|
||||
GetTensorShape(output), GetTensorData<bool>(output), func);
|
||||
} else {
|
||||
reference_ops::BroadcastBinaryFunction4DSlow<bool, bool, bool>(
|
||||
GetTensorShape(input1), GetTensorData<bool>(input1),
|
||||
GetTensorShape(input2), GetTensorData<bool>(input2),
|
||||
GetTensorShape(output), GetTensorData<bool>(output), func);
|
||||
}
|
||||
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
bool LogicalOr(bool x, bool y) { return x || y; }
|
||||
|
||||
TfLiteStatus LogicalOrEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
return LogicalImpl(context, node, LogicalOr);
|
||||
}
|
||||
|
||||
bool LogicalAnd(bool x, bool y) { return x && y; }
|
||||
|
||||
TfLiteStatus LogicalAndEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
return LogicalImpl(context, node, LogicalAnd);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace logical
|
||||
|
||||
TfLiteRegistration* Register_LOGICAL_OR() {
|
||||
// Init, Free, Prepare, Eval are satisfying the Interface required by
|
||||
// TfLiteRegistration.
|
||||
static TfLiteRegistration r = {/* init */ nullptr, /* free */ nullptr,
|
||||
/* prepare */ nullptr, logical::LogicalOrEval};
|
||||
return &r;
|
||||
}
|
||||
|
||||
TfLiteRegistration* Register_LOGICAL_AND() {
|
||||
// Init, Free, Prepare, Eval are satisfying the Interface required by
|
||||
// TfLiteRegistration.
|
||||
static TfLiteRegistration r = {/* init */ nullptr, /* free */ nullptr,
|
||||
/* prepare */ nullptr,
|
||||
logical::LogicalAndEval};
|
||||
return &r;
|
||||
}
|
||||
|
||||
} // namespace micro
|
||||
} // namespace ops
|
||||
} // namespace tflite
|
147
tensorflow/lite/experimental/micro/kernels/logical_test.cc
Normal file
147
tensorflow/lite/experimental/micro/kernels/logical_test.cc
Normal file
@ -0,0 +1,147 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||
#include "tensorflow/lite/c/c_api_internal.h"
|
||||
#include "tensorflow/lite/experimental/micro/kernels/all_ops_resolver.h"
|
||||
#include "tensorflow/lite/experimental/micro/simple_tensor_allocator.h"
|
||||
#include "tensorflow/lite/experimental/micro/testing/micro_test.h"
|
||||
#include "tensorflow/lite/experimental/micro/testing/test_utils.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace testing {
|
||||
namespace {
|
||||
|
||||
inline TfLiteTensor CreateBoolTensor(const bool* data, TfLiteIntArray* dims,
|
||||
const char* name) {
|
||||
TfLiteTensor result;
|
||||
result.type = kTfLiteBool;
|
||||
result.data.b = const_cast<bool*>(data);
|
||||
result.dims = dims;
|
||||
result.params = {};
|
||||
result.allocation_type = kTfLiteMemNone;
|
||||
result.bytes = ElementCount(*dims) * sizeof(bool);
|
||||
result.allocation = nullptr;
|
||||
result.name = name;
|
||||
return result;
|
||||
}
|
||||
|
||||
inline TfLiteTensor CreateBoolTensor(std::initializer_list<bool> data,
|
||||
TfLiteIntArray* dims, const char* name) {
|
||||
return CreateBoolTensor(data.begin(), dims, name);
|
||||
}
|
||||
|
||||
void TestLogicalOp(tflite::BuiltinOperator op,
|
||||
std::initializer_list<int> input1_dims_data,
|
||||
std::initializer_list<bool> input1_data,
|
||||
std::initializer_list<int> input2_dims_data,
|
||||
std::initializer_list<bool> input2_data,
|
||||
std::initializer_list<int> output_dims_data,
|
||||
std::initializer_list<bool> expected_output_data,
|
||||
bool* output_data) {
|
||||
TfLiteIntArray* input1_dims = IntArrayFromInitializer(input1_dims_data);
|
||||
TfLiteIntArray* input2_dims = IntArrayFromInitializer(input2_dims_data);
|
||||
TfLiteIntArray* output_dims = IntArrayFromInitializer(output_dims_data);
|
||||
const int output_dims_count = ElementCount(*output_dims);
|
||||
|
||||
constexpr int inputs_size = 2;
|
||||
constexpr int outputs_size = 1;
|
||||
constexpr int tensors_size = inputs_size + outputs_size;
|
||||
TfLiteTensor tensors[tensors_size] = {
|
||||
CreateBoolTensor(input1_data, input1_dims, "input1_tensor"),
|
||||
CreateBoolTensor(input2_data, input2_dims, "input2_tensor"),
|
||||
CreateBoolTensor(output_data, output_dims, "output_tensor"),
|
||||
};
|
||||
|
||||
TfLiteContext context;
|
||||
PopulateContext(tensors, tensors_size, &context);
|
||||
|
||||
::tflite::ops::micro::AllOpsResolver resolver;
|
||||
const TfLiteRegistration* registration = resolver.FindOp(op, 1);
|
||||
TF_LITE_MICRO_EXPECT_NE(nullptr, registration);
|
||||
|
||||
TfLiteIntArray* inputs_array = IntArrayFromInitializer({2, 0, 1});
|
||||
TfLiteIntArray* outputs_array = IntArrayFromInitializer({1, 2});
|
||||
TfLiteIntArray* temporaries_array = IntArrayFromInitializer({0});
|
||||
|
||||
TfLiteNode node;
|
||||
node.inputs = inputs_array;
|
||||
node.outputs = outputs_array;
|
||||
node.temporaries = temporaries_array;
|
||||
node.user_data = nullptr;
|
||||
node.builtin_data = nullptr;
|
||||
node.custom_initial_data = nullptr;
|
||||
node.custom_initial_data_size = 0;
|
||||
node.delegate = nullptr;
|
||||
|
||||
if (registration->prepare) {
|
||||
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
|
||||
}
|
||||
|
||||
TF_LITE_MICRO_EXPECT_NE(nullptr, registration->invoke);
|
||||
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->invoke(&context, &node));
|
||||
|
||||
TF_LITE_MICRO_EXPECT_EQ(output_dims_count, 4);
|
||||
for (int i = 0; i < output_dims_count; ++i) {
|
||||
TF_LITE_MICRO_EXPECT_EQ(expected_output_data.begin()[i], output_data[i]);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace testing
|
||||
} // namespace tflite
|
||||
|
||||
TF_LITE_MICRO_TESTS_BEGIN
|
||||
|
||||
TF_LITE_MICRO_TEST(LogicalOr) {
|
||||
bool output_data[4];
|
||||
tflite::testing::TestLogicalOp(
|
||||
tflite::BuiltinOperator_LOGICAL_OR, // operator
|
||||
{4, 1, 1, 1, 4}, {true, false, false, true}, // input1
|
||||
{4, 1, 1, 1, 4}, {true, false, true, false}, // input2
|
||||
{4, 1, 1, 1, 4}, {true, false, true, true}, // expected output
|
||||
output_data);
|
||||
}
|
||||
|
||||
TF_LITE_MICRO_TEST(BroadcastLogicalOr) {
|
||||
bool output_data[4];
|
||||
tflite::testing::TestLogicalOp(
|
||||
tflite::BuiltinOperator_LOGICAL_OR, // operator
|
||||
{4, 1, 1, 1, 4}, {true, false, false, true}, // input1
|
||||
{4, 1, 1, 1, 1}, {false}, // input2
|
||||
{4, 1, 1, 1, 4}, {true, false, false, true}, // expected output
|
||||
output_data);
|
||||
}
|
||||
|
||||
TF_LITE_MICRO_TEST(LogicalAnd) {
|
||||
bool output_data[4];
|
||||
tflite::testing::TestLogicalOp(
|
||||
tflite::BuiltinOperator_LOGICAL_AND, // operator
|
||||
{4, 1, 1, 1, 4}, {true, false, false, true}, // input1
|
||||
{4, 1, 1, 1, 4}, {true, false, true, false}, // input2
|
||||
{4, 1, 1, 1, 4}, {true, false, false, false}, // expected output
|
||||
output_data);
|
||||
}
|
||||
|
||||
TF_LITE_MICRO_TEST(BroadcastLogicalAnd) {
|
||||
bool output_data[4];
|
||||
tflite::testing::TestLogicalOp(
|
||||
tflite::BuiltinOperator_LOGICAL_AND, // operator
|
||||
{4, 1, 1, 1, 4}, {true, false, false, true}, // input1
|
||||
{4, 1, 1, 1, 1}, {true}, // input2
|
||||
{4, 1, 1, 1, 4}, {true, false, false, true}, // expected output
|
||||
output_data);
|
||||
}
|
||||
|
||||
TF_LITE_MICRO_TESTS_END
|
@ -107,6 +107,7 @@ tensorflow/lite/kernels/padding.h \
|
||||
tensorflow/lite/kernels/internal/common.h \
|
||||
tensorflow/lite/kernels/internal/compatibility.h \
|
||||
tensorflow/lite/kernels/internal/optimized/neon_check.h \
|
||||
tensorflow/lite/kernels/internal/reference/binary_function.h \
|
||||
tensorflow/lite/kernels/internal/reference/conv.h \
|
||||
tensorflow/lite/kernels/internal/reference/depthwiseconv_float.h \
|
||||
tensorflow/lite/kernels/internal/reference/depthwiseconv_uint8.h \
|
||||
|
@ -17,7 +17,6 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/lite/kernels/internal/common.h"
|
||||
#include "tensorflow/lite/kernels/internal/compatibility.h"
|
||||
#include "tensorflow/lite/kernels/internal/reference/comparisons.h"
|
||||
#include "tensorflow/lite/kernels/internal/types.h"
|
||||
|
||||
namespace tflite {
|
||||
|
Loading…
Reference in New Issue
Block a user