Internal change

PiperOrigin-RevId: 361284426
Change-Id: I105de3e1953cf94e715130f0e136a2e2dc456b97
This commit is contained in:
A. Unique TensorFlower 2021-03-05 22:15:29 -08:00 committed by TensorFlower Gardener
parent 32459f1fd6
commit 1edc795def
6 changed files with 23 additions and 328 deletions

View File

@ -119,6 +119,15 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteOk;
}
template <typename T>
std::function<bool(T, T)> GetComparefunction(bool is_arg_max) {
if (is_arg_max) {
return std::greater<T>();
} else {
return std::less<T>();
}
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node, bool is_arg_max) {
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
@ -135,7 +144,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node, bool is_arg_max) {
optimized_ops::ArgMinMax( \
GetTensorShape(input), GetTensorData<data_type>(input), \
GetTensorData<axis_type>(axis), GetTensorShape(output), \
GetTensorData<output_type>(output), is_arg_max)
GetTensorData<output_type>(output), \
GetComparefunction<data_type>(is_arg_max))
if (axis->type == kTfLiteInt32) {
switch (output->type) {
case kTfLiteInt32: {

View File

@ -209,19 +209,6 @@ TEST_P(ArgMinMaxOpTest, GetMaxArgOutput64) {
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 2}));
}
TEST_P(ArgMinMaxOpTest, GetMaxArgFloatLastAxis) {
std::vector<float> input{0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0};
for (int i = 1; i < 10; ++i) {
ArgMaxOpModel model({i}, TensorType_FLOAT32, 0, AxisType(), ConstantAxis(),
OutputType());
model.PopulateTensor<float>(
model.input(), std::vector<float>(input.begin(), input.begin() + i));
model.Invoke();
ValidateOutput(model, {i - 1});
}
}
TEST_P(ArgMinMaxOpTest, GetMinArgFloat) {
ArgMinOpModel model({1, 1, 1, 4}, TensorType_FLOAT32, 3, AxisType(),
ConstantAxis(), OutputType());
@ -272,18 +259,5 @@ TEST_P(ArgMinMaxOpTest, GetMinArgOutput64) {
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 2}));
}
TEST_P(ArgMinMaxOpTest, GetMinArgFloatLastAxis) {
std::vector<float> input{1.0, 0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1};
for (int i = 1; i < 10; ++i) {
ArgMinOpModel model({i}, TensorType_FLOAT32, 0, AxisType(), ConstantAxis(),
OutputType());
model.PopulateTensor<float>(
model.input(), std::vector<float>(input.begin(), input.begin() + i));
model.Invoke();
ValidateOutput(model, {i - 1});
}
}
} // namespace
} // namespace tflite

View File

@ -33,6 +33,8 @@ namespace tflite {
namespace optimized_ops {
// Unoptimized reference ops:
using reference_ops::ArgMax;
using reference_ops::ArgMinMax;
using reference_ops::Broadcast4DSlowGreater;
using reference_ops::Broadcast4DSlowGreaterEqual;
using reference_ops::Broadcast4DSlowGreaterEqualWithScaling;
@ -4964,31 +4966,6 @@ inline void StridedSlice(const T* input_data, const Dims<4>& input_dims,
DimsToShape(output_dims), output_data);
}
template <typename T1, typename T2, typename T3>
void ArgMax(const T3* axis, const T1* input_data,
const tflite::Dims<4>& input_dims, T2* output_data,
const tflite::Dims<4>& output_dims) {
// Assumes the input always has 4 dimensions, and therefore,
// output always has three dimensions.
auto output_shape = RuntimeShape(
{output_dims.sizes[2], output_dims.sizes[1], output_dims.sizes[0]});
// Another way to interpret this is that output_dims.sizes[4] is always 1.
TFLITE_DCHECK_EQ(output_shape.FlatSize(),
DimsToShape(output_dims).FlatSize());
// Legacy path only supported this.
TFLITE_DCHECK_EQ(axis[0], 3);
ArgMinMax(DimsToShape(input_dims), input_data, axis, output_shape,
output_data, /*is_arg_max=*/true);
}
template <typename T1, typename T2, typename T3>
void ArgMinMax(const T3* axis, const T1* input_data, const Dims<4>& input_dims,
T2* output_data, const Dims<4>& output_dims,
const bool is_arg_max) {
ArgMinMax(axis, DimsToShape(input_dims), input_data, DimsToShape(output_dims),
output_data, is_arg_max);
}
} // namespace optimized_ops
} // namespace tflite
#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_LEGACY_OPTIMIZED_OPS_H_

