[XLA:CPU] More accurate expm1 when x is small, take two
We approximate it with: expm1(x) = tanh(x/2)*(exp(x)+1) Additional care is taken to handle the case when x/2 underflows but x does not by simply approximating the result with x itself. Yet further care must be taken to handle the case when exp(x) would not be all that close to 1, in which case we simply use: expm1(x) = exp(x)-1 The pseudo-code for this is roughly: if x/2 == 0: return x exp_x = exp(x) if |x| > .5: return exp_x - 1 return tanh(x/2)*(exp_x+1) The actual code sequence emitted preserves vectorization in the case where different lanes observe inputs where the magnitudes are entirely different. This suffices to get us within a relative error of 4.76e-7 or about eight ULPs when compared against libm. PiperOrigin-RevId: 358861023 Change-Id: I4a51ec8e2a16a95b6cbaa2af3305ce3a16201c54
This commit is contained in:
parent
4856f23a49
commit
1e135c54c5
@ -2551,6 +2551,23 @@ llvm::Value* IrEmitter::EmitPrintf(absl::string_view fmt,
|
||||
call_args);
|
||||
}
|
||||
|
||||
llvm::Value* IrEmitter::EmitFprintf(absl::string_view fmt,
|
||||
absl::Span<llvm::Value* const> arguments) {
|
||||
llvm::Type* ptr_ty = b_.getInt8Ty()->getPointerTo();
|
||||
auto stderr_symbol =
|
||||
b_.GetInsertBlock()->getParent()->getParent()->getOrInsertGlobal("stderr",
|
||||
ptr_ty);
|
||||
std::vector<llvm::Value*> call_args;
|
||||
call_args.push_back(b_.CreateLoad(stderr_symbol));
|
||||
call_args.push_back(b_.CreateGlobalStringPtr(llvm_ir::AsStringRef(fmt)));
|
||||
absl::c_copy(arguments, std::back_inserter(call_args));
|
||||
return b_.CreateCall(
|
||||
b_.GetInsertBlock()->getParent()->getParent()->getOrInsertFunction(
|
||||
"fprintf", llvm::FunctionType::get(b_.getInt32Ty(), {ptr_ty, ptr_ty},
|
||||
/*isVarArg=*/true)),
|
||||
call_args);
|
||||
}
|
||||
|
||||
llvm::Value* IrEmitter::EmitCallToFunc(
|
||||
std::string func_name, const std::vector<llvm::Value*>& arguments,
|
||||
llvm::Type* return_type, bool does_not_throw, bool only_accesses_arg_memory,
|
||||
|
@ -418,6 +418,8 @@ class IrEmitter : public DfsHloVisitorWithDefault,
|
||||
// Emits printing during the execution.
|
||||
llvm::Value* EmitPrintf(absl::string_view fmt,
|
||||
absl::Span<llvm::Value* const> arguments);
|
||||
llvm::Value* EmitFprintf(absl::string_view fmt,
|
||||
absl::Span<llvm::Value* const> arguments);
|
||||
|
||||
// Emits a call to a non-variadic function `func_name` with arguments
|
||||
// `arguments` assuming C calling convention.
|
||||
|
@ -237,7 +237,10 @@ namespace {
|
||||
bool RegisterKnownJITSymbols() {
|
||||
xla::CustomCallTargetRegistry* registry =
|
||||
xla::CustomCallTargetRegistry::Global();
|
||||
registry->Register("fprintf", reinterpret_cast<void*>(&fprintf), "Host");
|
||||
registry->Register("printf", reinterpret_cast<void*>(&printf), "Host");
|
||||
registry->Register("stderr", reinterpret_cast<void*>(&stderr), "Host");
|
||||
registry->Register("puts", reinterpret_cast<void*>(&puts), "Host");
|
||||
|
||||
#define REGISTER_CPU_RUNTIME_SYMBOL(base_name) \
|
||||
do { \
|
||||
|
@ -1424,25 +1424,24 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitExpm1(PrimitiveType prim_type,
|
||||
auto type = llvm_ir::PrimitiveTypeToIrType(prim_type, module_);
|
||||
auto one = llvm::ConstantFP::get(type, 1.0);
|
||||
auto half = llvm::ConstantFP::get(type, 0.5);
|
||||
// When the exponent is large, the naive evaluation of e^(x) - 1 is more
|
||||
// accurate than the Taylor series.
|
||||
TF_ASSIGN_OR_RETURN(auto exp_x, EmitExp(prim_type, value, ""));
|
||||
auto for_large_x = FSub(exp_x, one);
|
||||
// The Taylor series for exp(x) is 1 + x + x^2/2 + x^3/6 + ….
|
||||
// We want exp(x)-1 which is x + x^2/2 + x^3/6 + ….
|
||||
// We use the second degree approximation of exp(x)-1 = x + x^2/2.
|
||||
auto x_squared = FMul(x, x);
|
||||
auto x_squared_over_two = FMul(x_squared, half);
|
||||
auto for_small_x = FAdd(x, x_squared_over_two);
|
||||
// At this point, the relative errors due to floating point precision loss of
|
||||
// calculating exp(x) - 1 and the polynomial exp(x)-1 = x + x^2/2 are about
|
||||
// equal, with a value of approximately 2^-16.
|
||||
const auto kExponentIsSmallThreshold = 0.009;
|
||||
auto zero = llvm::ConstantFP::get(type, 0.0);
|
||||
|
||||
// expm1(x) == tanh(x/2)*(exp(x)+1)
|
||||
// x/2 can underflow, if it does we approximate expm1 with x.
|
||||
auto x_over_two = FMul(x, half);
|
||||
auto x_over_two_is_zero = FCmpOEQ(x_over_two, zero);
|
||||
auto abs_x =
|
||||
llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs, {value}, {type}, b_);
|
||||
auto x_is_small =
|
||||
FCmpOLT(abs_x, llvm::ConstantFP::get(type, kExponentIsSmallThreshold));
|
||||
return Select(x_is_small, for_small_x, for_large_x);
|
||||
llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs, {x}, {type}, b_);
|
||||
// Use a naive exp(x)-1 calculation if |x| is > 0.5
|
||||
auto x_magnitude_is_large = FCmpOGT(abs_x, half);
|
||||
TF_ASSIGN_OR_RETURN(auto tanh_of_x_over_two, EmitTanh(prim_type, x_over_two));
|
||||
TF_ASSIGN_OR_RETURN(auto exp_of_x, EmitExp(prim_type, x, ""));
|
||||
auto exp_of_x_plus_one = FAdd(exp_of_x, one);
|
||||
auto exp_of_x_minus_one = FSub(exp_of_x, one);
|
||||
auto expm1_of_x = FMul(tanh_of_x_over_two, exp_of_x_plus_one);
|
||||
expm1_of_x = Select(x_magnitude_is_large, exp_of_x_minus_one, expm1_of_x);
|
||||
expm1_of_x = Select(x_over_two_is_zero, x, expm1_of_x);
|
||||
return expm1_of_x;
|
||||
}
|
||||
|
||||
StatusOr<llvm::Value*> ElementalIrEmitter::EmitPow(PrimitiveType prim_type,
|
||||
|
@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <limits>
|
||||
|
||||
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
|
||||
#include "tensorflow/compiler/xla/tests/exhaustive_op_test_utils.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
@ -300,7 +302,15 @@ UNARY_TEST_FLOAT_32_BITS_OR_LESS(Exp, {
|
||||
UNARY_TEST_FLOAT_32_BITS_OR_LESS(Expm1, {
|
||||
ErrorSpecGen error_spec_gen = GetDefaultSpecGenerator();
|
||||
if (ty_ == F32) {
|
||||
error_spec_gen = +[](NativeT x) { return ErrorSpec{0, 0.00015}; };
|
||||
if (platform_ == "Host") {
|
||||
error_spec_gen = +[](NativeT x) {
|
||||
// We expect no worse than an error of 8 ULPs.
|
||||
return ErrorSpec{
|
||||
0.0, std::scalbn(8.0f, -std::numeric_limits<float>::digits)};
|
||||
};
|
||||
} else {
|
||||
error_spec_gen = +[](NativeT x) { return ErrorSpec{0, 0.00015}; };
|
||||
}
|
||||
}
|
||||
|
||||
// Our CPU implementation of expm1 returns one incorrect value: says
|
||||
|
Loading…
x
Reference in New Issue
Block a user