Port comparison operations to separate header

PiperOrigin-RevId: 261627312
This commit is contained in:
Taehee Jeong 2019-08-05 00:00:00 -07:00 committed by TensorFlower Gardener
parent 00581c6347
commit f91e7353f8
3 changed files with 279 additions and 248 deletions

View File

@ -348,6 +348,7 @@ cc_library(
srcs = [],
hdrs = [
"reference/arg_min_max.h",
"reference/comparisons.h",
"reference/conv.h",
"reference/depthwiseconv_float.h",
"reference/depthwiseconv_uint8.h",
@ -404,6 +405,7 @@ cc_library(
srcs = [],
hdrs = [
"reference/arg_min_max.h",
"reference/comparisons.h",
"reference/conv.h",
"reference/depthwiseconv_float.h",
"reference/depthwiseconv_uint8.h",

View File

@ -0,0 +1,276 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_COMPARISONS_H_
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_COMPARISONS_H_
#include "profiling/instrumentation.h"
#include "tensorflow/lite/kernels/internal/common.h"
#include "tensorflow/lite/kernels/internal/types.h"
namespace tflite {
namespace reference_ops {
template <typename T>
inline bool EqualFn(T lhs, T rhs) {
return lhs == rhs;
}
template <typename T>
inline bool NotEqualFn(T lhs, T rhs) {
return lhs != rhs;
}
template <typename T>
inline bool GreaterFn(T lhs, T rhs) {
return lhs > rhs;
}
template <typename T>
inline bool GreaterEqualFn(T lhs, T rhs) {
return lhs >= rhs;
}
template <typename T>
inline bool LessFn(T lhs, T rhs) {
return lhs < rhs;
}
template <typename T>
inline bool LessEqualFn(T lhs, T rhs) {
return lhs <= rhs;
}
template <typename T>
using ComparisonFn = bool (*)(T, T);
template <typename T, ComparisonFn<T> F>
inline void ComparisonImpl(
const ComparisonParams& op_params, const RuntimeShape& input1_shape,
const T* input1_data, const RuntimeShape& input2_shape,
const T* input2_data, const RuntimeShape& output_shape, bool* output_data) {
const int64_t flatsize =
MatchingFlatSize(input1_shape, input2_shape, output_shape);
for (int64_t i = 0; i < flatsize; ++i) {
output_data[i] = F(input1_data[i], input2_data[i]);
}
}
template <ComparisonFn<float> F>
inline void Comparison(const ComparisonParams& op_params,
const RuntimeShape& input1_shape,
const float* input1_data,
const RuntimeShape& input2_shape,
const float* input2_data,
const RuntimeShape& output_shape, bool* output_data) {
ComparisonImpl<float, F>(op_params, input1_shape, input1_data, input2_shape,
input2_data, output_shape, output_data);
}
template <typename T, ComparisonFn<int32> F>
inline void ComparisonWithScaling(
const ComparisonParams& op_params, const RuntimeShape& input1_shape,
const T* input1_data, const RuntimeShape& input2_shape,
const T* input2_data, const RuntimeShape& output_shape, bool* output_data) {
int left_shift = op_params.left_shift;
int32 input1_offset = op_params.input1_offset;
int32 input1_multiplier = op_params.input1_multiplier;
int input1_shift = op_params.input1_shift;
int32 input2_offset = op_params.input2_offset;
int32 input2_multiplier = op_params.input2_multiplier;
int input2_shift = op_params.input2_shift;
const int64_t flatsize =
MatchingFlatSize(input1_shape, input2_shape, output_shape);
for (int64_t i = 0; i < flatsize; ++i) {
const int32 input1_val = input1_offset + input1_data[i];
const int32 input2_val = input2_offset + input2_data[i];
const int32 shifted_input1_val = input1_val * (1 << left_shift);
const int32 shifted_input2_val = input2_val * (1 << left_shift);
const int32 scaled_input1_val =
MultiplyByQuantizedMultiplierSmallerThanOneExp(
shifted_input1_val, input1_multiplier, input1_shift);
const int32 scaled_input2_val =
MultiplyByQuantizedMultiplierSmallerThanOneExp(
shifted_input2_val, input2_multiplier, input2_shift);
output_data[i] = F(scaled_input1_val, scaled_input2_val);
}
}
template <typename T, ComparisonFn<T> F>
inline void BroadcastComparison4DSlowImpl(
const ComparisonParams& op_params,
const RuntimeShape& unextended_input1_shape, const T* input1_data,
const RuntimeShape& unextended_input2_shape, const T* input2_data,
const RuntimeShape& unextended_output_shape, bool* output_data) {
gemmlowp::ScopedProfilingLabel label("BroadcastComparison4DSlow");
TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4);
TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), 4);
TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
const RuntimeShape output_shape =
RuntimeShape::ExtendedShape(4, unextended_output_shape);
NdArrayDesc<4> desc1;
NdArrayDesc<4> desc2;
NdArrayDescsForElementwiseBroadcast(unextended_input1_shape,
unextended_input2_shape, &desc1, &desc2);
for (int b = 0; b < output_shape.Dims(0); ++b) {
for (int y = 0; y < output_shape.Dims(1); ++y) {
for (int x = 0; x < output_shape.Dims(2); ++x) {
for (int c = 0; c < output_shape.Dims(3); ++c) {
output_data[Offset(output_shape, b, y, x, c)] =
F(input1_data[SubscriptToIndex(desc1, b, y, x, c)],
input2_data[SubscriptToIndex(desc2, b, y, x, c)]);
}
}
}
}
}
template <ComparisonFn<float> F>
inline void BroadcastComparison4DSlow(const ComparisonParams& op_params,
const RuntimeShape& input1_shape,
const float* input1_data,
const RuntimeShape& input2_shape,
const float* input2_data,
const RuntimeShape& output_shape,
bool* output_data) {
BroadcastComparison4DSlowImpl<float, F>(op_params, input1_shape, input1_data,
input2_shape, input2_data,
output_shape, output_data);
}
template <typename T, ComparisonFn<int32> F>
inline void BroadcastComparison4DSlowWithScaling(
const ComparisonParams& op_params,
const RuntimeShape& unextended_input1_shape, const T* input1_data,
const RuntimeShape& unextended_input2_shape, const T* input2_data,
const RuntimeShape& unextended_output_shape, bool* output_data) {
gemmlowp::ScopedProfilingLabel label("BroadcastComparison4DSlowWithScaling");
TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4);
TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), 4);
TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
const RuntimeShape output_shape =
RuntimeShape::ExtendedShape(4, unextended_output_shape);
NdArrayDesc<4> desc1;
NdArrayDesc<4> desc2;
NdArrayDescsForElementwiseBroadcast(unextended_input1_shape,
unextended_input2_shape, &desc1, &desc2);
int left_shift = op_params.left_shift;
int32 input1_offset = op_params.input1_offset;
int32 input1_multiplier = op_params.input1_multiplier;
int input1_shift = op_params.input1_shift;
int32 input2_offset = op_params.input2_offset;
int32 input2_multiplier = op_params.input2_multiplier;
int input2_shift = op_params.input2_shift;
for (int b = 0; b < output_shape.Dims(0); ++b) {
for (int y = 0; y < output_shape.Dims(1); ++y) {
for (int x = 0; x < output_shape.Dims(2); ++x) {
for (int c = 0; c < output_shape.Dims(3); ++c) {
const int32 input1_val =
input1_offset + input1_data[SubscriptToIndex(desc1, b, y, x, c)];
const int32 input2_val =
input2_offset + input2_data[SubscriptToIndex(desc2, b, y, x, c)];
const int32 shifted_input1_val = input1_val * (1 << left_shift);
const int32 shifted_input2_val = input2_val * (1 << left_shift);
const int32 scaled_input1_val =
MultiplyByQuantizedMultiplierSmallerThanOneExp(
shifted_input1_val, input1_multiplier, input1_shift);
const int32 scaled_input2_val =
MultiplyByQuantizedMultiplierSmallerThanOneExp(
shifted_input2_val, input2_multiplier, input2_shift);
output_data[Offset(output_shape, b, y, x, c)] =
F(scaled_input1_val, scaled_input2_val);
}
}
}
}
}
#define TFLITE_COMPARISON_OP(name) \
inline void name(const ComparisonParams& op_params, \
const RuntimeShape& input1_shape, const float* input1_data, \
const RuntimeShape& input2_shape, const float* input2_data, \
const RuntimeShape& output_shape, bool* output_data) { \
gemmlowp::ScopedProfilingLabel label(#name); \
Comparison<name##Fn>(op_params, input1_shape, input1_data, input2_shape, \
input2_data, output_shape, output_data); \
} \
template <typename T> \
inline void name##NoScaling( \
const ComparisonParams& op_params, const RuntimeShape& input1_shape, \
const T* input1_data, const RuntimeShape& input2_shape, \
const T* input2_data, const RuntimeShape& output_shape, \
bool* output_data) { \
gemmlowp::ScopedProfilingLabel label(#name "NoScaling"); \
ComparisonImpl<T, name##Fn>(op_params, input1_shape, input1_data, \
input2_shape, input2_data, output_shape, \
output_data); \
} \
template <typename T> \
inline void name##WithScaling( \
const ComparisonParams& op_params, const RuntimeShape& input1_shape, \
const T* input1_data, const RuntimeShape& input2_shape, \
const T* input2_data, const RuntimeShape& output_shape, \
bool* output_data) { \
gemmlowp::ScopedProfilingLabel label(#name "WithScaling/8bit"); \
ComparisonWithScaling<T, name##Fn>(op_params, input1_shape, input1_data, \
input2_shape, input2_data, \
output_shape, output_data); \
} \
template <typename T> \
inline void Broadcast4DSlow##name##NoScaling( \
const ComparisonParams& op_params, const RuntimeShape& input1_shape, \
const T* input1_data, const RuntimeShape& input2_shape, \
const T* input2_data, const RuntimeShape& output_shape, \
bool* output_data) { \
gemmlowp::ScopedProfilingLabel label("Broadcast4DSlow" #name "NoScaling"); \
BroadcastComparison4DSlowImpl<T, name##Fn>( \
op_params, input1_shape, input1_data, input2_shape, input2_data, \
output_shape, output_data); \
} \
inline void Broadcast4DSlow##name( \
const ComparisonParams& op_params, const RuntimeShape& input1_shape, \
const float* input1_data, const RuntimeShape& input2_shape, \
const float* input2_data, const RuntimeShape& output_shape, \
bool* output_data) { \
gemmlowp::ScopedProfilingLabel label("Broadcast4DSlow" #name); \
BroadcastComparison4DSlow<name##Fn>(op_params, input1_shape, input1_data, \
input2_shape, input2_data, \
output_shape, output_data); \
} \
template <typename T> \
inline void Broadcast4DSlow##name##WithScaling( \
const ComparisonParams& op_params, const RuntimeShape& input1_shape, \
const T* input1_data, const RuntimeShape& input2_shape, \
const T* input2_data, const RuntimeShape& output_shape, \
bool* output_data) { \
gemmlowp::ScopedProfilingLabel label("Broadcast4DSlow" #name "/8bit"); \
BroadcastComparison4DSlowWithScaling<T, name##Fn>( \
op_params, input1_shape, input1_data, input2_shape, input2_data, \
output_shape, output_data); \
}
TFLITE_COMPARISON_OP(Equal);
TFLITE_COMPARISON_OP(NotEqual);
TFLITE_COMPARISON_OP(Greater);
TFLITE_COMPARISON_OP(GreaterEqual);
TFLITE_COMPARISON_OP(Less);
TFLITE_COMPARISON_OP(LessEqual);
#undef TFLITE_COMPARISON_OP
} // namespace reference_ops
} // namespace tflite
#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_COMPARISONS_H_

