Port the logical kernel to the TfLiteEvalTensor API.
PiperOrigin-RevId: 323042504 Change-Id: I0b7814b588c9ca8b26dbc2eba28035a4226475e3
This commit is contained in:
parent
3098c7a84d
commit
23cb51dfb6
@ -265,6 +265,7 @@ tflite_micro_cc_test(
|
||||
"logical_test.cc",
|
||||
],
|
||||
deps = [
|
||||
":kernel_runner",
|
||||
"//tensorflow/lite/c:common",
|
||||
"//tensorflow/lite/micro:op_resolvers",
|
||||
"//tensorflow/lite/micro/testing:micro_test",
|
||||
|
@ -15,8 +15,8 @@ limitations under the License.
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/kernels/internal/reference/binary_function.h"
|
||||
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
|
||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||
#include "tensorflow/lite/kernels/op_macros.h"
|
||||
#include "tensorflow/lite/micro/kernels/kernel_util.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace ops {
|
||||
@ -31,20 +31,29 @@ constexpr int kOutputTensor = 0;
|
||||
|
||||
TfLiteStatus LogicalImpl(TfLiteContext* context, TfLiteNode* node,
|
||||
bool (*func)(bool, bool)) {
|
||||
const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
|
||||
const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
const TfLiteEvalTensor* input1 =
|
||||
tflite::micro::GetEvalInput(context, node, kInputTensor1);
|
||||
const TfLiteEvalTensor* input2 =
|
||||
tflite::micro::GetEvalInput(context, node, kInputTensor2);
|
||||
TfLiteEvalTensor* output =
|
||||
tflite::micro::GetEvalOutput(context, node, kOutputTensor);
|
||||
|
||||
if (HaveSameShapes(input1, input2)) {
|
||||
if (tflite::micro::HaveSameShapes(input1, input2)) {
|
||||
reference_ops::BinaryFunction<bool, bool, bool>(
|
||||
GetTensorShape(input1), GetTensorData<bool>(input1),
|
||||
GetTensorShape(input2), GetTensorData<bool>(input2),
|
||||
GetTensorShape(output), GetTensorData<bool>(output), func);
|
||||
tflite::micro::GetTensorShape(input1),
|
||||
tflite::micro::GetTensorData<bool>(input1),
|
||||
tflite::micro::GetTensorShape(input2),
|
||||
tflite::micro::GetTensorData<bool>(input2),
|
||||
tflite::micro::GetTensorShape(output),
|
||||
tflite::micro::GetTensorData<bool>(output), func);
|
||||
} else {
|
||||
reference_ops::BroadcastBinaryFunction4DSlow<bool, bool, bool>(
|
||||
GetTensorShape(input1), GetTensorData<bool>(input1),
|
||||
GetTensorShape(input2), GetTensorData<bool>(input2),
|
||||
GetTensorShape(output), GetTensorData<bool>(output), func);
|
||||
tflite::micro::GetTensorShape(input1),
|
||||
tflite::micro::GetTensorData<bool>(input1),
|
||||
tflite::micro::GetTensorShape(input2),
|
||||
tflite::micro::GetTensorData<bool>(input2),
|
||||
tflite::micro::GetTensorShape(output),
|
||||
tflite::micro::GetTensorData<bool>(output), func);
|
||||
}
|
||||
|
||||
return kTfLiteOk;
|
||||
|
@ -12,9 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/micro/all_ops_resolver.h"
|
||||
#include "tensorflow/lite/micro/kernels/kernel_runner.h"
|
||||
#include "tensorflow/lite/micro/testing/micro_test.h"
|
||||
#include "tensorflow/lite/micro/testing/test_utils.h"
|
||||
|
||||
@ -22,9 +24,10 @@ namespace tflite {
|
||||
namespace testing {
|
||||
namespace {
|
||||
|
||||
void TestLogicalOp(tflite::BuiltinOperator op, const int* input1_dims_data,
|
||||
const bool* input1_data, const int* input2_dims_data,
|
||||
const bool* input2_data, const int* output_dims_data,
|
||||
void TestLogicalOp(const TfLiteRegistration& registration,
|
||||
const int* input1_dims_data, const bool* input1_data,
|
||||
const int* input2_dims_data, const bool* input2_data,
|
||||
const int* output_dims_data,
|
||||
const bool* expected_output_data, bool* output_data) {
|
||||
TfLiteIntArray* input1_dims = IntArrayFromInts(input1_dims_data);
|
||||
TfLiteIntArray* input2_dims = IntArrayFromInts(input2_dims_data);
|
||||
@ -40,32 +43,17 @@ void TestLogicalOp(tflite::BuiltinOperator op, const int* input1_dims_data,
|
||||
CreateBoolTensor(output_data, output_dims),
|
||||
};
|
||||
|
||||
TfLiteContext context;
|
||||
PopulateContext(tensors, tensors_size, micro_test::reporter, &context);
|
||||
|
||||
::tflite::AllOpsResolver resolver;
|
||||
const TfLiteRegistration* registration = resolver.FindOp(op);
|
||||
TF_LITE_MICRO_EXPECT_NE(nullptr, registration);
|
||||
|
||||
int inputs_array_data[] = {2, 0, 1};
|
||||
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
|
||||
int outputs_array_data[] = {1, 2};
|
||||
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
|
||||
|
||||
TfLiteNode node;
|
||||
node.inputs = inputs_array;
|
||||
node.outputs = outputs_array;
|
||||
node.user_data = nullptr;
|
||||
node.builtin_data = nullptr;
|
||||
node.custom_initial_data = nullptr;
|
||||
node.custom_initial_data_size = 0;
|
||||
micro::KernelRunner runner(registration, tensors, tensors_size, inputs_array,
|
||||
outputs_array,
|
||||
/*builtin_data=*/nullptr, micro_test::reporter);
|
||||
|
||||
if (registration->prepare) {
|
||||
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
|
||||
}
|
||||
|
||||
TF_LITE_MICRO_EXPECT_NE(nullptr, registration->invoke);
|
||||
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->invoke(&context, &node));
|
||||
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.InitAndPrepare());
|
||||
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.Invoke());
|
||||
|
||||
TF_LITE_MICRO_EXPECT_EQ(output_dims_count, 4);
|
||||
for (int i = 0; i < output_dims_count; ++i) {
|
||||
@ -85,8 +73,8 @@ TF_LITE_MICRO_TEST(LogicalOr) {
|
||||
const bool input2[] = {true, false, true, false};
|
||||
const bool golden[] = {true, false, true, true};
|
||||
bool output_data[4];
|
||||
tflite::testing::TestLogicalOp(tflite::BuiltinOperator_LOGICAL_OR, shape,
|
||||
input1, shape, input2, shape, golden,
|
||||
tflite::testing::TestLogicalOp(tflite::ops::micro::Register_LOGICAL_OR(),
|
||||
shape, input1, shape, input2, shape, golden,
|
||||
output_data);
|
||||
}
|
||||
|
||||
@ -97,7 +85,7 @@ TF_LITE_MICRO_TEST(BroadcastLogicalOr) {
|
||||
const bool input2[] = {false};
|
||||
const bool golden[] = {true, false, false, true};
|
||||
bool output_data[4];
|
||||
tflite::testing::TestLogicalOp(tflite::BuiltinOperator_LOGICAL_OR,
|
||||
tflite::testing::TestLogicalOp(tflite::ops::micro::Register_LOGICAL_OR(),
|
||||
input1_shape, input1, input2_shape, input2,
|
||||
input1_shape, golden, output_data);
|
||||
}
|
||||
@ -108,8 +96,8 @@ TF_LITE_MICRO_TEST(LogicalAnd) {
|
||||
const bool input2[] = {true, false, true, false};
|
||||
const bool golden[] = {true, false, false, false};
|
||||
bool output_data[4];
|
||||
tflite::testing::TestLogicalOp(tflite::BuiltinOperator_LOGICAL_AND, shape,
|
||||
input1, shape, input2, shape, golden,
|
||||
tflite::testing::TestLogicalOp(tflite::ops::micro::Register_LOGICAL_AND(),
|
||||
shape, input1, shape, input2, shape, golden,
|
||||
output_data);
|
||||
}
|
||||
|
||||
@ -120,7 +108,7 @@ TF_LITE_MICRO_TEST(BroadcastLogicalAnd) {
|
||||
const bool input2[] = {true};
|
||||
const bool golden[] = {true, false, false, true};
|
||||
bool output_data[4];
|
||||
tflite::testing::TestLogicalOp(tflite::BuiltinOperator_LOGICAL_AND,
|
||||
tflite::testing::TestLogicalOp(tflite::ops::micro::Register_LOGICAL_AND(),
|
||||
input1_shape, input1, input2_shape, input2,
|
||||
input1_shape, golden, output_data);
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user