[TF/XLA] Change XLA:GPU and XLA:CPU argmax/argmin implementation to return the smallest index corresponding to the largest/smallest number

The change makes it consistent with numpy and TF.

PiperOrigin-RevId: 305754325
Change-Id: I2fae3c973b53802af6184c3b465ebd592f73692d
This commit is contained in:
George Karpenkov 2020-04-09 13:57:14 -07:00 committed by TensorFlower Gardener
parent d74dcb30ce
commit 536fc4c919
5 changed files with 91 additions and 49 deletions

View File

@ -41,6 +41,7 @@ xla_test(
deps = [
":arithmetic",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto_cc",

View File

@ -126,7 +126,7 @@ XlaComputation CreateMinMaxComputation(XlaBuilder* outer_builder,
XlaOp rhs_index =
Parameter(b, 3, ShapeUtil::MakeShape(index_type, {}), "rhs_index");
auto cmp = is_min ? Lt(lhs_value, rhs_value) : Gt(lhs_value, rhs_value);
auto cmp = is_min ? Le(lhs_value, rhs_value) : Ge(lhs_value, rhs_value);
XlaOp max = Select(cmp, lhs_value, rhs_value);
XlaOp arg_max = Select(cmp, lhs_index, rhs_index);
Tuple(b, {max, arg_max});
@ -178,36 +178,22 @@ XlaOp ArgMinMaxTwoPass(XlaOp input, PrimitiveType output_type, int axis,
reducer = CreateScalarMaxComputation(input_shape.element_type(), builder);
}
XlaOp iota = Iota(
builder, ShapeUtil::ChangeElementType(input_shape, output_type), axis);
XlaOp input_max = Reduce(input, init_value, reducer,
/*dimensions_to_reduce=*/{axis});
std::vector<int64> broadcast_dims(input_shape.rank() - 1);
std::iota(broadcast_dims.begin(), broadcast_dims.begin() + axis, 0);
std::iota(broadcast_dims.begin() + axis, broadcast_dims.end(), axis + 1);
// Compute a mask that has 1s for elements equal to the maximum.
XlaOp partial_mask =
ConvertElementType(Eq(input, input_max, broadcast_dims), output_type);
// In order to make identity elements for a bitwise And, we:
// Left shift the 1 to the leftmost bit, yielding 0x10...0
// Arithmetic right shift the 1 back to the rightmost bit, yielding
// 0xFF...F
int32 bits_in_type =
ShapeUtil::ByteSizeOfPrimitiveType(output_type) * 8 - 1;
XlaOp shift_amount = ConstantR0WithType(builder, output_type, bits_in_type);
XlaOp full_mask = ShiftRightArithmetic(
ShiftLeft(partial_mask, shift_amount), shift_amount);
XlaOp max_idx = MaxValue(builder, output_type);
XlaOp select_mask = Select(Eq(input, input_max, broadcast_dims),
/*on_true=*/iota,
/*on_false=*/
max_idx);
// And with the vector [0, 1, 2, ...] to convert each 0xFF...F into its
// index.
const int64 axis_size = ShapeUtil::GetDimension(input_shape, axis);
XlaOp iota = Iota(builder, output_type, axis_size);
XlaOp product = And(full_mask, iota, /*broadcast_dimensions=*/{axis});
// If there are multiple maximum elements, choose the one with the highest
// index.
return Reduce(product, MinValue(builder, output_type),
CreateScalarMaxComputation(output_type, builder),
return Reduce(select_mask, max_idx,
CreateScalarMinComputation(output_type, builder),
/*dimensions_to_reduce=*/{axis});
});
}

View File

@ -14,8 +14,12 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
#include <initializer_list>
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/primitive_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
#include "tensorflow/compiler/xla/tests/test_macros.h"
@ -25,42 +29,65 @@ limitations under the License.
namespace xla {
namespace {
using ArithmeticTest = ClientLibraryTestBase;
class ArithmeticTest : public ClientLibraryTestBase {
public:
template <typename NativeT>
void TestArgMin(std::initializer_list<std::initializer_list<NativeT>> input,
absl::Span<NativeT const> expected_output, int axis) {
return TestArgMinMax(input, expected_output, axis, /*is_min=*/true);
}
template <typename NativeT>
void TestArgMax(std::initializer_list<std::initializer_list<NativeT>> input,
absl::Span<NativeT const> expected_output, int axis) {
return TestArgMinMax(input, expected_output, axis, /*is_min=*/false);
}
private:
// Test ArgMin/ArgMax implementation, both single- and two- pass.
template <typename NativeT>
void TestArgMinMax(
std::initializer_list<std::initializer_list<NativeT>> input,
absl::Span<NativeT const> expected_output, int axis, bool is_min) {
if (is_min) {
TestArgMinMaxImpl(input, expected_output, axis, &ArgMin);
TestArgMinMaxImpl(input, expected_output, axis, &ArgMinTwoPass);
} else {
TestArgMinMaxImpl(input, expected_output, axis, &ArgMax);
TestArgMinMaxImpl(input, expected_output, axis, &ArgMaxTwoPass);
}
}
template <typename NativeT>
void TestArgMinMaxImpl(
std::initializer_list<std::initializer_list<NativeT>> input,
absl::Span<NativeT const> expected_output, int axis,
std::function<void(XlaOp, PrimitiveType, int)> MinMaxImpl) {
XlaBuilder builder(TestName());
XlaOp x = ConstantR2<NativeT>(&builder, input);
MinMaxImpl(x, primitive_util::NativeToPrimitiveType<NativeT>(), axis);
ComputeAndCompareR1<NativeT>(&builder, expected_output, {});
}
};
XLA_TEST_F(ArithmeticTest, ArgMinR2Axis0) {
XlaBuilder builder(TestName());
auto x = ConstantR2<int32>(&builder, {{1, 7, 4}, {6, 3, 5}, {8, 3, 3}});
ArgMin(x, S32, /*axis=*/0);
std::vector<int32> expected = {0, 2, 2};
ComputeAndCompareR1<int32>(&builder, expected, {});
TestArgMin<int32>({{1, 7, 4}, {6, 3, 5}, {8, 3, 3}}, {0, 1, 2},
/*axis=*/0);
}
XLA_TEST_F(ArithmeticTest, ArgMinR2Axis1) {
XlaBuilder builder(TestName());
auto x = ConstantR2<int32>(&builder, {{1, 7, 4}, {6, 3, 5}, {8, 3, 3}});
ArgMin(x, S32, /*axis=*/1);
std::vector<int32> expected = {0, 1, 2};
ComputeAndCompareR1<int32>(&builder, expected, {});
TestArgMin<int32>({{1, 7, 4}, {6, 3, 5}, {8, 3, 3}}, {0, 1, 1},
/*axis=*/1);
}
XLA_TEST_F(ArithmeticTest, ArgMaxR2Axis0) {
XlaBuilder builder(TestName());
auto x = ConstantR2<int32>(&builder, {{1, 7, 4}, {6, 3, 5}, {8, 3, 3}});
ArgMax(x, S32, /*axis=*/0);
std::vector<int32> expected = {2, 0, 1};
ComputeAndCompareR1<int32>(&builder, expected, {});
TestArgMax<int32>({{1, 7, 4}, {6, 3, 5}, {8, 3, 3}}, {2, 0, 1},
/*axis=*/0);
}
XLA_TEST_F(ArithmeticTest, ArgMaxR2Axis1) {
XlaBuilder builder(TestName());
auto x = ConstantR2<int32>(&builder, {{1, 7, 4}, {6, 3, 5}, {8, 3, 3}});
ArgMax(x, S32, /*axis=*/1);
std::vector<int32> expected = {1, 0, 0};
ComputeAndCompareR1<int32>(&builder, expected, {});
TestArgMax<int32>({{1, 7, 4}, {6, 3, 5}, {8, 3, 3}}, {1, 0, 0},
/*axis=*/1);
}
} // namespace

View File

@ -28,6 +28,7 @@ from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import control_flow_util
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.platform import test
@ -254,6 +255,21 @@ class DefFunctionTest(test.TestCase):
z()
def testArgMinMax(self):
@def_function.function(experimental_compile=True)
def argmax(x):
return math_ops.argmax(x)
@def_function.function(experimental_compile=True)
def argmin(x):
return math_ops.argmin(x)
self.assertAllClose(0, argmax(array_ops.ones([10], dtype=dtypes.float32)))
self.assertAllClose(0, argmax(array_ops.ones([10])))
self.assertAllClose(0, argmin(array_ops.ones([10], dtype=dtypes.float32)))
self.assertAllClose(0, argmin(array_ops.ones([10])))
if __name__ == '__main__':
ops.enable_eager_execution()

View File

@ -68,6 +68,14 @@ class ArgMaxTest(test.TestCase):
self._testBothArg(math_ops.argmax, x, 0, x.argmax())
self._testBothArg(math_ops.argmin, x, 0, x.argmin())
def _testTieBreaking(self, dtype):
x = np.zeros(200, dtype=dtype)
# Check that argmin and argmax match numpy along the primary axis for
# breaking ties.
self._testBothArg(math_ops.argmax, x, 0, x.argmax())
self._testBothArg(math_ops.argmin, x, 0, x.argmin())
def _testDim(self, dtype):
shape = (3, 2, 4, 5, 6, 3, 7)
x = np.arange(functools.reduce(lambda x, y: x * y, shape), dtype=dtype)
@ -81,6 +89,7 @@ class ArgMaxTest(test.TestCase):
def testFloat(self):
self._testBasic(np.float32)
self._testTieBreaking(np.float32)
self._testDim(np.float32)
def testFloatInt32Output(self):
@ -102,14 +111,17 @@ class ArgMaxTest(test.TestCase):
def testDouble(self):
self._testBasic(np.float64)
self._testTieBreaking(np.float64)
self._testDim(np.float64)
def testInt32(self):
self._testBasic(np.int32)
self._testTieBreaking(np.int32)
self._testDim(np.int32)
def testInt64(self):
self._testBasic(np.int64)
self._testTieBreaking(np.int64)
self._testDim(np.int64)
def testEmpty(self):