[XLA] Implement stable argmin and argmax

PiperOrigin-RevId: 307172027
Change-Id: I4c24631968bd7d22d2147f984888ec489c347bdf
This commit is contained in:
Blake Hechtman 2020-04-18 00:05:07 -07:00 committed by TensorFlower Gardener
parent 7ae5b12f47
commit f73e9d61a7
5 changed files with 64 additions and 37 deletions

View File

@ -109,7 +109,7 @@ class CategoricalOp : public XlaOpKernel {
/*axis=*/class_dimension);
} else {
argmax = xla::ArgMax(softmax_entries, xla_output_type,
/*axis=*/class_dimension);
/*axis=*/class_dimension, /*stable=*/true);
}
if (num_samples == 1) {

View File

@ -77,7 +77,7 @@ void XlaArgMinMaxOp::Compile(XlaOpKernelContext* ctx) {
if (is_gpu_) {
output = xla::ArgMaxTwoPass(input, index_xla_type, axis);
} else {
output = xla::ArgMax(input, index_xla_type, axis);
output = xla::ArgMax(input, index_xla_type, axis, /*stable=*/true);
}
}
@ -86,8 +86,7 @@ void XlaArgMinMaxOp::Compile(XlaOpKernelContext* ctx) {
XlaArgMaxOp::XlaArgMaxOp(OpKernelConstruction* ctx)
: XlaArgMinMaxOp(ctx, /*is_min=*/false) {}
REGISTER_XLA_OP(Name("ArgMax")
.CompileTimeConstantInput("dimension"),
REGISTER_XLA_OP(Name("ArgMax").CompileTimeConstantInput("dimension"),
XlaArgMaxOp);
namespace {

View File

@ -114,7 +114,8 @@ namespace {
XlaComputation CreateMinMaxComputation(XlaBuilder* outer_builder,
PrimitiveType value_type,
PrimitiveType index_type, bool is_min) {
PrimitiveType index_type, bool is_min,
bool stable, bool tie_low) {
auto sub_builder = outer_builder->CreateSubBuilder("minmax_func");
XlaBuilder* b = sub_builder.get();
XlaOp lhs_value =
@ -126,14 +127,21 @@ XlaComputation CreateMinMaxComputation(XlaBuilder* outer_builder,
XlaOp rhs_index =
Parameter(b, 3, ShapeUtil::MakeShape(index_type, {}), "rhs_index");
auto cmp = is_min ? Le(lhs_value, rhs_value) : Ge(lhs_value, rhs_value);
XlaOp cmp = is_min ? Le(lhs_value, rhs_value) : Ge(lhs_value, rhs_value);
XlaOp max = Select(cmp, lhs_value, rhs_value);
XlaOp arg_max = Select(cmp, lhs_index, rhs_index);
if (stable) {
XlaOp eq = Eq(lhs_value, rhs_value);
XlaOp tie_id =
tie_low ? Min(lhs_index, rhs_index) : Max(lhs_index, rhs_index);
arg_max = Select(eq, tie_id, arg_max);
}
Tuple(b, {max, arg_max});
return b->Build().ConsumeValueOrDie();
}
XlaOp ArgMinMax(XlaOp input, PrimitiveType output_type, int axis, bool is_min) {
XlaOp ArgMinMax(XlaOp input, PrimitiveType output_type, int axis, bool is_min,
bool stable, bool tie_low) {
XlaBuilder* builder = input.builder();
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(Shape input_shape, builder->GetShape(input));
@ -150,8 +158,9 @@ XlaOp ArgMinMax(XlaOp input, PrimitiveType output_type, int axis, bool is_min) {
iota_shape.set_element_type(index_type);
XlaOp iota = Iota(builder, iota_shape, axis);
XlaComputation reducer = CreateMinMaxComputation(
builder, input_shape.element_type(), index_type, is_min);
XlaComputation reducer =
CreateMinMaxComputation(builder, input_shape.element_type(), index_type,
is_min, stable, tie_low);
XlaOp max_argmax = Reduce(builder, {input, iota},
{value_init_value, index_init_value}, reducer,
/*dimensions_to_reduce=*/{axis});
@ -208,12 +217,14 @@ XlaOp ArgMinMaxTwoPass(XlaOp input, PrimitiveType output_type, int axis,
}
} // namespace
XlaOp ArgMax(XlaOp input, PrimitiveType output_type, int axis) {
return ArgMinMax(input, output_type, axis, /*is_min=*/false);
XlaOp ArgMax(XlaOp input, PrimitiveType output_type, int axis, bool stable,
bool tie_low) {
return ArgMinMax(input, output_type, axis, /*is_min=*/false, stable, tie_low);
}
XlaOp ArgMin(XlaOp input, PrimitiveType output_type, int axis) {
return ArgMinMax(input, output_type, axis, /*is_min=*/true);
XlaOp ArgMin(XlaOp input, PrimitiveType output_type, int axis, bool stable,
bool tie_low) {
return ArgMinMax(input, output_type, axis, /*is_min=*/true, stable, tie_low);
}
XlaOp ArgMaxTwoPass(XlaOp input, PrimitiveType output_type, int axis,

View File

@ -79,16 +79,20 @@ XlaOp Any(XlaOp predicates);
// Returns the argmax of `input` along `axis`. `output_type` is the type to
// use for the output. The `tie_low` argument drives the index selection is case
// of same values. If `true` (default behavior) the lowest index will be
// returned, otherwise the higher.
XlaOp ArgMax(XlaOp input, PrimitiveType output_type, int axis);
// returned, otherwise the higher. The tie_low argument only applies if `stable`
// is true or using the ArgMaxTwoPass.
XlaOp ArgMax(XlaOp input, PrimitiveType output_type, int axis,
bool stable = false, bool tie_low = true);
XlaOp ArgMaxTwoPass(XlaOp input, PrimitiveType output_type, int axis,
bool tie_low = true);
// Returns the argmin of `input` along `axis`. `output_type` is the type to
// use for the output. The `tie_low` argument drives the index selection is case
// of same values. If `true` (default behavior) the lowest index will be
// returned, otherwise the higher.
XlaOp ArgMin(XlaOp input, PrimitiveType output_type, int axis);
// returned, otherwise the higher. The tie_low argument only applies if `stable`
// is true or using the ArgMinTwoPass.
XlaOp ArgMin(XlaOp input, PrimitiveType output_type, int axis,
bool stable = false, bool tie_low = true);
XlaOp ArgMinTwoPass(XlaOp input, PrimitiveType output_type, int axis,
bool tie_low = true);

View File

@ -33,14 +33,16 @@ class ArithmeticTest : public ClientLibraryTestBase {
public:
template <typename NativeT>
void TestArgMin(std::initializer_list<std::initializer_list<NativeT>> input,
absl::Span<NativeT const> expected_output, int axis) {
return TestArgMinMax(input, expected_output, axis, /*is_min=*/true);
absl::Span<NativeT const> expected_output, int axis,
bool tie_low) {
TestArgMinMax(input, expected_output, axis, /*is_min=*/true, tie_low);
}
template <typename NativeT>
void TestArgMax(std::initializer_list<std::initializer_list<NativeT>> input,
absl::Span<NativeT const> expected_output, int axis) {
return TestArgMinMax(input, expected_output, axis, /*is_min=*/false);
absl::Span<NativeT const> expected_output, int axis,
bool tie_low) {
TestArgMinMax(input, expected_output, axis, /*is_min=*/false, tie_low);
}
private:
@ -48,18 +50,25 @@ class ArithmeticTest : public ClientLibraryTestBase {
template <typename NativeT>
void TestArgMinMax(
std::initializer_list<std::initializer_list<NativeT>> input,
absl::Span<NativeT const> expected_output, int axis, bool is_min) {
absl::Span<NativeT const> expected_output, int axis, bool is_min,
bool tie_low) {
if (is_min) {
TestArgMinMaxImpl(input, expected_output, axis, &ArgMin);
TestArgMinMaxImpl(input, expected_output, axis,
[](XlaOp op, PrimitiveType type, int axis) {
return ArgMinTwoPass(op, type, axis);
TestArgMinMaxImpl(
input, expected_output, [=](XlaOp op, PrimitiveType type) {
return ArgMin(op, type, axis, /*stable=*/true, tie_low);
});
TestArgMinMaxImpl(input, expected_output,
[=](XlaOp op, PrimitiveType type) {
return ArgMinTwoPass(op, type, axis, tie_low);
});
} else {
TestArgMinMaxImpl(input, expected_output, axis, &ArgMax);
TestArgMinMaxImpl(input, expected_output, axis,
[](XlaOp op, PrimitiveType type, int axis) {
return ArgMaxTwoPass(op, type, axis);
TestArgMinMaxImpl(
input, expected_output, [=](XlaOp op, PrimitiveType type) {
return ArgMax(op, type, axis, /*stable=*/true, tie_low);
});
TestArgMinMaxImpl(input, expected_output,
[=](XlaOp op, PrimitiveType type) {
return ArgMaxTwoPass(op, type, axis, tie_low);
});
}
}
@ -67,33 +76,37 @@ class ArithmeticTest : public ClientLibraryTestBase {
template <typename NativeT>
void TestArgMinMaxImpl(
std::initializer_list<std::initializer_list<NativeT>> input,
absl::Span<NativeT const> expected_output, int axis,
std::function<void(XlaOp, PrimitiveType, int)> MinMaxImpl) {
absl::Span<NativeT const> expected_output,
std::function<void(XlaOp, PrimitiveType)> MinMaxImpl) {
XlaBuilder builder(TestName());
XlaOp x = ConstantR2<NativeT>(&builder, input);
MinMaxImpl(x, primitive_util::NativeToPrimitiveType<NativeT>(), axis);
MinMaxImpl(x, primitive_util::NativeToPrimitiveType<NativeT>());
ComputeAndCompareR1<NativeT>(&builder, expected_output, {});
}
};
XLA_TEST_F(ArithmeticTest, ArgMinR2Axis0) {
TestArgMin<int32>({{1, 7, 4}, {6, 3, 5}, {8, 3, 3}}, {0, 1, 2},
/*axis=*/0);
/*axis=*/0, /*tie_low=*/true);
TestArgMin<int32>({{1, 7, 4}, {6, 3, 5}, {8, 3, 3}}, {0, 2, 2},
/*axis=*/0, /*tie_low=*/false);
}
XLA_TEST_F(ArithmeticTest, ArgMinR2Axis1) {
TestArgMin<int32>({{1, 7, 4}, {6, 3, 5}, {8, 3, 3}}, {0, 1, 1},
/*axis=*/1);
/*axis=*/1, /*tie_low=*/true);
TestArgMin<int32>({{1, 7, 4}, {6, 3, 5}, {8, 3, 3}}, {0, 1, 2},
/*axis=*/1, /*tie_low=*/false);
}
XLA_TEST_F(ArithmeticTest, ArgMaxR2Axis0) {
TestArgMax<int32>({{1, 7, 4}, {6, 3, 5}, {8, 3, 3}}, {2, 0, 1},
/*axis=*/0);
/*axis=*/0, /*tie_low=*/true);
}
XLA_TEST_F(ArithmeticTest, ArgMaxR2Axis1) {
TestArgMax<int32>({{1, 7, 4}, {6, 3, 5}, {8, 3, 3}}, {1, 0, 0},
/*axis=*/1);
/*axis=*/1, /*tie_low=*/true);
}
} // namespace