Enable XLA:CPU fast math for min/max by default to be similar to TF's behavior.
Another big change here is changing the use of this flag to use the value in the HloModule and not the global environment variable which was bad temporary behavior. PiperOrigin-RevId: 316844057 Change-Id: I995715ccc9009e9845fca77060b835fdc50fb4d2
This commit is contained in:
parent
02aad07710
commit
de907d8746
@ -64,7 +64,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() {
|
||||
opts.set_xla_cpu_fast_math_honor_division(true);
|
||||
|
||||
// By default, copy TF's Eigen style min_max behavior with nans.
|
||||
opts.set_xla_cpu_enable_fast_min_max(false);
|
||||
opts.set_xla_cpu_enable_fast_min_max(true);
|
||||
|
||||
opts.set_xla_gpu_enable_fast_min_max(true);
|
||||
|
||||
|
@ -50,6 +50,10 @@ class CpuElementalIrEmitter : public ElementalIrEmitter {
|
||||
return ir_emitter_->EmitThreadLocalCall(callee, parameters, name);
|
||||
}
|
||||
|
||||
bool fast_min_max() override {
|
||||
return hlo_module_config_.debug_options().xla_cpu_enable_fast_min_max();
|
||||
}
|
||||
|
||||
IrEmitter* ir_emitter_;
|
||||
};
|
||||
|
||||
|
@ -318,7 +318,9 @@ llvm::Value* GenerateVF32Log(llvm::IRBuilder<>* b, llvm::Value* input,
|
||||
llvm::Value* is_pos_inf_mask = vsl.FCmpEQMask(input, pos_inf);
|
||||
|
||||
// Cut off denormalized stuff.
|
||||
llvm::Value* tmp0 = vsl.Max(min_norm_pos, input);
|
||||
// Always allow fast max because we are checking for the nan above.
|
||||
llvm::Value* tmp0 =
|
||||
vsl.Max(min_norm_pos, input, /*enable_fast_min_max=*/true);
|
||||
|
||||
// VectorSupportLibrary (intentionally) can't juggle more than one type at a
|
||||
// time so drop down to IRBuilder for this bit.
|
||||
|
@ -80,10 +80,11 @@ llvm::Value* VectorSupportLibrary::Sub(llvm::Value* lhs, llvm::Value* rhs) {
|
||||
return b()->CreateFSub(lhs, rhs);
|
||||
}
|
||||
|
||||
llvm::Value* VectorSupportLibrary::Max(llvm::Value* lhs, llvm::Value* rhs) {
|
||||
llvm::Value* VectorSupportLibrary::Max(llvm::Value* lhs, llvm::Value* rhs,
|
||||
bool enable_fast_min_max) {
|
||||
AssertCorrectTypes({lhs, rhs});
|
||||
if (scalar_type_->isFloatingPointTy()) {
|
||||
return llvm_ir::EmitFloatMax(lhs, rhs, b_);
|
||||
return llvm_ir::EmitFloatMax(lhs, rhs, b_, enable_fast_min_max);
|
||||
} else {
|
||||
LOG(FATAL) << "Max for integers is unimplemented";
|
||||
}
|
||||
|
@ -78,9 +78,11 @@ class VectorSupportLibrary {
|
||||
llvm::Value* Sub(llvm::Value* lhs, const llvm::APFloat& rhs) {
|
||||
return Sub(lhs, GetConstantFloat(lhs->getType(), rhs));
|
||||
}
|
||||
llvm::Value* Max(llvm::Value* lhs, llvm::Value* rhs);
|
||||
llvm::Value* Max(const llvm::APFloat& lhs, llvm::Value* rhs) {
|
||||
return Max(GetConstantFloat(rhs->getType(), lhs), rhs);
|
||||
llvm::Value* Max(llvm::Value* lhs, llvm::Value* rhs,
|
||||
bool enable_fast_min_max);
|
||||
llvm::Value* Max(const llvm::APFloat& lhs, llvm::Value* rhs,
|
||||
bool enable_fast_min_max) {
|
||||
return Max(GetConstantFloat(rhs->getType(), lhs), rhs, enable_fast_min_max);
|
||||
}
|
||||
llvm::Value* Div(llvm::Value* lhs, llvm::Value* rhs);
|
||||
|
||||
|
@ -1313,12 +1313,12 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexBinaryOp(
|
||||
|
||||
llvm::Value* ElementalIrEmitter::EmitFloatMax(llvm::Value* lhs_value,
|
||||
llvm::Value* rhs_value) {
|
||||
return llvm_ir::EmitFloatMax(lhs_value, rhs_value, b_);
|
||||
return llvm_ir::EmitFloatMax(lhs_value, rhs_value, b_, fast_min_max());
|
||||
}
|
||||
|
||||
llvm::Value* ElementalIrEmitter::EmitFloatMin(llvm::Value* lhs_value,
|
||||
llvm::Value* rhs_value) {
|
||||
return llvm_ir::EmitFloatMin(lhs_value, rhs_value, b_);
|
||||
return llvm_ir::EmitFloatMin(lhs_value, rhs_value, b_, fast_min_max());
|
||||
}
|
||||
|
||||
StatusOr<llvm::Value*> ElementalIrEmitter::EmitLog(PrimitiveType prim_type,
|
||||
|
@ -245,6 +245,8 @@ class ElementalIrEmitter : public IrBuilderMixin<ElementalIrEmitter> {
|
||||
std::vector<llvm_ir::ElementGenerator> initial_value_generators,
|
||||
const llvm_ir::IrArray::Index& index);
|
||||
|
||||
virtual bool fast_min_max() = 0;
|
||||
|
||||
llvm::IRBuilder<>* const b_;
|
||||
|
||||
llvm::Module* module_;
|
||||
|
@ -96,6 +96,10 @@ class GpuElementalIrEmitter : public ElementalIrEmitter {
|
||||
|
||||
llvm::Value* EmitThreadId() override;
|
||||
|
||||
bool fast_min_max() override {
|
||||
return hlo_module_config_.debug_options().xla_gpu_enable_fast_min_max();
|
||||
}
|
||||
|
||||
private:
|
||||
// Emits IR for op, which must have opcode kPower.
|
||||
StatusOr<llvm::Value*> EmitPowerOp(const HloInstruction* op,
|
||||
|
@ -91,10 +91,8 @@ llvm::CallInst* EmitCallToIntrinsic(
|
||||
}
|
||||
|
||||
llvm::Value* EmitFloatMax(llvm::Value* lhs_value, llvm::Value* rhs_value,
|
||||
llvm::IRBuilder<>* b) {
|
||||
// TODO(tpopp): Pass this information down from the HLO's ModuleConfig.
|
||||
if (b->getFastMathFlags().noNaNs() ||
|
||||
GetDebugOptionsFromFlags().xla_cpu_enable_fast_min_max()) {
|
||||
llvm::IRBuilder<>* b, bool enable_fast_min_max) {
|
||||
if (b->getFastMathFlags().noNaNs() || enable_fast_min_max) {
|
||||
auto cmp = b->CreateFCmpUGE(lhs_value, rhs_value);
|
||||
return b->CreateSelect(cmp, lhs_value, rhs_value);
|
||||
} else {
|
||||
@ -106,10 +104,8 @@ llvm::Value* EmitFloatMax(llvm::Value* lhs_value, llvm::Value* rhs_value,
|
||||
}
|
||||
|
||||
llvm::Value* EmitFloatMin(llvm::Value* lhs_value, llvm::Value* rhs_value,
|
||||
llvm::IRBuilder<>* b) {
|
||||
// TODO(tpopp): Pass this information down from the HLO's ModuleConfig.
|
||||
if (b->getFastMathFlags().noNaNs() ||
|
||||
GetDebugOptionsFromFlags().xla_cpu_enable_fast_min_max()) {
|
||||
llvm::IRBuilder<>* b, bool enable_fast_min_max) {
|
||||
if (b->getFastMathFlags().noNaNs() || enable_fast_min_max) {
|
||||
auto cmp = b->CreateFCmpULE(lhs_value, rhs_value);
|
||||
return b->CreateSelect(cmp, lhs_value, rhs_value);
|
||||
} else {
|
||||
|
@ -108,12 +108,12 @@ llvm::CallInst* EmitCallToIntrinsic(
|
||||
// Emit float max. Emit maxnum intrinsic is fast math is disabled, or
|
||||
// fcmp+select otherwise
|
||||
llvm::Value* EmitFloatMax(llvm::Value* lhs_value, llvm::Value* rhs_value,
|
||||
llvm::IRBuilder<>* b);
|
||||
llvm::IRBuilder<>* b, bool enable_fast_min_max);
|
||||
|
||||
// Emit float min. Emit minnum intrinsic is fast math is disabled, or
|
||||
// fcmp+select otherwise
|
||||
llvm::Value* EmitFloatMin(llvm::Value* lhs_value, llvm::Value* rhs_value,
|
||||
llvm::IRBuilder<>* b);
|
||||
llvm::IRBuilder<>* b, bool enable_fast_min_max);
|
||||
|
||||
// Convenience methods for emitting a GEP instruction that indexes into a buffer
|
||||
// (1-dimensional array), equivalent to array[index]. The type is automatically
|
||||
|
@ -31,9 +31,13 @@ llvm::Value* EmitFastTanh(llvm::IRBuilder<>* b, llvm::Value* input) {
|
||||
b->CreateFCmpOLT(abs_x, llvm::ConstantFP::get(type, kCanUseApprox));
|
||||
|
||||
// Clamp the input to [-9, 9].
|
||||
//
|
||||
// To simplify the code base until it's an issue, don't have a slow min/max in
|
||||
// this approximation.
|
||||
llvm::Value* input_clamped = llvm_ir::EmitFloatMin(
|
||||
llvm_ir::EmitFloatMax(input, llvm::ConstantFP::get(type, -9.0), b),
|
||||
llvm::ConstantFP::get(type, 9.0), b);
|
||||
llvm_ir::EmitFloatMax(input, llvm::ConstantFP::get(type, -9.0), b,
|
||||
/*enable_fast_min_max=*/true),
|
||||
llvm::ConstantFP::get(type, 9.0), b, /*enable_fast_min_max=*/true);
|
||||
|
||||
static constexpr std::array<float, 7> numerator_coeffs{
|
||||
-2.76076847742355e-16f, 2.00018790482477e-13f, -8.60467152213735e-11f,
|
||||
|
@ -1004,6 +1004,8 @@ class ReluTest(test_lib.TestCase):
|
||||
z = self.evaluate(nn_ops.relu(constant_op.constant(x)))
|
||||
self.assertAllEqual(y, z)
|
||||
|
||||
@test_util.disable_xla(
|
||||
"This test relies on undefined behavior that XLA does not replicate")
|
||||
@test_util.run_deprecated_v1
|
||||
def testNaNs(self):
|
||||
# Test that relu(nan) = nan for various sizes.
|
||||
|
Loading…
Reference in New Issue
Block a user