View File

@ -64,6 +64,8 @@ namespace tflite {
namespace optimized_ops {
// Unoptimized reference ops:
using reference_ops::ArgMax;
using reference_ops::ArgMinMax;
using reference_ops::Broadcast4DSlowGreater;
using reference_ops::Broadcast4DSlowGreaterEqual;
using reference_ops::Broadcast4DSlowGreaterEqualWithScaling;
@ -7728,238 +7730,6 @@ inline void BroadcastPReluDispatch(
PReluElementWise, PReluScalarBroadcast);
}
// Returns the index with minimum value within `input_data`.
// If there is a tie, returns the smaller index.
template <typename T>
inline int ArgMinVector(const T* input_data, int size) {
T min_value = input_data[0];
int min_index = 0;
for (int i = 1; i < size; ++i) {
const T curr_value = input_data[i];
if (curr_value < min_value) {
min_value = curr_value;
min_index = i;
}
}
return min_index;
}
// Returns the index with maximum value within `input_data`.
// If there is a tie, returns the smaller index.
template <typename T>
inline int ArgMaxVector(const T* input_data, int size) {
T max_value = input_data[0];
int max_index = 0;
for (int i = 1; i < size; ++i) {
const T curr_value = input_data[i];
if (curr_value > max_value) {
max_value = curr_value;
max_index = i;
}
}
return max_index;
}
template <>
inline int ArgMinVector(const float* input_data, int size) {
int32_t min_index = 0;
float min_value = input_data[0];
int32_t i = 1;
#ifdef USE_NEON
if (size >= 4) {
float32x4_t min_value_f32x4 = vld1q_f32(input_data);
const int32_t index_init[4] = {0, 1, 2, 3};
int32x4_t min_index_s32x4 = vld1q_s32(index_init);
int32x4_t index_s32x4 = min_index_s32x4;
int32x4_t inc = vdupq_n_s32(4);
for (i = 4; i <= size - 4; i += 4) {
// Increase indices by 4.
index_s32x4 = vaddq_s32(index_s32x4, inc);
float32x4_t v = vld1q_f32(&input_data[i]);
uint32x4_t mask = vcltq_f32(v, min_value_f32x4);
min_value_f32x4 = vminq_f32(min_value_f32x4, v);
min_index_s32x4 = vbslq_s32(mask, index_s32x4, min_index_s32x4);
}
// Find min element within float32x4_t.
#ifdef __aarch64__
min_value = vminvq_f32(min_value_f32x4);
#else
float32x2_t min_value_f32x2 = vpmin_f32(vget_low_f32(min_value_f32x4),
vget_high_f32(min_value_f32x4));
min_value_f32x2 = vpmin_f32(min_value_f32x2, min_value_f32x2);
min_value = vget_lane_f32(min_value_f32x2, 0);
#endif // __aarch64__
// Mask indices of non-min values with max int32_t.
float32x4_t fill_min_value_f32x4 = vdupq_n_f32(min_value);
uint32x4_t mask = vceqq_f32(min_value_f32x4, fill_min_value_f32x4);
int32x4_t all_set = vdupq_n_s32(std::numeric_limits<int>::max());
min_index_s32x4 = vbslq_s32(mask, min_index_s32x4, all_set);
// Find min index of min values.
#ifdef __aarch64__
min_index = vminvq_s32(min_index_s32x4);
#else
uint32x2_t min_index_s32x2 = vpmin_s32(vget_low_s32(min_index_s32x4),
vget_high_s32(min_index_s32x4));
min_index_s32x2 = vpmin_s32(min_index_s32x2, min_index_s32x2);
min_index = vget_lane_s32(min_index_s32x2, 0);
#endif // __aarch64__
}
#endif // USE_NEON
// Leftover loop.
for (; i < size; ++i) {
const float curr_value = input_data[i];
if (curr_value < min_value) {
min_value = curr_value;
min_index = i;
}
}
return min_index;
}
template <>
inline int ArgMaxVector(const float* input_data, int size) {
int32_t max_index = 0;
float max_value = input_data[0];
int32_t i = 1;
#ifdef USE_NEON
if (size >= 4) {
float32x4_t max_value_f32x4 = vld1q_f32(input_data);
const int32_t index_init[4] = {0, 1, 2, 3};
int32x4_t max_index_s32x4 = vld1q_s32(index_init);
int32x4_t index_s32x4 = max_index_s32x4;
int32x4_t inc = vdupq_n_s32(4);
for (i = 4; i <= size - 4; i += 4) {
// Increase indices by 4.
index_s32x4 = vaddq_s32(index_s32x4, inc);
float32x4_t v = vld1q_f32(&input_data[i]);
uint32x4_t mask = vcgtq_f32(v, max_value_f32x4);
max_value_f32x4 = vmaxq_f32(max_value_f32x4, v);
max_index_s32x4 = vbslq_s32(mask, index_s32x4, max_index_s32x4);
}
// Find max element within float32x4_t.
#ifdef __aarch64__
max_value = vmaxvq_f32(max_value_f32x4);
#else
float32x2_t max_value_f32x2 = vpmax_f32(vget_low_f32(max_value_f32x4),
vget_high_f32(max_value_f32x4));
max_value_f32x2 = vpmax_f32(max_value_f32x2, max_value_f32x2);
max_value = vget_lane_f32(max_value_f32x2, 0);
#endif // __aarch64__
// Mask indices of non-max values with max int32_t.
float32x4_t fill_max_value_f32x4 = vdupq_n_f32(max_value);
uint32x4_t mask = vceqq_f32(max_value_f32x4, fill_max_value_f32x4);
int32x4_t all_set = vdupq_n_s32(std::numeric_limits<int>::max());
max_index_s32x4 = vbslq_s32(mask, max_index_s32x4, all_set);
// Find min index of max values.
#ifdef __aarch64__
max_index = vminvq_s32(max_index_s32x4);
#else
uint32x2_t max_index_s32x2 = vpmin_s32(vget_low_s32(max_index_s32x4),
vget_high_s32(max_index_s32x4));
max_index_s32x2 = vpmin_s32(max_index_s32x2, max_index_s32x2);
max_index = vget_lane_s32(max_index_s32x2, 0);
#endif // __aarch64__
}
#endif // USE_NEON
// Leftover loop.
for (; i < size; ++i) {
const float curr_value = input_data[i];
if (curr_value > max_value) {
max_value = curr_value;
max_index = i;
}
}
return max_index;
}
// Specializes ArgMinMax function with axis=dims-1.
// In this case, ArgMinMax reduction is applied on contiguous memory.
template <typename T1, typename T2, bool is_arg_max>
inline void ArgMinMaxLastAxis(const RuntimeShape& input_shape,
const T1* input_data,
const RuntimeShape& output_shape,
T2* output_data) {
TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 2);
TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 1);
TFLITE_DCHECK_EQ(input_shape.Dims(0), output_shape.Dims(0));
int outer_size = input_shape.Dims(0);
int axis_size = input_shape.Dims(1);
for (int outer = 0; outer < outer_size; ++outer) {
if (is_arg_max) {
output_data[outer] = static_cast<T2>(
ArgMaxVector<T1>(input_data + outer * axis_size, axis_size));
} else {
output_data[outer] = static_cast<T2>(
ArgMinVector<T1>(input_data + outer * axis_size, axis_size));
}
}
}
template <typename T1, typename T2, typename T3>
inline void ArgMinMax(const RuntimeShape& input1_shape, const T1* input1_data,
const T3* input2_data, const RuntimeShape& output_shape,
T2* output_data, const bool is_arg_max) {
ruy::profiler::ScopeLabel label("ArgMinMax");
TFLITE_DCHECK_GT(input1_shape.DimensionsCount(), 0);
TFLITE_DCHECK_EQ(input1_shape.DimensionsCount() - 1,
output_shape.DimensionsCount());
int axis = input2_data[0];
if (axis < 0) {
axis += input1_shape.DimensionsCount();
}
const int axis_size = input1_shape.Dims(axis);
int outer_size = 1;
for (int i = 0; i < axis; ++i) {
TFLITE_DCHECK_EQ(input1_shape.Dims(i), output_shape.Dims(i));
outer_size *= input1_shape.Dims(i);
}
int inner_size = 1;
const int dims_count = input1_shape.DimensionsCount();
for (int i = axis + 1; i < dims_count; ++i) {
TFLITE_DCHECK_EQ(input1_shape.Dims(i), output_shape.Dims(i - 1));
inner_size *= input1_shape.Dims(i);
}
// Call specialized function when axis=dims-1. So far, only float32 is
// optimized so reroute to specialized function only when T1 is float32.
if (inner_size == 1 && std::is_same<T1, float>::value) {
if (is_arg_max) {
ArgMinMaxLastAxis<T1, T2, /*is_arg_max=*/true>(
{outer_size, axis_size}, input1_data, {outer_size}, output_data);
} else {
ArgMinMaxLastAxis<T1, T2, /*is_arg_max=*/false>(
{outer_size, axis_size}, input1_data, {outer_size}, output_data);
}
return;
}
reference_ops::ArgMinMax(input1_shape, input1_data, input2_data, output_shape,
output_data, is_arg_max);
}
template <typename T1, typename T2, typename T3>
void ArgMax(const RuntimeShape& input1_shape, const T1* input1_data,
const T3* input2_data, const RuntimeShape& output_shape,
T2* output_data) {
ArgMinMax(input1_shape, input1_data, input2_data, output_shape, output_data,
/*is_arg_max=*/true);
}
// Convenience version that allows, for example, generated-code calls to be
// the same as other binary ops.
// For backward compatibility, reference_ops has ArgMax function.
template <typename T1, typename T2, typename T3>
inline void ArgMax(const RuntimeShape& input1_shape, const T1* input1_data,
const RuntimeShape& input2_shape, const T3* input2_data,
const RuntimeShape& output_shape, T2* output_data) {
// Drop shape of second input: not needed.
ArgMax(input1_shape, input1_data, input2_data, output_shape, output_data);
}
} // namespace optimized_ops
} // namespace tflite

