diff --git a/tensorflow/lite/kernels/internal/BUILD b/tensorflow/lite/kernels/internal/BUILD index f1e91450fe1..0d96bc01258 100644 --- a/tensorflow/lite/kernels/internal/BUILD +++ b/tensorflow/lite/kernels/internal/BUILD @@ -348,6 +348,7 @@ cc_library( srcs = [], hdrs = [ "reference/arg_min_max.h", + "reference/comparisons.h", "reference/conv.h", "reference/depthwiseconv_float.h", "reference/depthwiseconv_uint8.h", @@ -404,6 +405,7 @@ cc_library( srcs = [], hdrs = [ "reference/arg_min_max.h", + "reference/comparisons.h", "reference/conv.h", "reference/depthwiseconv_float.h", "reference/depthwiseconv_uint8.h", diff --git a/tensorflow/lite/kernels/internal/reference/comparisons.h b/tensorflow/lite/kernels/internal/reference/comparisons.h new file mode 100644 index 00000000000..7f8072fa820 --- /dev/null +++ b/tensorflow/lite/kernels/internal/reference/comparisons.h @@ -0,0 +1,276 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_COMPARISONS_H_ +#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_COMPARISONS_H_ + +#include "profiling/instrumentation.h" +#include "tensorflow/lite/kernels/internal/common.h" +#include "tensorflow/lite/kernels/internal/types.h" + +namespace tflite { + +namespace reference_ops { + +template +inline bool EqualFn(T lhs, T rhs) { + return lhs == rhs; +} + +template +inline bool NotEqualFn(T lhs, T rhs) { + return lhs != rhs; +} + +template +inline bool GreaterFn(T lhs, T rhs) { + return lhs > rhs; +} +template +inline bool GreaterEqualFn(T lhs, T rhs) { + return lhs >= rhs; +} +template +inline bool LessFn(T lhs, T rhs) { + return lhs < rhs; +} +template +inline bool LessEqualFn(T lhs, T rhs) { + return lhs <= rhs; +} + +template +using ComparisonFn = bool (*)(T, T); + +template F> +inline void ComparisonImpl( + const ComparisonParams& op_params, const RuntimeShape& input1_shape, + const T* input1_data, const RuntimeShape& input2_shape, + const T* input2_data, const RuntimeShape& output_shape, bool* output_data) { + const int64_t flatsize = + MatchingFlatSize(input1_shape, input2_shape, output_shape); + for (int64_t i = 0; i < flatsize; ++i) { + output_data[i] = F(input1_data[i], input2_data[i]); + } +} + +template F> +inline void Comparison(const ComparisonParams& op_params, + const RuntimeShape& input1_shape, + const float* input1_data, + const RuntimeShape& input2_shape, + const float* input2_data, + const RuntimeShape& output_shape, bool* output_data) { + ComparisonImpl(op_params, input1_shape, input1_data, input2_shape, + input2_data, output_shape, output_data); +} + +template F> +inline void ComparisonWithScaling( + const ComparisonParams& op_params, const RuntimeShape& input1_shape, + const T* input1_data, const RuntimeShape& input2_shape, + const T* input2_data, const RuntimeShape& output_shape, bool* output_data) { + int left_shift = op_params.left_shift; + int32 input1_offset = op_params.input1_offset; + int32 input1_multiplier = op_params.input1_multiplier; + int input1_shift = op_params.input1_shift; + int32 input2_offset = op_params.input2_offset; + int32 input2_multiplier = op_params.input2_multiplier; + int input2_shift = op_params.input2_shift; + + const int64_t flatsize = + MatchingFlatSize(input1_shape, input2_shape, output_shape); + for (int64_t i = 0; i < flatsize; ++i) { + const int32 input1_val = input1_offset + input1_data[i]; + const int32 input2_val = input2_offset + input2_data[i]; + const int32 shifted_input1_val = input1_val * (1 << left_shift); + const int32 shifted_input2_val = input2_val * (1 << left_shift); + const int32 scaled_input1_val = + MultiplyByQuantizedMultiplierSmallerThanOneExp( + shifted_input1_val, input1_multiplier, input1_shift); + const int32 scaled_input2_val = + MultiplyByQuantizedMultiplierSmallerThanOneExp( + shifted_input2_val, input2_multiplier, input2_shift); + output_data[i] = F(scaled_input1_val, scaled_input2_val); + } +} + +template F> +inline void BroadcastComparison4DSlowImpl( + const ComparisonParams& op_params, + const RuntimeShape& unextended_input1_shape, const T* input1_data, + const RuntimeShape& unextended_input2_shape, const T* input2_data, + const RuntimeShape& unextended_output_shape, bool* output_data) { + gemmlowp::ScopedProfilingLabel label("BroadcastComparison4DSlow"); + TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4); + TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), 4); + TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4); + const RuntimeShape output_shape = + RuntimeShape::ExtendedShape(4, unextended_output_shape); + + NdArrayDesc<4> desc1; + NdArrayDesc<4> desc2; + NdArrayDescsForElementwiseBroadcast(unextended_input1_shape, + unextended_input2_shape, &desc1, &desc2); + + for (int b = 0; b < output_shape.Dims(0); ++b) { + for (int y = 0; y < output_shape.Dims(1); ++y) { + for (int x = 0; x < output_shape.Dims(2); ++x) { + for (int c = 0; c < output_shape.Dims(3); ++c) { + output_data[Offset(output_shape, b, y, x, c)] = + F(input1_data[SubscriptToIndex(desc1, b, y, x, c)], + input2_data[SubscriptToIndex(desc2, b, y, x, c)]); + } + } + } + } +} +template F> +inline void BroadcastComparison4DSlow(const ComparisonParams& op_params, + const RuntimeShape& input1_shape, + const float* input1_data, + const RuntimeShape& input2_shape, + const float* input2_data, + const RuntimeShape& output_shape, + bool* output_data) { + BroadcastComparison4DSlowImpl(op_params, input1_shape, input1_data, + input2_shape, input2_data, + output_shape, output_data); +} + +template F> +inline void BroadcastComparison4DSlowWithScaling( + const ComparisonParams& op_params, + const RuntimeShape& unextended_input1_shape, const T* input1_data, + const RuntimeShape& unextended_input2_shape, const T* input2_data, + const RuntimeShape& unextended_output_shape, bool* output_data) { + gemmlowp::ScopedProfilingLabel label("BroadcastComparison4DSlowWithScaling"); + TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4); + TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), 4); + TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4); + const RuntimeShape output_shape = + RuntimeShape::ExtendedShape(4, unextended_output_shape); + + NdArrayDesc<4> desc1; + NdArrayDesc<4> desc2; + NdArrayDescsForElementwiseBroadcast(unextended_input1_shape, + unextended_input2_shape, &desc1, &desc2); + + int left_shift = op_params.left_shift; + int32 input1_offset = op_params.input1_offset; + int32 input1_multiplier = op_params.input1_multiplier; + int input1_shift = op_params.input1_shift; + int32 input2_offset = op_params.input2_offset; + int32 input2_multiplier = op_params.input2_multiplier; + int input2_shift = op_params.input2_shift; + + for (int b = 0; b < output_shape.Dims(0); ++b) { + for (int y = 0; y < output_shape.Dims(1); ++y) { + for (int x = 0; x < output_shape.Dims(2); ++x) { + for (int c = 0; c < output_shape.Dims(3); ++c) { + const int32 input1_val = + input1_offset + input1_data[SubscriptToIndex(desc1, b, y, x, c)]; + const int32 input2_val = + input2_offset + input2_data[SubscriptToIndex(desc2, b, y, x, c)]; + const int32 shifted_input1_val = input1_val * (1 << left_shift); + const int32 shifted_input2_val = input2_val * (1 << left_shift); + const int32 scaled_input1_val = + MultiplyByQuantizedMultiplierSmallerThanOneExp( + shifted_input1_val, input1_multiplier, input1_shift); + const int32 scaled_input2_val = + MultiplyByQuantizedMultiplierSmallerThanOneExp( + shifted_input2_val, input2_multiplier, input2_shift); + output_data[Offset(output_shape, b, y, x, c)] = + F(scaled_input1_val, scaled_input2_val); + } + } + } + } +} + +#define TFLITE_COMPARISON_OP(name) \ + inline void name(const ComparisonParams& op_params, \ + const RuntimeShape& input1_shape, const float* input1_data, \ + const RuntimeShape& input2_shape, const float* input2_data, \ + const RuntimeShape& output_shape, bool* output_data) { \ + gemmlowp::ScopedProfilingLabel label(#name); \ + Comparison(op_params, input1_shape, input1_data, input2_shape, \ + input2_data, output_shape, output_data); \ + } \ + template \ + inline void name##NoScaling( \ + const ComparisonParams& op_params, const RuntimeShape& input1_shape, \ + const T* input1_data, const RuntimeShape& input2_shape, \ + const T* input2_data, const RuntimeShape& output_shape, \ + bool* output_data) { \ + gemmlowp::ScopedProfilingLabel label(#name "NoScaling"); \ + ComparisonImpl(op_params, input1_shape, input1_data, \ + input2_shape, input2_data, output_shape, \ + output_data); \ + } \ + template \ + inline void name##WithScaling( \ + const ComparisonParams& op_params, const RuntimeShape& input1_shape, \ + const T* input1_data, const RuntimeShape& input2_shape, \ + const T* input2_data, const RuntimeShape& output_shape, \ + bool* output_data) { \ + gemmlowp::ScopedProfilingLabel label(#name "WithScaling/8bit"); \ + ComparisonWithScaling(op_params, input1_shape, input1_data, \ + input2_shape, input2_data, \ + output_shape, output_data); \ + } \ + template \ + inline void Broadcast4DSlow##name##NoScaling( \ + const ComparisonParams& op_params, const RuntimeShape& input1_shape, \ + const T* input1_data, const RuntimeShape& input2_shape, \ + const T* input2_data, const RuntimeShape& output_shape, \ + bool* output_data) { \ + gemmlowp::ScopedProfilingLabel label("Broadcast4DSlow" #name "NoScaling"); \ + BroadcastComparison4DSlowImpl( \ + op_params, input1_shape, input1_data, input2_shape, input2_data, \ + output_shape, output_data); \ + } \ + inline void Broadcast4DSlow##name( \ + const ComparisonParams& op_params, const RuntimeShape& input1_shape, \ + const float* input1_data, const RuntimeShape& input2_shape, \ + const float* input2_data, const RuntimeShape& output_shape, \ + bool* output_data) { \ + gemmlowp::ScopedProfilingLabel label("Broadcast4DSlow" #name); \ + BroadcastComparison4DSlow(op_params, input1_shape, input1_data, \ + input2_shape, input2_data, \ + output_shape, output_data); \ + } \ + template \ + inline void Broadcast4DSlow##name##WithScaling( \ + const ComparisonParams& op_params, const RuntimeShape& input1_shape, \ + const T* input1_data, const RuntimeShape& input2_shape, \ + const T* input2_data, const RuntimeShape& output_shape, \ + bool* output_data) { \ + gemmlowp::ScopedProfilingLabel label("Broadcast4DSlow" #name "/8bit"); \ + BroadcastComparison4DSlowWithScaling( \ + op_params, input1_shape, input1_data, input2_shape, input2_data, \ + output_shape, output_data); \ + } +TFLITE_COMPARISON_OP(Equal); +TFLITE_COMPARISON_OP(NotEqual); +TFLITE_COMPARISON_OP(Greater); +TFLITE_COMPARISON_OP(GreaterEqual); +TFLITE_COMPARISON_OP(Less); +TFLITE_COMPARISON_OP(LessEqual); +#undef TFLITE_COMPARISON_OP + +} // namespace reference_ops +} // namespace tflite + +#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_COMPARISONS_H_ diff --git a/tensorflow/lite/kernels/internal/reference/reference_ops.h b/tensorflow/lite/kernels/internal/reference/reference_ops.h index 919e7e86ae7..457f8946e66 100644 --- a/tensorflow/lite/kernels/internal/reference/reference_ops.h +++ b/tensorflow/lite/kernels/internal/reference/reference_ops.h @@ -33,6 +33,7 @@ limitations under the License. #include "tensorflow/lite/kernels/internal/common.h" #include "tensorflow/lite/kernels/internal/quantization_util.h" #include "tensorflow/lite/kernels/internal/reference/arg_min_max.h" +#include "tensorflow/lite/kernels/internal/reference/comparisons.h" #include "tensorflow/lite/kernels/internal/reference/conv.h" #include "tensorflow/lite/kernels/internal/reference/floor.h" #include "tensorflow/lite/kernels/internal/reference/fully_connected.h" @@ -2471,7 +2472,6 @@ inline void Tanh(const TanhParams&, const RuntimeShape& input_shape, Tanh(input_shape, input_data, output_shape, output_data); } - inline void Tanh(const TanhParams& params, const RuntimeShape& input_shape, const int16* input_data, const RuntimeShape& output_shape, int16* output_data) { @@ -3768,253 +3768,6 @@ inline void TransposeConv(const ConvParams& params, } } -template -inline bool EqualFn(T lhs, T rhs) { - return lhs == rhs; -} - -template -inline bool NotEqualFn(T lhs, T rhs) { - return lhs != rhs; -} - -template -inline bool GreaterFn(T lhs, T rhs) { - return lhs > rhs; -} -template -inline bool GreaterEqualFn(T lhs, T rhs) { - return lhs >= rhs; -} -template -inline bool LessFn(T lhs, T rhs) { - return lhs < rhs; -} -template -inline bool LessEqualFn(T lhs, T rhs) { - return lhs <= rhs; -} - -template -using ComparisonFn = bool (*)(T, T); - -template F> -inline void ComparisonImpl( - const ComparisonParams& op_params, const RuntimeShape& input1_shape, - const T* input1_data, const RuntimeShape& input2_shape, - const T* input2_data, const RuntimeShape& output_shape, bool* output_data) { - const int64_t flatsize = - MatchingFlatSize(input1_shape, input2_shape, output_shape); - for (int64_t i = 0; i < flatsize; ++i) { - output_data[i] = F(input1_data[i], input2_data[i]); - } -} - -template F> -inline void Comparison(const ComparisonParams& op_params, - const RuntimeShape& input1_shape, - const float* input1_data, - const RuntimeShape& input2_shape, - const float* input2_data, - const RuntimeShape& output_shape, bool* output_data) { - ComparisonImpl(op_params, input1_shape, input1_data, input2_shape, - input2_data, output_shape, output_data); -} - -template F> -inline void ComparisonWithScaling( - const ComparisonParams& op_params, const RuntimeShape& input1_shape, - const T* input1_data, const RuntimeShape& input2_shape, - const T* input2_data, const RuntimeShape& output_shape, bool* output_data) { - int left_shift = op_params.left_shift; - int32 input1_offset = op_params.input1_offset; - int32 input1_multiplier = op_params.input1_multiplier; - int input1_shift = op_params.input1_shift; - int32 input2_offset = op_params.input2_offset; - int32 input2_multiplier = op_params.input2_multiplier; - int input2_shift = op_params.input2_shift; - - const int64_t flatsize = - MatchingFlatSize(input1_shape, input2_shape, output_shape); - for (int64_t i = 0; i < flatsize; ++i) { - const int32 input1_val = input1_offset + input1_data[i]; - const int32 input2_val = input2_offset + input2_data[i]; - const int32 shifted_input1_val = input1_val * (1 << left_shift); - const int32 shifted_input2_val = input2_val * (1 << left_shift); - const int32 scaled_input1_val = - MultiplyByQuantizedMultiplierSmallerThanOneExp( - shifted_input1_val, input1_multiplier, input1_shift); - const int32 scaled_input2_val = - MultiplyByQuantizedMultiplierSmallerThanOneExp( - shifted_input2_val, input2_multiplier, input2_shift); - output_data[i] = F(scaled_input1_val, scaled_input2_val); - } -} - -template F> -inline void BroadcastComparison4DSlowImpl( - const ComparisonParams& op_params, - const RuntimeShape& unextended_input1_shape, const T* input1_data, - const RuntimeShape& unextended_input2_shape, const T* input2_data, - const RuntimeShape& unextended_output_shape, bool* output_data) { - gemmlowp::ScopedProfilingLabel label("BroadcastComparison4DSlow"); - TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4); - TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), 4); - TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4); - const RuntimeShape output_shape = - RuntimeShape::ExtendedShape(4, unextended_output_shape); - - NdArrayDesc<4> desc1; - NdArrayDesc<4> desc2; - NdArrayDescsForElementwiseBroadcast(unextended_input1_shape, - unextended_input2_shape, &desc1, &desc2); - - for (int b = 0; b < output_shape.Dims(0); ++b) { - for (int y = 0; y < output_shape.Dims(1); ++y) { - for (int x = 0; x < output_shape.Dims(2); ++x) { - for (int c = 0; c < output_shape.Dims(3); ++c) { - output_data[Offset(output_shape, b, y, x, c)] = - F(input1_data[SubscriptToIndex(desc1, b, y, x, c)], - input2_data[SubscriptToIndex(desc2, b, y, x, c)]); - } - } - } - } -} -template F> -inline void BroadcastComparison4DSlow(const ComparisonParams& op_params, - const RuntimeShape& input1_shape, - const float* input1_data, - const RuntimeShape& input2_shape, - const float* input2_data, - const RuntimeShape& output_shape, - bool* output_data) { - BroadcastComparison4DSlowImpl(op_params, input1_shape, input1_data, - input2_shape, input2_data, - output_shape, output_data); -} - -template F> -inline void BroadcastComparison4DSlowWithScaling( - const ComparisonParams& op_params, - const RuntimeShape& unextended_input1_shape, const T* input1_data, - const RuntimeShape& unextended_input2_shape, const T* input2_data, - const RuntimeShape& unextended_output_shape, bool* output_data) { - gemmlowp::ScopedProfilingLabel label("BroadcastComparison4DSlowWithScaling"); - TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4); - TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), 4); - TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4); - const RuntimeShape output_shape = - RuntimeShape::ExtendedShape(4, unextended_output_shape); - - NdArrayDesc<4> desc1; - NdArrayDesc<4> desc2; - NdArrayDescsForElementwiseBroadcast(unextended_input1_shape, - unextended_input2_shape, &desc1, &desc2); - - int left_shift = op_params.left_shift; - int32 input1_offset = op_params.input1_offset; - int32 input1_multiplier = op_params.input1_multiplier; - int input1_shift = op_params.input1_shift; - int32 input2_offset = op_params.input2_offset; - int32 input2_multiplier = op_params.input2_multiplier; - int input2_shift = op_params.input2_shift; - - for (int b = 0; b < output_shape.Dims(0); ++b) { - for (int y = 0; y < output_shape.Dims(1); ++y) { - for (int x = 0; x < output_shape.Dims(2); ++x) { - for (int c = 0; c < output_shape.Dims(3); ++c) { - const int32 input1_val = - input1_offset + input1_data[SubscriptToIndex(desc1, b, y, x, c)]; - const int32 input2_val = - input2_offset + input2_data[SubscriptToIndex(desc2, b, y, x, c)]; - const int32 shifted_input1_val = input1_val * (1 << left_shift); - const int32 shifted_input2_val = input2_val * (1 << left_shift); - const int32 scaled_input1_val = - MultiplyByQuantizedMultiplierSmallerThanOneExp( - shifted_input1_val, input1_multiplier, input1_shift); - const int32 scaled_input2_val = - MultiplyByQuantizedMultiplierSmallerThanOneExp( - shifted_input2_val, input2_multiplier, input2_shift); - output_data[Offset(output_shape, b, y, x, c)] = - F(scaled_input1_val, scaled_input2_val); - } - } - } - } -} - -#define TFLITE_COMPARISON_OP(name) \ - inline void name(const ComparisonParams& op_params, \ - const RuntimeShape& input1_shape, const float* input1_data, \ - const RuntimeShape& input2_shape, const float* input2_data, \ - const RuntimeShape& output_shape, bool* output_data) { \ - gemmlowp::ScopedProfilingLabel label(#name); \ - Comparison(op_params, input1_shape, input1_data, input2_shape, \ - input2_data, output_shape, output_data); \ - } \ - template \ - inline void name##NoScaling( \ - const ComparisonParams& op_params, const RuntimeShape& input1_shape, \ - const T* input1_data, const RuntimeShape& input2_shape, \ - const T* input2_data, const RuntimeShape& output_shape, \ - bool* output_data) { \ - gemmlowp::ScopedProfilingLabel label(#name "NoScaling"); \ - ComparisonImpl(op_params, input1_shape, input1_data, \ - input2_shape, input2_data, output_shape, \ - output_data); \ - } \ - template \ - inline void name##WithScaling( \ - const ComparisonParams& op_params, const RuntimeShape& input1_shape, \ - const T* input1_data, const RuntimeShape& input2_shape, \ - const T* input2_data, const RuntimeShape& output_shape, \ - bool* output_data) { \ - gemmlowp::ScopedProfilingLabel label(#name "WithScaling/8bit"); \ - ComparisonWithScaling(op_params, input1_shape, input1_data, \ - input2_shape, input2_data, \ - output_shape, output_data); \ - } \ - template \ - inline void Broadcast4DSlow##name##NoScaling( \ - const ComparisonParams& op_params, const RuntimeShape& input1_shape, \ - const T* input1_data, const RuntimeShape& input2_shape, \ - const T* input2_data, const RuntimeShape& output_shape, \ - bool* output_data) { \ - gemmlowp::ScopedProfilingLabel label("Broadcast4DSlow" #name "NoScaling"); \ - BroadcastComparison4DSlowImpl( \ - op_params, input1_shape, input1_data, input2_shape, input2_data, \ - output_shape, output_data); \ - } \ - inline void Broadcast4DSlow##name( \ - const ComparisonParams& op_params, const RuntimeShape& input1_shape, \ - const float* input1_data, const RuntimeShape& input2_shape, \ - const float* input2_data, const RuntimeShape& output_shape, \ - bool* output_data) { \ - gemmlowp::ScopedProfilingLabel label("Broadcast4DSlow" #name); \ - BroadcastComparison4DSlow(op_params, input1_shape, input1_data, \ - input2_shape, input2_data, \ - output_shape, output_data); \ - } \ - template \ - inline void Broadcast4DSlow##name##WithScaling( \ - const ComparisonParams& op_params, const RuntimeShape& input1_shape, \ - const T* input1_data, const RuntimeShape& input2_shape, \ - const T* input2_data, const RuntimeShape& output_shape, \ - bool* output_data) { \ - gemmlowp::ScopedProfilingLabel label("Broadcast4DSlow" #name "/8bit"); \ - BroadcastComparison4DSlowWithScaling( \ - op_params, input1_shape, input1_data, input2_shape, input2_data, \ - output_shape, output_data); \ - } -TFLITE_COMPARISON_OP(Equal); -TFLITE_COMPARISON_OP(NotEqual); -TFLITE_COMPARISON_OP(Greater); -TFLITE_COMPARISON_OP(GreaterEqual); -TFLITE_COMPARISON_OP(Less); -TFLITE_COMPARISON_OP(LessEqual); -#undef TFLITE_COMPARISON_OP - template void Select(const RuntimeShape& input_condition_shape, const D* input_condition_data, const RuntimeShape& input_x_shape,