[XLA] Simplify ({0, ..., n-1} + n) % n to {0, ..., n-1}.
We break this down into two simplification rules: - (X + n) % n ==> X % N if X + N does not overflow. - X % n ==> X if X < n. PiperOrigin-RevId: 256907449
This commit is contained in:
parent
ad6e2c74ce
commit
780530fd75
@ -220,6 +220,7 @@ cc_library(
|
||||
"@com_google_absl//absl/container:inlined_vector",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
@ -16,6 +16,10 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_OVERFLOW_UTIL_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_OVERFLOW_UTIL_H_
|
||||
|
||||
#include <type_traits>
|
||||
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
@ -45,6 +49,54 @@ inline int64 MultiplyWithoutOverflow(const int64 x, const int64 y) {
|
||||
return static_cast<int64>(uxy);
|
||||
}
|
||||
|
||||
// Computes x + y and returns nullopt if it overflows.
|
||||
//
|
||||
// x and y must be signed integers.
|
||||
template <typename T>
|
||||
inline absl::optional<T> OverflowSafeAdd(T x, T y) {
|
||||
static_assert(std::is_signed<T>::value,
|
||||
"Only implemented for signed numbers T.");
|
||||
static_assert(std::is_integral<T>::value, "Only implemented for integers T.");
|
||||
// "Signed integer overflow occurs on integer addition iff the operands have
|
||||
// the same sign and the sum has a sign opposite to that of the operands."
|
||||
// Hacker's Delight 2nd ed, p 28.
|
||||
using U = typename std::make_unsigned<T>::type;
|
||||
const U ux = x;
|
||||
const U uy = y;
|
||||
const U usum = ux + uy;
|
||||
const T sum = usum;
|
||||
if (x >= 0 == y >= 0 && sum >= 0 != x >= 0) {
|
||||
return absl::nullopt;
|
||||
}
|
||||
return sum;
|
||||
}
|
||||
|
||||
inline bool FitsInIntegralType(int64 x, PrimitiveType ty) {
|
||||
switch (ty) {
|
||||
case S8:
|
||||
return std::numeric_limits<int8>::min() <= x &&
|
||||
std::numeric_limits<int8>::max() >= x;
|
||||
case S16:
|
||||
return std::numeric_limits<int16>::min() <= x &&
|
||||
std::numeric_limits<int16>::max() >= x;
|
||||
case S32:
|
||||
return std::numeric_limits<int32>::min() <= x &&
|
||||
std::numeric_limits<int32>::max() >= x;
|
||||
case S64:
|
||||
return true;
|
||||
case U8:
|
||||
return 0 <= x && std::numeric_limits<uint8>::max() >= x;
|
||||
case U16:
|
||||
return 0 <= x && std::numeric_limits<uint16>::max() >= x;
|
||||
case U32:
|
||||
return 0 <= x && std::numeric_limits<uint32>::max() >= x;
|
||||
case U64:
|
||||
return 0 <= x;
|
||||
default:
|
||||
LOG(FATAL) << "Invalid primitive type " << PrimitiveType_Name(ty);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_OVERFLOW_UTIL_H_
|
||||
|
@ -37,6 +37,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/layout_util.h"
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/overflow_util.h"
|
||||
#include "tensorflow/compiler/xla/primitive_util.h"
|
||||
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
|
||||
@ -2713,6 +2714,65 @@ Status AlgebraicSimplifierVisitor::HandleRemainder(HloInstruction* remainder) {
|
||||
break;
|
||||
}
|
||||
|
||||
// If M < N, then {0, ..., M} % N ==> {0, ..., M}.
|
||||
//
|
||||
// Currently this only covers the case when N is a broadcasted constant
|
||||
// scalar. We could also cover the case when N is a non-broadcasted constant
|
||||
// with the same value repeated.
|
||||
HloInstruction* iota;
|
||||
HloInstruction* divisor;
|
||||
if (Match(remainder,
|
||||
m::Remainder(m::Iota(&iota),
|
||||
m::Broadcast(m::ConstantEffectiveScalar(&divisor))))) {
|
||||
// The iota counts {0, ..., iota_upper_bound - 1}. (Actually this is
|
||||
// conservative; the iota may overflow and count up to a smaller value than
|
||||
// this. But that's OK for our purposes here.)
|
||||
int64 iota_upper_bound = iota->shape().dimensions(
|
||||
Cast<HloIotaInstruction>(iota)->iota_dimension());
|
||||
StatusOr<int64> divisor_val = divisor->literal().GetIntegralAsS64(
|
||||
std::vector<int64>(0, divisor->shape().dimensions_size()));
|
||||
if (divisor_val.ok() && divisor_val.ValueOrDie() >= iota_upper_bound) {
|
||||
return ReplaceInstruction(remainder, iota);
|
||||
}
|
||||
}
|
||||
|
||||
// (X + N) % N = X % N, so long as X + N does not overflow.
|
||||
//
|
||||
// We don't have range tracking in XLA that would let us know whether X + N
|
||||
// overflows, so for now we only do this simplification when X is an iota. We
|
||||
// could add other operations where it's easy to see a range, such as
|
||||
// remainder, convert, etc., though at some point we'd probably want a
|
||||
// range-tracking analysis.
|
||||
HloInstruction* bcast;
|
||||
HloInstruction* addend;
|
||||
if (Match(
|
||||
remainder,
|
||||
m::Remainder(
|
||||
m::AddAnyOrder(m::Iota(&iota),
|
||||
m::Broadcast(m::ConstantEffectiveScalar(&addend))),
|
||||
m::Broadcast(&bcast, m::ConstantEffectiveScalar(&divisor)))) &&
|
||||
addend == divisor) {
|
||||
// The iota counts {0, ...iota_upper_bound - 1}, with the same caveat above
|
||||
// that iota_upper_bound is conservative, and the true upper bound may be
|
||||
// smaller.
|
||||
int64 iota_upper_bound = iota->shape().dimensions(
|
||||
Cast<HloIotaInstruction>(iota)->iota_dimension());
|
||||
StatusOr<int64> divisor_val = divisor->literal().GetIntegralAsS64(
|
||||
std::vector<int64>(0, divisor->shape().dimensions_size()));
|
||||
if (divisor_val.ok()) {
|
||||
// Check whether divisor_val + iota_upper_bound - 1 overflows.
|
||||
absl::optional<int64> max_val =
|
||||
OverflowSafeAdd(divisor_val.ValueOrDie(), iota_upper_bound);
|
||||
if (max_val.has_value() &&
|
||||
FitsInIntegralType(*max_val, iota->shape().element_type())) {
|
||||
return ReplaceWithNewInstruction(
|
||||
remainder,
|
||||
HloInstruction::CreateBinary(remainder->shape(),
|
||||
HloOpcode::kRemainder, iota, bcast));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -5457,5 +5457,51 @@ TEST_F(AlgebraicSimplifierTest, CompareSame) {
|
||||
GmockMatch(m::Broadcast(m::ConstantScalar(true))));
|
||||
}
|
||||
|
||||
TEST_F(AlgebraicSimplifierTest, RemainderOfIota) {
|
||||
const char* kModuleStr = R"(
|
||||
HloModule m
|
||||
test {
|
||||
iota = s32[5,1000] iota(), iota_dimension=0
|
||||
five = s32[] constant(5)
|
||||
five_bcast = s32[5,1000] broadcast(s32[] five), dimensions={}
|
||||
ROOT remainder = s32[5,1000] remainder(iota, s32[5,1000] five_bcast)
|
||||
})";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
|
||||
ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
|
||||
EXPECT_THAT(m->entry_computation()->root_instruction(),
|
||||
GmockMatch(m::Iota()));
|
||||
}
|
||||
|
||||
TEST_F(AlgebraicSimplifierTest, RemainderOfNPlusIota) {
|
||||
const char* kModuleStr = R"(
|
||||
HloModule m
|
||||
test {
|
||||
iota = s32[5,1000] iota(), iota_dimension=0
|
||||
five = s32[] constant(5)
|
||||
five_bcast = s32[5,1000] broadcast(five), dimensions={}
|
||||
sum = s32[5,1000] add(iota, five_bcast)
|
||||
ROOT remainder = s32[5,1000] remainder(sum, s32[5,1000] five_bcast)
|
||||
})";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
|
||||
ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
|
||||
EXPECT_THAT(m->entry_computation()->root_instruction(),
|
||||
GmockMatch(m::Remainder(m::Iota(), m::Broadcast())));
|
||||
}
|
||||
|
||||
// No simplification because 125 + 5 overflows S8.
|
||||
TEST_F(AlgebraicSimplifierTest, RemainderOfNPlusIotaOverflow) {
|
||||
const char* kModuleStr = R"(
|
||||
HloModule m
|
||||
test {
|
||||
iota = s8[126] iota(), iota_dimension=0
|
||||
five = s8[] constant(5)
|
||||
five_bcast = s8[126] broadcast(five), dimensions={}
|
||||
sum = s8[126] add(iota, five_bcast)
|
||||
ROOT remainder = s8[126] remainder(sum, s8[126] five_bcast)
|
||||
})";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
|
||||
ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
|
Loading…
x
Reference in New Issue
Block a user