[XLA] Simplify ({0, ..., n-1} + n) % n to {0, ..., n-1}.

We break this down into two simplification rules:

 - (X + n) % n  ==> X % N if X + N does not overflow.
 - X % n        ==> X     if X < n.

PiperOrigin-RevId: 256907449
This commit is contained in:
Justin Lebar 2019-07-07 22:13:14 -07:00 committed by TensorFlower Gardener
parent ad6e2c74ce
commit 780530fd75
4 changed files with 159 additions and 0 deletions

View File

@ -220,6 +220,7 @@ cc_library(
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/types:optional",
"@com_google_absl//absl/types:span",
],
)

View File

@ -16,6 +16,10 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_OVERFLOW_UTIL_H_
#define TENSORFLOW_COMPILER_XLA_OVERFLOW_UTIL_H_
#include <type_traits>
#include "absl/types/optional.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
@ -45,6 +49,54 @@ inline int64 MultiplyWithoutOverflow(const int64 x, const int64 y) {
return static_cast<int64>(uxy);
}
// Computes x + y and returns nullopt if it overflows.
//
// x and y must be signed integers.
template <typename T>
inline absl::optional<T> OverflowSafeAdd(T x, T y) {
static_assert(std::is_signed<T>::value,
"Only implemented for signed numbers T.");
static_assert(std::is_integral<T>::value, "Only implemented for integers T.");
// "Signed integer overflow occurs on integer addition iff the operands have
// the same sign and the sum has a sign opposite to that of the operands."
// Hacker's Delight 2nd ed, p 28.
using U = typename std::make_unsigned<T>::type;
const U ux = x;
const U uy = y;
const U usum = ux + uy;
const T sum = usum;
if (x >= 0 == y >= 0 && sum >= 0 != x >= 0) {
return absl::nullopt;
}
return sum;
}
inline bool FitsInIntegralType(int64 x, PrimitiveType ty) {
switch (ty) {
case S8:
return std::numeric_limits<int8>::min() <= x &&
std::numeric_limits<int8>::max() >= x;
case S16:
return std::numeric_limits<int16>::min() <= x &&
std::numeric_limits<int16>::max() >= x;
case S32:
return std::numeric_limits<int32>::min() <= x &&
std::numeric_limits<int32>::max() >= x;
case S64:
return true;
case U8:
return 0 <= x && std::numeric_limits<uint8>::max() >= x;
case U16:
return 0 <= x && std::numeric_limits<uint16>::max() >= x;
case U32:
return 0 <= x && std::numeric_limits<uint32>::max() >= x;
case U64:
return 0 <= x;
default:
LOG(FATAL) << "Invalid primitive type " << PrimitiveType_Name(ty);
}
}
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_OVERFLOW_UTIL_H_

View File

@ -37,6 +37,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/overflow_util.h"
#include "tensorflow/compiler/xla/primitive_util.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
@ -2713,6 +2714,65 @@ Status AlgebraicSimplifierVisitor::HandleRemainder(HloInstruction* remainder) {
break;
}
// If M < N, then {0, ..., M} % N ==> {0, ..., M}.
//
// Currently this only covers the case when N is a broadcasted constant
// scalar. We could also cover the case when N is a non-broadcasted constant
// with the same value repeated.
HloInstruction* iota;
HloInstruction* divisor;
if (Match(remainder,
m::Remainder(m::Iota(&iota),
m::Broadcast(m::ConstantEffectiveScalar(&divisor))))) {
// The iota counts {0, ..., iota_upper_bound - 1}. (Actually this is
// conservative; the iota may overflow and count up to a smaller value than
// this. But that's OK for our purposes here.)
int64 iota_upper_bound = iota->shape().dimensions(
Cast<HloIotaInstruction>(iota)->iota_dimension());
StatusOr<int64> divisor_val = divisor->literal().GetIntegralAsS64(
std::vector<int64>(0, divisor->shape().dimensions_size()));
if (divisor_val.ok() && divisor_val.ValueOrDie() >= iota_upper_bound) {
return ReplaceInstruction(remainder, iota);
}
}
// (X + N) % N = X % N, so long as X + N does not overflow.
//
// We don't have range tracking in XLA that would let us know whether X + N
// overflows, so for now we only do this simplification when X is an iota. We
// could add other operations where it's easy to see a range, such as
// remainder, convert, etc., though at some point we'd probably want a
// range-tracking analysis.
HloInstruction* bcast;
HloInstruction* addend;
if (Match(
remainder,
m::Remainder(
m::AddAnyOrder(m::Iota(&iota),
m::Broadcast(m::ConstantEffectiveScalar(&addend))),
m::Broadcast(&bcast, m::ConstantEffectiveScalar(&divisor)))) &&
addend == divisor) {
// The iota counts {0, ...iota_upper_bound - 1}, with the same caveat above
// that iota_upper_bound is conservative, and the true upper bound may be
// smaller.
int64 iota_upper_bound = iota->shape().dimensions(
Cast<HloIotaInstruction>(iota)->iota_dimension());
StatusOr<int64> divisor_val = divisor->literal().GetIntegralAsS64(
std::vector<int64>(0, divisor->shape().dimensions_size()));
if (divisor_val.ok()) {
// Check whether divisor_val + iota_upper_bound - 1 overflows.
absl::optional<int64> max_val =
OverflowSafeAdd(divisor_val.ValueOrDie(), iota_upper_bound);
if (max_val.has_value() &&
FitsInIntegralType(*max_val, iota->shape().element_type())) {
return ReplaceWithNewInstruction(
remainder,
HloInstruction::CreateBinary(remainder->shape(),
HloOpcode::kRemainder, iota, bcast));
}
}
}
return Status::OK();
}

View File

@ -5457,5 +5457,51 @@ TEST_F(AlgebraicSimplifierTest, CompareSame) {
GmockMatch(m::Broadcast(m::ConstantScalar(true))));
}
TEST_F(AlgebraicSimplifierTest, RemainderOfIota) {
const char* kModuleStr = R"(
HloModule m
test {
iota = s32[5,1000] iota(), iota_dimension=0
five = s32[] constant(5)
five_bcast = s32[5,1000] broadcast(s32[] five), dimensions={}
ROOT remainder = s32[5,1000] remainder(iota, s32[5,1000] five_bcast)
})";
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
EXPECT_THAT(m->entry_computation()->root_instruction(),
GmockMatch(m::Iota()));
}
TEST_F(AlgebraicSimplifierTest, RemainderOfNPlusIota) {
const char* kModuleStr = R"(
HloModule m
test {
iota = s32[5,1000] iota(), iota_dimension=0
five = s32[] constant(5)
five_bcast = s32[5,1000] broadcast(five), dimensions={}
sum = s32[5,1000] add(iota, five_bcast)
ROOT remainder = s32[5,1000] remainder(sum, s32[5,1000] five_bcast)
})";
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
EXPECT_THAT(m->entry_computation()->root_instruction(),
GmockMatch(m::Remainder(m::Iota(), m::Broadcast())));
}
// No simplification because 125 + 5 overflows S8.
TEST_F(AlgebraicSimplifierTest, RemainderOfNPlusIotaOverflow) {
const char* kModuleStr = R"(
HloModule m
test {
iota = s8[126] iota(), iota_dimension=0
five = s8[] constant(5)
five_bcast = s8[126] broadcast(five), dimensions={}
sum = s8[126] add(iota, five_bcast)
ROOT remainder = s8[126] remainder(sum, s8[126] five_bcast)
})";
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
}
} // namespace
} // namespace xla