Move binary functions to a separate header.

PiperOrigin-RevId: 261823955
This commit is contained in:
Juho Ha 2019-08-05 19:44:19 -07:00 committed by TensorFlower Gardener
parent 640b5f2513
commit c7fd77e7d0
5 changed files with 100 additions and 127 deletions

View File

@ -348,6 +348,7 @@ cc_library(
srcs = [],
hdrs = [
"reference/arg_min_max.h",
"reference/binary_function.h",
"reference/comparisons.h",
"reference/conv.h",
"reference/depthwiseconv_float.h",
@ -405,6 +406,7 @@ cc_library(
srcs = [],
hdrs = [
"reference/arg_min_max.h",
"reference/binary_function.h",
"reference/comparisons.h",
"reference/conv.h",
"reference/depthwiseconv_float.h",

View File

@ -0,0 +1,85 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_BINARY_FUNCTION_H_
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_BINARY_FUNCTION_H_
#include "tensorflow/lite/kernels/internal/common.h"
#include "tensorflow/lite/kernels/internal/compatibility.h"
#include "tensorflow/lite/kernels/internal/reference/comparisons.h"
#include "tensorflow/lite/kernels/internal/types.h"
namespace tflite {
namespace reference_ops {
// TODO(ycling): Refactoring. Remove BroadcastLogical and use the more
// generalized and efficient BroadcastBinaryFunction.
//
// Also appears to duplicte MinimumMaximum.
//
// R: Result type. T1: Input 1 type. T2: Input 2 type.
template <typename R, typename T1, typename T2>
inline void BroadcastBinaryFunction4DSlow(
const RuntimeShape& unextended_input1_shape, const T1* input1_data,
const RuntimeShape& unextended_input2_shape, const T2* input2_data,
const RuntimeShape& unextended_output_shape, R* output_data,
R (*func)(T1, T2)) {
TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4);
TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), 4);
TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
const RuntimeShape output_shape =
RuntimeShape::ExtendedShape(4, unextended_output_shape);
NdArrayDesc<4> desc1;
NdArrayDesc<4> desc2;
NdArrayDescsForElementwiseBroadcast(unextended_input1_shape,
unextended_input2_shape, &desc1, &desc2);
for (int b = 0; b < output_shape.Dims(0); ++b) {
for (int y = 0; y < output_shape.Dims(1); ++y) {
for (int x = 0; x < output_shape.Dims(2); ++x) {
for (int c = 0; c < output_shape.Dims(3); ++c) {
auto out_idx = Offset(output_shape, b, y, x, c);
auto in1_idx = SubscriptToIndex(desc1, b, y, x, c);
auto in2_idx = SubscriptToIndex(desc2, b, y, x, c);
auto in1_val = input1_data[in1_idx];
auto in2_val = input2_data[in2_idx];
output_data[out_idx] = func(in1_val, in2_val);
}
}
}
}
}
// R: Result type. T1: Input 1 type. T2: Input 2 type.
// TODO(renjieliu): Refactor other binary functions to use this one.
template <typename R, typename T1, typename T2>
inline void BinaryFunction(const RuntimeShape& input1_shape,
const T1* input1_data,
const RuntimeShape& input2_shape,
const T2* input2_data,
const RuntimeShape& output_shape, R* output_data,
R (*func)(T1, T2)) {
const int flat_size =
MatchingFlatSize(input1_shape, input2_shape, output_shape);
for (int i = 0; i < flat_size; ++i) {
output_data[i] = func(input1_data[i], input2_data[i]);
}
}
} // namespace reference_ops
} // namespace tflite
#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_BINARY_FUNCTION_H_

View File

@ -2192,25 +2192,6 @@ inline void BroadcastPow(const T* input1_data, const Dims<4>& input1_dims,
DimsToShape(output_dims), output_data);
}
inline void Logical(const bool* input1_data, const Dims<4>& input1_dims,
const bool* input2_data, const Dims<4>& input2_dims,
bool* output_data, const Dims<4>& output_dims,
const std::function<bool(bool, bool)>& func) {
Logical(DimsToShape(input1_dims), input1_data, DimsToShape(input2_dims),
input2_data, DimsToShape(output_dims), output_data, func);
}
inline void BroadcastLogical(const bool* input1_data,
const Dims<4>& input1_dims,
const bool* input2_data,
const Dims<4>& input2_dims, bool* output_data,
const Dims<4>& output_dims,
const std::function<bool(bool, bool)>& func) {
BroadcastLogical4DSlow(DimsToShape(input1_dims), input1_data,
DimsToShape(input2_dims), input2_data,
DimsToShape(output_dims), output_data, func);
}
// R: Result type. T1: Input 1 type. T2: Input 2 type.
template <typename R, typename T1, typename T2>
inline void BroadcastBinaryFunction(const T1* input1_data,

View File

@ -33,6 +33,7 @@ limitations under the License.
#include "tensorflow/lite/kernels/internal/common.h"
#include "tensorflow/lite/kernels/internal/quantization_util.h"
#include "tensorflow/lite/kernels/internal/reference/arg_min_max.h"
#include "tensorflow/lite/kernels/internal/reference/binary_function.h"
#include "tensorflow/lite/kernels/internal/reference/comparisons.h"
#include "tensorflow/lite/kernels/internal/reference/conv.h"
#include "tensorflow/lite/kernels/internal/reference/floor.h"
@ -3917,104 +3918,6 @@ inline void BroadcastPow4DSlow(const RuntimeShape& unextended_input1_shape,
}
}
inline void Logical(const RuntimeShape& input1_shape, const bool* input1_data,
const RuntimeShape& input2_shape, const bool* input2_data,
const RuntimeShape& output_shape, bool* output_data,
const std::function<bool(bool, bool)>& func) {
const int flat_size =
MatchingFlatSize(input1_shape, input2_shape, output_shape);
for (int i = 0; i < flat_size; ++i) {
output_data[i] = func(input1_data[i], input2_data[i]);
}
}
inline void BroadcastLogical4DSlow(
const RuntimeShape& unextended_input1_shape, const bool* input1_data,
const RuntimeShape& unextended_input2_shape, const bool* input2_data,
const RuntimeShape& unextended_output_shape, bool* output_data,
const std::function<bool(bool, bool)>& func) {
TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4);
TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), 4);
TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
const RuntimeShape output_shape =
RuntimeShape::ExtendedShape(4, unextended_output_shape);
NdArrayDesc<4> desc1;
NdArrayDesc<4> desc2;
NdArrayDescsForElementwiseBroadcast(unextended_input1_shape,
unextended_input2_shape, &desc1, &desc2);
for (int b = 0; b < output_shape.Dims(0); ++b) {
for (int y = 0; y < output_shape.Dims(1); ++y) {
for (int x = 0; x < output_shape.Dims(2); ++x) {
for (int c = 0; c < output_shape.Dims(3); ++c) {
auto out_idx = Offset(output_shape, b, y, x, c);
auto in1_idx = SubscriptToIndex(desc1, b, y, x, c);
auto in2_idx = SubscriptToIndex(desc2, b, y, x, c);
auto in1_val = input1_data[in1_idx];
auto in2_val = input2_data[in2_idx];
output_data[out_idx] = func(in1_val, in2_val);
}
}
}
}
}
// TODO(ycling): Refactoring. Remove BroadcastLogical and use the more
// generalized and efficient BroadcastBinaryFunction.
//
// Also appears to duplicte MinimumMaximum.
//
// R: Result type. T1: Input 1 type. T2: Input 2 type.
template <typename R, typename T1, typename T2>
inline void BroadcastBinaryFunction4DSlow(
const RuntimeShape& unextended_input1_shape, const T1* input1_data,
const RuntimeShape& unextended_input2_shape, const T2* input2_data,
const RuntimeShape& unextended_output_shape, R* output_data,
R (*func)(T1, T2)) {
TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4);
TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), 4);
TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
const RuntimeShape output_shape =
RuntimeShape::ExtendedShape(4, unextended_output_shape);
NdArrayDesc<4> desc1;
NdArrayDesc<4> desc2;
NdArrayDescsForElementwiseBroadcast(unextended_input1_shape,
unextended_input2_shape, &desc1, &desc2);
for (int b = 0; b < output_shape.Dims(0); ++b) {
for (int y = 0; y < output_shape.Dims(1); ++y) {
for (int x = 0; x < output_shape.Dims(2); ++x) {
for (int c = 0; c < output_shape.Dims(3); ++c) {
auto out_idx = Offset(output_shape, b, y, x, c);
auto in1_idx = SubscriptToIndex(desc1, b, y, x, c);
auto in2_idx = SubscriptToIndex(desc2, b, y, x, c);
auto in1_val = input1_data[in1_idx];
auto in2_val = input2_data[in2_idx];
output_data[out_idx] = func(in1_val, in2_val);
}
}
}
}
}
// R: Result type. T1: Input 1 type. T2: Input 2 type.
// TODO(renjieliu): Refactor other binary functions to use this one.
template <typename R, typename T1, typename T2>
inline void BinaryFunction(const RuntimeShape& input1_shape,
const T1* input1_data,
const RuntimeShape& input2_shape,
const T2* input2_data,
const RuntimeShape& output_shape, R* output_data,
R (*func)(T1, T2)) {
const int flat_size =
MatchingFlatSize(input1_shape, input2_shape, output_shape);
for (int i = 0; i < flat_size; ++i) {
output_data[i] = func(input1_data[i], input2_data[i]);
}
}
template <typename T>
inline void ResizeNearestNeighbor(
const tflite::ResizeNearestNeighborParams& op_params,

View File

@ -78,7 +78,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
}
TfLiteStatus LogicalImpl(TfLiteContext* context, TfLiteNode* node,
const std::function<bool(bool, bool)>& func) {
bool (*func)(bool, bool)) {
OpData* data = reinterpret_cast<OpData*>(node->user_data);
const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
@ -86,28 +86,30 @@ TfLiteStatus LogicalImpl(TfLiteContext* context, TfLiteNode* node,
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
if (data->requires_broadcast) {
reference_ops::BroadcastLogical4DSlow(
reference_ops::BroadcastBinaryFunction4DSlow<bool, bool, bool>(
GetTensorShape(input1), GetTensorData<bool>(input1),
GetTensorShape(input2), GetTensorData<bool>(input2),
GetTensorShape(output), GetTensorData<bool>(output), func);
} else {
reference_ops::Logical(GetTensorShape(input1), GetTensorData<bool>(input1),
GetTensorShape(input2), GetTensorData<bool>(input2),
GetTensorShape(output), GetTensorData<bool>(output),
func);
reference_ops::BinaryFunction<bool, bool, bool>(
GetTensorShape(input1), GetTensorData<bool>(input1),
GetTensorShape(input2), GetTensorData<bool>(input2),
GetTensorShape(output), GetTensorData<bool>(output), func);
}
return kTfLiteOk;
}
bool LogicalOr(bool x, bool y) { return x || y; }
TfLiteStatus LogicalOrEval(TfLiteContext* context, TfLiteNode* node) {
const auto logical_or_func = std::logical_or<bool>();
return LogicalImpl(context, node, logical_or_func);
return LogicalImpl(context, node, LogicalOr);
}
bool LogicalAnd(bool x, bool y) { return x && y; }
TfLiteStatus LogicalAndEval(TfLiteContext* context, TfLiteNode* node) {
const auto logical_and_func = std::logical_and<bool>();
return LogicalImpl(context, node, logical_and_func);
return LogicalImpl(context, node, LogicalAnd);
}
} // namespace