[TF:XLA] Implement FloorMod with fewer calls to xla::Rem
This implementation is faster as xla::Rem is rather expensive. It is also more numerically sound as it has fewer rounding steps. For example, given the single precision values: x = -1.46146206e-09 y = 0.562811792 old FloorMod(x, y) would compute 0 new FloorMod(x, y) computes 0.56281179 This agrees with numpy.mod which returns the element-wise remainder of the quotient floor_divide(x, y). PiperOrigin-RevId: 239061236
This commit is contained in:
parent
e1e6ec9c2c
commit
955bc76f9e
@ -148,14 +148,17 @@ XLA_MAKE_BINARY(Xdivy, XdivyImpl(lhs, rhs, broadcast_helper));
|
|||||||
|
|
||||||
// Implementation of FloorMod. Pseudo-code:
|
// Implementation of FloorMod. Pseudo-code:
|
||||||
// T trunc_mod = std::fmod(x, y);
|
// T trunc_mod = std::fmod(x, y);
|
||||||
// return (x < T(0)) == (y < T(0)) ? trunc_mod : std::fmod(trunc_mod + y, y);
|
// return trunc_mod != 0 && (y < 0 != trunc_mod < 0) ? trunc_mod + y
|
||||||
|
// : trunc_mod;
|
||||||
static xla::XlaOp FloorModImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x,
|
static xla::XlaOp FloorModImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x,
|
||||||
xla::XlaOp y, const BCast& broadcast_helper) {
|
xla::XlaOp y, const BCast& broadcast_helper) {
|
||||||
std::tie(x, y) = XlaBinaryOp::Broadcast(x, y, broadcast_helper);
|
std::tie(x, y) = XlaBinaryOp::Broadcast(x, y, broadcast_helper);
|
||||||
auto zero = XlaHelpers::Zero(b, dtype);
|
auto zero = XlaHelpers::Zero(b, dtype);
|
||||||
auto same_sign = xla::Eq(xla::Lt(x, zero), xla::Lt(y, zero));
|
|
||||||
auto trunc_mod = xla::Rem(x, y);
|
auto trunc_mod = xla::Rem(x, y);
|
||||||
return xla::Select(same_sign, trunc_mod, xla::Rem(xla::Add(trunc_mod, y), y));
|
auto trunc_mod_not_zero = xla::Ne(trunc_mod, zero);
|
||||||
|
auto do_plus = xla::And(xla::Ne(xla::Lt(trunc_mod, zero), xla::Lt(y, zero)),
|
||||||
|
trunc_mod_not_zero);
|
||||||
|
return xla::Select(do_plus, xla::Add(trunc_mod, y), trunc_mod);
|
||||||
}
|
}
|
||||||
XLA_MAKE_BINARY(FloorMod,
|
XLA_MAKE_BINARY(FloorMod,
|
||||||
FloorModImpl(b, input_type(0), lhs, rhs, broadcast_helper));
|
FloorModImpl(b, input_type(0), lhs, rhs, broadcast_helper));
|
||||||
|
Loading…
Reference in New Issue
Block a user