View File

@ -33,6 +33,7 @@ limitations under the License.
#include "tensorflow/lite/kernels/internal/common.h"
#include "tensorflow/lite/kernels/internal/quantization_util.h"
#include "tensorflow/lite/kernels/internal/reference/arg_min_max.h"
#include "tensorflow/lite/kernels/internal/reference/comparisons.h"
#include "tensorflow/lite/kernels/internal/reference/conv.h"
#include "tensorflow/lite/kernels/internal/reference/floor.h"
#include "tensorflow/lite/kernels/internal/reference/fully_connected.h"
@ -2471,7 +2472,6 @@ inline void Tanh(const TanhParams&, const RuntimeShape& input_shape,
Tanh(input_shape, input_data, output_shape, output_data);
}
inline void Tanh(const TanhParams& params, const RuntimeShape& input_shape,
const int16* input_data, const RuntimeShape& output_shape,
int16* output_data) {
@ -3768,253 +3768,6 @@ inline void TransposeConv(const ConvParams& params,
}
}
template <typename T>
inline bool EqualFn(T lhs, T rhs) {
return lhs == rhs;
}
template <typename T>
inline bool NotEqualFn(T lhs, T rhs) {
return lhs != rhs;
}
template <typename T>
inline bool GreaterFn(T lhs, T rhs) {
return lhs > rhs;
}
template <typename T>
inline bool GreaterEqualFn(T lhs, T rhs) {
return lhs >= rhs;
}
template <typename T>
inline bool LessFn(T lhs, T rhs) {
return lhs < rhs;
}
template <typename T>
inline bool LessEqualFn(T lhs, T rhs) {
return lhs <= rhs;
}
template <typename T>
using ComparisonFn = bool (*)(T, T);
template <typename T, ComparisonFn<T> F>
inline void ComparisonImpl(
const ComparisonParams& op_params, const RuntimeShape& input1_shape,
const T* input1_data, const RuntimeShape& input2_shape,
const T* input2_data, const RuntimeShape& output_shape, bool* output_data) {
const int64_t flatsize =
MatchingFlatSize(input1_shape, input2_shape, output_shape);
for (int64_t i = 0; i < flatsize; ++i) {
output_data[i] = F(input1_data[i], input2_data[i]);
}
}
template <ComparisonFn<float> F>
inline void Comparison(const ComparisonParams& op_params,
const RuntimeShape& input1_shape,
const float* input1_data,
const RuntimeShape& input2_shape,
const float* input2_data,
const RuntimeShape& output_shape, bool* output_data) {
ComparisonImpl<float, F>(op_params, input1_shape, input1_data, input2_shape,
input2_data, output_shape, output_data);
}
template <typename T, ComparisonFn<int32> F>
inline void ComparisonWithScaling(
const ComparisonParams& op_params, const RuntimeShape& input1_shape,
const T* input1_data, const RuntimeShape& input2_shape,
const T* input2_data, const RuntimeShape& output_shape, bool* output_data) {
int left_shift = op_params.left_shift;
int32 input1_offset = op_params.input1_offset;
int32 input1_multiplier = op_params.input1_multiplier;
int input1_shift = op_params.input1_shift;
int32 input2_offset = op_params.input2_offset;
int32 input2_multiplier = op_params.input2_multiplier;
int input2_shift = op_params.input2_shift;
const int64_t flatsize =
MatchingFlatSize(input1_shape, input2_shape, output_shape);
for (int64_t i = 0; i < flatsize; ++i) {
const int32 input1_val = input1_offset + input1_data[i];
const int32 input2_val = input2_offset + input2_data[i];
const int32 shifted_input1_val = input1_val * (1 << left_shift);
const int32 shifted_input2_val = input2_val * (1 << left_shift);
const int32 scaled_input1_val =
MultiplyByQuantizedMultiplierSmallerThanOneExp(
shifted_input1_val, input1_multiplier, input1_shift);
const int32 scaled_input2_val =
MultiplyByQuantizedMultiplierSmallerThanOneExp(
shifted_input2_val, input2_multiplier, input2_shift);
output_data[i] = F(scaled_input1_val, scaled_input2_val);
}
}
template <typename T, ComparisonFn<T> F>
inline void BroadcastComparison4DSlowImpl(
const ComparisonParams& op_params,
const RuntimeShape& unextended_input1_shape, const T* input1_data,
const RuntimeShape& unextended_input2_shape, const T* input2_data,
const RuntimeShape& unextended_output_shape, bool* output_data) {
gemmlowp::ScopedProfilingLabel label("BroadcastComparison4DSlow");
TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4);
TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), 4);
TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
const RuntimeShape output_shape =
RuntimeShape::ExtendedShape(4, unextended_output_shape);
NdArrayDesc<4> desc1;
NdArrayDesc<4> desc2;
NdArrayDescsForElementwiseBroadcast(unextended_input1_shape,
unextended_input2_shape, &desc1, &desc2);
for (int b = 0; b < output_shape.Dims(0); ++b) {
for (int y = 0; y < output_shape.Dims(1); ++y) {
for (int x = 0; x < output_shape.Dims(2); ++x) {
for (int c = 0; c < output_shape.Dims(3); ++c) {
output_data[Offset(output_shape, b, y, x, c)] =
F(input1_data[SubscriptToIndex(desc1, b, y, x, c)],
input2_data[SubscriptToIndex(desc2, b, y, x, c)]);
}
}
}
}
}
template <ComparisonFn<float> F>
inline void BroadcastComparison4DSlow(const ComparisonParams& op_params,
const RuntimeShape& input1_shape,
const float* input1_data,
const RuntimeShape& input2_shape,
const float* input2_data,
const RuntimeShape& output_shape,
bool* output_data) {
BroadcastComparison4DSlowImpl<float, F>(op_params, input1_shape, input1_data,
input2_shape, input2_data,
output_shape, output_data);
}
template <typename T, ComparisonFn<int32> F>
inline void BroadcastComparison4DSlowWithScaling(
const ComparisonParams& op_params,
const RuntimeShape& unextended_input1_shape, const T* input1_data,
const RuntimeShape& unextended_input2_shape, const T* input2_data,
const RuntimeShape& unextended_output_shape, bool* output_data) {
gemmlowp::ScopedProfilingLabel label("BroadcastComparison4DSlowWithScaling");
TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4);
TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), 4);
TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
const RuntimeShape output_shape =
RuntimeShape::ExtendedShape(4, unextended_output_shape);
NdArrayDesc<4> desc1;
NdArrayDesc<4> desc2;
NdArrayDescsForElementwiseBroadcast(unextended_input1_shape,
unextended_input2_shape, &desc1, &desc2);
int left_shift = op_params.left_shift;
int32 input1_offset = op_params.input1_offset;
int32 input1_multiplier = op_params.input1_multiplier;
int input1_shift = op_params.input1_shift;
int32 input2_offset = op_params.input2_offset;
int32 input2_multiplier = op_params.input2_multiplier;
int input2_shift = op_params.input2_shift;
for (int b = 0; b < output_shape.Dims(0); ++b) {
for (int y = 0; y < output_shape.Dims(1); ++y) {
for (int x = 0; x < output_shape.Dims(2); ++x) {
for (int c = 0; c < output_shape.Dims(3); ++c) {
const int32 input1_val =
input1_offset + input1_data[SubscriptToIndex(desc1, b, y, x, c)];
const int32 input2_val =
input2_offset + input2_data[SubscriptToIndex(desc2, b, y, x, c)];
const int32 shifted_input1_val = input1_val * (1 << left_shift);
const int32 shifted_input2_val = input2_val * (1 << left_shift);
const int32 scaled_input1_val =
MultiplyByQuantizedMultiplierSmallerThanOneExp(
shifted_input1_val, input1_multiplier, input1_shift);
const int32 scaled_input2_val =
MultiplyByQuantizedMultiplierSmallerThanOneExp(
shifted_input2_val, input2_multiplier, input2_shift);
output_data[Offset(output_shape, b, y, x, c)] =
F(scaled_input1_val, scaled_input2_val);
}
}
}
}
}
#define TFLITE_COMPARISON_OP(name) \
inline void name(const ComparisonParams& op_params, \
const RuntimeShape& input1_shape, const float* input1_data, \
const RuntimeShape& input2_shape, const float* input2_data, \
const RuntimeShape& output_shape, bool* output_data) { \
gemmlowp::ScopedProfilingLabel label(#name); \
Comparison<name##Fn>(op_params, input1_shape, input1_data, input2_shape, \
input2_data, output_shape, output_data); \
} \
template <typename T> \
inline void name##NoScaling( \
const ComparisonParams& op_params, const RuntimeShape& input1_shape, \
const T* input1_data, const RuntimeShape& input2_shape, \
const T* input2_data, const RuntimeShape& output_shape, \
bool* output_data) { \
gemmlowp::ScopedProfilingLabel label(#name "NoScaling"); \
ComparisonImpl<T, name##Fn>(op_params, input1_shape, input1_data, \
input2_shape, input2_data, output_shape, \
output_data); \
} \
template <typename T> \
inline void name##WithScaling( \
const ComparisonParams& op_params, const RuntimeShape& input1_shape, \
const T* input1_data, const RuntimeShape& input2_shape, \
const T* input2_data, const RuntimeShape& output_shape, \
bool* output_data) { \
gemmlowp::ScopedProfilingLabel label(#name "WithScaling/8bit"); \
ComparisonWithScaling<T, name##Fn>(op_params, input1_shape, input1_data, \
input2_shape, input2_data, \
output_shape, output_data); \
} \
template <typename T> \
inline void Broadcast4DSlow##name##NoScaling( \
const ComparisonParams& op_params, const RuntimeShape& input1_shape, \
const T* input1_data, const RuntimeShape& input2_shape, \
const T* input2_data, const RuntimeShape& output_shape, \
bool* output_data) { \
gemmlowp::ScopedProfilingLabel label("Broadcast4DSlow" #name "NoScaling"); \
BroadcastComparison4DSlowImpl<T, name##Fn>( \
op_params, input1_shape, input1_data, input2_shape, input2_data, \
output_shape, output_data); \
} \
inline void Broadcast4DSlow##name( \
const ComparisonParams& op_params, const RuntimeShape& input1_shape, \
const float* input1_data, const RuntimeShape& input2_shape, \
const float* input2_data, const RuntimeShape& output_shape, \
bool* output_data) { \
gemmlowp::ScopedProfilingLabel label("Broadcast4DSlow" #name); \
BroadcastComparison4DSlow<name##Fn>(op_params, input1_shape, input1_data, \
input2_shape, input2_data, \
output_shape, output_data); \
} \
template <typename T> \
inline void Broadcast4DSlow##name##WithScaling( \
const ComparisonParams& op_params, const RuntimeShape& input1_shape, \
const T* input1_data, const RuntimeShape& input2_shape, \
const T* input2_data, const RuntimeShape& output_shape, \
bool* output_data) { \
gemmlowp::ScopedProfilingLabel label("Broadcast4DSlow" #name "/8bit"); \
BroadcastComparison4DSlowWithScaling<T, name##Fn>( \
op_params, input1_shape, input1_data, input2_shape, input2_data, \
output_shape, output_data); \
}
TFLITE_COMPARISON_OP(Equal);
TFLITE_COMPARISON_OP(NotEqual);
TFLITE_COMPARISON_OP(Greater);
TFLITE_COMPARISON_OP(GreaterEqual);
TFLITE_COMPARISON_OP(Less);
TFLITE_COMPARISON_OP(LessEqual);
#undef TFLITE_COMPARISON_OP
template <typename D, typename T>
void Select(const RuntimeShape& input_condition_shape,
const D* input_condition_data, const RuntimeShape& input_x_shape,