From 1edc795defb760129ebb5042ca36b92e670cbaa2 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 5 Mar 2021 22:15:29 -0800 Subject: [PATCH] Internal change PiperOrigin-RevId: 361284426 Change-Id: I105de3e1953cf94e715130f0e136a2e2dc456b97 --- tensorflow/lite/kernels/arg_min_max.cc | 12 +- tensorflow/lite/kernels/arg_min_max_test.cc | 26 -- .../internal/optimized/legacy_optimized_ops.h | 27 +- .../internal/optimized/optimized_ops.h | 234 +----------------- .../kernels/internal/reference/arg_min_max.h | 20 -- .../lite/testing/op_tests/arg_min_max.py | 32 +-- 6 files changed, 23 insertions(+), 328 deletions(-) diff --git a/tensorflow/lite/kernels/arg_min_max.cc b/tensorflow/lite/kernels/arg_min_max.cc index a0ba8cb9f8b..03a6961e609 100644 --- a/tensorflow/lite/kernels/arg_min_max.cc +++ b/tensorflow/lite/kernels/arg_min_max.cc @@ -119,6 +119,15 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { return kTfLiteOk; } +template +std::function GetComparefunction(bool is_arg_max) { + if (is_arg_max) { + return std::greater(); + } else { + return std::less(); + } +} + TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node, bool is_arg_max) { const TfLiteTensor* input; TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input)); @@ -135,7 +144,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node, bool is_arg_max) { optimized_ops::ArgMinMax( \ GetTensorShape(input), GetTensorData(input), \ GetTensorData(axis), GetTensorShape(output), \ - GetTensorData(output), is_arg_max) + GetTensorData(output), \ + GetComparefunction(is_arg_max)) if (axis->type == kTfLiteInt32) { switch (output->type) { case kTfLiteInt32: { diff --git a/tensorflow/lite/kernels/arg_min_max_test.cc b/tensorflow/lite/kernels/arg_min_max_test.cc index b3159ed7981..957d3473b8d 100644 --- a/tensorflow/lite/kernels/arg_min_max_test.cc +++ b/tensorflow/lite/kernels/arg_min_max_test.cc @@ -209,19 +209,6 @@ TEST_P(ArgMinMaxOpTest, GetMaxArgOutput64) { EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 2})); } -TEST_P(ArgMinMaxOpTest, GetMaxArgFloatLastAxis) { - std::vector input{0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0}; - for (int i = 1; i < 10; ++i) { - ArgMaxOpModel model({i}, TensorType_FLOAT32, 0, AxisType(), ConstantAxis(), - OutputType()); - model.PopulateTensor( - model.input(), std::vector(input.begin(), input.begin() + i)); - model.Invoke(); - - ValidateOutput(model, {i - 1}); - } -} - TEST_P(ArgMinMaxOpTest, GetMinArgFloat) { ArgMinOpModel model({1, 1, 1, 4}, TensorType_FLOAT32, 3, AxisType(), ConstantAxis(), OutputType()); @@ -272,18 +259,5 @@ TEST_P(ArgMinMaxOpTest, GetMinArgOutput64) { EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 2})); } -TEST_P(ArgMinMaxOpTest, GetMinArgFloatLastAxis) { - std::vector input{1.0, 0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1}; - for (int i = 1; i < 10; ++i) { - ArgMinOpModel model({i}, TensorType_FLOAT32, 0, AxisType(), ConstantAxis(), - OutputType()); - model.PopulateTensor( - model.input(), std::vector(input.begin(), input.begin() + i)); - model.Invoke(); - - ValidateOutput(model, {i - 1}); - } -} - } // namespace } // namespace tflite diff --git a/tensorflow/lite/kernels/internal/optimized/legacy_optimized_ops.h b/tensorflow/lite/kernels/internal/optimized/legacy_optimized_ops.h index 44a553c527a..60a5de4491d 100644 --- a/tensorflow/lite/kernels/internal/optimized/legacy_optimized_ops.h +++ b/tensorflow/lite/kernels/internal/optimized/legacy_optimized_ops.h @@ -33,6 +33,8 @@ namespace tflite { namespace optimized_ops { // Unoptimized reference ops: +using reference_ops::ArgMax; +using reference_ops::ArgMinMax; using reference_ops::Broadcast4DSlowGreater; using reference_ops::Broadcast4DSlowGreaterEqual; using reference_ops::Broadcast4DSlowGreaterEqualWithScaling; @@ -4964,31 +4966,6 @@ inline void StridedSlice(const T* input_data, const Dims<4>& input_dims, DimsToShape(output_dims), output_data); } -template -void ArgMax(const T3* axis, const T1* input_data, - const tflite::Dims<4>& input_dims, T2* output_data, - const tflite::Dims<4>& output_dims) { - // Assumes the input always has 4 dimensions, and therefore, - // output always has three dimensions. - auto output_shape = RuntimeShape( - {output_dims.sizes[2], output_dims.sizes[1], output_dims.sizes[0]}); - // Another way to interpret this is that output_dims.sizes[4] is always 1. - TFLITE_DCHECK_EQ(output_shape.FlatSize(), - DimsToShape(output_dims).FlatSize()); - // Legacy path only supported this. - TFLITE_DCHECK_EQ(axis[0], 3); - ArgMinMax(DimsToShape(input_dims), input_data, axis, output_shape, - output_data, /*is_arg_max=*/true); -} - -template -void ArgMinMax(const T3* axis, const T1* input_data, const Dims<4>& input_dims, - T2* output_data, const Dims<4>& output_dims, - const bool is_arg_max) { - ArgMinMax(axis, DimsToShape(input_dims), input_data, DimsToShape(output_dims), - output_data, is_arg_max); -} - } // namespace optimized_ops } // namespace tflite #endif // TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_LEGACY_OPTIMIZED_OPS_H_ diff --git a/tensorflow/lite/kernels/internal/optimized/optimized_ops.h b/tensorflow/lite/kernels/internal/optimized/optimized_ops.h index f3b6e15c8e7..aca89ae657b 100644 --- a/tensorflow/lite/kernels/internal/optimized/optimized_ops.h +++ b/tensorflow/lite/kernels/internal/optimized/optimized_ops.h @@ -64,6 +64,8 @@ namespace tflite { namespace optimized_ops { // Unoptimized reference ops: +using reference_ops::ArgMax; +using reference_ops::ArgMinMax; using reference_ops::Broadcast4DSlowGreater; using reference_ops::Broadcast4DSlowGreaterEqual; using reference_ops::Broadcast4DSlowGreaterEqualWithScaling; @@ -7728,238 +7730,6 @@ inline void BroadcastPReluDispatch( PReluElementWise, PReluScalarBroadcast); } -// Returns the index with minimum value within `input_data`. -// If there is a tie, returns the smaller index. -template -inline int ArgMinVector(const T* input_data, int size) { - T min_value = input_data[0]; - int min_index = 0; - for (int i = 1; i < size; ++i) { - const T curr_value = input_data[i]; - if (curr_value < min_value) { - min_value = curr_value; - min_index = i; - } - } - return min_index; -} - -// Returns the index with maximum value within `input_data`. -// If there is a tie, returns the smaller index. -template -inline int ArgMaxVector(const T* input_data, int size) { - T max_value = input_data[0]; - int max_index = 0; - for (int i = 1; i < size; ++i) { - const T curr_value = input_data[i]; - if (curr_value > max_value) { - max_value = curr_value; - max_index = i; - } - } - return max_index; -} - -template <> -inline int ArgMinVector(const float* input_data, int size) { - int32_t min_index = 0; - float min_value = input_data[0]; - int32_t i = 1; -#ifdef USE_NEON - if (size >= 4) { - float32x4_t min_value_f32x4 = vld1q_f32(input_data); - const int32_t index_init[4] = {0, 1, 2, 3}; - int32x4_t min_index_s32x4 = vld1q_s32(index_init); - int32x4_t index_s32x4 = min_index_s32x4; - int32x4_t inc = vdupq_n_s32(4); - for (i = 4; i <= size - 4; i += 4) { - // Increase indices by 4. - index_s32x4 = vaddq_s32(index_s32x4, inc); - float32x4_t v = vld1q_f32(&input_data[i]); - uint32x4_t mask = vcltq_f32(v, min_value_f32x4); - min_value_f32x4 = vminq_f32(min_value_f32x4, v); - min_index_s32x4 = vbslq_s32(mask, index_s32x4, min_index_s32x4); - } - // Find min element within float32x4_t. -#ifdef __aarch64__ - min_value = vminvq_f32(min_value_f32x4); -#else - float32x2_t min_value_f32x2 = vpmin_f32(vget_low_f32(min_value_f32x4), - vget_high_f32(min_value_f32x4)); - min_value_f32x2 = vpmin_f32(min_value_f32x2, min_value_f32x2); - min_value = vget_lane_f32(min_value_f32x2, 0); -#endif // __aarch64__ - // Mask indices of non-min values with max int32_t. - float32x4_t fill_min_value_f32x4 = vdupq_n_f32(min_value); - uint32x4_t mask = vceqq_f32(min_value_f32x4, fill_min_value_f32x4); - int32x4_t all_set = vdupq_n_s32(std::numeric_limits::max()); - min_index_s32x4 = vbslq_s32(mask, min_index_s32x4, all_set); - // Find min index of min values. -#ifdef __aarch64__ - min_index = vminvq_s32(min_index_s32x4); -#else - uint32x2_t min_index_s32x2 = vpmin_s32(vget_low_s32(min_index_s32x4), - vget_high_s32(min_index_s32x4)); - min_index_s32x2 = vpmin_s32(min_index_s32x2, min_index_s32x2); - min_index = vget_lane_s32(min_index_s32x2, 0); -#endif // __aarch64__ - } -#endif // USE_NEON - // Leftover loop. - for (; i < size; ++i) { - const float curr_value = input_data[i]; - if (curr_value < min_value) { - min_value = curr_value; - min_index = i; - } - } - return min_index; -} - -template <> -inline int ArgMaxVector(const float* input_data, int size) { - int32_t max_index = 0; - float max_value = input_data[0]; - int32_t i = 1; -#ifdef USE_NEON - if (size >= 4) { - float32x4_t max_value_f32x4 = vld1q_f32(input_data); - const int32_t index_init[4] = {0, 1, 2, 3}; - int32x4_t max_index_s32x4 = vld1q_s32(index_init); - int32x4_t index_s32x4 = max_index_s32x4; - int32x4_t inc = vdupq_n_s32(4); - for (i = 4; i <= size - 4; i += 4) { - // Increase indices by 4. - index_s32x4 = vaddq_s32(index_s32x4, inc); - float32x4_t v = vld1q_f32(&input_data[i]); - uint32x4_t mask = vcgtq_f32(v, max_value_f32x4); - max_value_f32x4 = vmaxq_f32(max_value_f32x4, v); - max_index_s32x4 = vbslq_s32(mask, index_s32x4, max_index_s32x4); - } - // Find max element within float32x4_t. -#ifdef __aarch64__ - max_value = vmaxvq_f32(max_value_f32x4); -#else - float32x2_t max_value_f32x2 = vpmax_f32(vget_low_f32(max_value_f32x4), - vget_high_f32(max_value_f32x4)); - max_value_f32x2 = vpmax_f32(max_value_f32x2, max_value_f32x2); - max_value = vget_lane_f32(max_value_f32x2, 0); -#endif // __aarch64__ - // Mask indices of non-max values with max int32_t. - float32x4_t fill_max_value_f32x4 = vdupq_n_f32(max_value); - uint32x4_t mask = vceqq_f32(max_value_f32x4, fill_max_value_f32x4); - int32x4_t all_set = vdupq_n_s32(std::numeric_limits::max()); - max_index_s32x4 = vbslq_s32(mask, max_index_s32x4, all_set); - // Find min index of max values. -#ifdef __aarch64__ - max_index = vminvq_s32(max_index_s32x4); -#else - uint32x2_t max_index_s32x2 = vpmin_s32(vget_low_s32(max_index_s32x4), - vget_high_s32(max_index_s32x4)); - max_index_s32x2 = vpmin_s32(max_index_s32x2, max_index_s32x2); - max_index = vget_lane_s32(max_index_s32x2, 0); -#endif // __aarch64__ - } -#endif // USE_NEON - // Leftover loop. - for (; i < size; ++i) { - const float curr_value = input_data[i]; - if (curr_value > max_value) { - max_value = curr_value; - max_index = i; - } - } - return max_index; -} - -// Specializes ArgMinMax function with axis=dims-1. -// In this case, ArgMinMax reduction is applied on contiguous memory. -template -inline void ArgMinMaxLastAxis(const RuntimeShape& input_shape, - const T1* input_data, - const RuntimeShape& output_shape, - T2* output_data) { - TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 2); - TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 1); - TFLITE_DCHECK_EQ(input_shape.Dims(0), output_shape.Dims(0)); - - int outer_size = input_shape.Dims(0); - int axis_size = input_shape.Dims(1); - for (int outer = 0; outer < outer_size; ++outer) { - if (is_arg_max) { - output_data[outer] = static_cast( - ArgMaxVector(input_data + outer * axis_size, axis_size)); - } else { - output_data[outer] = static_cast( - ArgMinVector(input_data + outer * axis_size, axis_size)); - } - } -} - -template -inline void ArgMinMax(const RuntimeShape& input1_shape, const T1* input1_data, - const T3* input2_data, const RuntimeShape& output_shape, - T2* output_data, const bool is_arg_max) { - ruy::profiler::ScopeLabel label("ArgMinMax"); - - TFLITE_DCHECK_GT(input1_shape.DimensionsCount(), 0); - TFLITE_DCHECK_EQ(input1_shape.DimensionsCount() - 1, - output_shape.DimensionsCount()); - int axis = input2_data[0]; - if (axis < 0) { - axis += input1_shape.DimensionsCount(); - } - const int axis_size = input1_shape.Dims(axis); - - int outer_size = 1; - for (int i = 0; i < axis; ++i) { - TFLITE_DCHECK_EQ(input1_shape.Dims(i), output_shape.Dims(i)); - outer_size *= input1_shape.Dims(i); - } - - int inner_size = 1; - const int dims_count = input1_shape.DimensionsCount(); - for (int i = axis + 1; i < dims_count; ++i) { - TFLITE_DCHECK_EQ(input1_shape.Dims(i), output_shape.Dims(i - 1)); - inner_size *= input1_shape.Dims(i); - } - - // Call specialized function when axis=dims-1. So far, only float32 is - // optimized so reroute to specialized function only when T1 is float32. - if (inner_size == 1 && std::is_same::value) { - if (is_arg_max) { - ArgMinMaxLastAxis( - {outer_size, axis_size}, input1_data, {outer_size}, output_data); - } else { - ArgMinMaxLastAxis( - {outer_size, axis_size}, input1_data, {outer_size}, output_data); - } - return; - } - - reference_ops::ArgMinMax(input1_shape, input1_data, input2_data, output_shape, - output_data, is_arg_max); -} - -template -void ArgMax(const RuntimeShape& input1_shape, const T1* input1_data, - const T3* input2_data, const RuntimeShape& output_shape, - T2* output_data) { - ArgMinMax(input1_shape, input1_data, input2_data, output_shape, output_data, - /*is_arg_max=*/true); -} - -// Convenience version that allows, for example, generated-code calls to be -// the same as other binary ops. -// For backward compatibility, reference_ops has ArgMax function. -template -inline void ArgMax(const RuntimeShape& input1_shape, const T1* input1_data, - const RuntimeShape& input2_shape, const T3* input2_data, - const RuntimeShape& output_shape, T2* output_data) { - // Drop shape of second input: not needed. - ArgMax(input1_shape, input1_data, input2_data, output_shape, output_data); -} - } // namespace optimized_ops } // namespace tflite diff --git a/tensorflow/lite/kernels/internal/reference/arg_min_max.h b/tensorflow/lite/kernels/internal/reference/arg_min_max.h index 8154fbf71e3..e6f34fd73f4 100644 --- a/tensorflow/lite/kernels/internal/reference/arg_min_max.h +++ b/tensorflow/lite/kernels/internal/reference/arg_min_max.h @@ -15,23 +15,12 @@ limitations under the License. #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_ARG_MIN_MAX_H_ #define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_ARG_MIN_MAX_H_ -#include - #include "tensorflow/lite/kernels/internal/types.h" namespace tflite { namespace reference_ops { -template -std::function GetComparefunction(bool is_arg_max) { - if (is_arg_max) { - return std::greater(); - } else { - return std::less(); - } -} - template void ArgMinMax(const RuntimeShape& input1_shape, const T1* input1_data, const T3* input2_data, const RuntimeShape& output_shape, @@ -73,15 +62,6 @@ void ArgMinMax(const RuntimeShape& input1_shape, const T1* input1_data, } } } - -template -void ArgMinMax(const RuntimeShape& input1_shape, const T1* input1_data, - const T3* input2_data, const RuntimeShape& output_shape, - T2* output_data, const bool is_arg_max) { - ArgMinMax(input1_shape, input1_data, input2_data, output_shape, output_data, - GetComparefunction(is_arg_max)); -} - } // namespace reference_ops } // namespace tflite diff --git a/tensorflow/lite/testing/op_tests/arg_min_max.py b/tensorflow/lite/testing/op_tests/arg_min_max.py index d9645ffa213..ec0013225e0 100644 --- a/tensorflow/lite/testing/op_tests/arg_min_max.py +++ b/tensorflow/lite/testing/op_tests/arg_min_max.py @@ -29,26 +29,13 @@ from tensorflow.lite.testing.zip_test_utils import register_make_test_function def make_arg_min_max_tests(options): """Make a set of tests to do arg_max.""" - test_parameters = [ - { - "input_dtype": [tf.float32, tf.int32], - "input_shape": [[], [1, 1, 1, 3], [2, 3, 4, 5], [2, 3, 3], [5, 5], - [10]], - "output_type": [tf.int32, tf.int64], - "is_arg_max": [True], - "is_last_axis": [False], - "dynamic_range_quantize": [False, True], - }, - { - "input_dtype": [tf.float32, tf.int32], - "input_shape": [[1], [2], [3], [4], [5], [6], [7], [8], [9], [10], - [2, 10], [3, 4, 50], [2, 3, 5, 100]], - "output_type": [tf.int32, tf.int64], - "is_arg_max": [False, True], - "is_last_axis": [True], - "dynamic_range_quantize": [False, True], - }, - ] + test_parameters = [{ + "input_dtype": [tf.float32, tf.int32], + "input_shape": [[], [1, 1, 1, 3], [2, 3, 4, 5], [2, 3, 3], [5, 5], [10]], + "output_type": [tf.int32, tf.int64], + "is_arg_max": [True], + "dynamic_range_quantize": [False, True], + }] def build_graph(parameters): """Build the topk op testing graph.""" @@ -56,10 +43,7 @@ def make_arg_min_max_tests(options): dtype=parameters["input_dtype"], name="input", shape=parameters["input_shape"]) - if not parameters["is_last_axis"]: - axis = random.randint(0, max(len(parameters["input_shape"]) - 1, 0)) - else: - axis = -1 + axis = random.randint(0, max(len(parameters["input_shape"]) - 1, 0)) if parameters["is_arg_max"]: out = tf.math.argmax( input_value, axis, output_type=parameters["output_type"])