From c7fd77e7d0586b77593f87a664ce1e8b756619b9 Mon Sep 17 00:00:00 2001 From: Juho Ha Date: Mon, 5 Aug 2019 19:44:19 -0700 Subject: [PATCH] Move binary functions to a separate header. PiperOrigin-RevId: 261823955 --- tensorflow/lite/kernels/internal/BUILD | 2 + .../internal/reference/binary_function.h | 85 ++++++++++++++++ .../internal/reference/legacy_reference_ops.h | 19 ---- .../internal/reference/reference_ops.h | 99 +------------------ tensorflow/lite/kernels/logical.cc | 22 +++-- 5 files changed, 100 insertions(+), 127 deletions(-) create mode 100644 tensorflow/lite/kernels/internal/reference/binary_function.h diff --git a/tensorflow/lite/kernels/internal/BUILD b/tensorflow/lite/kernels/internal/BUILD index 0d96bc01258..e11deb11711 100644 --- a/tensorflow/lite/kernels/internal/BUILD +++ b/tensorflow/lite/kernels/internal/BUILD @@ -348,6 +348,7 @@ cc_library( srcs = [], hdrs = [ "reference/arg_min_max.h", + "reference/binary_function.h", "reference/comparisons.h", "reference/conv.h", "reference/depthwiseconv_float.h", @@ -405,6 +406,7 @@ cc_library( srcs = [], hdrs = [ "reference/arg_min_max.h", + "reference/binary_function.h", "reference/comparisons.h", "reference/conv.h", "reference/depthwiseconv_float.h", diff --git a/tensorflow/lite/kernels/internal/reference/binary_function.h b/tensorflow/lite/kernels/internal/reference/binary_function.h new file mode 100644 index 00000000000..874bf9e9eb9 --- /dev/null +++ b/tensorflow/lite/kernels/internal/reference/binary_function.h @@ -0,0 +1,85 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_BINARY_FUNCTION_H_ +#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_BINARY_FUNCTION_H_ + +#include "tensorflow/lite/kernels/internal/common.h" +#include "tensorflow/lite/kernels/internal/compatibility.h" +#include "tensorflow/lite/kernels/internal/reference/comparisons.h" +#include "tensorflow/lite/kernels/internal/types.h" + +namespace tflite { + +namespace reference_ops { + +// TODO(ycling): Refactoring. Remove BroadcastLogical and use the more +// generalized and efficient BroadcastBinaryFunction. +// +// Also appears to duplicte MinimumMaximum. +// +// R: Result type. T1: Input 1 type. T2: Input 2 type. +template +inline void BroadcastBinaryFunction4DSlow( + const RuntimeShape& unextended_input1_shape, const T1* input1_data, + const RuntimeShape& unextended_input2_shape, const T2* input2_data, + const RuntimeShape& unextended_output_shape, R* output_data, + R (*func)(T1, T2)) { + TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4); + TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), 4); + TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4); + const RuntimeShape output_shape = + RuntimeShape::ExtendedShape(4, unextended_output_shape); + + NdArrayDesc<4> desc1; + NdArrayDesc<4> desc2; + NdArrayDescsForElementwiseBroadcast(unextended_input1_shape, + unextended_input2_shape, &desc1, &desc2); + + for (int b = 0; b < output_shape.Dims(0); ++b) { + for (int y = 0; y < output_shape.Dims(1); ++y) { + for (int x = 0; x < output_shape.Dims(2); ++x) { + for (int c = 0; c < output_shape.Dims(3); ++c) { + auto out_idx = Offset(output_shape, b, y, x, c); + auto in1_idx = SubscriptToIndex(desc1, b, y, x, c); + auto in2_idx = SubscriptToIndex(desc2, b, y, x, c); + auto in1_val = input1_data[in1_idx]; + auto in2_val = input2_data[in2_idx]; + output_data[out_idx] = func(in1_val, in2_val); + } + } + } + } +} + +// R: Result type. T1: Input 1 type. T2: Input 2 type. +// TODO(renjieliu): Refactor other binary functions to use this one. +template +inline void BinaryFunction(const RuntimeShape& input1_shape, + const T1* input1_data, + const RuntimeShape& input2_shape, + const T2* input2_data, + const RuntimeShape& output_shape, R* output_data, + R (*func)(T1, T2)) { + const int flat_size = + MatchingFlatSize(input1_shape, input2_shape, output_shape); + for (int i = 0; i < flat_size; ++i) { + output_data[i] = func(input1_data[i], input2_data[i]); + } +} + +} // namespace reference_ops +} // namespace tflite + +#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_BINARY_FUNCTION_H_ diff --git a/tensorflow/lite/kernels/internal/reference/legacy_reference_ops.h b/tensorflow/lite/kernels/internal/reference/legacy_reference_ops.h index 082f86e5c9e..615abdfcfaf 100644 --- a/tensorflow/lite/kernels/internal/reference/legacy_reference_ops.h +++ b/tensorflow/lite/kernels/internal/reference/legacy_reference_ops.h @@ -2192,25 +2192,6 @@ inline void BroadcastPow(const T* input1_data, const Dims<4>& input1_dims, DimsToShape(output_dims), output_data); } -inline void Logical(const bool* input1_data, const Dims<4>& input1_dims, - const bool* input2_data, const Dims<4>& input2_dims, - bool* output_data, const Dims<4>& output_dims, - const std::function& func) { - Logical(DimsToShape(input1_dims), input1_data, DimsToShape(input2_dims), - input2_data, DimsToShape(output_dims), output_data, func); -} - -inline void BroadcastLogical(const bool* input1_data, - const Dims<4>& input1_dims, - const bool* input2_data, - const Dims<4>& input2_dims, bool* output_data, - const Dims<4>& output_dims, - const std::function& func) { - BroadcastLogical4DSlow(DimsToShape(input1_dims), input1_data, - DimsToShape(input2_dims), input2_data, - DimsToShape(output_dims), output_data, func); -} - // R: Result type. T1: Input 1 type. T2: Input 2 type. template inline void BroadcastBinaryFunction(const T1* input1_data, diff --git a/tensorflow/lite/kernels/internal/reference/reference_ops.h b/tensorflow/lite/kernels/internal/reference/reference_ops.h index 457f8946e66..225fe3cb778 100644 --- a/tensorflow/lite/kernels/internal/reference/reference_ops.h +++ b/tensorflow/lite/kernels/internal/reference/reference_ops.h @@ -33,6 +33,7 @@ limitations under the License. #include "tensorflow/lite/kernels/internal/common.h" #include "tensorflow/lite/kernels/internal/quantization_util.h" #include "tensorflow/lite/kernels/internal/reference/arg_min_max.h" +#include "tensorflow/lite/kernels/internal/reference/binary_function.h" #include "tensorflow/lite/kernels/internal/reference/comparisons.h" #include "tensorflow/lite/kernels/internal/reference/conv.h" #include "tensorflow/lite/kernels/internal/reference/floor.h" @@ -3917,104 +3918,6 @@ inline void BroadcastPow4DSlow(const RuntimeShape& unextended_input1_shape, } } -inline void Logical(const RuntimeShape& input1_shape, const bool* input1_data, - const RuntimeShape& input2_shape, const bool* input2_data, - const RuntimeShape& output_shape, bool* output_data, - const std::function& func) { - const int flat_size = - MatchingFlatSize(input1_shape, input2_shape, output_shape); - for (int i = 0; i < flat_size; ++i) { - output_data[i] = func(input1_data[i], input2_data[i]); - } -} - -inline void BroadcastLogical4DSlow( - const RuntimeShape& unextended_input1_shape, const bool* input1_data, - const RuntimeShape& unextended_input2_shape, const bool* input2_data, - const RuntimeShape& unextended_output_shape, bool* output_data, - const std::function& func) { - TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4); - TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), 4); - TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4); - const RuntimeShape output_shape = - RuntimeShape::ExtendedShape(4, unextended_output_shape); - - NdArrayDesc<4> desc1; - NdArrayDesc<4> desc2; - NdArrayDescsForElementwiseBroadcast(unextended_input1_shape, - unextended_input2_shape, &desc1, &desc2); - - for (int b = 0; b < output_shape.Dims(0); ++b) { - for (int y = 0; y < output_shape.Dims(1); ++y) { - for (int x = 0; x < output_shape.Dims(2); ++x) { - for (int c = 0; c < output_shape.Dims(3); ++c) { - auto out_idx = Offset(output_shape, b, y, x, c); - auto in1_idx = SubscriptToIndex(desc1, b, y, x, c); - auto in2_idx = SubscriptToIndex(desc2, b, y, x, c); - auto in1_val = input1_data[in1_idx]; - auto in2_val = input2_data[in2_idx]; - output_data[out_idx] = func(in1_val, in2_val); - } - } - } - } -} - -// TODO(ycling): Refactoring. Remove BroadcastLogical and use the more -// generalized and efficient BroadcastBinaryFunction. -// -// Also appears to duplicte MinimumMaximum. -// -// R: Result type. T1: Input 1 type. T2: Input 2 type. -template -inline void BroadcastBinaryFunction4DSlow( - const RuntimeShape& unextended_input1_shape, const T1* input1_data, - const RuntimeShape& unextended_input2_shape, const T2* input2_data, - const RuntimeShape& unextended_output_shape, R* output_data, - R (*func)(T1, T2)) { - TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4); - TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), 4); - TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4); - const RuntimeShape output_shape = - RuntimeShape::ExtendedShape(4, unextended_output_shape); - - NdArrayDesc<4> desc1; - NdArrayDesc<4> desc2; - NdArrayDescsForElementwiseBroadcast(unextended_input1_shape, - unextended_input2_shape, &desc1, &desc2); - - for (int b = 0; b < output_shape.Dims(0); ++b) { - for (int y = 0; y < output_shape.Dims(1); ++y) { - for (int x = 0; x < output_shape.Dims(2); ++x) { - for (int c = 0; c < output_shape.Dims(3); ++c) { - auto out_idx = Offset(output_shape, b, y, x, c); - auto in1_idx = SubscriptToIndex(desc1, b, y, x, c); - auto in2_idx = SubscriptToIndex(desc2, b, y, x, c); - auto in1_val = input1_data[in1_idx]; - auto in2_val = input2_data[in2_idx]; - output_data[out_idx] = func(in1_val, in2_val); - } - } - } - } -} - -// R: Result type. T1: Input 1 type. T2: Input 2 type. -// TODO(renjieliu): Refactor other binary functions to use this one. -template -inline void BinaryFunction(const RuntimeShape& input1_shape, - const T1* input1_data, - const RuntimeShape& input2_shape, - const T2* input2_data, - const RuntimeShape& output_shape, R* output_data, - R (*func)(T1, T2)) { - const int flat_size = - MatchingFlatSize(input1_shape, input2_shape, output_shape); - for (int i = 0; i < flat_size; ++i) { - output_data[i] = func(input1_data[i], input2_data[i]); - } -} - template inline void ResizeNearestNeighbor( const tflite::ResizeNearestNeighborParams& op_params, diff --git a/tensorflow/lite/kernels/logical.cc b/tensorflow/lite/kernels/logical.cc index 582bcff64a8..7a2805d503b 100644 --- a/tensorflow/lite/kernels/logical.cc +++ b/tensorflow/lite/kernels/logical.cc @@ -78,7 +78,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { } TfLiteStatus LogicalImpl(TfLiteContext* context, TfLiteNode* node, - const std::function& func) { + bool (*func)(bool, bool)) { OpData* data = reinterpret_cast(node->user_data); const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); @@ -86,28 +86,30 @@ TfLiteStatus LogicalImpl(TfLiteContext* context, TfLiteNode* node, TfLiteTensor* output = GetOutput(context, node, kOutputTensor); if (data->requires_broadcast) { - reference_ops::BroadcastLogical4DSlow( + reference_ops::BroadcastBinaryFunction4DSlow( GetTensorShape(input1), GetTensorData(input1), GetTensorShape(input2), GetTensorData(input2), GetTensorShape(output), GetTensorData(output), func); } else { - reference_ops::Logical(GetTensorShape(input1), GetTensorData(input1), - GetTensorShape(input2), GetTensorData(input2), - GetTensorShape(output), GetTensorData(output), - func); + reference_ops::BinaryFunction( + GetTensorShape(input1), GetTensorData(input1), + GetTensorShape(input2), GetTensorData(input2), + GetTensorShape(output), GetTensorData(output), func); } return kTfLiteOk; } +bool LogicalOr(bool x, bool y) { return x || y; } + TfLiteStatus LogicalOrEval(TfLiteContext* context, TfLiteNode* node) { - const auto logical_or_func = std::logical_or(); - return LogicalImpl(context, node, logical_or_func); + return LogicalImpl(context, node, LogicalOr); } +bool LogicalAnd(bool x, bool y) { return x && y; } + TfLiteStatus LogicalAndEval(TfLiteContext* context, TfLiteNode* node) { - const auto logical_and_func = std::logical_and(); - return LogicalImpl(context, node, logical_and_func); + return LogicalImpl(context, node, LogicalAnd); } } // namespace