From 955bc76f9ee7c08b81afa1cf764ac2a68ccc924a Mon Sep 17 00:00:00 2001 From: David Majnemer Date: Mon, 18 Mar 2019 14:27:07 -0700 Subject: [PATCH] [TF:XLA] Implement FloorMod with fewer calls to xla::Rem This implementation is faster as xla::Rem is rather expensive. It is also more numerically sound as it has fewer rounding steps. For example, given the single precision values: x = -1.46146206e-09 y = 0.562811792 old FloorMod(x, y) would compute 0 new FloorMod(x, y) computes 0.56281179 This agrees with numpy.mod which returns the element-wise remainder of the quotient floor_divide(x, y). PiperOrigin-RevId: 239061236 --- tensorflow/compiler/tf2xla/kernels/binary_ops.cc | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc index f69b5dc0222..b708e91722e 100644 --- a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc @@ -148,14 +148,17 @@ XLA_MAKE_BINARY(Xdivy, XdivyImpl(lhs, rhs, broadcast_helper)); // Implementation of FloorMod. Pseudo-code: // T trunc_mod = std::fmod(x, y); -// return (x < T(0)) == (y < T(0)) ? trunc_mod : std::fmod(trunc_mod + y, y); +// return trunc_mod != 0 && (y < 0 != trunc_mod < 0) ? trunc_mod + y +// : trunc_mod; static xla::XlaOp FloorModImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x, xla::XlaOp y, const BCast& broadcast_helper) { std::tie(x, y) = XlaBinaryOp::Broadcast(x, y, broadcast_helper); auto zero = XlaHelpers::Zero(b, dtype); - auto same_sign = xla::Eq(xla::Lt(x, zero), xla::Lt(y, zero)); auto trunc_mod = xla::Rem(x, y); - return xla::Select(same_sign, trunc_mod, xla::Rem(xla::Add(trunc_mod, y), y)); + auto trunc_mod_not_zero = xla::Ne(trunc_mod, zero); + auto do_plus = xla::And(xla::Ne(xla::Lt(trunc_mod, zero), xla::Lt(y, zero)), + trunc_mod_not_zero); + return xla::Select(do_plus, xla::Add(trunc_mod, y), trunc_mod); } XLA_MAKE_BINARY(FloorMod, FloorModImpl(b, input_type(0), lhs, rhs, broadcast_helper));