Enable XLA:CPU fast math for min/max by default to be similar to TF's behavior.

Another big change here is changing the use of this flag to use the value in the HloModule and not the global environment variable which was bad temporary behavior.

PiperOrigin-RevId: 316844057
Change-Id: I995715ccc9009e9845fca77060b835fdc50fb4d2
This commit is contained in:
Tres Popp 2020-06-17 01:20:35 -07:00 committed by TensorFlower Gardener
parent 02aad07710
commit de907d8746
12 changed files with 38 additions and 21 deletions

View File

@ -64,7 +64,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() {
opts.set_xla_cpu_fast_math_honor_division(true); opts.set_xla_cpu_fast_math_honor_division(true);
// By default, copy TF's Eigen style min_max behavior with nans. // By default, copy TF's Eigen style min_max behavior with nans.
opts.set_xla_cpu_enable_fast_min_max(false); opts.set_xla_cpu_enable_fast_min_max(true);
opts.set_xla_gpu_enable_fast_min_max(true); opts.set_xla_gpu_enable_fast_min_max(true);

View File

@ -50,6 +50,10 @@ class CpuElementalIrEmitter : public ElementalIrEmitter {
return ir_emitter_->EmitThreadLocalCall(callee, parameters, name); return ir_emitter_->EmitThreadLocalCall(callee, parameters, name);
} }
bool fast_min_max() override {
return hlo_module_config_.debug_options().xla_cpu_enable_fast_min_max();
}
IrEmitter* ir_emitter_; IrEmitter* ir_emitter_;
}; };

View File

