Enable XLA:CPU fast math for min/max by default to be similar to TF's behavior.

Another big change here is changing the use of this flag to use the value in the HloModule and not the global environment variable which was bad temporary behavior.

PiperOrigin-RevId: 316844057
Change-Id: I995715ccc9009e9845fca77060b835fdc50fb4d2
This commit is contained in:
Tres Popp 2020-06-17 01:20:35 -07:00 committed by TensorFlower Gardener
parent 02aad07710
commit de907d8746
12 changed files with 38 additions and 21 deletions

View File

@ -64,7 +64,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() {
opts.set_xla_cpu_fast_math_honor_division(true);
// By default, copy TF's Eigen style min_max behavior with nans.
opts.set_xla_cpu_enable_fast_min_max(false);
opts.set_xla_cpu_enable_fast_min_max(true);
opts.set_xla_gpu_enable_fast_min_max(true);

View File

@ -50,6 +50,10 @@ class CpuElementalIrEmitter : public ElementalIrEmitter {
return ir_emitter_->EmitThreadLocalCall(callee, parameters, name);
}
bool fast_min_max() override {
return hlo_module_config_.debug_options().xla_cpu_enable_fast_min_max();
}
IrEmitter* ir_emitter_;
};

View File

@ -318,7 +318,9 @@ llvm::Value* GenerateVF32Log(llvm::IRBuilder<>* b, llvm::Value* input,
llvm::Value* is_pos_inf_mask = vsl.FCmpEQMask(input, pos_inf);
// Cut off denormalized stuff.
llvm::Value* tmp0 = vsl.Max(min_norm_pos, input);
// Always allow fast max because we are checking for the nan above.
llvm::Value* tmp0 =
vsl.Max(min_norm_pos, input, /*enable_fast_min_max=*/true);
// VectorSupportLibrary (intentionally) can't juggle more than one type at a
// time so drop down to IRBuilder for this bit.

View File

@ -80,10 +80,11 @@ llvm::Value* VectorSupportLibrary::Sub(llvm::Value* lhs, llvm::Value* rhs) {
return b()->CreateFSub(lhs, rhs);
}
llvm::Value* VectorSupportLibrary::Max(llvm::Value* lhs, llvm::Value* rhs) {
llvm::Value* VectorSupportLibrary::Max(llvm::Value* lhs, llvm::Value* rhs,
bool enable_fast_min_max) {
AssertCorrectTypes({lhs, rhs});
if (scalar_type_->isFloatingPointTy()) {
return llvm_ir::EmitFloatMax(lhs, rhs, b_);
return llvm_ir::EmitFloatMax(lhs, rhs, b_, enable_fast_min_max);
} else {
LOG(FATAL) << "Max for integers is unimplemented";
}

View File

@ -78,9 +78,11 @@ class VectorSupportLibrary {
llvm::Value* Sub(llvm::Value* lhs, const llvm::APFloat& rhs) {
return Sub(lhs, GetConstantFloat(lhs->getType(), rhs));
}
llvm::Value* Max(llvm::Value* lhs, llvm::Value* rhs);
llvm::Value* Max(const llvm::APFloat& lhs, llvm::Value* rhs) {
return Max(GetConstantFloat(rhs->getType(), lhs), rhs);
llvm::Value* Max(llvm::Value* lhs, llvm::Value* rhs,
bool enable_fast_min_max);
llvm::Value* Max(const llvm::APFloat& lhs, llvm::Value* rhs,
bool enable_fast_min_max) {
return Max(GetConstantFloat(rhs->getType(), lhs), rhs, enable_fast_min_max);
}
llvm::Value* Div(llvm::Value* lhs, llvm::Value* rhs);

View File

@ -1313,12 +1313,12 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexBinaryOp(
llvm::Value* ElementalIrEmitter::EmitFloatMax(llvm::Value* lhs_value,
llvm::Value* rhs_value) {
return llvm_ir::EmitFloatMax(lhs_value, rhs_value, b_);
return llvm_ir::EmitFloatMax(lhs_value, rhs_value, b_, fast_min_max());
}
llvm::Value* ElementalIrEmitter::EmitFloatMin(llvm::Value* lhs_value,
llvm::Value* rhs_value) {
return llvm_ir::EmitFloatMin(lhs_value, rhs_value, b_);
return llvm_ir::EmitFloatMin(lhs_value, rhs_value, b_, fast_min_max());
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitLog(PrimitiveType prim_type,

View File

@ -245,6 +245,8 @@ class ElementalIrEmitter : public IrBuilderMixin<ElementalIrEmitter> {
std::vector<llvm_ir::ElementGenerator> initial_value_generators,
const llvm_ir::IrArray::Index& index);
virtual bool fast_min_max() = 0;
llvm::IRBuilder<>* const b_;
llvm::Module* module_;

View File

@ -96,6 +96,10 @@ class GpuElementalIrEmitter : public ElementalIrEmitter {
llvm::Value* EmitThreadId() override;
bool fast_min_max() override {
return hlo_module_config_.debug_options().xla_gpu_enable_fast_min_max();
}
private:
// Emits IR for op, which must have opcode kPower.
StatusOr<llvm::Value*> EmitPowerOp(const HloInstruction* op,

View File

@ -91,10 +91,8 @@ llvm::CallInst* EmitCallToIntrinsic(
}
llvm::Value* EmitFloatMax(llvm::Value* lhs_value, llvm::Value* rhs_value,
llvm::IRBuilder<>* b) {
// TODO(tpopp): Pass this information down from the HLO's ModuleConfig.
if (b->getFastMathFlags().noNaNs() ||
GetDebugOptionsFromFlags().xla_cpu_enable_fast_min_max()) {
llvm::IRBuilder<>* b, bool enable_fast_min_max) {
if (b->getFastMathFlags().noNaNs() || enable_fast_min_max) {
auto cmp = b->CreateFCmpUGE(lhs_value, rhs_value);
return b->CreateSelect(cmp, lhs_value, rhs_value);
} else {
@ -106,10 +104,8 @@ llvm::Value* EmitFloatMax(llvm::Value* lhs_value, llvm::Value* rhs_value,
}
llvm::Value* EmitFloatMin(llvm::Value* lhs_value, llvm::Value* rhs_value,
llvm::IRBuilder<>* b) {
// TODO(tpopp): Pass this information down from the HLO's ModuleConfig.
if (b->getFastMathFlags().noNaNs() ||
GetDebugOptionsFromFlags().xla_cpu_enable_fast_min_max()) {
llvm::IRBuilder<>* b, bool enable_fast_min_max) {
if (b->getFastMathFlags().noNaNs() || enable_fast_min_max) {
auto cmp = b->CreateFCmpULE(lhs_value, rhs_value);
return b->CreateSelect(cmp, lhs_value, rhs_value);
} else {

View File

@ -108,12 +108,12 @@ llvm::CallInst* EmitCallToIntrinsic(
// Emit float max. Emit maxnum intrinsic is fast math is disabled, or
// fcmp+select otherwise
llvm::Value* EmitFloatMax(llvm::Value* lhs_value, llvm::Value* rhs_value,
llvm::IRBuilder<>* b);
llvm::IRBuilder<>* b, bool enable_fast_min_max);
// Emit float min. Emit minnum intrinsic is fast math is disabled, or
// fcmp+select otherwise
llvm::Value* EmitFloatMin(llvm::Value* lhs_value, llvm::Value* rhs_value,
llvm::IRBuilder<>* b);
llvm::IRBuilder<>* b, bool enable_fast_min_max);
// Convenience methods for emitting a GEP instruction that indexes into a buffer
// (1-dimensional array), equivalent to array[index]. The type is automatically

View File

@ -31,9 +31,13 @@ llvm::Value* EmitFastTanh(llvm::IRBuilder<>* b, llvm::Value* input) {
b->CreateFCmpOLT(abs_x, llvm::ConstantFP::get(type, kCanUseApprox));
// Clamp the input to [-9, 9].
//
// To simplify the code base until it's an issue, don't have a slow min/max in
// this approximation.
llvm::Value* input_clamped = llvm_ir::EmitFloatMin(
llvm_ir::EmitFloatMax(input, llvm::ConstantFP::get(type, -9.0), b),
llvm::ConstantFP::get(type, 9.0), b);
llvm_ir::EmitFloatMax(input, llvm::ConstantFP::get(type, -9.0), b,
/*enable_fast_min_max=*/true),
llvm::ConstantFP::get(type, 9.0), b, /*enable_fast_min_max=*/true);
static constexpr std::array<float, 7> numerator_coeffs{
-2.76076847742355e-16f, 2.00018790482477e-13f, -8.60467152213735e-11f,

View File

@ -1004,6 +1004,8 @@ class ReluTest(test_lib.TestCase):
z = self.evaluate(nn_ops.relu(constant_op.constant(x)))
self.assertAllEqual(y, z)
@test_util.disable_xla(
"This test relies on undefined behavior that XLA does not replicate")
@test_util.run_deprecated_v1
def testNaNs(self):
# Test that relu(nan) = nan for various sizes.