diff --git a/tensorflow/compiler/xla/client/lib/BUILD b/tensorflow/compiler/xla/client/lib/BUILD index 6fcdef46f29..38ddbd5abf7 100644 --- a/tensorflow/compiler/xla/client/lib/BUILD +++ b/tensorflow/compiler/xla/client/lib/BUILD @@ -41,6 +41,7 @@ xla_test( deps = [ ":arithmetic", "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto_cc", diff --git a/tensorflow/compiler/xla/client/lib/arithmetic.cc b/tensorflow/compiler/xla/client/lib/arithmetic.cc index de573429fdc..a24f110fd7a 100644 --- a/tensorflow/compiler/xla/client/lib/arithmetic.cc +++ b/tensorflow/compiler/xla/client/lib/arithmetic.cc @@ -126,7 +126,7 @@ XlaComputation CreateMinMaxComputation(XlaBuilder* outer_builder, XlaOp rhs_index = Parameter(b, 3, ShapeUtil::MakeShape(index_type, {}), "rhs_index"); - auto cmp = is_min ? Lt(lhs_value, rhs_value) : Gt(lhs_value, rhs_value); + auto cmp = is_min ? Le(lhs_value, rhs_value) : Ge(lhs_value, rhs_value); XlaOp max = Select(cmp, lhs_value, rhs_value); XlaOp arg_max = Select(cmp, lhs_index, rhs_index); Tuple(b, {max, arg_max}); @@ -178,36 +178,22 @@ XlaOp ArgMinMaxTwoPass(XlaOp input, PrimitiveType output_type, int axis, reducer = CreateScalarMaxComputation(input_shape.element_type(), builder); } + XlaOp iota = Iota( + builder, ShapeUtil::ChangeElementType(input_shape, output_type), axis); XlaOp input_max = Reduce(input, init_value, reducer, /*dimensions_to_reduce=*/{axis}); std::vector broadcast_dims(input_shape.rank() - 1); std::iota(broadcast_dims.begin(), broadcast_dims.begin() + axis, 0); std::iota(broadcast_dims.begin() + axis, broadcast_dims.end(), axis + 1); - // Compute a mask that has 1s for elements equal to the maximum. - XlaOp partial_mask = - ConvertElementType(Eq(input, input_max, broadcast_dims), output_type); - // In order to make identity elements for a bitwise And, we: - // Left shift the 1 to the leftmost bit, yielding 0x10...0 - // Arithmetic right shift the 1 back to the rightmost bit, yielding - // 0xFF...F - int32 bits_in_type = - ShapeUtil::ByteSizeOfPrimitiveType(output_type) * 8 - 1; - XlaOp shift_amount = ConstantR0WithType(builder, output_type, bits_in_type); - XlaOp full_mask = ShiftRightArithmetic( - ShiftLeft(partial_mask, shift_amount), shift_amount); + XlaOp max_idx = MaxValue(builder, output_type); + XlaOp select_mask = Select(Eq(input, input_max, broadcast_dims), + /*on_true=*/iota, + /*on_false=*/ + max_idx); - // And with the vector [0, 1, 2, ...] to convert each 0xFF...F into its - // index. - - const int64 axis_size = ShapeUtil::GetDimension(input_shape, axis); - XlaOp iota = Iota(builder, output_type, axis_size); - XlaOp product = And(full_mask, iota, /*broadcast_dimensions=*/{axis}); - - // If there are multiple maximum elements, choose the one with the highest - // index. - return Reduce(product, MinValue(builder, output_type), - CreateScalarMaxComputation(output_type, builder), + return Reduce(select_mask, max_idx, + CreateScalarMinComputation(output_type, builder), /*dimensions_to_reduce=*/{axis}); }); } diff --git a/tensorflow/compiler/xla/client/lib/arithmetic_test.cc b/tensorflow/compiler/xla/client/lib/arithmetic_test.cc index a13839f9db8..d3ff14d8a9b 100644 --- a/tensorflow/compiler/xla/client/lib/arithmetic_test.cc +++ b/tensorflow/compiler/xla/client/lib/arithmetic_test.cc @@ -14,8 +14,12 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/client/lib/arithmetic.h" + +#include + #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/test_macros.h" @@ -25,42 +29,65 @@ limitations under the License. namespace xla { namespace { -using ArithmeticTest = ClientLibraryTestBase; +class ArithmeticTest : public ClientLibraryTestBase { + public: + template + void TestArgMin(std::initializer_list> input, + absl::Span expected_output, int axis) { + return TestArgMinMax(input, expected_output, axis, /*is_min=*/true); + } + + template + void TestArgMax(std::initializer_list> input, + absl::Span expected_output, int axis) { + return TestArgMinMax(input, expected_output, axis, /*is_min=*/false); + } + + private: + // Test ArgMin/ArgMax implementation, both single- and two- pass. + template + void TestArgMinMax( + std::initializer_list> input, + absl::Span expected_output, int axis, bool is_min) { + if (is_min) { + TestArgMinMaxImpl(input, expected_output, axis, &ArgMin); + TestArgMinMaxImpl(input, expected_output, axis, &ArgMinTwoPass); + } else { + TestArgMinMaxImpl(input, expected_output, axis, &ArgMax); + TestArgMinMaxImpl(input, expected_output, axis, &ArgMaxTwoPass); + } + } + + template + void TestArgMinMaxImpl( + std::initializer_list> input, + absl::Span expected_output, int axis, + std::function MinMaxImpl) { + XlaBuilder builder(TestName()); + XlaOp x = ConstantR2(&builder, input); + MinMaxImpl(x, primitive_util::NativeToPrimitiveType(), axis); + ComputeAndCompareR1(&builder, expected_output, {}); + } +}; XLA_TEST_F(ArithmeticTest, ArgMinR2Axis0) { - XlaBuilder builder(TestName()); - auto x = ConstantR2(&builder, {{1, 7, 4}, {6, 3, 5}, {8, 3, 3}}); - ArgMin(x, S32, /*axis=*/0); - - std::vector expected = {0, 2, 2}; - ComputeAndCompareR1(&builder, expected, {}); + TestArgMin({{1, 7, 4}, {6, 3, 5}, {8, 3, 3}}, {0, 1, 2}, + /*axis=*/0); } XLA_TEST_F(ArithmeticTest, ArgMinR2Axis1) { - XlaBuilder builder(TestName()); - auto x = ConstantR2(&builder, {{1, 7, 4}, {6, 3, 5}, {8, 3, 3}}); - ArgMin(x, S32, /*axis=*/1); - - std::vector expected = {0, 1, 2}; - ComputeAndCompareR1(&builder, expected, {}); + TestArgMin({{1, 7, 4}, {6, 3, 5}, {8, 3, 3}}, {0, 1, 1}, + /*axis=*/1); } XLA_TEST_F(ArithmeticTest, ArgMaxR2Axis0) { - XlaBuilder builder(TestName()); - auto x = ConstantR2(&builder, {{1, 7, 4}, {6, 3, 5}, {8, 3, 3}}); - ArgMax(x, S32, /*axis=*/0); - - std::vector expected = {2, 0, 1}; - ComputeAndCompareR1(&builder, expected, {}); + TestArgMax({{1, 7, 4}, {6, 3, 5}, {8, 3, 3}}, {2, 0, 1}, + /*axis=*/0); } XLA_TEST_F(ArithmeticTest, ArgMaxR2Axis1) { - XlaBuilder builder(TestName()); - auto x = ConstantR2(&builder, {{1, 7, 4}, {6, 3, 5}, {8, 3, 3}}); - ArgMax(x, S32, /*axis=*/1); - - std::vector expected = {1, 0, 0}; - ComputeAndCompareR1(&builder, expected, {}); + TestArgMax({{1, 7, 4}, {6, 3, 5}, {8, 3, 3}}, {1, 0, 0}, + /*axis=*/1); } } // namespace diff --git a/tensorflow/python/eager/def_function_xla_jit_test.py b/tensorflow/python/eager/def_function_xla_jit_test.py index 3a311f2b2c5..f17033d126c 100644 --- a/tensorflow/python/eager/def_function_xla_jit_test.py +++ b/tensorflow/python/eager/def_function_xla_jit_test.py @@ -28,6 +28,7 @@ from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import control_flow_util +from tensorflow.python.ops import math_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.platform import test @@ -254,6 +255,21 @@ class DefFunctionTest(test.TestCase): z() + def testArgMinMax(self): + + @def_function.function(experimental_compile=True) + def argmax(x): + return math_ops.argmax(x) + + @def_function.function(experimental_compile=True) + def argmin(x): + return math_ops.argmin(x) + + self.assertAllClose(0, argmax(array_ops.ones([10], dtype=dtypes.float32))) + self.assertAllClose(0, argmax(array_ops.ones([10]))) + self.assertAllClose(0, argmin(array_ops.ones([10], dtype=dtypes.float32))) + self.assertAllClose(0, argmin(array_ops.ones([10]))) + if __name__ == '__main__': ops.enable_eager_execution() diff --git a/tensorflow/python/kernel_tests/argmax_op_test.py b/tensorflow/python/kernel_tests/argmax_op_test.py index 86d2941b8d3..023766c899d 100644 --- a/tensorflow/python/kernel_tests/argmax_op_test.py +++ b/tensorflow/python/kernel_tests/argmax_op_test.py @@ -68,6 +68,14 @@ class ArgMaxTest(test.TestCase): self._testBothArg(math_ops.argmax, x, 0, x.argmax()) self._testBothArg(math_ops.argmin, x, 0, x.argmin()) + def _testTieBreaking(self, dtype): + x = np.zeros(200, dtype=dtype) + + # Check that argmin and argmax match numpy along the primary axis for + # breaking ties. + self._testBothArg(math_ops.argmax, x, 0, x.argmax()) + self._testBothArg(math_ops.argmin, x, 0, x.argmin()) + def _testDim(self, dtype): shape = (3, 2, 4, 5, 6, 3, 7) x = np.arange(functools.reduce(lambda x, y: x * y, shape), dtype=dtype) @@ -81,6 +89,7 @@ class ArgMaxTest(test.TestCase): def testFloat(self): self._testBasic(np.float32) + self._testTieBreaking(np.float32) self._testDim(np.float32) def testFloatInt32Output(self): @@ -102,14 +111,17 @@ class ArgMaxTest(test.TestCase): def testDouble(self): self._testBasic(np.float64) + self._testTieBreaking(np.float64) self._testDim(np.float64) def testInt32(self): self._testBasic(np.int32) + self._testTieBreaking(np.int32) self._testDim(np.int32) def testInt64(self): self._testBasic(np.int64) + self._testTieBreaking(np.int64) self._testDim(np.int64) def testEmpty(self):