[XLA] Implement stable argmin and argmax
PiperOrigin-RevId: 307172027 Change-Id: I4c24631968bd7d22d2147f984888ec489c347bdf
This commit is contained in:
parent
7ae5b12f47
commit
f73e9d61a7
@ -109,7 +109,7 @@ class CategoricalOp : public XlaOpKernel {
|
||||
/*axis=*/class_dimension);
|
||||
} else {
|
||||
argmax = xla::ArgMax(softmax_entries, xla_output_type,
|
||||
/*axis=*/class_dimension);
|
||||
/*axis=*/class_dimension, /*stable=*/true);
|
||||
}
|
||||
|
||||
if (num_samples == 1) {
|
||||
|
@ -77,7 +77,7 @@ void XlaArgMinMaxOp::Compile(XlaOpKernelContext* ctx) {
|
||||
if (is_gpu_) {
|
||||
output = xla::ArgMaxTwoPass(input, index_xla_type, axis);
|
||||
} else {
|
||||
output = xla::ArgMax(input, index_xla_type, axis);
|
||||
output = xla::ArgMax(input, index_xla_type, axis, /*stable=*/true);
|
||||
}
|
||||
}
|
||||
|
||||
@ -86,8 +86,7 @@ void XlaArgMinMaxOp::Compile(XlaOpKernelContext* ctx) {
|
||||
|
||||
XlaArgMaxOp::XlaArgMaxOp(OpKernelConstruction* ctx)
|
||||
: XlaArgMinMaxOp(ctx, /*is_min=*/false) {}
|
||||
REGISTER_XLA_OP(Name("ArgMax")
|
||||
.CompileTimeConstantInput("dimension"),
|
||||
REGISTER_XLA_OP(Name("ArgMax").CompileTimeConstantInput("dimension"),
|
||||
XlaArgMaxOp);
|
||||
|
||||
namespace {
|
||||
|
@ -114,7 +114,8 @@ namespace {
|
||||
|
||||
XlaComputation CreateMinMaxComputation(XlaBuilder* outer_builder,
|
||||
PrimitiveType value_type,
|
||||
PrimitiveType index_type, bool is_min) {
|
||||
PrimitiveType index_type, bool is_min,
|
||||
bool stable, bool tie_low) {
|
||||
auto sub_builder = outer_builder->CreateSubBuilder("minmax_func");
|
||||
XlaBuilder* b = sub_builder.get();
|
||||
XlaOp lhs_value =
|
||||
@ -126,14 +127,21 @@ XlaComputation CreateMinMaxComputation(XlaBuilder* outer_builder,
|
||||
XlaOp rhs_index =
|
||||
Parameter(b, 3, ShapeUtil::MakeShape(index_type, {}), "rhs_index");
|
||||
|
||||
auto cmp = is_min ? Le(lhs_value, rhs_value) : Ge(lhs_value, rhs_value);
|
||||
XlaOp cmp = is_min ? Le(lhs_value, rhs_value) : Ge(lhs_value, rhs_value);
|
||||
XlaOp max = Select(cmp, lhs_value, rhs_value);
|
||||
XlaOp arg_max = Select(cmp, lhs_index, rhs_index);
|
||||
if (stable) {
|
||||
XlaOp eq = Eq(lhs_value, rhs_value);
|
||||
XlaOp tie_id =
|
||||
tie_low ? Min(lhs_index, rhs_index) : Max(lhs_index, rhs_index);
|
||||
arg_max = Select(eq, tie_id, arg_max);
|
||||
}
|
||||
Tuple(b, {max, arg_max});
|
||||
return b->Build().ConsumeValueOrDie();
|
||||
}
|
||||
|
||||
XlaOp ArgMinMax(XlaOp input, PrimitiveType output_type, int axis, bool is_min) {
|
||||
XlaOp ArgMinMax(XlaOp input, PrimitiveType output_type, int axis, bool is_min,
|
||||
bool stable, bool tie_low) {
|
||||
XlaBuilder* builder = input.builder();
|
||||
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
TF_ASSIGN_OR_RETURN(Shape input_shape, builder->GetShape(input));
|
||||
@ -150,8 +158,9 @@ XlaOp ArgMinMax(XlaOp input, PrimitiveType output_type, int axis, bool is_min) {
|
||||
iota_shape.set_element_type(index_type);
|
||||
XlaOp iota = Iota(builder, iota_shape, axis);
|
||||
|
||||
XlaComputation reducer = CreateMinMaxComputation(
|
||||
builder, input_shape.element_type(), index_type, is_min);
|
||||
XlaComputation reducer =
|
||||
CreateMinMaxComputation(builder, input_shape.element_type(), index_type,
|
||||
is_min, stable, tie_low);
|
||||
XlaOp max_argmax = Reduce(builder, {input, iota},
|
||||
{value_init_value, index_init_value}, reducer,
|
||||
/*dimensions_to_reduce=*/{axis});
|
||||
@ -208,12 +217,14 @@ XlaOp ArgMinMaxTwoPass(XlaOp input, PrimitiveType output_type, int axis,
|
||||
}
|
||||
} // namespace
|
||||
|
||||
XlaOp ArgMax(XlaOp input, PrimitiveType output_type, int axis) {
|
||||
return ArgMinMax(input, output_type, axis, /*is_min=*/false);
|
||||
XlaOp ArgMax(XlaOp input, PrimitiveType output_type, int axis, bool stable,
|
||||
bool tie_low) {
|
||||
return ArgMinMax(input, output_type, axis, /*is_min=*/false, stable, tie_low);
|
||||
}
|
||||
|
||||
XlaOp ArgMin(XlaOp input, PrimitiveType output_type, int axis) {
|
||||
return ArgMinMax(input, output_type, axis, /*is_min=*/true);
|
||||
XlaOp ArgMin(XlaOp input, PrimitiveType output_type, int axis, bool stable,
|
||||
bool tie_low) {
|
||||
return ArgMinMax(input, output_type, axis, /*is_min=*/true, stable, tie_low);
|
||||
}
|
||||
|
||||
XlaOp ArgMaxTwoPass(XlaOp input, PrimitiveType output_type, int axis,
|
||||
|
@ -79,16 +79,20 @@ XlaOp Any(XlaOp predicates);
|
||||
// Returns the argmax of `input` along `axis`. `output_type` is the type to
|
||||
// use for the output. The `tie_low` argument drives the index selection is case
|
||||
// of same values. If `true` (default behavior) the lowest index will be
|
||||
// returned, otherwise the higher.
|
||||
XlaOp ArgMax(XlaOp input, PrimitiveType output_type, int axis);
|
||||
// returned, otherwise the higher. The tie_low argument only applies if `stable`
|
||||
// is true or using the ArgMaxTwoPass.
|
||||
XlaOp ArgMax(XlaOp input, PrimitiveType output_type, int axis,
|
||||
bool stable = false, bool tie_low = true);
|
||||
XlaOp ArgMaxTwoPass(XlaOp input, PrimitiveType output_type, int axis,
|
||||
bool tie_low = true);
|
||||
|
||||
// Returns the argmin of `input` along `axis`. `output_type` is the type to
|
||||
// use for the output. The `tie_low` argument drives the index selection is case
|
||||
// of same values. If `true` (default behavior) the lowest index will be
|
||||
// returned, otherwise the higher.
|
||||
XlaOp ArgMin(XlaOp input, PrimitiveType output_type, int axis);
|
||||
// returned, otherwise the higher. The tie_low argument only applies if `stable`
|
||||
// is true or using the ArgMinTwoPass.
|
||||
XlaOp ArgMin(XlaOp input, PrimitiveType output_type, int axis,
|
||||
bool stable = false, bool tie_low = true);
|
||||
XlaOp ArgMinTwoPass(XlaOp input, PrimitiveType output_type, int axis,
|
||||
bool tie_low = true);
|
||||
|
||||
|
@ -33,14 +33,16 @@ class ArithmeticTest : public ClientLibraryTestBase {
|
||||
public:
|
||||
template <typename NativeT>
|
||||
void TestArgMin(std::initializer_list<std::initializer_list<NativeT>> input,
|
||||
absl::Span<NativeT const> expected_output, int axis) {
|
||||
return TestArgMinMax(input, expected_output, axis, /*is_min=*/true);
|
||||
absl::Span<NativeT const> expected_output, int axis,
|
||||
bool tie_low) {
|
||||
TestArgMinMax(input, expected_output, axis, /*is_min=*/true, tie_low);
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
void TestArgMax(std::initializer_list<std::initializer_list<NativeT>> input,
|
||||
absl::Span<NativeT const> expected_output, int axis) {
|
||||
return TestArgMinMax(input, expected_output, axis, /*is_min=*/false);
|
||||
absl::Span<NativeT const> expected_output, int axis,
|
||||
bool tie_low) {
|
||||
TestArgMinMax(input, expected_output, axis, /*is_min=*/false, tie_low);
|
||||
}
|
||||
|
||||
private:
|
||||
@ -48,18 +50,25 @@ class ArithmeticTest : public ClientLibraryTestBase {
|
||||
template <typename NativeT>
|
||||
void TestArgMinMax(
|
||||
std::initializer_list<std::initializer_list<NativeT>> input,
|
||||
absl::Span<NativeT const> expected_output, int axis, bool is_min) {
|
||||
absl::Span<NativeT const> expected_output, int axis, bool is_min,
|
||||
bool tie_low) {
|
||||
if (is_min) {
|
||||
TestArgMinMaxImpl(input, expected_output, axis, &ArgMin);
|
||||
TestArgMinMaxImpl(input, expected_output, axis,
|
||||
[](XlaOp op, PrimitiveType type, int axis) {
|
||||
return ArgMinTwoPass(op, type, axis);
|
||||
TestArgMinMaxImpl(
|
||||
input, expected_output, [=](XlaOp op, PrimitiveType type) {
|
||||
return ArgMin(op, type, axis, /*stable=*/true, tie_low);
|
||||
});
|
||||
TestArgMinMaxImpl(input, expected_output,
|
||||
[=](XlaOp op, PrimitiveType type) {
|
||||
return ArgMinTwoPass(op, type, axis, tie_low);
|
||||
});
|
||||
} else {
|
||||
TestArgMinMaxImpl(input, expected_output, axis, &ArgMax);
|
||||
TestArgMinMaxImpl(input, expected_output, axis,
|
||||
[](XlaOp op, PrimitiveType type, int axis) {
|
||||
return ArgMaxTwoPass(op, type, axis);
|
||||
TestArgMinMaxImpl(
|
||||
input, expected_output, [=](XlaOp op, PrimitiveType type) {
|
||||
return ArgMax(op, type, axis, /*stable=*/true, tie_low);
|
||||
});
|
||||
TestArgMinMaxImpl(input, expected_output,
|
||||
[=](XlaOp op, PrimitiveType type) {
|
||||
return ArgMaxTwoPass(op, type, axis, tie_low);
|
||||
});
|
||||
}
|
||||
}
|
||||
@ -67,33 +76,37 @@ class ArithmeticTest : public ClientLibraryTestBase {
|
||||
template <typename NativeT>
|
||||
void TestArgMinMaxImpl(
|
||||
std::initializer_list<std::initializer_list<NativeT>> input,
|
||||
absl::Span<NativeT const> expected_output, int axis,
|
||||
std::function<void(XlaOp, PrimitiveType, int)> MinMaxImpl) {
|
||||
absl::Span<NativeT const> expected_output,
|
||||
std::function<void(XlaOp, PrimitiveType)> MinMaxImpl) {
|
||||
XlaBuilder builder(TestName());
|
||||
XlaOp x = ConstantR2<NativeT>(&builder, input);
|
||||
MinMaxImpl(x, primitive_util::NativeToPrimitiveType<NativeT>(), axis);
|
||||
MinMaxImpl(x, primitive_util::NativeToPrimitiveType<NativeT>());
|
||||
ComputeAndCompareR1<NativeT>(&builder, expected_output, {});
|
||||
}
|
||||
};
|
||||
|
||||
XLA_TEST_F(ArithmeticTest, ArgMinR2Axis0) {
|
||||
TestArgMin<int32>({{1, 7, 4}, {6, 3, 5}, {8, 3, 3}}, {0, 1, 2},
|
||||
/*axis=*/0);
|
||||
/*axis=*/0, /*tie_low=*/true);
|
||||
TestArgMin<int32>({{1, 7, 4}, {6, 3, 5}, {8, 3, 3}}, {0, 2, 2},
|
||||
/*axis=*/0, /*tie_low=*/false);
|
||||
}
|
||||
|
||||
XLA_TEST_F(ArithmeticTest, ArgMinR2Axis1) {
|
||||
TestArgMin<int32>({{1, 7, 4}, {6, 3, 5}, {8, 3, 3}}, {0, 1, 1},
|
||||
/*axis=*/1);
|
||||
/*axis=*/1, /*tie_low=*/true);
|
||||
TestArgMin<int32>({{1, 7, 4}, {6, 3, 5}, {8, 3, 3}}, {0, 1, 2},
|
||||
/*axis=*/1, /*tie_low=*/false);
|
||||
}
|
||||
|
||||
XLA_TEST_F(ArithmeticTest, ArgMaxR2Axis0) {
|
||||
TestArgMax<int32>({{1, 7, 4}, {6, 3, 5}, {8, 3, 3}}, {2, 0, 1},
|
||||
/*axis=*/0);
|
||||
/*axis=*/0, /*tie_low=*/true);
|
||||
}
|
||||
|
||||
XLA_TEST_F(ArithmeticTest, ArgMaxR2Axis1) {
|
||||
TestArgMax<int32>({{1, 7, 4}, {6, 3, 5}, {8, 3, 3}}, {1, 0, 0},
|
||||
/*axis=*/1);
|
||||
/*axis=*/1, /*tie_low=*/true);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
Loading…
Reference in New Issue
Block a user