diff --git a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc index b708e91722e..59f6041a608 100644 --- a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc @@ -112,7 +112,17 @@ static xla::XlaOp FloorDivImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x, xla::XlaOp y, const BCast& broadcast_helper) { std::tie(x, y) = XlaBinaryOp::Broadcast(x, y, broadcast_helper); if (DataTypeIsFloating(dtype)) { - return xla::Floor(xla::Div(x, y)); + if (dtype == DataType::DT_BFLOAT16) { + // The result of a BF16 division may produce the Ceil of what was + // computed by F32 division, so avoid end user confusion by doing the + // intermediate divide in F32. + return xla::ConvertElementType( + xla::Floor(xla::Div(xla::ConvertElementType(x, xla::F32), + xla::ConvertElementType(y, xla::F32))), + xla::BF16); + } else { + return xla::Floor(xla::Div(x, y)); + } } if (DataTypeIsUnsigned(dtype)) { return xla::Div(x, y); diff --git a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc index 85223795aa8..8716484a3c1 100644 --- a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc @@ -183,8 +183,7 @@ class MaxPoolOp : public PoolingOp { class MaxPool2DOp : public MaxPoolOp { public: explicit MaxPool2DOp(OpKernelConstruction* ctx) - : MaxPoolOp(ctx, /*num_spatial_dims=*/2) { - } + : MaxPoolOp(ctx, /*num_spatial_dims=*/2) {} }; REGISTER_XLA_OP(Name("MaxPool"), MaxPool2DOp); REGISTER_XLA_OP(Name("MaxPoolV2") @@ -245,8 +244,7 @@ class AvgPoolOp : public PoolingOp { class AvgPool2DOp : public AvgPoolOp { public: explicit AvgPool2DOp(OpKernelConstruction* ctx) - : AvgPoolOp(ctx, /*num_spatial_dims=*/2) { - } + : AvgPoolOp(ctx, /*num_spatial_dims=*/2) {} }; REGISTER_XLA_OP(Name("AvgPool"), AvgPool2DOp); @@ -454,8 +452,7 @@ class AvgPoolGradOp : public XlaOpKernel { class AvgPool2DGradOp : public AvgPoolGradOp { public: explicit AvgPool2DGradOp(OpKernelConstruction* ctx) - : AvgPoolGradOp(ctx, /*num_spatial_dims=*/2) { - } + : AvgPoolGradOp(ctx, /*num_spatial_dims=*/2) {} }; REGISTER_XLA_OP( Name("AvgPoolGrad").CompileTimeConstantInput("orig_input_shape"), @@ -558,10 +555,13 @@ class MaxPoolGradGradOp : public XlaOpKernel { auto b = ctx->builder(); auto sixteen = xla::ConstantR0(b, 16); - // in (f32) -> round to bf16 -> f32 for correct bitwidth -> 16-high-bit u32 + // in (f32) -> round to 7 mantissa bits (bf16)-> 16-high-bit u32. + // + // NOTE: Use a ReducePrecision operation instead of a cast to BF16 and back + // to F32 since the XLA compiler may ignore narrowing casts to floating + // point types if the debug option xla_allow_excess_precision is set. auto in_hi = xla::BitcastConvertType( - xla::ConvertElementType(xla::ConvertElementType(input, xla::BF16), - xla::F32), + xla::ReducePrecision(input, /*exponent_bits=*/8, /*mantissa_bits=*/7), xla::U32); auto bp_int = xla::BitcastConvertType(out_backprop, xla::U32); auto bp_hi = xla::ShiftRightLogical(bp_int, sixteen); diff --git a/tensorflow/compiler/xla/debug_options_flags.cc b/tensorflow/compiler/xla/debug_options_flags.cc index 2437bf04b0f..e08058bfc28 100644 --- a/tensorflow/compiler/xla/debug_options_flags.cc +++ b/tensorflow/compiler/xla/debug_options_flags.cc @@ -53,6 +53,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_cpu_enable_fast_math(true); opts.set_xla_gpu_enable_fast_min_max(true); + opts.set_xla_allow_excess_precision(true); opts.set_xla_force_host_platform_device_count(1); return opts; } @@ -381,6 +382,11 @@ static void AllocateFlags() { flag_values->xla_hlo_graph_sharding_color(), "Assign colors based on sharding assignments when generating the " "HLO graphs."), + tensorflow::Flag( + "xla_allow_excess_precision", + bool_setter_for(&DebugOptions::set_xla_allow_excess_precision), + flag_values->xla_allow_excess_precision(), + "Allow xla to increase the output precision of an instruction."), }); ParseFlagsFromEnvAndDieIfUnknown("XLA_FLAGS", *flag_objects); } diff --git a/tensorflow/compiler/xla/xla.proto b/tensorflow/compiler/xla/xla.proto index 3c15679ffec..e92db21e517 100644 --- a/tensorflow/compiler/xla/xla.proto +++ b/tensorflow/compiler/xla/xla.proto @@ -176,6 +176,9 @@ message DebugOptions { // this is true we don't propagate NaNs through Min and Max. bool xla_gpu_enable_fast_min_max = 100; + // Allows xla to increase the output precision of floating point operations. + bool xla_allow_excess_precision = 122; + // Crashes the program when any kind of verification fails, instead of just // logging the failures. One example is cross checking of convolution results // among different algorithms. @@ -262,7 +265,7 @@ message DebugOptions { // END flags controlling dumping HLO modules. // - // Next id: 121 + // Next id: 123 // Extra options to pass to the compilation backend (e.g. LLVM); specific // interpretation of these values is left to the backend.