[XLA] Add an environment variable that allows xla to increase the precision inside of fusion instructions.
PiperOrigin-RevId: 240923609
This commit is contained in:
parent
4dd52fbdc6
commit
d11a7f8d20
@ -112,7 +112,17 @@ static xla::XlaOp FloorDivImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x,
|
||||
xla::XlaOp y, const BCast& broadcast_helper) {
|
||||
std::tie(x, y) = XlaBinaryOp::Broadcast(x, y, broadcast_helper);
|
||||
if (DataTypeIsFloating(dtype)) {
|
||||
return xla::Floor(xla::Div(x, y));
|
||||
if (dtype == DataType::DT_BFLOAT16) {
|
||||
// The result of a BF16 division may produce the Ceil of what was
|
||||
// computed by F32 division, so avoid end user confusion by doing the
|
||||
// intermediate divide in F32.
|
||||
return xla::ConvertElementType(
|
||||
xla::Floor(xla::Div(xla::ConvertElementType(x, xla::F32),
|
||||
xla::ConvertElementType(y, xla::F32))),
|
||||
xla::BF16);
|
||||
} else {
|
||||
return xla::Floor(xla::Div(x, y));
|
||||
}
|
||||
}
|
||||
if (DataTypeIsUnsigned(dtype)) {
|
||||
return xla::Div(x, y);
|
||||
|
||||
@ -183,8 +183,7 @@ class MaxPoolOp : public PoolingOp {
|
||||
class MaxPool2DOp : public MaxPoolOp {
|
||||
public:
|
||||
explicit MaxPool2DOp(OpKernelConstruction* ctx)
|
||||
: MaxPoolOp(ctx, /*num_spatial_dims=*/2) {
|
||||
}
|
||||
: MaxPoolOp(ctx, /*num_spatial_dims=*/2) {}
|
||||
};
|
||||
REGISTER_XLA_OP(Name("MaxPool"), MaxPool2DOp);
|
||||
REGISTER_XLA_OP(Name("MaxPoolV2")
|
||||
@ -245,8 +244,7 @@ class AvgPoolOp : public PoolingOp {
|
||||
class AvgPool2DOp : public AvgPoolOp {
|
||||
public:
|
||||
explicit AvgPool2DOp(OpKernelConstruction* ctx)
|
||||
: AvgPoolOp(ctx, /*num_spatial_dims=*/2) {
|
||||
}
|
||||
: AvgPoolOp(ctx, /*num_spatial_dims=*/2) {}
|
||||
};
|
||||
REGISTER_XLA_OP(Name("AvgPool"), AvgPool2DOp);
|
||||
|
||||
@ -454,8 +452,7 @@ class AvgPoolGradOp : public XlaOpKernel {
|
||||
class AvgPool2DGradOp : public AvgPoolGradOp {
|
||||
public:
|
||||
explicit AvgPool2DGradOp(OpKernelConstruction* ctx)
|
||||
: AvgPoolGradOp(ctx, /*num_spatial_dims=*/2) {
|
||||
}
|
||||
: AvgPoolGradOp(ctx, /*num_spatial_dims=*/2) {}
|
||||
};
|
||||
REGISTER_XLA_OP(
|
||||
Name("AvgPoolGrad").CompileTimeConstantInput("orig_input_shape"),
|
||||
@ -558,10 +555,13 @@ class MaxPoolGradGradOp : public XlaOpKernel {
|
||||
auto b = ctx->builder();
|
||||
|
||||
auto sixteen = xla::ConstantR0<uint32>(b, 16);
|
||||
// in (f32) -> round to bf16 -> f32 for correct bitwidth -> 16-high-bit u32
|
||||
// in (f32) -> round to 7 mantissa bits (bf16)-> 16-high-bit u32.
|
||||
//
|
||||
// NOTE: Use a ReducePrecision operation instead of a cast to BF16 and back
|
||||
// to F32 since the XLA compiler may ignore narrowing casts to floating
|
||||
// point types if the debug option xla_allow_excess_precision is set.
|
||||
auto in_hi = xla::BitcastConvertType(
|
||||
xla::ConvertElementType(xla::ConvertElementType(input, xla::BF16),
|
||||
xla::F32),
|
||||
xla::ReducePrecision(input, /*exponent_bits=*/8, /*mantissa_bits=*/7),
|
||||
xla::U32);
|
||||
auto bp_int = xla::BitcastConvertType(out_backprop, xla::U32);
|
||||
auto bp_hi = xla::ShiftRightLogical(bp_int, sixteen);
|
||||
|
||||
@ -53,6 +53,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() {
|
||||
opts.set_xla_cpu_enable_fast_math(true);
|
||||
opts.set_xla_gpu_enable_fast_min_max(true);
|
||||
|
||||
opts.set_xla_allow_excess_precision(true);
|
||||
opts.set_xla_force_host_platform_device_count(1);
|
||||
return opts;
|
||||
}
|
||||
@ -381,6 +382,11 @@ static void AllocateFlags() {
|
||||
flag_values->xla_hlo_graph_sharding_color(),
|
||||
"Assign colors based on sharding assignments when generating the "
|
||||
"HLO graphs."),
|
||||
tensorflow::Flag(
|
||||
"xla_allow_excess_precision",
|
||||
bool_setter_for(&DebugOptions::set_xla_allow_excess_precision),
|
||||
flag_values->xla_allow_excess_precision(),
|
||||
"Allow xla to increase the output precision of an instruction."),
|
||||
});
|
||||
ParseFlagsFromEnvAndDieIfUnknown("XLA_FLAGS", *flag_objects);
|
||||
}
|
||||
|
||||
@ -176,6 +176,9 @@ message DebugOptions {
|
||||
// this is true we don't propagate NaNs through Min and Max.
|
||||
bool xla_gpu_enable_fast_min_max = 100;
|
||||
|
||||
// Allows xla to increase the output precision of floating point operations.
|
||||
bool xla_allow_excess_precision = 122;
|
||||
|
||||
// Crashes the program when any kind of verification fails, instead of just
|
||||
// logging the failures. One example is cross checking of convolution results
|
||||
// among different algorithms.
|
||||
@ -262,7 +265,7 @@ message DebugOptions {
|
||||
// END flags controlling dumping HLO modules.
|
||||
//
|
||||
|
||||
// Next id: 121
|
||||
// Next id: 123
|
||||
|
||||
// Extra options to pass to the compilation backend (e.g. LLVM); specific
|
||||
// interpretation of these values is left to the backend.
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user