Internal change
PiperOrigin-RevId: 361284426 Change-Id: I105de3e1953cf94e715130f0e136a2e2dc456b97
This commit is contained in:
parent
32459f1fd6
commit
1edc795def
@ -119,6 +119,15 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::function<bool(T, T)> GetComparefunction(bool is_arg_max) {
|
||||
if (is_arg_max) {
|
||||
return std::greater<T>();
|
||||
} else {
|
||||
return std::less<T>();
|
||||
}
|
||||
}
|
||||
|
||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node, bool is_arg_max) {
|
||||
const TfLiteTensor* input;
|
||||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
|
||||
@ -135,7 +144,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node, bool is_arg_max) {
|
||||
optimized_ops::ArgMinMax( \
|
||||
GetTensorShape(input), GetTensorData<data_type>(input), \
|
||||
GetTensorData<axis_type>(axis), GetTensorShape(output), \
|
||||
GetTensorData<output_type>(output), is_arg_max)
|
||||
GetTensorData<output_type>(output), \
|
||||
GetComparefunction<data_type>(is_arg_max))
|
||||
if (axis->type == kTfLiteInt32) {
|
||||
switch (output->type) {
|
||||
case kTfLiteInt32: {
|
||||
|
@ -209,19 +209,6 @@ TEST_P(ArgMinMaxOpTest, GetMaxArgOutput64) {
|
||||
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 2}));
|
||||
}
|
||||
|
||||
TEST_P(ArgMinMaxOpTest, GetMaxArgFloatLastAxis) {
|
||||
std::vector<float> input{0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0};
|
||||
for (int i = 1; i < 10; ++i) {
|
||||
ArgMaxOpModel model({i}, TensorType_FLOAT32, 0, AxisType(), ConstantAxis(),
|
||||
OutputType());
|
||||
model.PopulateTensor<float>(
|
||||
model.input(), std::vector<float>(input.begin(), input.begin() + i));
|
||||
model.Invoke();
|
||||
|
||||
ValidateOutput(model, {i - 1});
|
||||
}
|
||||
}
|
||||
|
||||
TEST_P(ArgMinMaxOpTest, GetMinArgFloat) {
|
||||
ArgMinOpModel model({1, 1, 1, 4}, TensorType_FLOAT32, 3, AxisType(),
|
||||
ConstantAxis(), OutputType());
|
||||
@ -272,18 +259,5 @@ TEST_P(ArgMinMaxOpTest, GetMinArgOutput64) {
|
||||
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 2}));
|
||||
}
|
||||
|
||||
TEST_P(ArgMinMaxOpTest, GetMinArgFloatLastAxis) {
|
||||
std::vector<float> input{1.0, 0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1};
|
||||
for (int i = 1; i < 10; ++i) {
|
||||
ArgMinOpModel model({i}, TensorType_FLOAT32, 0, AxisType(), ConstantAxis(),
|
||||
OutputType());
|
||||
model.PopulateTensor<float>(
|
||||
model.input(), std::vector<float>(input.begin(), input.begin() + i));
|
||||
model.Invoke();
|
||||
|
||||
ValidateOutput(model, {i - 1});
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tflite
|
||||
|
@ -33,6 +33,8 @@ namespace tflite {
|
||||
namespace optimized_ops {
|
||||
|
||||
// Unoptimized reference ops:
|
||||
using reference_ops::ArgMax;
|
||||
using reference_ops::ArgMinMax;
|
||||
using reference_ops::Broadcast4DSlowGreater;
|
||||
using reference_ops::Broadcast4DSlowGreaterEqual;
|
||||
using reference_ops::Broadcast4DSlowGreaterEqualWithScaling;
|
||||
@ -4964,31 +4966,6 @@ inline void StridedSlice(const T* input_data, const Dims<4>& input_dims,
|
||||
DimsToShape(output_dims), output_data);
|
||||
}
|
||||
|
||||
template <typename T1, typename T2, typename T3>
|
||||
void ArgMax(const T3* axis, const T1* input_data,
|
||||
const tflite::Dims<4>& input_dims, T2* output_data,
|
||||
const tflite::Dims<4>& output_dims) {
|
||||
// Assumes the input always has 4 dimensions, and therefore,
|
||||
// output always has three dimensions.
|
||||
auto output_shape = RuntimeShape(
|
||||
{output_dims.sizes[2], output_dims.sizes[1], output_dims.sizes[0]});
|
||||
// Another way to interpret this is that output_dims.sizes[4] is always 1.
|
||||
TFLITE_DCHECK_EQ(output_shape.FlatSize(),
|
||||
DimsToShape(output_dims).FlatSize());
|
||||
// Legacy path only supported this.
|
||||
TFLITE_DCHECK_EQ(axis[0], 3);
|
||||
ArgMinMax(DimsToShape(input_dims), input_data, axis, output_shape,
|
||||
output_data, /*is_arg_max=*/true);
|
||||
}
|
||||
|
||||
template <typename T1, typename T2, typename T3>
|
||||
void ArgMinMax(const T3* axis, const T1* input_data, const Dims<4>& input_dims,
|
||||
T2* output_data, const Dims<4>& output_dims,
|
||||
const bool is_arg_max) {
|
||||
ArgMinMax(axis, DimsToShape(input_dims), input_data, DimsToShape(output_dims),
|
||||
output_data, is_arg_max);
|
||||
}
|
||||
|
||||
} // namespace optimized_ops
|
||||
} // namespace tflite
|
||||
#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_LEGACY_OPTIMIZED_OPS_H_
|
||||
|
@ -64,6 +64,8 @@ namespace tflite {
|
||||
namespace optimized_ops {
|
||||
|
||||
// Unoptimized reference ops:
|
||||
using reference_ops::ArgMax;
|
||||
using reference_ops::ArgMinMax;
|
||||
using reference_ops::Broadcast4DSlowGreater;
|
||||
using reference_ops::Broadcast4DSlowGreaterEqual;
|
||||
using reference_ops::Broadcast4DSlowGreaterEqualWithScaling;
|
||||
@ -7728,238 +7730,6 @@ inline void BroadcastPReluDispatch(
|
||||
PReluElementWise, PReluScalarBroadcast);
|
||||
}
|
||||
|
||||
// Returns the index with minimum value within `input_data`.
|
||||
// If there is a tie, returns the smaller index.
|
||||
template <typename T>
|
||||
inline int ArgMinVector(const T* input_data, int size) {
|
||||
T min_value = input_data[0];
|
||||
int min_index = 0;
|
||||
for (int i = 1; i < size; ++i) {
|
||||
const T curr_value = input_data[i];
|
||||
if (curr_value < min_value) {
|
||||
min_value = curr_value;
|
||||
min_index = i;
|
||||
}
|
||||
}
|
||||
return min_index;
|
||||
}
|
||||
|
||||
// Returns the index with maximum value within `input_data`.
|
||||
// If there is a tie, returns the smaller index.
|
||||
template <typename T>
|
||||
inline int ArgMaxVector(const T* input_data, int size) {
|
||||
T max_value = input_data[0];
|
||||
int max_index = 0;
|
||||
for (int i = 1; i < size; ++i) {
|
||||
const T curr_value = input_data[i];
|
||||
if (curr_value > max_value) {
|
||||
max_value = curr_value;
|
||||
max_index = i;
|
||||
}
|
||||
}
|
||||
return max_index;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline int ArgMinVector(const float* input_data, int size) {
|
||||
int32_t min_index = 0;
|
||||
float min_value = input_data[0];
|
||||
int32_t i = 1;
|
||||
#ifdef USE_NEON
|
||||
if (size >= 4) {
|
||||
float32x4_t min_value_f32x4 = vld1q_f32(input_data);
|
||||
const int32_t index_init[4] = {0, 1, 2, 3};
|
||||
int32x4_t min_index_s32x4 = vld1q_s32(index_init);
|
||||
int32x4_t index_s32x4 = min_index_s32x4;
|
||||
int32x4_t inc = vdupq_n_s32(4);
|
||||
for (i = 4; i <= size - 4; i += 4) {
|
||||
// Increase indices by 4.
|
||||
index_s32x4 = vaddq_s32(index_s32x4, inc);
|
||||
float32x4_t v = vld1q_f32(&input_data[i]);
|
||||
uint32x4_t mask = vcltq_f32(v, min_value_f32x4);
|
||||
min_value_f32x4 = vminq_f32(min_value_f32x4, v);
|
||||
min_index_s32x4 = vbslq_s32(mask, index_s32x4, min_index_s32x4);
|
||||
}
|
||||
// Find min element within float32x4_t.
|
||||
#ifdef __aarch64__
|
||||
min_value = vminvq_f32(min_value_f32x4);
|
||||
#else
|
||||
float32x2_t min_value_f32x2 = vpmin_f32(vget_low_f32(min_value_f32x4),
|
||||
vget_high_f32(min_value_f32x4));
|
||||
min_value_f32x2 = vpmin_f32(min_value_f32x2, min_value_f32x2);
|
||||
min_value = vget_lane_f32(min_value_f32x2, 0);
|
||||
#endif // __aarch64__
|
||||
// Mask indices of non-min values with max int32_t.
|
||||
float32x4_t fill_min_value_f32x4 = vdupq_n_f32(min_value);
|
||||
uint32x4_t mask = vceqq_f32(min_value_f32x4, fill_min_value_f32x4);
|
||||
int32x4_t all_set = vdupq_n_s32(std::numeric_limits<int>::max());
|
||||
min_index_s32x4 = vbslq_s32(mask, min_index_s32x4, all_set);
|
||||
// Find min index of min values.
|
||||
#ifdef __aarch64__
|
||||
min_index = vminvq_s32(min_index_s32x4);
|
||||
#else
|
||||
uint32x2_t min_index_s32x2 = vpmin_s32(vget_low_s32(min_index_s32x4),
|
||||
vget_high_s32(min_index_s32x4));
|
||||
min_index_s32x2 = vpmin_s32(min_index_s32x2, min_index_s32x2);
|
||||
min_index = vget_lane_s32(min_index_s32x2, 0);
|
||||
#endif // __aarch64__
|
||||
}
|
||||
#endif // USE_NEON
|
||||
// Leftover loop.
|
||||
for (; i < size; ++i) {
|
||||
const float curr_value = input_data[i];
|
||||
if (curr_value < min_value) {
|
||||
min_value = curr_value;
|
||||
min_index = i;
|
||||
}
|
||||
}
|
||||
return min_index;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline int ArgMaxVector(const float* input_data, int size) {
|
||||
int32_t max_index = 0;
|
||||
float max_value = input_data[0];
|
||||
int32_t i = 1;
|
||||
#ifdef USE_NEON
|
||||
if (size >= 4) {
|
||||
float32x4_t max_value_f32x4 = vld1q_f32(input_data);
|
||||
const int32_t index_init[4] = {0, 1, 2, 3};
|
||||
int32x4_t max_index_s32x4 = vld1q_s32(index_init);
|
||||
int32x4_t index_s32x4 = max_index_s32x4;
|
||||
int32x4_t inc = vdupq_n_s32(4);
|
||||
for (i = 4; i <= size - 4; i += 4) {
|
||||
// Increase indices by 4.
|
||||
index_s32x4 = vaddq_s32(index_s32x4, inc);
|
||||
float32x4_t v = vld1q_f32(&input_data[i]);
|
||||
uint32x4_t mask = vcgtq_f32(v, max_value_f32x4);
|
||||
max_value_f32x4 = vmaxq_f32(max_value_f32x4, v);
|
||||
max_index_s32x4 = vbslq_s32(mask, index_s32x4, max_index_s32x4);
|
||||
}
|
||||
// Find max element within float32x4_t.
|
||||
#ifdef __aarch64__
|
||||
max_value = vmaxvq_f32(max_value_f32x4);
|
||||
#else
|
||||
float32x2_t max_value_f32x2 = vpmax_f32(vget_low_f32(max_value_f32x4),
|
||||
vget_high_f32(max_value_f32x4));
|
||||
max_value_f32x2 = vpmax_f32(max_value_f32x2, max_value_f32x2);
|
||||
max_value = vget_lane_f32(max_value_f32x2, 0);
|
||||
#endif // __aarch64__
|
||||
// Mask indices of non-max values with max int32_t.
|
||||
float32x4_t fill_max_value_f32x4 = vdupq_n_f32(max_value);
|
||||
uint32x4_t mask = vceqq_f32(max_value_f32x4, fill_max_value_f32x4);
|
||||
int32x4_t all_set = vdupq_n_s32(std::numeric_limits<int>::max());
|
||||
max_index_s32x4 = vbslq_s32(mask, max_index_s32x4, all_set);
|
||||
// Find min index of max values.
|
||||
#ifdef __aarch64__
|
||||
max_index = vminvq_s32(max_index_s32x4);
|
||||
#else
|
||||
uint32x2_t max_index_s32x2 = vpmin_s32(vget_low_s32(max_index_s32x4),
|
||||
vget_high_s32(max_index_s32x4));
|
||||
max_index_s32x2 = vpmin_s32(max_index_s32x2, max_index_s32x2);
|
||||
max_index = vget_lane_s32(max_index_s32x2, 0);
|
||||
#endif // __aarch64__
|
||||
}
|
||||
#endif // USE_NEON
|
||||
// Leftover loop.
|
||||
for (; i < size; ++i) {
|
||||
const float curr_value = input_data[i];
|
||||
if (curr_value > max_value) {
|
||||
max_value = curr_value;
|
||||
max_index = i;
|
||||
}
|
||||
}
|
||||
return max_index;
|
||||
}
|
||||
|
||||
// Specializes ArgMinMax function with axis=dims-1.
|
||||
// In this case, ArgMinMax reduction is applied on contiguous memory.
|
||||
template <typename T1, typename T2, bool is_arg_max>
|
||||
inline void ArgMinMaxLastAxis(const RuntimeShape& input_shape,
|
||||
const T1* input_data,
|
||||
const RuntimeShape& output_shape,
|
||||
T2* output_data) {
|
||||
TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 2);
|
||||
TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 1);
|
||||
TFLITE_DCHECK_EQ(input_shape.Dims(0), output_shape.Dims(0));
|
||||
|
||||
int outer_size = input_shape.Dims(0);
|
||||
int axis_size = input_shape.Dims(1);
|
||||
for (int outer = 0; outer < outer_size; ++outer) {
|
||||
if (is_arg_max) {
|
||||
output_data[outer] = static_cast<T2>(
|
||||
ArgMaxVector<T1>(input_data + outer * axis_size, axis_size));
|
||||
} else {
|
||||
output_data[outer] = static_cast<T2>(
|
||||
ArgMinVector<T1>(input_data + outer * axis_size, axis_size));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T1, typename T2, typename T3>
|
||||
inline void ArgMinMax(const RuntimeShape& input1_shape, const T1* input1_data,
|
||||
const T3* input2_data, const RuntimeShape& output_shape,
|
||||
T2* output_data, const bool is_arg_max) {
|
||||
ruy::profiler::ScopeLabel label("ArgMinMax");
|
||||
|
||||
TFLITE_DCHECK_GT(input1_shape.DimensionsCount(), 0);
|
||||
TFLITE_DCHECK_EQ(input1_shape.DimensionsCount() - 1,
|
||||
output_shape.DimensionsCount());
|
||||
int axis = input2_data[0];
|
||||
if (axis < 0) {
|
||||
axis += input1_shape.DimensionsCount();
|
||||
}
|
||||
const int axis_size = input1_shape.Dims(axis);
|
||||
|
||||
int outer_size = 1;
|
||||
for (int i = 0; i < axis; ++i) {
|
||||
TFLITE_DCHECK_EQ(input1_shape.Dims(i), output_shape.Dims(i));
|
||||
outer_size *= input1_shape.Dims(i);
|
||||
}
|
||||
|
||||
int inner_size = 1;
|
||||
const int dims_count = input1_shape.DimensionsCount();
|
||||
for (int i = axis + 1; i < dims_count; ++i) {
|
||||
TFLITE_DCHECK_EQ(input1_shape.Dims(i), output_shape.Dims(i - 1));
|
||||
inner_size *= input1_shape.Dims(i);
|
||||
}
|
||||
|
||||
// Call specialized function when axis=dims-1. So far, only float32 is
|
||||
// optimized so reroute to specialized function only when T1 is float32.
|
||||
if (inner_size == 1 && std::is_same<T1, float>::value) {
|
||||
if (is_arg_max) {
|
||||
ArgMinMaxLastAxis<T1, T2, /*is_arg_max=*/true>(
|
||||
{outer_size, axis_size}, input1_data, {outer_size}, output_data);
|
||||
} else {
|
||||
ArgMinMaxLastAxis<T1, T2, /*is_arg_max=*/false>(
|
||||
{outer_size, axis_size}, input1_data, {outer_size}, output_data);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
reference_ops::ArgMinMax(input1_shape, input1_data, input2_data, output_shape,
|
||||
output_data, is_arg_max);
|
||||
}
|
||||
|
||||
template <typename T1, typename T2, typename T3>
|
||||
void ArgMax(const RuntimeShape& input1_shape, const T1* input1_data,
|
||||
const T3* input2_data, const RuntimeShape& output_shape,
|
||||
T2* output_data) {
|
||||
ArgMinMax(input1_shape, input1_data, input2_data, output_shape, output_data,
|
||||
/*is_arg_max=*/true);
|
||||
}
|
||||
|
||||
// Convenience version that allows, for example, generated-code calls to be
|
||||
// the same as other binary ops.
|
||||
// For backward compatibility, reference_ops has ArgMax function.
|
||||
template <typename T1, typename T2, typename T3>
|
||||
inline void ArgMax(const RuntimeShape& input1_shape, const T1* input1_data,
|
||||
const RuntimeShape& input2_shape, const T3* input2_data,
|
||||
const RuntimeShape& output_shape, T2* output_data) {
|
||||
// Drop shape of second input: not needed.
|
||||
ArgMax(input1_shape, input1_data, input2_data, output_shape, output_data);
|
||||
}
|
||||
|
||||
} // namespace optimized_ops
|
||||
} // namespace tflite
|
||||
|
||||
|
@ -15,23 +15,12 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_ARG_MIN_MAX_H_
|
||||
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_ARG_MIN_MAX_H_
|
||||
|
||||
#include <functional>
|
||||
|
||||
#include "tensorflow/lite/kernels/internal/types.h"
|
||||
|
||||
namespace tflite {
|
||||
|
||||
namespace reference_ops {
|
||||
|
||||
template <typename T>
|
||||
std::function<bool(T, T)> GetComparefunction(bool is_arg_max) {
|
||||
if (is_arg_max) {
|
||||
return std::greater<T>();
|
||||
} else {
|
||||
return std::less<T>();
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T1, typename T2, typename T3, typename Cmp>
|
||||
void ArgMinMax(const RuntimeShape& input1_shape, const T1* input1_data,
|
||||
const T3* input2_data, const RuntimeShape& output_shape,
|
||||
@ -73,15 +62,6 @@ void ArgMinMax(const RuntimeShape& input1_shape, const T1* input1_data,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T1, typename T2, typename T3>
|
||||
void ArgMinMax(const RuntimeShape& input1_shape, const T1* input1_data,
|
||||
const T3* input2_data, const RuntimeShape& output_shape,
|
||||
T2* output_data, const bool is_arg_max) {
|
||||
ArgMinMax(input1_shape, input1_data, input2_data, output_shape, output_data,
|
||||
GetComparefunction<T1>(is_arg_max));
|
||||
}
|
||||
|
||||
} // namespace reference_ops
|
||||
} // namespace tflite
|
||||
|
||||
|
@ -29,26 +29,13 @@ from tensorflow.lite.testing.zip_test_utils import register_make_test_function
|
||||
def make_arg_min_max_tests(options):
|
||||
"""Make a set of tests to do arg_max."""
|
||||
|
||||
test_parameters = [
|
||||
{
|
||||
"input_dtype": [tf.float32, tf.int32],
|
||||
"input_shape": [[], [1, 1, 1, 3], [2, 3, 4, 5], [2, 3, 3], [5, 5],
|
||||
[10]],
|
||||
"output_type": [tf.int32, tf.int64],
|
||||
"is_arg_max": [True],
|
||||
"is_last_axis": [False],
|
||||
"dynamic_range_quantize": [False, True],
|
||||
},
|
||||
{
|
||||
"input_dtype": [tf.float32, tf.int32],
|
||||
"input_shape": [[1], [2], [3], [4], [5], [6], [7], [8], [9], [10],
|
||||
[2, 10], [3, 4, 50], [2, 3, 5, 100]],
|
||||
"output_type": [tf.int32, tf.int64],
|
||||
"is_arg_max": [False, True],
|
||||
"is_last_axis": [True],
|
||||
"dynamic_range_quantize": [False, True],
|
||||
},
|
||||
]
|
||||
test_parameters = [{
|
||||
"input_dtype": [tf.float32, tf.int32],
|
||||
"input_shape": [[], [1, 1, 1, 3], [2, 3, 4, 5], [2, 3, 3], [5, 5], [10]],
|
||||
"output_type": [tf.int32, tf.int64],
|
||||
"is_arg_max": [True],
|
||||
"dynamic_range_quantize": [False, True],
|
||||
}]
|
||||
|
||||
def build_graph(parameters):
|
||||
"""Build the topk op testing graph."""
|
||||
@ -56,10 +43,7 @@ def make_arg_min_max_tests(options):
|
||||
dtype=parameters["input_dtype"],
|
||||
name="input",
|
||||
shape=parameters["input_shape"])
|
||||
if not parameters["is_last_axis"]:
|
||||
axis = random.randint(0, max(len(parameters["input_shape"]) - 1, 0))
|
||||
else:
|
||||
axis = -1
|
||||
axis = random.randint(0, max(len(parameters["input_shape"]) - 1, 0))
|
||||
if parameters["is_arg_max"]:
|
||||
out = tf.math.argmax(
|
||||
input_value, axis, output_type=parameters["output_type"])
|
||||
|
Loading…
Reference in New Issue
Block a user