diff --git a/tensorflow/lite/micro/kernels/BUILD b/tensorflow/lite/micro/kernels/BUILD index 7b9ec5dd8bb..35813aad620 100644 --- a/tensorflow/lite/micro/kernels/BUILD +++ b/tensorflow/lite/micro/kernels/BUILD @@ -265,6 +265,7 @@ tflite_micro_cc_test( "logical_test.cc", ], deps = [ + ":kernel_runner", "//tensorflow/lite/c:common", "//tensorflow/lite/micro:op_resolvers", "//tensorflow/lite/micro/testing:micro_test", diff --git a/tensorflow/lite/micro/kernels/logical.cc b/tensorflow/lite/micro/kernels/logical.cc index cbb818193ac..f4033ba8856 100644 --- a/tensorflow/lite/micro/kernels/logical.cc +++ b/tensorflow/lite/micro/kernels/logical.cc @@ -15,8 +15,8 @@ limitations under the License. #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/kernels/internal/reference/binary_function.h" #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" -#include "tensorflow/lite/kernels/kernel_util.h" #include "tensorflow/lite/kernels/op_macros.h" +#include "tensorflow/lite/micro/kernels/kernel_util.h" namespace tflite { namespace ops { @@ -31,20 +31,29 @@ constexpr int kOutputTensor = 0; TfLiteStatus LogicalImpl(TfLiteContext* context, TfLiteNode* node, bool (*func)(bool, bool)) { - const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); - const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); - TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + const TfLiteEvalTensor* input1 = + tflite::micro::GetEvalInput(context, node, kInputTensor1); + const TfLiteEvalTensor* input2 = + tflite::micro::GetEvalInput(context, node, kInputTensor2); + TfLiteEvalTensor* output = + tflite::micro::GetEvalOutput(context, node, kOutputTensor); - if (HaveSameShapes(input1, input2)) { + if (tflite::micro::HaveSameShapes(input1, input2)) { reference_ops::BinaryFunction( - GetTensorShape(input1), GetTensorData(input1), - GetTensorShape(input2), GetTensorData(input2), - GetTensorShape(output), GetTensorData(output), func); + tflite::micro::GetTensorShape(input1), + tflite::micro::GetTensorData(input1), + tflite::micro::GetTensorShape(input2), + tflite::micro::GetTensorData(input2), + tflite::micro::GetTensorShape(output), + tflite::micro::GetTensorData(output), func); } else { reference_ops::BroadcastBinaryFunction4DSlow( - GetTensorShape(input1), GetTensorData(input1), - GetTensorShape(input2), GetTensorData(input2), - GetTensorShape(output), GetTensorData(output), func); + tflite::micro::GetTensorShape(input1), + tflite::micro::GetTensorData(input1), + tflite::micro::GetTensorShape(input2), + tflite::micro::GetTensorData(input2), + tflite::micro::GetTensorShape(output), + tflite::micro::GetTensorData(output), func); } return kTfLiteOk; diff --git a/tensorflow/lite/micro/kernels/logical_test.cc b/tensorflow/lite/micro/kernels/logical_test.cc index 89a7a0ae74a..d5355c830b6 100644 --- a/tensorflow/lite/micro/kernels/logical_test.cc +++ b/tensorflow/lite/micro/kernels/logical_test.cc @@ -12,9 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ + #include "tensorflow/lite/c/builtin_op_data.h" #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/micro/all_ops_resolver.h" +#include "tensorflow/lite/micro/kernels/kernel_runner.h" #include "tensorflow/lite/micro/testing/micro_test.h" #include "tensorflow/lite/micro/testing/test_utils.h" @@ -22,9 +24,10 @@ namespace tflite { namespace testing { namespace { -void TestLogicalOp(tflite::BuiltinOperator op, const int* input1_dims_data, - const bool* input1_data, const int* input2_dims_data, - const bool* input2_data, const int* output_dims_data, +void TestLogicalOp(const TfLiteRegistration& registration, + const int* input1_dims_data, const bool* input1_data, + const int* input2_dims_data, const bool* input2_data, + const int* output_dims_data, const bool* expected_output_data, bool* output_data) { TfLiteIntArray* input1_dims = IntArrayFromInts(input1_dims_data); TfLiteIntArray* input2_dims = IntArrayFromInts(input2_dims_data); @@ -40,32 +43,17 @@ void TestLogicalOp(tflite::BuiltinOperator op, const int* input1_dims_data, CreateBoolTensor(output_data, output_dims), }; - TfLiteContext context; - PopulateContext(tensors, tensors_size, micro_test::reporter, &context); - - ::tflite::AllOpsResolver resolver; - const TfLiteRegistration* registration = resolver.FindOp(op); - TF_LITE_MICRO_EXPECT_NE(nullptr, registration); - int inputs_array_data[] = {2, 0, 1}; TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data); int outputs_array_data[] = {1, 2}; TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data); - TfLiteNode node; - node.inputs = inputs_array; - node.outputs = outputs_array; - node.user_data = nullptr; - node.builtin_data = nullptr; - node.custom_initial_data = nullptr; - node.custom_initial_data_size = 0; + micro::KernelRunner runner(registration, tensors, tensors_size, inputs_array, + outputs_array, + /*builtin_data=*/nullptr, micro_test::reporter); - if (registration->prepare) { - TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node)); - } - - TF_LITE_MICRO_EXPECT_NE(nullptr, registration->invoke); - TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->invoke(&context, &node)); + TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.InitAndPrepare()); + TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.Invoke()); TF_LITE_MICRO_EXPECT_EQ(output_dims_count, 4); for (int i = 0; i < output_dims_count; ++i) { @@ -85,8 +73,8 @@ TF_LITE_MICRO_TEST(LogicalOr) { const bool input2[] = {true, false, true, false}; const bool golden[] = {true, false, true, true}; bool output_data[4]; - tflite::testing::TestLogicalOp(tflite::BuiltinOperator_LOGICAL_OR, shape, - input1, shape, input2, shape, golden, + tflite::testing::TestLogicalOp(tflite::ops::micro::Register_LOGICAL_OR(), + shape, input1, shape, input2, shape, golden, output_data); } @@ -97,7 +85,7 @@ TF_LITE_MICRO_TEST(BroadcastLogicalOr) { const bool input2[] = {false}; const bool golden[] = {true, false, false, true}; bool output_data[4]; - tflite::testing::TestLogicalOp(tflite::BuiltinOperator_LOGICAL_OR, + tflite::testing::TestLogicalOp(tflite::ops::micro::Register_LOGICAL_OR(), input1_shape, input1, input2_shape, input2, input1_shape, golden, output_data); } @@ -108,8 +96,8 @@ TF_LITE_MICRO_TEST(LogicalAnd) { const bool input2[] = {true, false, true, false}; const bool golden[] = {true, false, false, false}; bool output_data[4]; - tflite::testing::TestLogicalOp(tflite::BuiltinOperator_LOGICAL_AND, shape, - input1, shape, input2, shape, golden, + tflite::testing::TestLogicalOp(tflite::ops::micro::Register_LOGICAL_AND(), + shape, input1, shape, input2, shape, golden, output_data); } @@ -120,7 +108,7 @@ TF_LITE_MICRO_TEST(BroadcastLogicalAnd) { const bool input2[] = {true}; const bool golden[] = {true, false, false, true}; bool output_data[4]; - tflite::testing::TestLogicalOp(tflite::BuiltinOperator_LOGICAL_AND, + tflite::testing::TestLogicalOp(tflite::ops::micro::Register_LOGICAL_AND(), input1_shape, input1, input2_shape, input2, input1_shape, golden, output_data); }