Port comparison operations to separate header
PiperOrigin-RevId: 261627312
This commit is contained in:
parent
00581c6347
commit
f91e7353f8
@ -348,6 +348,7 @@ cc_library(
|
||||
srcs = [],
|
||||
hdrs = [
|
||||
"reference/arg_min_max.h",
|
||||
"reference/comparisons.h",
|
||||
"reference/conv.h",
|
||||
"reference/depthwiseconv_float.h",
|
||||
"reference/depthwiseconv_uint8.h",
|
||||
@ -404,6 +405,7 @@ cc_library(
|
||||
srcs = [],
|
||||
hdrs = [
|
||||
"reference/arg_min_max.h",
|
||||
"reference/comparisons.h",
|
||||
"reference/conv.h",
|
||||
"reference/depthwiseconv_float.h",
|
||||
"reference/depthwiseconv_uint8.h",
|
||||
|
276
tensorflow/lite/kernels/internal/reference/comparisons.h
Normal file
276
tensorflow/lite/kernels/internal/reference/comparisons.h
Normal file
@ -0,0 +1,276 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_COMPARISONS_H_
|
||||
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_COMPARISONS_H_
|
||||
|
||||
#include "profiling/instrumentation.h"
|
||||
#include "tensorflow/lite/kernels/internal/common.h"
|
||||
#include "tensorflow/lite/kernels/internal/types.h"
|
||||
|
||||
namespace tflite {
|
||||
|
||||
namespace reference_ops {
|
||||
|
||||
template <typename T>
|
||||
inline bool EqualFn(T lhs, T rhs) {
|
||||
return lhs == rhs;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline bool NotEqualFn(T lhs, T rhs) {
|
||||
return lhs != rhs;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline bool GreaterFn(T lhs, T rhs) {
|
||||
return lhs > rhs;
|
||||
}
|
||||
template <typename T>
|
||||
inline bool GreaterEqualFn(T lhs, T rhs) {
|
||||
return lhs >= rhs;
|
||||
}
|
||||
template <typename T>
|
||||
inline bool LessFn(T lhs, T rhs) {
|
||||
return lhs < rhs;
|
||||
}
|
||||
template <typename T>
|
||||
inline bool LessEqualFn(T lhs, T rhs) {
|
||||
return lhs <= rhs;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
using ComparisonFn = bool (*)(T, T);
|
||||
|
||||
template <typename T, ComparisonFn<T> F>
|
||||
inline void ComparisonImpl(
|
||||
const ComparisonParams& op_params, const RuntimeShape& input1_shape,
|
||||
const T* input1_data, const RuntimeShape& input2_shape,
|
||||
const T* input2_data, const RuntimeShape& output_shape, bool* output_data) {
|
||||
const int64_t flatsize =
|
||||
MatchingFlatSize(input1_shape, input2_shape, output_shape);
|
||||
for (int64_t i = 0; i < flatsize; ++i) {
|
||||
output_data[i] = F(input1_data[i], input2_data[i]);
|
||||
}
|
||||
}
|
||||
|
||||
template <ComparisonFn<float> F>
|
||||
inline void Comparison(const ComparisonParams& op_params,
|
||||
const RuntimeShape& input1_shape,
|
||||
const float* input1_data,
|
||||
const RuntimeShape& input2_shape,
|
||||
const float* input2_data,
|
||||
const RuntimeShape& output_shape, bool* output_data) {
|
||||
ComparisonImpl<float, F>(op_params, input1_shape, input1_data, input2_shape,
|
||||
input2_data, output_shape, output_data);
|
||||
}
|
||||
|
||||
template <typename T, ComparisonFn<int32> F>
|
||||
inline void ComparisonWithScaling(
|
||||
const ComparisonParams& op_params, const RuntimeShape& input1_shape,
|
||||
const T* input1_data, const RuntimeShape& input2_shape,
|
||||
const T* input2_data, const RuntimeShape& output_shape, bool* output_data) {
|
||||
int left_shift = op_params.left_shift;
|
||||
int32 input1_offset = op_params.input1_offset;
|
||||
int32 input1_multiplier = op_params.input1_multiplier;
|
||||
int input1_shift = op_params.input1_shift;
|
||||
int32 input2_offset = op_params.input2_offset;
|
||||
int32 input2_multiplier = op_params.input2_multiplier;
|
||||
int input2_shift = op_params.input2_shift;
|
||||
|
||||
const int64_t flatsize =
|
||||
MatchingFlatSize(input1_shape, input2_shape, output_shape);
|
||||
for (int64_t i = 0; i < flatsize; ++i) {
|
||||
const int32 input1_val = input1_offset + input1_data[i];
|
||||
const int32 input2_val = input2_offset + input2_data[i];
|
||||
const int32 shifted_input1_val = input1_val * (1 << left_shift);
|
||||
const int32 shifted_input2_val = input2_val * (1 << left_shift);
|
||||
const int32 scaled_input1_val =
|
||||
MultiplyByQuantizedMultiplierSmallerThanOneExp(
|
||||
shifted_input1_val, input1_multiplier, input1_shift);
|
||||
const int32 scaled_input2_val =
|
||||
MultiplyByQuantizedMultiplierSmallerThanOneExp(
|
||||
shifted_input2_val, input2_multiplier, input2_shift);
|
||||
output_data[i] = F(scaled_input1_val, scaled_input2_val);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, ComparisonFn<T> F>
|
||||
inline void BroadcastComparison4DSlowImpl(
|
||||
const ComparisonParams& op_params,
|
||||
const RuntimeShape& unextended_input1_shape, const T* input1_data,
|
||||
const RuntimeShape& unextended_input2_shape, const T* input2_data,
|
||||
const RuntimeShape& unextended_output_shape, bool* output_data) {
|
||||
gemmlowp::ScopedProfilingLabel label("BroadcastComparison4DSlow");
|
||||
TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4);
|
||||
TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), 4);
|
||||
TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
|
||||
const RuntimeShape output_shape =
|
||||
RuntimeShape::ExtendedShape(4, unextended_output_shape);
|
||||
|
||||
NdArrayDesc<4> desc1;
|
||||
NdArrayDesc<4> desc2;
|
||||
NdArrayDescsForElementwiseBroadcast(unextended_input1_shape,
|
||||
unextended_input2_shape, &desc1, &desc2);
|
||||
|
||||
for (int b = 0; b < output_shape.Dims(0); ++b) {
|
||||
for (int y = 0; y < output_shape.Dims(1); ++y) {
|
||||
for (int x = 0; x < output_shape.Dims(2); ++x) {
|
||||
for (int c = 0; c < output_shape.Dims(3); ++c) {
|
||||
output_data[Offset(output_shape, b, y, x, c)] =
|
||||
F(input1_data[SubscriptToIndex(desc1, b, y, x, c)],
|
||||
input2_data[SubscriptToIndex(desc2, b, y, x, c)]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
template <ComparisonFn<float> F>
|
||||
inline void BroadcastComparison4DSlow(const ComparisonParams& op_params,
|
||||
const RuntimeShape& input1_shape,
|
||||
const float* input1_data,
|
||||
const RuntimeShape& input2_shape,
|
||||
const float* input2_data,
|
||||
const RuntimeShape& output_shape,
|
||||
bool* output_data) {
|
||||
BroadcastComparison4DSlowImpl<float, F>(op_params, input1_shape, input1_data,
|
||||
input2_shape, input2_data,
|
||||
output_shape, output_data);
|
||||
}
|
||||
|
||||
template <typename T, ComparisonFn<int32> F>
|
||||
inline void BroadcastComparison4DSlowWithScaling(
|
||||
const ComparisonParams& op_params,
|
||||
const RuntimeShape& unextended_input1_shape, const T* input1_data,
|
||||
const RuntimeShape& unextended_input2_shape, const T* input2_data,
|
||||
const RuntimeShape& unextended_output_shape, bool* output_data) {
|
||||
gemmlowp::ScopedProfilingLabel label("BroadcastComparison4DSlowWithScaling");
|
||||
TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4);
|
||||
TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), 4);
|
||||
TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
|
||||
const RuntimeShape output_shape =
|
||||
RuntimeShape::ExtendedShape(4, unextended_output_shape);
|
||||
|
||||
NdArrayDesc<4> desc1;
|
||||
NdArrayDesc<4> desc2;
|
||||
NdArrayDescsForElementwiseBroadcast(unextended_input1_shape,
|
||||
unextended_input2_shape, &desc1, &desc2);
|
||||
|
||||
int left_shift = op_params.left_shift;
|
||||
int32 input1_offset = op_params.input1_offset;
|
||||
int32 input1_multiplier = op_params.input1_multiplier;
|
||||
int input1_shift = op_params.input1_shift;
|
||||
int32 input2_offset = op_params.input2_offset;
|
||||
int32 input2_multiplier = op_params.input2_multiplier;
|
||||
int input2_shift = op_params.input2_shift;
|
||||
|
||||
for (int b = 0; b < output_shape.Dims(0); ++b) {
|
||||
for (int y = 0; y < output_shape.Dims(1); ++y) {
|
||||
for (int x = 0; x < output_shape.Dims(2); ++x) {
|
||||
for (int c = 0; c < output_shape.Dims(3); ++c) {
|
||||
const int32 input1_val =
|
||||
input1_offset + input1_data[SubscriptToIndex(desc1, b, y, x, c)];
|
||||
const int32 input2_val =
|
||||
input2_offset + input2_data[SubscriptToIndex(desc2, b, y, x, c)];
|
||||
const int32 shifted_input1_val = input1_val * (1 << left_shift);
|
||||
const int32 shifted_input2_val = input2_val * (1 << left_shift);
|
||||
const int32 scaled_input1_val =
|
||||
MultiplyByQuantizedMultiplierSmallerThanOneExp(
|
||||
shifted_input1_val, input1_multiplier, input1_shift);
|
||||
const int32 scaled_input2_val =
|
||||
MultiplyByQuantizedMultiplierSmallerThanOneExp(
|
||||
shifted_input2_val, input2_multiplier, input2_shift);
|
||||
output_data[Offset(output_shape, b, y, x, c)] =
|
||||
F(scaled_input1_val, scaled_input2_val);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#define TFLITE_COMPARISON_OP(name) \
|
||||
inline void name(const ComparisonParams& op_params, \
|
||||
const RuntimeShape& input1_shape, const float* input1_data, \
|
||||
const RuntimeShape& input2_shape, const float* input2_data, \
|
||||
const RuntimeShape& output_shape, bool* output_data) { \
|
||||
gemmlowp::ScopedProfilingLabel label(#name); \
|
||||
Comparison<name##Fn>(op_params, input1_shape, input1_data, input2_shape, \
|
||||
input2_data, output_shape, output_data); \
|
||||
} \
|
||||
template <typename T> \
|
||||
inline void name##NoScaling( \
|
||||
const ComparisonParams& op_params, const RuntimeShape& input1_shape, \
|
||||
const T* input1_data, const RuntimeShape& input2_shape, \
|
||||
const T* input2_data, const RuntimeShape& output_shape, \
|
||||
bool* output_data) { \
|
||||
gemmlowp::ScopedProfilingLabel label(#name "NoScaling"); \
|
||||
ComparisonImpl<T, name##Fn>(op_params, input1_shape, input1_data, \
|
||||
input2_shape, input2_data, output_shape, \
|
||||
output_data); \
|
||||
} \
|
||||
template <typename T> \
|
||||
inline void name##WithScaling( \
|
||||
const ComparisonParams& op_params, const RuntimeShape& input1_shape, \
|
||||
const T* input1_data, const RuntimeShape& input2_shape, \
|
||||
const T* input2_data, const RuntimeShape& output_shape, \
|
||||
bool* output_data) { \
|
||||
gemmlowp::ScopedProfilingLabel label(#name "WithScaling/8bit"); \
|
||||
ComparisonWithScaling<T, name##Fn>(op_params, input1_shape, input1_data, \
|
||||
input2_shape, input2_data, \
|
||||
output_shape, output_data); \
|
||||
} \
|
||||
template <typename T> \
|
||||
inline void Broadcast4DSlow##name##NoScaling( \
|
||||
const ComparisonParams& op_params, const RuntimeShape& input1_shape, \
|
||||
const T* input1_data, const RuntimeShape& input2_shape, \
|
||||
const T* input2_data, const RuntimeShape& output_shape, \
|
||||
bool* output_data) { \
|
||||
gemmlowp::ScopedProfilingLabel label("Broadcast4DSlow" #name "NoScaling"); \
|
||||
BroadcastComparison4DSlowImpl<T, name##Fn>( \
|
||||
op_params, input1_shape, input1_data, input2_shape, input2_data, \
|
||||
output_shape, output_data); \
|
||||
} \
|
||||
inline void Broadcast4DSlow##name( \
|
||||
const ComparisonParams& op_params, const RuntimeShape& input1_shape, \
|
||||
const float* input1_data, const RuntimeShape& input2_shape, \
|
||||
const float* input2_data, const RuntimeShape& output_shape, \
|
||||
bool* output_data) { \
|
||||
gemmlowp::ScopedProfilingLabel label("Broadcast4DSlow" #name); \
|
||||
BroadcastComparison4DSlow<name##Fn>(op_params, input1_shape, input1_data, \
|
||||
input2_shape, input2_data, \
|
||||
output_shape, output_data); \
|
||||
} \
|
||||
template <typename T> \
|
||||
inline void Broadcast4DSlow##name##WithScaling( \
|
||||
const ComparisonParams& op_params, const RuntimeShape& input1_shape, \
|
||||
const T* input1_data, const RuntimeShape& input2_shape, \
|
||||
const T* input2_data, const RuntimeShape& output_shape, \
|
||||
bool* output_data) { \
|
||||
gemmlowp::ScopedProfilingLabel label("Broadcast4DSlow" #name "/8bit"); \
|
||||
BroadcastComparison4DSlowWithScaling<T, name##Fn>( \
|
||||
op_params, input1_shape, input1_data, input2_shape, input2_data, \
|
||||
output_shape, output_data); \
|
||||
}
|
||||
TFLITE_COMPARISON_OP(Equal);
|
||||
TFLITE_COMPARISON_OP(NotEqual);
|
||||
TFLITE_COMPARISON_OP(Greater);
|
||||
TFLITE_COMPARISON_OP(GreaterEqual);
|
||||
TFLITE_COMPARISON_OP(Less);
|
||||
TFLITE_COMPARISON_OP(LessEqual);
|
||||
#undef TFLITE_COMPARISON_OP
|
||||
|
||||
} // namespace reference_ops
|
||||
} // namespace tflite
|
||||
|
||||
#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_COMPARISONS_H_
|
@ -33,6 +33,7 @@ limitations under the License.
|
||||
#include "tensorflow/lite/kernels/internal/common.h"
|
||||
#include "tensorflow/lite/kernels/internal/quantization_util.h"
|
||||
#include "tensorflow/lite/kernels/internal/reference/arg_min_max.h"
|
||||
#include "tensorflow/lite/kernels/internal/reference/comparisons.h"
|
||||
#include "tensorflow/lite/kernels/internal/reference/conv.h"
|
||||
#include "tensorflow/lite/kernels/internal/reference/floor.h"
|
||||
#include "tensorflow/lite/kernels/internal/reference/fully_connected.h"
|
||||
@ -2471,7 +2472,6 @@ inline void Tanh(const TanhParams&, const RuntimeShape& input_shape,
|
||||
Tanh(input_shape, input_data, output_shape, output_data);
|
||||
}
|
||||
|
||||
|
||||
inline void Tanh(const TanhParams& params, const RuntimeShape& input_shape,
|
||||
const int16* input_data, const RuntimeShape& output_shape,
|
||||
int16* output_data) {
|
||||
@ -3768,253 +3768,6 @@ inline void TransposeConv(const ConvParams& params,
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline bool EqualFn(T lhs, T rhs) {
|
||||
return lhs == rhs;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline bool NotEqualFn(T lhs, T rhs) {
|
||||
return lhs != rhs;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline bool GreaterFn(T lhs, T rhs) {
|
||||
return lhs > rhs;
|
||||
}
|
||||
template <typename T>
|
||||
inline bool GreaterEqualFn(T lhs, T rhs) {
|
||||
return lhs >= rhs;
|
||||
}
|
||||
template <typename T>
|
||||
inline bool LessFn(T lhs, T rhs) {
|
||||
return lhs < rhs;
|
||||
}
|
||||
template <typename T>
|
||||
inline bool LessEqualFn(T lhs, T rhs) {
|
||||
return lhs <= rhs;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
using ComparisonFn = bool (*)(T, T);
|
||||
|
||||
template <typename T, ComparisonFn<T> F>
|
||||
inline void ComparisonImpl(
|
||||
const ComparisonParams& op_params, const RuntimeShape& input1_shape,
|
||||
const T* input1_data, const RuntimeShape& input2_shape,
|
||||
const T* input2_data, const RuntimeShape& output_shape, bool* output_data) {
|
||||
const int64_t flatsize =
|
||||
MatchingFlatSize(input1_shape, input2_shape, output_shape);
|
||||
for (int64_t i = 0; i < flatsize; ++i) {
|
||||
output_data[i] = F(input1_data[i], input2_data[i]);
|
||||
}
|
||||
}
|
||||
|
||||
template <ComparisonFn<float> F>
|
||||
inline void Comparison(const ComparisonParams& op_params,
|
||||
const RuntimeShape& input1_shape,
|
||||
const float* input1_data,
|
||||
const RuntimeShape& input2_shape,
|
||||
const float* input2_data,
|
||||
const RuntimeShape& output_shape, bool* output_data) {
|
||||
ComparisonImpl<float, F>(op_params, input1_shape, input1_data, input2_shape,
|
||||
input2_data, output_shape, output_data);
|
||||
}
|
||||
|
||||
template <typename T, ComparisonFn<int32> F>
|
||||
inline void ComparisonWithScaling(
|
||||
const ComparisonParams& op_params, const RuntimeShape& input1_shape,
|
||||
const T* input1_data, const RuntimeShape& input2_shape,
|
||||
const T* input2_data, const RuntimeShape& output_shape, bool* output_data) {
|
||||
int left_shift = op_params.left_shift;
|
||||
int32 input1_offset = op_params.input1_offset;
|
||||
int32 input1_multiplier = op_params.input1_multiplier;
|
||||
int input1_shift = op_params.input1_shift;
|
||||
int32 input2_offset = op_params.input2_offset;
|
||||
int32 input2_multiplier = op_params.input2_multiplier;
|
||||
int input2_shift = op_params.input2_shift;
|
||||
|
||||
const int64_t flatsize =
|
||||
MatchingFlatSize(input1_shape, input2_shape, output_shape);
|
||||
for (int64_t i = 0; i < flatsize; ++i) {
|
||||
const int32 input1_val = input1_offset + input1_data[i];
|
||||
const int32 input2_val = input2_offset + input2_data[i];
|
||||
const int32 shifted_input1_val = input1_val * (1 << left_shift);
|
||||
const int32 shifted_input2_val = input2_val * (1 << left_shift);
|
||||
const int32 scaled_input1_val =
|
||||
MultiplyByQuantizedMultiplierSmallerThanOneExp(
|
||||
shifted_input1_val, input1_multiplier, input1_shift);
|
||||
const int32 scaled_input2_val =
|
||||
MultiplyByQuantizedMultiplierSmallerThanOneExp(
|
||||
shifted_input2_val, input2_multiplier, input2_shift);
|
||||
output_data[i] = F(scaled_input1_val, scaled_input2_val);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, ComparisonFn<T> F>
|
||||
inline void BroadcastComparison4DSlowImpl(
|
||||
const ComparisonParams& op_params,
|
||||
const RuntimeShape& unextended_input1_shape, const T* input1_data,
|
||||
const RuntimeShape& unextended_input2_shape, const T* input2_data,
|
||||
const RuntimeShape& unextended_output_shape, bool* output_data) {
|
||||
gemmlowp::ScopedProfilingLabel label("BroadcastComparison4DSlow");
|
||||
TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4);
|
||||
TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), 4);
|
||||
TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
|
||||
const RuntimeShape output_shape =
|
||||
RuntimeShape::ExtendedShape(4, unextended_output_shape);
|
||||
|
||||
NdArrayDesc<4> desc1;
|
||||
NdArrayDesc<4> desc2;
|
||||
NdArrayDescsForElementwiseBroadcast(unextended_input1_shape,
|
||||
unextended_input2_shape, &desc1, &desc2);
|
||||
|
||||
for (int b = 0; b < output_shape.Dims(0); ++b) {
|
||||
for (int y = 0; y < output_shape.Dims(1); ++y) {
|
||||
for (int x = 0; x < output_shape.Dims(2); ++x) {
|
||||
for (int c = 0; c < output_shape.Dims(3); ++c) {
|
||||
output_data[Offset(output_shape, b, y, x, c)] =
|
||||
F(input1_data[SubscriptToIndex(desc1, b, y, x, c)],
|
||||
input2_data[SubscriptToIndex(desc2, b, y, x, c)]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
template <ComparisonFn<float> F>
|
||||
inline void BroadcastComparison4DSlow(const ComparisonParams& op_params,
|
||||
const RuntimeShape& input1_shape,
|
||||
const float* input1_data,
|
||||
const RuntimeShape& input2_shape,
|
||||
const float* input2_data,
|
||||
const RuntimeShape& output_shape,
|
||||
bool* output_data) {
|
||||
BroadcastComparison4DSlowImpl<float, F>(op_params, input1_shape, input1_data,
|
||||
input2_shape, input2_data,
|
||||
output_shape, output_data);
|
||||
}
|
||||
|
||||
template <typename T, ComparisonFn<int32> F>
|
||||
inline void BroadcastComparison4DSlowWithScaling(
|
||||
const ComparisonParams& op_params,
|
||||
const RuntimeShape& unextended_input1_shape, const T* input1_data,
|
||||
const RuntimeShape& unextended_input2_shape, const T* input2_data,
|
||||
const RuntimeShape& unextended_output_shape, bool* output_data) {
|
||||
gemmlowp::ScopedProfilingLabel label("BroadcastComparison4DSlowWithScaling");
|
||||
TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4);
|
||||
TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), 4);
|
||||
TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
|
||||
const RuntimeShape output_shape =
|
||||
RuntimeShape::ExtendedShape(4, unextended_output_shape);
|
||||
|
||||
NdArrayDesc<4> desc1;
|
||||
NdArrayDesc<4> desc2;
|
||||
NdArrayDescsForElementwiseBroadcast(unextended_input1_shape,
|
||||
unextended_input2_shape, &desc1, &desc2);
|
||||
|
||||
int left_shift = op_params.left_shift;
|
||||
int32 input1_offset = op_params.input1_offset;
|
||||
int32 input1_multiplier = op_params.input1_multiplier;
|
||||
int input1_shift = op_params.input1_shift;
|
||||
int32 input2_offset = op_params.input2_offset;
|
||||
int32 input2_multiplier = op_params.input2_multiplier;
|
||||
int input2_shift = op_params.input2_shift;
|
||||
|
||||
for (int b = 0; b < output_shape.Dims(0); ++b) {
|
||||
for (int y = 0; y < output_shape.Dims(1); ++y) {
|
||||
for (int x = 0; x < output_shape.Dims(2); ++x) {
|
||||
for (int c = 0; c < output_shape.Dims(3); ++c) {
|
||||
const int32 input1_val =
|
||||
input1_offset + input1_data[SubscriptToIndex(desc1, b, y, x, c)];
|
||||
const int32 input2_val =
|
||||
input2_offset + input2_data[SubscriptToIndex(desc2, b, y, x, c)];
|
||||
const int32 shifted_input1_val = input1_val * (1 << left_shift);
|
||||
const int32 shifted_input2_val = input2_val * (1 << left_shift);
|
||||
const int32 scaled_input1_val =
|
||||
MultiplyByQuantizedMultiplierSmallerThanOneExp(
|
||||
shifted_input1_val, input1_multiplier, input1_shift);
|
||||
const int32 scaled_input2_val =
|
||||
MultiplyByQuantizedMultiplierSmallerThanOneExp(
|
||||
shifted_input2_val, input2_multiplier, input2_shift);
|
||||
output_data[Offset(output_shape, b, y, x, c)] =
|
||||
F(scaled_input1_val, scaled_input2_val);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#define TFLITE_COMPARISON_OP(name) \
|
||||
inline void name(const ComparisonParams& op_params, \
|
||||
const RuntimeShape& input1_shape, const float* input1_data, \
|
||||
const RuntimeShape& input2_shape, const float* input2_data, \
|
||||
const RuntimeShape& output_shape, bool* output_data) { \
|
||||
gemmlowp::ScopedProfilingLabel label(#name); \
|
||||
Comparison<name##Fn>(op_params, input1_shape, input1_data, input2_shape, \
|
||||
input2_data, output_shape, output_data); \
|
||||
} \
|
||||
template <typename T> \
|
||||
inline void name##NoScaling( \
|
||||
const ComparisonParams& op_params, const RuntimeShape& input1_shape, \
|
||||
const T* input1_data, const RuntimeShape& input2_shape, \
|
||||
const T* input2_data, const RuntimeShape& output_shape, \
|
||||
bool* output_data) { \
|
||||
gemmlowp::ScopedProfilingLabel label(#name "NoScaling"); \
|
||||
ComparisonImpl<T, name##Fn>(op_params, input1_shape, input1_data, \
|
||||
input2_shape, input2_data, output_shape, \
|
||||
output_data); \
|
||||
} \
|
||||
template <typename T> \
|
||||
inline void name##WithScaling( \
|
||||
const ComparisonParams& op_params, const RuntimeShape& input1_shape, \
|
||||
const T* input1_data, const RuntimeShape& input2_shape, \
|
||||
const T* input2_data, const RuntimeShape& output_shape, \
|
||||
bool* output_data) { \
|
||||
gemmlowp::ScopedProfilingLabel label(#name "WithScaling/8bit"); \
|
||||
ComparisonWithScaling<T, name##Fn>(op_params, input1_shape, input1_data, \
|
||||
input2_shape, input2_data, \
|
||||
output_shape, output_data); \
|
||||
} \
|
||||
template <typename T> \
|
||||
inline void Broadcast4DSlow##name##NoScaling( \
|
||||
const ComparisonParams& op_params, const RuntimeShape& input1_shape, \
|
||||
const T* input1_data, const RuntimeShape& input2_shape, \
|
||||
const T* input2_data, const RuntimeShape& output_shape, \
|
||||
bool* output_data) { \
|
||||
gemmlowp::ScopedProfilingLabel label("Broadcast4DSlow" #name "NoScaling"); \
|
||||
BroadcastComparison4DSlowImpl<T, name##Fn>( \
|
||||
op_params, input1_shape, input1_data, input2_shape, input2_data, \
|
||||
output_shape, output_data); \
|
||||
} \
|
||||
inline void Broadcast4DSlow##name( \
|
||||
const ComparisonParams& op_params, const RuntimeShape& input1_shape, \
|
||||
const float* input1_data, const RuntimeShape& input2_shape, \
|
||||
const float* input2_data, const RuntimeShape& output_shape, \
|
||||
bool* output_data) { \
|
||||
gemmlowp::ScopedProfilingLabel label("Broadcast4DSlow" #name); \
|
||||
BroadcastComparison4DSlow<name##Fn>(op_params, input1_shape, input1_data, \
|
||||
input2_shape, input2_data, \
|
||||
output_shape, output_data); \
|
||||
} \
|
||||
template <typename T> \
|
||||
inline void Broadcast4DSlow##name##WithScaling( \
|
||||
const ComparisonParams& op_params, const RuntimeShape& input1_shape, \
|
||||
const T* input1_data, const RuntimeShape& input2_shape, \
|
||||
const T* input2_data, const RuntimeShape& output_shape, \
|
||||
bool* output_data) { \
|
||||
gemmlowp::ScopedProfilingLabel label("Broadcast4DSlow" #name "/8bit"); \
|
||||
BroadcastComparison4DSlowWithScaling<T, name##Fn>( \
|
||||
op_params, input1_shape, input1_data, input2_shape, input2_data, \
|
||||
output_shape, output_data); \
|
||||
}
|
||||
TFLITE_COMPARISON_OP(Equal);
|
||||
TFLITE_COMPARISON_OP(NotEqual);
|
||||
TFLITE_COMPARISON_OP(Greater);
|
||||
TFLITE_COMPARISON_OP(GreaterEqual);
|
||||
TFLITE_COMPARISON_OP(Less);
|
||||
TFLITE_COMPARISON_OP(LessEqual);
|
||||
#undef TFLITE_COMPARISON_OP
|
||||
|
||||
template <typename D, typename T>
|
||||
void Select(const RuntimeShape& input_condition_shape,
|
||||
const D* input_condition_data, const RuntimeShape& input_x_shape,
|
||||
|
Loading…
x
Reference in New Issue
Block a user