View File

@ -15,23 +15,12 @@ limitations under the License.
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_ARG_MIN_MAX_H_
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_ARG_MIN_MAX_H_
#include <functional>
#include "tensorflow/lite/kernels/internal/types.h"
namespace tflite {
namespace reference_ops {
template <typename T>
std::function<bool(T, T)> GetComparefunction(bool is_arg_max) {
if (is_arg_max) {
return std::greater<T>();
} else {
return std::less<T>();
}
}
template <typename T1, typename T2, typename T3, typename Cmp>
void ArgMinMax(const RuntimeShape& input1_shape, const T1* input1_data,
const T3* input2_data, const RuntimeShape& output_shape,
@ -73,15 +62,6 @@ void ArgMinMax(const RuntimeShape& input1_shape, const T1* input1_data,
}
}
}
template <typename T1, typename T2, typename T3>
void ArgMinMax(const RuntimeShape& input1_shape, const T1* input1_data,
const T3* input2_data, const RuntimeShape& output_shape,
T2* output_data, const bool is_arg_max) {
ArgMinMax(input1_shape, input1_data, input2_data, output_shape, output_data,
GetComparefunction<T1>(is_arg_max));
}
} // namespace reference_ops
} // namespace tflite

View File

@ -29,26 +29,13 @@ from tensorflow.lite.testing.zip_test_utils import register_make_test_function
def make_arg_min_max_tests(options):
"""Make a set of tests to do arg_max."""
test_parameters = [
{
"input_dtype": [tf.float32, tf.int32],
"input_shape": [[], [1, 1, 1, 3], [2, 3, 4, 5], [2, 3, 3], [5, 5],
[10]],
"output_type": [tf.int32, tf.int64],
"is_arg_max": [True],
"is_last_axis": [False],
"dynamic_range_quantize": [False, True],
},
{
"input_dtype": [tf.float32, tf.int32],
"input_shape": [[1], [2], [3], [4], [5], [6], [7], [8], [9], [10],
[2, 10], [3, 4, 50], [2, 3, 5, 100]],
"output_type": [tf.int32, tf.int64],
"is_arg_max": [False, True],
"is_last_axis": [True],
"dynamic_range_quantize": [False, True],
},
]
test_parameters = [{
"input_dtype": [tf.float32, tf.int32],
"input_shape": [[], [1, 1, 1, 3], [2, 3, 4, 5], [2, 3, 3], [5, 5], [10]],
"output_type": [tf.int32, tf.int64],
"is_arg_max": [True],
"dynamic_range_quantize": [False, True],
}]
def build_graph(parameters):
"""Build the topk op testing graph."""
@ -56,10 +43,7 @@ def make_arg_min_max_tests(options):
dtype=parameters["input_dtype"],
name="input",
shape=parameters["input_shape"])
if not parameters["is_last_axis"]:
axis = random.randint(0, max(len(parameters["input_shape"]) - 1, 0))
else:
axis = -1
axis = random.randint(0, max(len(parameters["input_shape"]) - 1, 0))
if parameters["is_arg_max"]:
out = tf.math.argmax(
input_value, axis, output_type=parameters["output_type"])