[TF/XLA] Change XLA:GPU and XLA:CPU argmax/argmin implementation to return the smallest index corresponding to the largest/smallest number
The change makes it consistent with numpy and TF. PiperOrigin-RevId: 305754325 Change-Id: I2fae3c973b53802af6184c3b465ebd592f73692d
This commit is contained in:
parent
d74dcb30ce
commit
536fc4c919
@ -41,6 +41,7 @@ xla_test(
|
||||
deps = [
|
||||
":arithmetic",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||
|
||||
@ -126,7 +126,7 @@ XlaComputation CreateMinMaxComputation(XlaBuilder* outer_builder,
|
||||
XlaOp rhs_index =
|
||||
Parameter(b, 3, ShapeUtil::MakeShape(index_type, {}), "rhs_index");
|
||||
|
||||
auto cmp = is_min ? Lt(lhs_value, rhs_value) : Gt(lhs_value, rhs_value);
|
||||
auto cmp = is_min ? Le(lhs_value, rhs_value) : Ge(lhs_value, rhs_value);
|
||||
XlaOp max = Select(cmp, lhs_value, rhs_value);
|
||||
XlaOp arg_max = Select(cmp, lhs_index, rhs_index);
|
||||
Tuple(b, {max, arg_max});
|
||||
@ -178,36 +178,22 @@ XlaOp ArgMinMaxTwoPass(XlaOp input, PrimitiveType output_type, int axis,
|
||||
reducer = CreateScalarMaxComputation(input_shape.element_type(), builder);
|
||||
}
|
||||
|
||||
XlaOp iota = Iota(
|
||||
builder, ShapeUtil::ChangeElementType(input_shape, output_type), axis);
|
||||
XlaOp input_max = Reduce(input, init_value, reducer,
|
||||
/*dimensions_to_reduce=*/{axis});
|
||||
std::vector<int64> broadcast_dims(input_shape.rank() - 1);
|
||||
std::iota(broadcast_dims.begin(), broadcast_dims.begin() + axis, 0);
|
||||
std::iota(broadcast_dims.begin() + axis, broadcast_dims.end(), axis + 1);
|
||||
// Compute a mask that has 1s for elements equal to the maximum.
|
||||
XlaOp partial_mask =
|
||||
ConvertElementType(Eq(input, input_max, broadcast_dims), output_type);
|
||||
|
||||
// In order to make identity elements for a bitwise And, we:
|
||||
// Left shift the 1 to the leftmost bit, yielding 0x10...0
|
||||
// Arithmetic right shift the 1 back to the rightmost bit, yielding
|
||||
// 0xFF...F
|
||||
int32 bits_in_type =
|
||||
ShapeUtil::ByteSizeOfPrimitiveType(output_type) * 8 - 1;
|
||||
XlaOp shift_amount = ConstantR0WithType(builder, output_type, bits_in_type);
|
||||
XlaOp full_mask = ShiftRightArithmetic(
|
||||
ShiftLeft(partial_mask, shift_amount), shift_amount);
|
||||
XlaOp max_idx = MaxValue(builder, output_type);
|
||||
XlaOp select_mask = Select(Eq(input, input_max, broadcast_dims),
|
||||
/*on_true=*/iota,
|
||||
/*on_false=*/
|
||||
max_idx);
|
||||
|
||||
// And with the vector [0, 1, 2, ...] to convert each 0xFF...F into its
|
||||
// index.
|
||||
|
||||
const int64 axis_size = ShapeUtil::GetDimension(input_shape, axis);
|
||||
XlaOp iota = Iota(builder, output_type, axis_size);
|
||||
XlaOp product = And(full_mask, iota, /*broadcast_dimensions=*/{axis});
|
||||
|
||||
// If there are multiple maximum elements, choose the one with the highest
|
||||
// index.
|
||||
return Reduce(product, MinValue(builder, output_type),
|
||||
CreateScalarMaxComputation(output_type, builder),
|
||||
return Reduce(select_mask, max_idx,
|
||||
CreateScalarMinComputation(output_type, builder),
|
||||
/*dimensions_to_reduce=*/{axis});
|
||||
});
|
||||
}
|
||||
|
||||
@ -14,8 +14,12 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
|
||||
|
||||
#include <initializer_list>
|
||||
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/primitive_util.h"
|
||||
#include "tensorflow/compiler/xla/test.h"
|
||||
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
|
||||
#include "tensorflow/compiler/xla/tests/test_macros.h"
|
||||
@ -25,42 +29,65 @@ limitations under the License.
|
||||
namespace xla {
|
||||
namespace {
|
||||
|
||||
using ArithmeticTest = ClientLibraryTestBase;
|
||||
class ArithmeticTest : public ClientLibraryTestBase {
|
||||
public:
|
||||
template <typename NativeT>
|
||||
void TestArgMin(std::initializer_list<std::initializer_list<NativeT>> input,
|
||||
absl::Span<NativeT const> expected_output, int axis) {
|
||||
return TestArgMinMax(input, expected_output, axis, /*is_min=*/true);
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
void TestArgMax(std::initializer_list<std::initializer_list<NativeT>> input,
|
||||
absl::Span<NativeT const> expected_output, int axis) {
|
||||
return TestArgMinMax(input, expected_output, axis, /*is_min=*/false);
|
||||
}
|
||||
|
||||
private:
|
||||
// Test ArgMin/ArgMax implementation, both single- and two- pass.
|
||||
template <typename NativeT>
|
||||
void TestArgMinMax(
|
||||
std::initializer_list<std::initializer_list<NativeT>> input,
|
||||
absl::Span<NativeT const> expected_output, int axis, bool is_min) {
|
||||
if (is_min) {
|
||||
TestArgMinMaxImpl(input, expected_output, axis, &ArgMin);
|
||||
TestArgMinMaxImpl(input, expected_output, axis, &ArgMinTwoPass);
|
||||
} else {
|
||||
TestArgMinMaxImpl(input, expected_output, axis, &ArgMax);
|
||||
TestArgMinMaxImpl(input, expected_output, axis, &ArgMaxTwoPass);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
void TestArgMinMaxImpl(
|
||||
std::initializer_list<std::initializer_list<NativeT>> input,
|
||||
absl::Span<NativeT const> expected_output, int axis,
|
||||
std::function<void(XlaOp, PrimitiveType, int)> MinMaxImpl) {
|
||||
XlaBuilder builder(TestName());
|
||||
XlaOp x = ConstantR2<NativeT>(&builder, input);
|
||||
MinMaxImpl(x, primitive_util::NativeToPrimitiveType<NativeT>(), axis);
|
||||
ComputeAndCompareR1<NativeT>(&builder, expected_output, {});
|
||||
}
|
||||
};
|
||||
|
||||
XLA_TEST_F(ArithmeticTest, ArgMinR2Axis0) {
|
||||
XlaBuilder builder(TestName());
|
||||
auto x = ConstantR2<int32>(&builder, {{1, 7, 4}, {6, 3, 5}, {8, 3, 3}});
|
||||
ArgMin(x, S32, /*axis=*/0);
|
||||
|
||||
std::vector<int32> expected = {0, 2, 2};
|
||||
ComputeAndCompareR1<int32>(&builder, expected, {});
|
||||
TestArgMin<int32>({{1, 7, 4}, {6, 3, 5}, {8, 3, 3}}, {0, 1, 2},
|
||||
/*axis=*/0);
|
||||
}
|
||||
|
||||
XLA_TEST_F(ArithmeticTest, ArgMinR2Axis1) {
|
||||
XlaBuilder builder(TestName());
|
||||
auto x = ConstantR2<int32>(&builder, {{1, 7, 4}, {6, 3, 5}, {8, 3, 3}});
|
||||
ArgMin(x, S32, /*axis=*/1);
|
||||
|
||||
std::vector<int32> expected = {0, 1, 2};
|
||||
ComputeAndCompareR1<int32>(&builder, expected, {});
|
||||
TestArgMin<int32>({{1, 7, 4}, {6, 3, 5}, {8, 3, 3}}, {0, 1, 1},
|
||||
/*axis=*/1);
|
||||
}
|
||||
|
||||
XLA_TEST_F(ArithmeticTest, ArgMaxR2Axis0) {
|
||||
XlaBuilder builder(TestName());
|
||||
auto x = ConstantR2<int32>(&builder, {{1, 7, 4}, {6, 3, 5}, {8, 3, 3}});
|
||||
ArgMax(x, S32, /*axis=*/0);
|
||||
|
||||
std::vector<int32> expected = {2, 0, 1};
|
||||
ComputeAndCompareR1<int32>(&builder, expected, {});
|
||||
TestArgMax<int32>({{1, 7, 4}, {6, 3, 5}, {8, 3, 3}}, {2, 0, 1},
|
||||
/*axis=*/0);
|
||||
}
|
||||
|
||||
XLA_TEST_F(ArithmeticTest, ArgMaxR2Axis1) {
|
||||
XlaBuilder builder(TestName());
|
||||
auto x = ConstantR2<int32>(&builder, {{1, 7, 4}, {6, 3, 5}, {8, 3, 3}});
|
||||
ArgMax(x, S32, /*axis=*/1);
|
||||
|
||||
std::vector<int32> expected = {1, 0, 0};
|
||||
ComputeAndCompareR1<int32>(&builder, expected, {});
|
||||
TestArgMax<int32>({{1, 7, 4}, {6, 3, 5}, {8, 3, 3}}, {1, 0, 0},
|
||||
/*axis=*/1);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
@ -28,6 +28,7 @@ from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import control_flow_util
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
@ -254,6 +255,21 @@ class DefFunctionTest(test.TestCase):
|
||||
|
||||
z()
|
||||
|
||||
def testArgMinMax(self):
|
||||
|
||||
@def_function.function(experimental_compile=True)
|
||||
def argmax(x):
|
||||
return math_ops.argmax(x)
|
||||
|
||||
@def_function.function(experimental_compile=True)
|
||||
def argmin(x):
|
||||
return math_ops.argmin(x)
|
||||
|
||||
self.assertAllClose(0, argmax(array_ops.ones([10], dtype=dtypes.float32)))
|
||||
self.assertAllClose(0, argmax(array_ops.ones([10])))
|
||||
self.assertAllClose(0, argmin(array_ops.ones([10], dtype=dtypes.float32)))
|
||||
self.assertAllClose(0, argmin(array_ops.ones([10])))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
ops.enable_eager_execution()
|
||||
|
||||
@ -68,6 +68,14 @@ class ArgMaxTest(test.TestCase):
|
||||
self._testBothArg(math_ops.argmax, x, 0, x.argmax())
|
||||
self._testBothArg(math_ops.argmin, x, 0, x.argmin())
|
||||
|
||||
def _testTieBreaking(self, dtype):
|
||||
x = np.zeros(200, dtype=dtype)
|
||||
|
||||
# Check that argmin and argmax match numpy along the primary axis for
|
||||
# breaking ties.
|
||||
self._testBothArg(math_ops.argmax, x, 0, x.argmax())
|
||||
self._testBothArg(math_ops.argmin, x, 0, x.argmin())
|
||||
|
||||
def _testDim(self, dtype):
|
||||
shape = (3, 2, 4, 5, 6, 3, 7)
|
||||
x = np.arange(functools.reduce(lambda x, y: x * y, shape), dtype=dtype)
|
||||
@ -81,6 +89,7 @@ class ArgMaxTest(test.TestCase):
|
||||
|
||||
def testFloat(self):
|
||||
self._testBasic(np.float32)
|
||||
self._testTieBreaking(np.float32)
|
||||
self._testDim(np.float32)
|
||||
|
||||
def testFloatInt32Output(self):
|
||||
@ -102,14 +111,17 @@ class ArgMaxTest(test.TestCase):
|
||||
|
||||
def testDouble(self):
|
||||
self._testBasic(np.float64)
|
||||
self._testTieBreaking(np.float64)
|
||||
self._testDim(np.float64)
|
||||
|
||||
def testInt32(self):
|
||||
self._testBasic(np.int32)
|
||||
self._testTieBreaking(np.int32)
|
||||
self._testDim(np.int32)
|
||||
|
||||
def testInt64(self):
|
||||
self._testBasic(np.int64)
|
||||
self._testTieBreaking(np.int64)
|
||||
self._testDim(np.int64)
|
||||
|
||||
def testEmpty(self):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user