[XLA] Use variadic reduce for argmin/argmax implementation to avoid issues with

floating point equality checking and allowing excess precision.

PiperOrigin-RevId: 244093452
This commit is contained in:
Blake Hechtman 2019-04-17 16:48:43 -07:00 committed by TensorFlower Gardener
parent f3d310102d
commit 94745ad5bd
5 changed files with 89 additions and 7 deletions

View File

@ -36,7 +36,9 @@ namespace {
class CategoricalOp : public XlaOpKernel {
public:
explicit CategoricalOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
explicit CategoricalOp(OpKernelConstruction* ctx)
: XlaOpKernel(ctx),
is_gpu_(ctx->device_type().type_string() == DEVICE_GPU_XLA_JIT) {}
void Compile(XlaOpKernelContext* ctx) override {
// Get the logits
@ -101,8 +103,15 @@ class CategoricalOp : public XlaOpKernel {
xla::PrimitiveType xla_output_type;
OP_REQUIRES_OK(ctx,
DataTypeToPrimitiveType(output_type(0), &xla_output_type));
xla::XlaOp argmax = xla::ArgMax(softmax_entries, xla_output_type,
/*axis=*/class_dimension);
xla::XlaOp argmax;
if (is_gpu_) {
argmax = xla::ArgMaxTwoPass(softmax_entries, xla_output_type,
/*axis=*/class_dimension);
} else {
argmax = xla::ArgMax(softmax_entries, xla_output_type,
/*axis=*/class_dimension);
}
if (num_samples == 1) {
argmax = xla::Reshape(argmax, {batch_size, 1});
}
@ -124,6 +133,7 @@ class CategoricalOp : public XlaOpKernel {
}
private:
bool is_gpu_;
TF_DISALLOW_COPY_AND_ASSIGN(CategoricalOp);
};

View File

@ -31,7 +31,9 @@ limitations under the License.
namespace tensorflow {
XlaArgMinMaxOp::XlaArgMinMaxOp(OpKernelConstruction* ctx, bool is_min)
: XlaOpKernel(ctx), is_min_(is_min) {}
: XlaOpKernel(ctx),
is_min_(is_min),
is_gpu_(ctx->device_type().type_string() == DEVICE_GPU_XLA_JIT) {}
void XlaArgMinMaxOp::Compile(XlaOpKernelContext* ctx) {
const TensorShape input_shape = ctx->InputShape(0);
@ -64,10 +66,19 @@ void XlaArgMinMaxOp::Compile(XlaOpKernelContext* ctx) {
xla::XlaOp input = ctx->Input(0);
xla::XlaOp output;
// One pass ArgMin/ArgMax is slow on GPUs.
if (is_min_) {
output = xla::ArgMin(input, index_xla_type, axis);
if (is_gpu_) {
output = xla::ArgMinTwoPass(input, index_xla_type, axis);
} else {
output = xla::ArgMin(input, index_xla_type, axis);
}
} else {
output = xla::ArgMax(input, index_xla_type, axis);
if (is_gpu_) {
output = xla::ArgMaxTwoPass(input, index_xla_type, axis);
} else {
output = xla::ArgMax(input, index_xla_type, axis);
}
}
ctx->SetOutput(0, output);

View File

@ -30,6 +30,7 @@ class XlaArgMinMaxOp : public XlaOpKernel {
private:
const bool is_min_; // Are we computing ArgMin (true) or ArgMax (false)?
const bool is_gpu_;
};
class XlaArgMaxOp : public XlaArgMinMaxOp {

View File

@ -125,8 +125,60 @@ XlaOp Any(XlaOp predicates) {
namespace {
XlaComputation CreateMinMaxComputation(XlaBuilder* outer_builder,
PrimitiveType value_type,
PrimitiveType index_type, bool is_min) {
auto sub_builder = outer_builder->CreateSubBuilder("minmax_func");
XlaBuilder* b = sub_builder.get();
XlaOp lhs_value =
Parameter(b, 0, ShapeUtil::MakeShape(value_type, {}), "lhs_value");
XlaOp lhs_index =
Parameter(b, 1, ShapeUtil::MakeShape(index_type, {}), "lhs_index");
XlaOp rhs_value =
Parameter(b, 2, ShapeUtil::MakeShape(value_type, {}), "rhs_value");
XlaOp rhs_index =
Parameter(b, 3, ShapeUtil::MakeShape(index_type, {}), "rhs_index");
auto cmp = is_min ? Lt(lhs_value, rhs_value) : Gt(lhs_value, rhs_value);
XlaOp max = Select(cmp, lhs_value, rhs_value);
XlaOp arg_max = Select(cmp, lhs_index, rhs_index);
Tuple(b, {max, arg_max});
return b->Build().ConsumeValueOrDie();
}
XlaOp ArgMinMax(XlaOp input, PrimitiveType output_type, int axis, bool is_min) {
XlaBuilder* builder = input.builder();
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(Shape input_shape, builder->GetShape(input));
XlaOp value_init_value;
if (is_min) {
value_init_value = MaxValue(builder, input_shape.element_type());
} else {
value_init_value = MinValue(builder, input_shape.element_type());
}
int64 dimension_size = input_shape.dimensions(axis);
auto index_type = dimension_size <= INT32_MAX ? S32 : output_type;
XlaOp index_init_value = Zero(builder, index_type);
auto iota_shape = input_shape;
iota_shape.set_element_type(index_type);
XlaOp iota = Iota(builder, iota_shape, axis);
XlaComputation reducer = CreateMinMaxComputation(
builder, input_shape.element_type(), index_type, is_min);
XlaOp max_argmax = Reduce(builder, {input, iota},
{value_init_value, index_init_value}, reducer,
/*dimensions_to_reduce=*/{axis});
XlaOp argmax = GetTupleElement(max_argmax, 1);
if (index_type != output_type) {
argmax = ConvertElementType(argmax, output_type);
}
return argmax;
});
}
XlaOp ArgMinMaxTwoPass(XlaOp input, PrimitiveType output_type, int axis,
bool is_min) {
XlaBuilder* builder = input.builder();
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(Shape input_shape, builder->GetShape(input));
XlaOp init_value;
@ -172,7 +224,6 @@ XlaOp ArgMinMax(XlaOp input, PrimitiveType output_type, int axis, bool is_min) {
/*dimensions_to_reduce=*/{axis});
});
}
} // namespace
XlaOp ArgMax(XlaOp input, PrimitiveType output_type, int axis) {
@ -183,4 +234,11 @@ XlaOp ArgMin(XlaOp input, PrimitiveType output_type, int axis) {
return ArgMinMax(input, output_type, axis, /*is_min=*/true);
}
XlaOp ArgMaxTwoPass(XlaOp input, PrimitiveType output_type, int axis) {
return ArgMinMaxTwoPass(input, output_type, axis, /*is_min=*/false);
}
XlaOp ArgMinTwoPass(XlaOp input, PrimitiveType output_type, int axis) {
return ArgMinMaxTwoPass(input, output_type, axis, /*is_min=*/true);
}
} // namespace xla

View File

@ -60,10 +60,12 @@ XlaOp Any(XlaOp predicates);
// Returns the argmax of `input` along `axis`. `output_type` is the type to
// use for the output.
XlaOp ArgMax(XlaOp input, PrimitiveType output_type, int axis);
XlaOp ArgMaxTwoPass(XlaOp input, PrimitiveType output_type, int axis);
// Returns the argmin of `input` along `axis`. `output_type` is the type to
// use for the output.
XlaOp ArgMin(XlaOp input, PrimitiveType output_type, int axis);
XlaOp ArgMinTwoPass(XlaOp input, PrimitiveType output_type, int axis);
} // namespace xla