[XLA] Add an environment variable that allows xla to increase the precision inside of fusion instructions.

PiperOrigin-RevId: 240923609
This commit is contained in:
Blake Hechtman 2019-03-28 22:52:34 -07:00 committed by TensorFlower Gardener
parent 4dd52fbdc6
commit d11a7f8d20
4 changed files with 30 additions and 11 deletions

View File

@ -112,7 +112,17 @@ static xla::XlaOp FloorDivImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x,
xla::XlaOp y, const BCast& broadcast_helper) {
std::tie(x, y) = XlaBinaryOp::Broadcast(x, y, broadcast_helper);
if (DataTypeIsFloating(dtype)) {
return xla::Floor(xla::Div(x, y));
if (dtype == DataType::DT_BFLOAT16) {
// The result of a BF16 division may produce the Ceil of what was
// computed by F32 division, so avoid end user confusion by doing the
// intermediate divide in F32.
return xla::ConvertElementType(
xla::Floor(xla::Div(xla::ConvertElementType(x, xla::F32),
xla::ConvertElementType(y, xla::F32))),
xla::BF16);
} else {
return xla::Floor(xla::Div(x, y));
}
}
if (DataTypeIsUnsigned(dtype)) {
return xla::Div(x, y);

View File

@ -183,8 +183,7 @@ class MaxPoolOp : public PoolingOp {
class MaxPool2DOp : public MaxPoolOp {
public:
explicit MaxPool2DOp(OpKernelConstruction* ctx)
: MaxPoolOp(ctx, /*num_spatial_dims=*/2) {
}
: MaxPoolOp(ctx, /*num_spatial_dims=*/2) {}
};
REGISTER_XLA_OP(Name("MaxPool"), MaxPool2DOp);
REGISTER_XLA_OP(Name("MaxPoolV2")
@ -245,8 +244,7 @@ class AvgPoolOp : public PoolingOp {
class AvgPool2DOp : public AvgPoolOp {
public:
explicit AvgPool2DOp(OpKernelConstruction* ctx)
: AvgPoolOp(ctx, /*num_spatial_dims=*/2) {
}
: AvgPoolOp(ctx, /*num_spatial_dims=*/2) {}
};
REGISTER_XLA_OP(Name("AvgPool"), AvgPool2DOp);
@ -454,8 +452,7 @@ class AvgPoolGradOp : public XlaOpKernel {
class AvgPool2DGradOp : public AvgPoolGradOp {
public:
explicit AvgPool2DGradOp(OpKernelConstruction* ctx)
: AvgPoolGradOp(ctx, /*num_spatial_dims=*/2) {
}
: AvgPoolGradOp(ctx, /*num_spatial_dims=*/2) {}
};
REGISTER_XLA_OP(
Name("AvgPoolGrad").CompileTimeConstantInput("orig_input_shape"),
@ -558,10 +555,13 @@ class MaxPoolGradGradOp : public XlaOpKernel {
auto b = ctx->builder();
auto sixteen = xla::ConstantR0<uint32>(b, 16);
// in (f32) -> round to bf16 -> f32 for correct bitwidth -> 16-high-bit u32
// in (f32) -> round to 7 mantissa bits (bf16)-> 16-high-bit u32.
//
// NOTE: Use a ReducePrecision operation instead of a cast to BF16 and back
// to F32 since the XLA compiler may ignore narrowing casts to floating
// point types if the debug option xla_allow_excess_precision is set.
auto in_hi = xla::BitcastConvertType(
xla::ConvertElementType(xla::ConvertElementType(input, xla::BF16),
xla::F32),
xla::ReducePrecision(input, /*exponent_bits=*/8, /*mantissa_bits=*/7),
xla::U32);
auto bp_int = xla::BitcastConvertType(out_backprop, xla::U32);
auto bp_hi = xla::ShiftRightLogical(bp_int, sixteen);

View File

@ -53,6 +53,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() {
opts.set_xla_cpu_enable_fast_math(true);
opts.set_xla_gpu_enable_fast_min_max(true);
opts.set_xla_allow_excess_precision(true);
opts.set_xla_force_host_platform_device_count(1);
return opts;
}
@ -381,6 +382,11 @@ static void AllocateFlags() {
flag_values->xla_hlo_graph_sharding_color(),
"Assign colors based on sharding assignments when generating the "
"HLO graphs."),
tensorflow::Flag(
"xla_allow_excess_precision",
bool_setter_for(&DebugOptions::set_xla_allow_excess_precision),
flag_values->xla_allow_excess_precision(),
"Allow xla to increase the output precision of an instruction."),
});
ParseFlagsFromEnvAndDieIfUnknown("XLA_FLAGS", *flag_objects);
}

View File

@ -176,6 +176,9 @@ message DebugOptions {
// this is true we don't propagate NaNs through Min and Max.
bool xla_gpu_enable_fast_min_max = 100;
// Allows xla to increase the output precision of floating point operations.
bool xla_allow_excess_precision = 122;
// Crashes the program when any kind of verification fails, instead of just
// logging the failures. One example is cross checking of convolution results
// among different algorithms.
@ -262,7 +265,7 @@ message DebugOptions {
// END flags controlling dumping HLO modules.
//
// Next id: 121
// Next id: 123
// Extra options to pass to the compilation backend (e.g. LLVM); specific
// interpretation of these values is left to the backend.