@ -318,7 +318,9 @@ llvm::Value* GenerateVF32Log(llvm::IRBuilder<>* b, llvm::Value* input,
llvm::Value* is_pos_inf_mask = vsl.FCmpEQMask(input, pos_inf); llvm::Value* is_pos_inf_mask = vsl.FCmpEQMask(input, pos_inf);
// Cut off denormalized stuff. // Cut off denormalized stuff.
llvm::Value* tmp0 = vsl.Max(min_norm_pos, input); // Always allow fast max because we are checking for the nan above.
llvm::Value* tmp0 =
vsl.Max(min_norm_pos, input, /*enable_fast_min_max=*/true);
// VectorSupportLibrary (intentionally) can't juggle more than one type at a // VectorSupportLibrary (intentionally) can't juggle more than one type at a
// time so drop down to IRBuilder for this bit. // time so drop down to IRBuilder for this bit.

View File

@ -80,10 +80,11 @@ llvm::Value* VectorSupportLibrary::Sub(llvm::Value* lhs, llvm::Value* rhs) {
return b()->CreateFSub(lhs, rhs); return b()->CreateFSub(lhs, rhs);
} }
llvm::Value* VectorSupportLibrary::Max(llvm::Value* lhs, llvm::Value* rhs) { llvm::Value* VectorSupportLibrary::Max(llvm::Value* lhs, llvm::Value* rhs,
bool enable_fast_min_max) {
AssertCorrectTypes({lhs, rhs}); AssertCorrectTypes({lhs, rhs});
if (scalar_type_->isFloatingPointTy()) { if (scalar_type_->isFloatingPointTy()) {
return llvm_ir::EmitFloatMax(lhs, rhs, b_); return llvm_ir::EmitFloatMax(lhs, rhs, b_, enable_fast_min_max);
} else { } else {
LOG(FATAL) << "Max for integers is unimplemented"; LOG(FATAL) << "Max for integers is unimplemented";
} }

View File

@ -78,9 +78,11 @@ class VectorSupportLibrary {
llvm::Value* Sub(llvm::Value* lhs, const llvm::APFloat& rhs) { llvm::Value* Sub(llvm::Value* lhs, const llvm::APFloat& rhs) {
return Sub(lhs, GetConstantFloat(lhs->getType(), rhs)); return Sub(lhs, GetConstantFloat(lhs->getType(), rhs));
} }
llvm::Value* Max(llvm::Value* lhs, llvm::Value* rhs); llvm::Value* Max(llvm::Value* lhs, llvm::Value* rhs,
llvm::Value* Max(const llvm::APFloat& lhs, llvm::Value* rhs) { bool enable_fast_min_max);
return Max(GetConstantFloat(rhs->getType(), lhs), rhs); llvm::Value* Max(const llvm::APFloat& lhs, llvm::Value* rhs,
bool enable_fast_min_max) {
return Max(GetConstantFloat(rhs->getType(), lhs), rhs, enable_fast_min_max);
} }
llvm::Value* Div(llvm::Value* lhs, llvm::Value* rhs); llvm::Value* Div(llvm::Value* lhs, llvm::Value* rhs);

View File

@ -1313,12 +1313,12 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexBinaryOp(
llvm::Value* ElementalIrEmitter::EmitFloatMax(llvm::Value* lhs_value, llvm::Value* ElementalIrEmitter::EmitFloatMax(llvm::Value* lhs_value,
llvm::Value* rhs_value) { llvm::Value* rhs_value) {
return llvm_ir::EmitFloatMax(lhs_value, rhs_value, b_); return llvm_ir::EmitFloatMax(lhs_value, rhs_value, b_, fast_min_max());
} }
llvm::Value* ElementalIrEmitter::EmitFloatMin(llvm::Value* lhs_value, llvm::Value* ElementalIrEmitter::EmitFloatMin(llvm::Value* lhs_value,
llvm::Value* rhs_value) { llvm::Value* rhs_value) {
return llvm_ir::EmitFloatMin(lhs_value, rhs_value, b_); return llvm_ir::EmitFloatMin(lhs_value, rhs_value, b_, fast_min_max());
} }
StatusOr<llvm::Value*> ElementalIrEmitter::EmitLog(PrimitiveType prim_type, StatusOr<llvm::Value*> ElementalIrEmitter::EmitLog(PrimitiveType prim_type,

View File

@ -245,6 +245,8 @@ class ElementalIrEmitter : public IrBuilderMixin<ElementalIrEmitter> {
std::vector<llvm_ir::ElementGenerator> initial_value_generators, std::vector<llvm_ir::ElementGenerator> initial_value_generators,
const llvm_ir::IrArray::Index& index); const llvm_ir::IrArray::Index& index);
virtual bool fast_min_max() = 0;
llvm::IRBuilder<>* const b_; llvm::IRBuilder<>* const b_;
llvm::Module* module_; llvm::Module* module_;

View File

@ -96,6 +96,10 @@ class GpuElementalIrEmitter : public ElementalIrEmitter {
llvm::Value* EmitThreadId() override; llvm::Value* EmitThreadId() override;
bool fast_min_max() override {
return hlo_module_config_.debug_options().xla_gpu_enable_fast_min_max();
}
private: private:
// Emits IR for op, which must have opcode kPower. // Emits IR for op, which must have opcode kPower.
StatusOr<llvm::Value*> EmitPowerOp(const HloInstruction* op, StatusOr<llvm::Value*> EmitPowerOp(const HloInstruction* op,

View File

@ -91,10 +91,8 @@ llvm::CallInst* EmitCallToIntrinsic(
} }
llvm::Value* EmitFloatMax(llvm::Value* lhs_value, llvm::Value* rhs_value, llvm::Value* EmitFloatMax(llvm::Value* lhs_value, llvm::Value* rhs_value,
llvm::IRBuilder<>* b) { llvm::IRBuilder<>* b, bool enable_fast_min_max) {
// TODO(tpopp): Pass this information down from the HLO's ModuleConfig. if (b->getFastMathFlags().noNaNs() || enable_fast_min_max) {
if (b->getFastMathFlags().noNaNs() ||
GetDebugOptionsFromFlags().xla_cpu_enable_fast_min_max()) {
auto cmp = b->CreateFCmpUGE(lhs_value, rhs_value); auto cmp = b->CreateFCmpUGE(lhs_value, rhs_value);
return b->CreateSelect(cmp, lhs_value, rhs_value); return b->CreateSelect(cmp, lhs_value, rhs_value);
} else { } else {
@ -106,10 +104,8 @@ llvm::Value* EmitFloatMax(llvm::Value* lhs_value, llvm::Value* rhs_value,
} }
llvm::Value* EmitFloatMin(llvm::Value* lhs_value, llvm::Value* rhs_value, llvm::Value* EmitFloatMin(llvm::Value* lhs_value, llvm::Value* rhs_value,
llvm::IRBuilder<>* b) { llvm::IRBuilder<>* b, bool enable_fast_min_max) {
// TODO(tpopp): Pass this information down from the HLO's ModuleConfig. if (b->getFastMathFlags().noNaNs() || enable_fast_min_max) {
if (b->getFastMathFlags().noNaNs() ||
GetDebugOptionsFromFlags().xla_cpu_enable_fast_min_max()) {
auto cmp = b->CreateFCmpULE(lhs_value, rhs_value); auto cmp = b->CreateFCmpULE(lhs_value, rhs_value);
return b->CreateSelect(cmp, lhs_value, rhs_value); return b->CreateSelect(cmp, lhs_value, rhs_value);
} else { } else {

View File

@ -108,12 +108,12 @@ llvm::CallInst* EmitCallToIntrinsic(
// Emit float max. Emit maxnum intrinsic is fast math is disabled, or // Emit float max. Emit maxnum intrinsic is fast math is disabled, or
// fcmp+select otherwise // fcmp+select otherwise
llvm::Value* EmitFloatMax(llvm::Value* lhs_value, llvm::Value* rhs_value, llvm::Value* EmitFloatMax(llvm::Value* lhs_value, llvm::Value* rhs_value,
llvm::IRBuilder<>* b); llvm::IRBuilder<>* b, bool enable_fast_min_max);
// Emit float min. Emit minnum intrinsic is fast math is disabled, or // Emit float min. Emit minnum intrinsic is fast math is disabled, or
// fcmp+select otherwise // fcmp+select otherwise
llvm::Value* EmitFloatMin(llvm::Value* lhs_value, llvm::Value* rhs_value, llvm::Value* EmitFloatMin(llvm::Value* lhs_value, llvm::Value* rhs_value,
llvm::IRBuilder<>* b); llvm::IRBuilder<>* b, bool enable_fast_min_max);
// Convenience methods for emitting a GEP instruction that indexes into a buffer // Convenience methods for emitting a GEP instruction that indexes into a buffer
// (1-dimensional array), equivalent to array[index]. The type is automatically // (1-dimensional array), equivalent to array[index]. The type is automatically

View File

@ -31,9 +31,13 @@ llvm::Value* EmitFastTanh(llvm::IRBuilder<>* b, llvm::Value* input) {
b->CreateFCmpOLT(abs_x, llvm::ConstantFP::get(type, kCanUseApprox)); b->CreateFCmpOLT(abs_x, llvm::ConstantFP::get(type, kCanUseApprox));
// Clamp the input to [-9, 9]. // Clamp the input to [-9, 9].
//
// To simplify the code base until it's an issue, don't have a slow min/max in
// this approximation.
llvm::Value* input_clamped = llvm_ir::EmitFloatMin( llvm::Value* input_clamped = llvm_ir::EmitFloatMin(
llvm_ir::EmitFloatMax(input, llvm::ConstantFP::get(type, -9.0), b), llvm_ir::EmitFloatMax(input, llvm::ConstantFP::get(type, -9.0), b,
llvm::ConstantFP::get(type, 9.0), b); /*enable_fast_min_max=*/true),
llvm::ConstantFP::get(type, 9.0), b, /*enable_fast_min_max=*/true);
static constexpr std::array<float, 7> numerator_coeffs{ static constexpr std::array<float, 7> numerator_coeffs{
-2.76076847742355e-16f, 2.00018790482477e-13f, -8.60467152213735e-11f, -2.76076847742355e-16f, 2.00018790482477e-13f, -8.60467152213735e-11f,

View File

@ -1004,6 +1004,8 @@ class ReluTest(test_lib.TestCase):
z = self.evaluate(nn_ops.relu(constant_op.constant(x))) z = self.evaluate(nn_ops.relu(constant_op.constant(x)))
self.assertAllEqual(y, z) self.assertAllEqual(y, z)
@test_util.disable_xla(
"This test relies on undefined behavior that XLA does not replicate")
@test_util.run_deprecated_v1 @test_util.run_deprecated_v1
def testNaNs(self): def testNaNs(self):
# Test that relu(nan) = nan for various sizes. # Test that relu(nan) = nan for various sizes.