[XLA] Use variadic reduce for argmin/argmax implementation to avoid issues with
floating point equality checking and allowing excess precision. PiperOrigin-RevId: 244093452
This commit is contained in:
parent
f3d310102d
commit
94745ad5bd
@ -36,7 +36,9 @@ namespace {
|
||||
|
||||
class CategoricalOp : public XlaOpKernel {
|
||||
public:
|
||||
explicit CategoricalOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
|
||||
explicit CategoricalOp(OpKernelConstruction* ctx)
|
||||
: XlaOpKernel(ctx),
|
||||
is_gpu_(ctx->device_type().type_string() == DEVICE_GPU_XLA_JIT) {}
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
// Get the logits
|
||||
@ -101,8 +103,15 @@ class CategoricalOp : public XlaOpKernel {
|
||||
xla::PrimitiveType xla_output_type;
|
||||
OP_REQUIRES_OK(ctx,
|
||||
DataTypeToPrimitiveType(output_type(0), &xla_output_type));
|
||||
xla::XlaOp argmax = xla::ArgMax(softmax_entries, xla_output_type,
|
||||
/*axis=*/class_dimension);
|
||||
xla::XlaOp argmax;
|
||||
if (is_gpu_) {
|
||||
argmax = xla::ArgMaxTwoPass(softmax_entries, xla_output_type,
|
||||
/*axis=*/class_dimension);
|
||||
} else {
|
||||
argmax = xla::ArgMax(softmax_entries, xla_output_type,
|
||||
/*axis=*/class_dimension);
|
||||
}
|
||||
|
||||
if (num_samples == 1) {
|
||||
argmax = xla::Reshape(argmax, {batch_size, 1});
|
||||
}
|
||||
@ -124,6 +133,7 @@ class CategoricalOp : public XlaOpKernel {
|
||||
}
|
||||
|
||||
private:
|
||||
bool is_gpu_;
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(CategoricalOp);
|
||||
};
|
||||
|
||||
|
@ -31,7 +31,9 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
XlaArgMinMaxOp::XlaArgMinMaxOp(OpKernelConstruction* ctx, bool is_min)
|
||||
: XlaOpKernel(ctx), is_min_(is_min) {}
|
||||
: XlaOpKernel(ctx),
|
||||
is_min_(is_min),
|
||||
is_gpu_(ctx->device_type().type_string() == DEVICE_GPU_XLA_JIT) {}
|
||||
|
||||
void XlaArgMinMaxOp::Compile(XlaOpKernelContext* ctx) {
|
||||
const TensorShape input_shape = ctx->InputShape(0);
|
||||
@ -64,10 +66,19 @@ void XlaArgMinMaxOp::Compile(XlaOpKernelContext* ctx) {
|
||||
|
||||
xla::XlaOp input = ctx->Input(0);
|
||||
xla::XlaOp output;
|
||||
// One pass ArgMin/ArgMax is slow on GPUs.
|
||||
if (is_min_) {
|
||||
output = xla::ArgMin(input, index_xla_type, axis);
|
||||
if (is_gpu_) {
|
||||
output = xla::ArgMinTwoPass(input, index_xla_type, axis);
|
||||
} else {
|
||||
output = xla::ArgMin(input, index_xla_type, axis);
|
||||
}
|
||||
} else {
|
||||
output = xla::ArgMax(input, index_xla_type, axis);
|
||||
if (is_gpu_) {
|
||||
output = xla::ArgMaxTwoPass(input, index_xla_type, axis);
|
||||
} else {
|
||||
output = xla::ArgMax(input, index_xla_type, axis);
|
||||
}
|
||||
}
|
||||
|
||||
ctx->SetOutput(0, output);
|
||||
|
@ -30,6 +30,7 @@ class XlaArgMinMaxOp : public XlaOpKernel {
|
||||
|
||||
private:
|
||||
const bool is_min_; // Are we computing ArgMin (true) or ArgMax (false)?
|
||||
const bool is_gpu_;
|
||||
};
|
||||
|
||||
class XlaArgMaxOp : public XlaArgMinMaxOp {
|
||||
|
@ -125,8 +125,60 @@ XlaOp Any(XlaOp predicates) {
|
||||
|
||||
namespace {
|
||||
|
||||
XlaComputation CreateMinMaxComputation(XlaBuilder* outer_builder,
|
||||
PrimitiveType value_type,
|
||||
PrimitiveType index_type, bool is_min) {
|
||||
auto sub_builder = outer_builder->CreateSubBuilder("minmax_func");
|
||||
XlaBuilder* b = sub_builder.get();
|
||||
XlaOp lhs_value =
|
||||
Parameter(b, 0, ShapeUtil::MakeShape(value_type, {}), "lhs_value");
|
||||
XlaOp lhs_index =
|
||||
Parameter(b, 1, ShapeUtil::MakeShape(index_type, {}), "lhs_index");
|
||||
XlaOp rhs_value =
|
||||
Parameter(b, 2, ShapeUtil::MakeShape(value_type, {}), "rhs_value");
|
||||
XlaOp rhs_index =
|
||||
Parameter(b, 3, ShapeUtil::MakeShape(index_type, {}), "rhs_index");
|
||||
|
||||
auto cmp = is_min ? Lt(lhs_value, rhs_value) : Gt(lhs_value, rhs_value);
|
||||
XlaOp max = Select(cmp, lhs_value, rhs_value);
|
||||
XlaOp arg_max = Select(cmp, lhs_index, rhs_index);
|
||||
Tuple(b, {max, arg_max});
|
||||
return b->Build().ConsumeValueOrDie();
|
||||
}
|
||||
|
||||
XlaOp ArgMinMax(XlaOp input, PrimitiveType output_type, int axis, bool is_min) {
|
||||
XlaBuilder* builder = input.builder();
|
||||
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
TF_ASSIGN_OR_RETURN(Shape input_shape, builder->GetShape(input));
|
||||
XlaOp value_init_value;
|
||||
if (is_min) {
|
||||
value_init_value = MaxValue(builder, input_shape.element_type());
|
||||
} else {
|
||||
value_init_value = MinValue(builder, input_shape.element_type());
|
||||
}
|
||||
int64 dimension_size = input_shape.dimensions(axis);
|
||||
auto index_type = dimension_size <= INT32_MAX ? S32 : output_type;
|
||||
XlaOp index_init_value = Zero(builder, index_type);
|
||||
auto iota_shape = input_shape;
|
||||
iota_shape.set_element_type(index_type);
|
||||
XlaOp iota = Iota(builder, iota_shape, axis);
|
||||
|
||||
XlaComputation reducer = CreateMinMaxComputation(
|
||||
builder, input_shape.element_type(), index_type, is_min);
|
||||
XlaOp max_argmax = Reduce(builder, {input, iota},
|
||||
{value_init_value, index_init_value}, reducer,
|
||||
/*dimensions_to_reduce=*/{axis});
|
||||
XlaOp argmax = GetTupleElement(max_argmax, 1);
|
||||
if (index_type != output_type) {
|
||||
argmax = ConvertElementType(argmax, output_type);
|
||||
}
|
||||
return argmax;
|
||||
});
|
||||
}
|
||||
|
||||
XlaOp ArgMinMaxTwoPass(XlaOp input, PrimitiveType output_type, int axis,
|
||||
bool is_min) {
|
||||
XlaBuilder* builder = input.builder();
|
||||
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
TF_ASSIGN_OR_RETURN(Shape input_shape, builder->GetShape(input));
|
||||
XlaOp init_value;
|
||||
@ -172,7 +224,6 @@ XlaOp ArgMinMax(XlaOp input, PrimitiveType output_type, int axis, bool is_min) {
|
||||
/*dimensions_to_reduce=*/{axis});
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
XlaOp ArgMax(XlaOp input, PrimitiveType output_type, int axis) {
|
||||
@ -183,4 +234,11 @@ XlaOp ArgMin(XlaOp input, PrimitiveType output_type, int axis) {
|
||||
return ArgMinMax(input, output_type, axis, /*is_min=*/true);
|
||||
}
|
||||
|
||||
XlaOp ArgMaxTwoPass(XlaOp input, PrimitiveType output_type, int axis) {
|
||||
return ArgMinMaxTwoPass(input, output_type, axis, /*is_min=*/false);
|
||||
}
|
||||
|
||||
XlaOp ArgMinTwoPass(XlaOp input, PrimitiveType output_type, int axis) {
|
||||
return ArgMinMaxTwoPass(input, output_type, axis, /*is_min=*/true);
|
||||
}
|
||||
} // namespace xla
|
||||
|
@ -60,10 +60,12 @@ XlaOp Any(XlaOp predicates);
|
||||
// Returns the argmax of `input` along `axis`. `output_type` is the type to
|
||||
// use for the output.
|
||||
XlaOp ArgMax(XlaOp input, PrimitiveType output_type, int axis);
|
||||
XlaOp ArgMaxTwoPass(XlaOp input, PrimitiveType output_type, int axis);
|
||||
|
||||
// Returns the argmin of `input` along `axis`. `output_type` is the type to
|
||||
// use for the output.
|
||||
XlaOp ArgMin(XlaOp input, PrimitiveType output_type, int axis);
|
||||
XlaOp ArgMinTwoPass(XlaOp input, PrimitiveType output_type, int axis);
|
||||
|
||||
} // namespace xla
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user