Port the logical kernel to the TfLiteEvalTensor API.

PiperOrigin-RevId: 323042504
Change-Id: I0b7814b588c9ca8b26dbc2eba28035a4226475e3
This commit is contained in:
Nick Kreeger 2020-07-24 12:01:24 -07:00 committed by TensorFlower Gardener
parent 3098c7a84d
commit 23cb51dfb6
3 changed files with 38 additions and 40 deletions

View File

@ -265,6 +265,7 @@ tflite_micro_cc_test(
"logical_test.cc",
],
deps = [
":kernel_runner",
"//tensorflow/lite/c:common",
"//tensorflow/lite/micro:op_resolvers",
"//tensorflow/lite/micro/testing:micro_test",

View File

@ -15,8 +15,8 @@ limitations under the License.
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/kernels/internal/reference/binary_function.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/kernels/op_macros.h"
#include "tensorflow/lite/micro/kernels/kernel_util.h"
namespace tflite {
namespace ops {
@ -31,20 +31,29 @@ constexpr int kOutputTensor = 0;
TfLiteStatus LogicalImpl(TfLiteContext* context, TfLiteNode* node,
bool (*func)(bool, bool)) {
const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
const TfLiteEvalTensor* input1 =
tflite::micro::GetEvalInput(context, node, kInputTensor1);
const TfLiteEvalTensor* input2 =
tflite::micro::GetEvalInput(context, node, kInputTensor2);
TfLiteEvalTensor* output =
tflite::micro::GetEvalOutput(context, node, kOutputTensor);
if (HaveSameShapes(input1, input2)) {
if (tflite::micro::HaveSameShapes(input1, input2)) {
reference_ops::BinaryFunction<bool, bool, bool>(
GetTensorShape(input1), GetTensorData<bool>(input1),
GetTensorShape(input2), GetTensorData<bool>(input2),
GetTensorShape(output), GetTensorData<bool>(output), func);
tflite::micro::GetTensorShape(input1),
tflite::micro::GetTensorData<bool>(input1),
tflite::micro::GetTensorShape(input2),
tflite::micro::GetTensorData<bool>(input2),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<bool>(output), func);
} else {
reference_ops::BroadcastBinaryFunction4DSlow<bool, bool, bool>(
GetTensorShape(input1), GetTensorData<bool>(input1),
GetTensorShape(input2), GetTensorData<bool>(input2),
GetTensorShape(output), GetTensorData<bool>(output), func);
tflite::micro::GetTensorShape(input1),
tflite::micro::GetTensorData<bool>(input1),
tflite::micro::GetTensorShape(input2),
tflite::micro::GetTensorData<bool>(input2),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<bool>(output), func);
}
return kTfLiteOk;

View File

@ -12,9 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/micro/all_ops_resolver.h"
#include "tensorflow/lite/micro/kernels/kernel_runner.h"
#include "tensorflow/lite/micro/testing/micro_test.h"
#include "tensorflow/lite/micro/testing/test_utils.h"
@ -22,9 +24,10 @@ namespace tflite {
namespace testing {
namespace {
void TestLogicalOp(tflite::BuiltinOperator op, const int* input1_dims_data,
const bool* input1_data, const int* input2_dims_data,
const bool* input2_data, const int* output_dims_data,
void TestLogicalOp(const TfLiteRegistration& registration,
const int* input1_dims_data, const bool* input1_data,
const int* input2_dims_data, const bool* input2_data,
const int* output_dims_data,
const bool* expected_output_data, bool* output_data) {
TfLiteIntArray* input1_dims = IntArrayFromInts(input1_dims_data);
TfLiteIntArray* input2_dims = IntArrayFromInts(input2_dims_data);
@ -40,32 +43,17 @@ void TestLogicalOp(tflite::BuiltinOperator op, const int* input1_dims_data,
CreateBoolTensor(output_data, output_dims),
};
TfLiteContext context;
PopulateContext(tensors, tensors_size, micro_test::reporter, &context);
::tflite::AllOpsResolver resolver;
const TfLiteRegistration* registration = resolver.FindOp(op);
TF_LITE_MICRO_EXPECT_NE(nullptr, registration);
int inputs_array_data[] = {2, 0, 1};
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
int outputs_array_data[] = {1, 2};
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
TfLiteNode node;
node.inputs = inputs_array;
node.outputs = outputs_array;
node.user_data = nullptr;
node.builtin_data = nullptr;
node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0;
micro::KernelRunner runner(registration, tensors, tensors_size, inputs_array,
outputs_array,
/*builtin_data=*/nullptr, micro_test::reporter);
if (registration->prepare) {
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
}
TF_LITE_MICRO_EXPECT_NE(nullptr, registration->invoke);
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->invoke(&context, &node));
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.InitAndPrepare());
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.Invoke());
TF_LITE_MICRO_EXPECT_EQ(output_dims_count, 4);
for (int i = 0; i < output_dims_count; ++i) {
@ -85,8 +73,8 @@ TF_LITE_MICRO_TEST(LogicalOr) {
const bool input2[] = {true, false, true, false};
const bool golden[] = {true, false, true, true};
bool output_data[4];
tflite::testing::TestLogicalOp(tflite::BuiltinOperator_LOGICAL_OR, shape,
input1, shape, input2, shape, golden,
tflite::testing::TestLogicalOp(tflite::ops::micro::Register_LOGICAL_OR(),
shape, input1, shape, input2, shape, golden,
output_data);
}
@ -97,7 +85,7 @@ TF_LITE_MICRO_TEST(BroadcastLogicalOr) {
const bool input2[] = {false};
const bool golden[] = {true, false, false, true};
bool output_data[4];
tflite::testing::TestLogicalOp(tflite::BuiltinOperator_LOGICAL_OR,
tflite::testing::TestLogicalOp(tflite::ops::micro::Register_LOGICAL_OR(),
input1_shape, input1, input2_shape, input2,
input1_shape, golden, output_data);
}
@ -108,8 +96,8 @@ TF_LITE_MICRO_TEST(LogicalAnd) {
const bool input2[] = {true, false, true, false};
const bool golden[] = {true, false, false, false};
bool output_data[4];
tflite::testing::TestLogicalOp(tflite::BuiltinOperator_LOGICAL_AND, shape,
input1, shape, input2, shape, golden,
tflite::testing::TestLogicalOp(tflite::ops::micro::Register_LOGICAL_AND(),
shape, input1, shape, input2, shape, golden,
output_data);
}
@ -120,7 +108,7 @@ TF_LITE_MICRO_TEST(BroadcastLogicalAnd) {
const bool input2[] = {true};
const bool golden[] = {true, false, false, true};
bool output_data[4];
tflite::testing::TestLogicalOp(tflite::BuiltinOperator_LOGICAL_AND,
tflite::testing::TestLogicalOp(tflite::ops::micro::Register_LOGICAL_AND(),
input1_shape, input1, input2_shape, input2,
input1_shape, golden, output_data);
}