[XLA] Convert Abs(a)*Abs(a) to a*a and add an option to allow for numerically unsafe algebraic simplifications
PiperOrigin-RevId: 325084126 Change-Id: Id8bf89ba6601d7bb1efc2b167e6e9accf5913114
This commit is contained in:
parent
8846105326
commit
b2f5d100d1
tensorflow/compiler/xla/service
@ -665,7 +665,7 @@ Status AlgebraicSimplifierVisitor::ScalarMultiplyReduction(
|
||||
HloInstruction* inst;
|
||||
HloInstruction* user;
|
||||
int64 index;
|
||||
std::tie(inst, user, index) = operands.back();
|
||||
std::tie (inst, user, index) = operands.back();
|
||||
operands.pop_back();
|
||||
|
||||
// Skip the op types that are not commutative with multiply.
|
||||
@ -913,7 +913,7 @@ Status AlgebraicSimplifierVisitor::HandleAdd(HloInstruction* add) {
|
||||
(Match(lhs, m::Multiply(m::Op(&c), m::Op(&a))) &&
|
||||
Match(rhs, m::MultiplyAnyOrder(m::Op().Is(c), m::Op(&b))))) &&
|
||||
(ShapeUtil::ElementIsIntegral(add->shape()) ||
|
||||
options_.enable_floats_are_real() || IsAllFpConstantPowerOf2(c))) {
|
||||
IsAllFpConstantPowerOf2(c))) {
|
||||
return ReplaceWithNewInstruction(
|
||||
add, HloInstruction::CreateBinary(
|
||||
add->shape(), HloOpcode::kMultiply,
|
||||
@ -2667,17 +2667,6 @@ Status AlgebraicSimplifierVisitor::HandleMultiply(HloInstruction* multiply) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
{
|
||||
HloInstruction* abs_operand;
|
||||
if (lhs == rhs && Match(lhs, m::Abs(m::Op(&abs_operand))) &&
|
||||
!ShapeUtil::ElementIsComplex(abs_operand->shape())) {
|
||||
TF_RETURN_IF_ERROR(multiply->ReplaceOperandWith(0, abs_operand));
|
||||
TF_RETURN_IF_ERROR(multiply->ReplaceOperandWith(1, abs_operand));
|
||||
changed_ = true;
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
HloInstruction *convert_operand, *operand;
|
||||
// Mul(Convert(Pred), operand) => select(pred, operand, 0)
|
||||
@ -3048,8 +3037,8 @@ Status AlgebraicSimplifierVisitor::HandleBroadcast(HloInstruction* broadcast) {
|
||||
HloInstruction* new_broadcast = computation_->AddInstruction(
|
||||
HloInstruction::CreateBroadcast(user->shape(), operand, {}));
|
||||
// Use HloInstruction::ReplaceAllUsesWith instead of
|
||||
// HloComputation::ReplaceWithNewInstruction because we are replacing
|
||||
// an instruction other than the visited instruction.
|
||||
// HloComputation::ReplaceWithNewInstruction because we are replacing an
|
||||
// instruction other than the visited instruction.
|
||||
changed_ = true;
|
||||
return user->ReplaceAllUsesWith(new_broadcast);
|
||||
}
|
||||
@ -3166,11 +3155,9 @@ Status AlgebraicSimplifierVisitor::HandleConvert(HloInstruction* convert) {
|
||||
|
||||
// Eliminate a convert pair if it is a no-op. The following are a few
|
||||
// example cases that are being handled:
|
||||
// 1. convert(convert(A, $TYPE1), $TYPE2) is simplified to A if A is of
|
||||
// $TYPE2
|
||||
// 1. convert(convert(A, $TYPE1), $TYPE2) is simplified to A if A is of $TYPE2
|
||||
// and convert(A, $TYPE1) is an upcast
|
||||
// 2. convert(convert(A, $TYPE1),$TYPE2) is simplified to A if A is of
|
||||
// $TYPE2
|
||||
// 2. convert(convert(A, $TYPE1),$TYPE2) is simplified to A if A is of $TYPE2
|
||||
// and convert(A, $TYPE1) is an upcast and is an integral conversion from
|
||||
// unsigned to signed (only signed to unsigned conversion is NOT allowed)
|
||||
// 3. Tuple(convert(A, $TYPE1) , floor(convert(convert(A, $TYPE1), $TYPE2)),
|
||||
@ -3306,8 +3293,7 @@ Status AlgebraicSimplifierVisitor::HandlePad(HloInstruction* pad) {
|
||||
pad->shape(), nonzero_pad->mutable_shape()));
|
||||
simplifier_->UpdateLayout(nonzero_pad->mutable_shape());
|
||||
|
||||
// Second, construct the slice instruction to perform the negative
|
||||
// padding.
|
||||
// Second, construct the slice instruction to perform the negative padding.
|
||||
std::vector<int64> start_indices;
|
||||
std::vector<int64> end_indices;
|
||||
std::vector<int64> strides;
|
||||
@ -3460,8 +3446,8 @@ AlgebraicSimplifierVisitor::TryToSinkBroadcastAfterOpWithUniqueNonScalarOperand(
|
||||
|
||||
Shape changed_shape;
|
||||
for (HloInstruction* user_operand : user->operands()) {
|
||||
// If this is a broadcast operand that is not our original broadcast
|
||||
// input to this function then we might need to change the input.
|
||||
// If this is a broadcast operand that is not our original broadcast input
|
||||
// to this function then we might need to change the input.
|
||||
if (is_compatible_broadcast(user_operand)) {
|
||||
// If this is a broadcast from a scalar value rewrite a broadcast from
|
||||
// the scalar to the new shape enforced from the other broadcast
|
||||
@ -3632,16 +3618,16 @@ Status AlgebraicSimplifierVisitor::HandleRemainder(HloInstruction* remainder) {
|
||||
// If M < N, then {0, ..., M} % N ==> {0, ..., M}.
|
||||
//
|
||||
// Currently this only covers the case when N is a broadcasted constant
|
||||
// scalar. We could also cover the case when N is a non-broadcasted
|
||||
// constant with the same value repeated.
|
||||
// scalar. We could also cover the case when N is a non-broadcasted constant
|
||||
// with the same value repeated.
|
||||
HloInstruction* iota;
|
||||
HloInstruction* divisor;
|
||||
if (Match(remainder,
|
||||
m::Remainder(m::Iota(&iota),
|
||||
m::Broadcast(m::ConstantEffectiveScalar(&divisor))))) {
|
||||
// The iota counts {0, ..., iota_upper_bound - 1}. (Actually this is
|
||||
// conservative; the iota may overflow and count up to a smaller value
|
||||
// than this. But that's OK for our purposes here.)
|
||||
// conservative; the iota may overflow and count up to a smaller value than
|
||||
// this. But that's OK for our purposes here.)
|
||||
int64 iota_upper_bound = iota->shape().dimensions(
|
||||
Cast<HloIotaInstruction>(iota)->iota_dimension());
|
||||
absl::optional<int64> divisor_val = divisor->literal().GetIntegralAsS64(
|
||||
@ -3654,8 +3640,8 @@ Status AlgebraicSimplifierVisitor::HandleRemainder(HloInstruction* remainder) {
|
||||
// (X + N) % N = X % N, so long as X + N does not overflow.
|
||||
//
|
||||
// We don't have range tracking in XLA that would let us know whether X + N
|
||||
// overflows, so for now we only do this simplification when X is an iota.
|
||||
// We could add other operations where it's easy to see a range, such as
|
||||
// overflows, so for now we only do this simplification when X is an iota. We
|
||||
// could add other operations where it's easy to see a range, such as
|
||||
// remainder, convert, etc., though at some point we'd probably want a
|
||||
// range-tracking analysis.
|
||||
HloInstruction* bcast;
|
||||
@ -3667,9 +3653,9 @@ Status AlgebraicSimplifierVisitor::HandleRemainder(HloInstruction* remainder) {
|
||||
m::Broadcast(m::ConstantEffectiveScalar(&addend))),
|
||||
m::Broadcast(&bcast, m::ConstantEffectiveScalar(&divisor)))) &&
|
||||
addend == divisor) {
|
||||
// The iota counts {0, ...iota_upper_bound - 1}, with the same caveat
|
||||
// above that iota_upper_bound is conservative, and the true upper bound
|
||||
// may be smaller.
|
||||
// The iota counts {0, ...iota_upper_bound - 1}, with the same caveat above
|
||||
// that iota_upper_bound is conservative, and the true upper bound may be
|
||||
// smaller.
|
||||
int64 iota_upper_bound = iota->shape().dimensions(
|
||||
Cast<HloIotaInstruction>(iota)->iota_dimension());
|
||||
absl::optional<int64> divisor_val = divisor->literal().GetIntegralAsS64(
|
||||
@ -3774,9 +3760,9 @@ Status AlgebraicSimplifierVisitor::HandleReverse(HloInstruction* reverse) {
|
||||
|
||||
StatusOr<bool> AlgebraicSimplifierVisitor::TrySimplifyScalarSlice(
|
||||
HloInstruction* slice) {
|
||||
// Only try to do this for effective scalars. We could do the same for
|
||||
// slicing out larger pieces of padding (replacing with a broadcast of the
|
||||
// padding value), but this is probably not worth it.
|
||||
// Only try to do this for effective scalars. We could do the same for slicing
|
||||
// out larger pieces of padding (replacing with a broadcast of the padding
|
||||
// value), but this is probably not worth it.
|
||||
if (!ShapeUtil::IsEffectiveScalar(slice->shape())) {
|
||||
return false;
|
||||
}
|
||||
@ -3877,8 +3863,8 @@ StatusOr<bool> AlgebraicSimplifierVisitor::TryToReorderSliceAndReshape(
|
||||
return false;
|
||||
}
|
||||
|
||||
// Allowing a slice to move through a reverse with any necessary updates to
|
||||
// the slice config.
|
||||
// Allowing a slice to move through a reverse with any necessary updates to the
|
||||
// slice config.
|
||||
StatusOr<bool> AlgebraicSimplifierVisitor::TryToReorderSliceAndReverse(
|
||||
HloInstruction* slice) {
|
||||
VLOG(2) << "Entered TryToReorderSliceAndReverse for slice:"
|
||||
@ -3906,8 +3892,8 @@ StatusOr<bool> AlgebraicSimplifierVisitor::TryToReorderSliceAndReverse(
|
||||
<< new_limits[rdim];
|
||||
}
|
||||
// New slice formed from the reverse_operand, but strides and shape of the
|
||||
// slice output remains the same. New slice's starts and limits are
|
||||
// updated for ONLY the reversed dimensions as indicated above.
|
||||
// slice output remains the same. New slice's starts and limits are updated
|
||||
// for ONLY the reversed dimensions as indicated above.
|
||||
HloInstruction* new_slice = computation_->AddInstruction(
|
||||
HloInstruction::CreateSlice(slice->shape(), reverse_operand, new_starts,
|
||||
new_limits, new_strides));
|
||||
@ -3934,8 +3920,7 @@ Status AlgebraicSimplifierVisitor::HandleSlice(HloInstruction* slice) {
|
||||
if (Match(slice, m::Slice(m::Pad(&pad, m::Op(&pad_operand), m::Op())))) {
|
||||
// Is the result of the slice the pad operand.
|
||||
bool slice_undoes_pad = true;
|
||||
// Can the slice be moved to the pad_operand without any padding being
|
||||
// read.
|
||||
// Can the slice be moved to the pad_operand without any padding being read.
|
||||
bool slice_inside_pad = true;
|
||||
// Does this slice slice out pading only.
|
||||
bool slice_in_padding = false;
|
||||
@ -4070,8 +4055,8 @@ Status AlgebraicSimplifierVisitor::HandleSlice(HloInstruction* slice) {
|
||||
}
|
||||
}
|
||||
|
||||
// Do not try to reorder slices and reshapes after layout assignment as it
|
||||
// may be invalid.
|
||||
// Do not try to reorder slices and reshapes after layout assignment as it may
|
||||
// be invalid.
|
||||
if (!options_.is_layout_sensitive()) {
|
||||
TF_ASSIGN_OR_RETURN(replaced, TryToReorderSliceAndReshape(slice));
|
||||
}
|
||||
@ -4121,8 +4106,8 @@ Status AlgebraicSimplifierVisitor::HandleDynamicSlice(
|
||||
if (ShapeUtil::IsScalar(dynamic_slice->shape())) {
|
||||
return ReplaceInstruction(dynamic_slice, operand);
|
||||
}
|
||||
// DynamicSlice where operand has the same size as the output is simply
|
||||
// equal to operand.
|
||||
// DynamicSlice where operand has the same size as the output is simply equal
|
||||
// to operand.
|
||||
if (SameShape(operand, dynamic_slice)) {
|
||||
return ReplaceInstruction(dynamic_slice, operand);
|
||||
}
|
||||
@ -4453,8 +4438,8 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* hlo) {
|
||||
// Convert Reduce(concat({a,b,...})) to
|
||||
// map(reduce(a),map(reduce(b),...,))
|
||||
//
|
||||
// This should make fusion easier or use less memory bandwidth in the
|
||||
// unfused case.
|
||||
// This should make fusion easier or use less memory bandwidth in the unfused
|
||||
// case.
|
||||
if (arg->opcode() == HloOpcode::kConcatenate &&
|
||||
absl::c_linear_search(reduce->dimensions(),
|
||||
arg->concatenate_dimension())) {
|
||||
@ -4473,9 +4458,9 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* hlo) {
|
||||
}
|
||||
|
||||
HloInstruction *dot, *lhs, *rhs;
|
||||
// Convert Reduce(Dot(X,Y)) to Dot(X,Y) if any of the dimensions reduced
|
||||
// were batch dimensions of the dot. The transformation supports reducing
|
||||
// other dimensions as well.
|
||||
// Convert Reduce(Dot(X,Y)) to Dot(X,Y) if any of the dimensions reduced were
|
||||
// batch dimensions of the dot. The transformation supports reducing other
|
||||
// dimensions as well.
|
||||
if (options_.enable_dot_strength_reduction() &&
|
||||
Match(arg, m::Dot(&dot, m::Op(&lhs), m::Op(&rhs)).WithOneUser()) &&
|
||||
Match(reduce->to_apply()->root_instruction(),
|
||||
@ -4547,13 +4532,13 @@ Status AlgebraicSimplifierVisitor::HandleReduceWindow(
|
||||
if (options_.enable_window_reduce_to_reduce_replacement()) {
|
||||
// A reduce window can be expressed as a reduce and a reshape if all
|
||||
// dimensions either have a window size of one or the entire dimension. If
|
||||
// there is no stride, dilation, or padding, this is as easy as checking
|
||||
// the size of the output shape and window dimension.
|
||||
// there is no stride, dilation, or padding, this is as easy as checking the
|
||||
// size of the output shape and window dimension.
|
||||
//
|
||||
// The reshape is a bitcast since it adds one-sized dimensions. Often
|
||||
// these ones are immediately removed as well with another reshape. The
|
||||
// implementation of reduce tends to be slightly more efficient at
|
||||
// reducing entire dimensions compared to reduce window.
|
||||
// The reshape is a bitcast since it adds one-sized dimensions. Often these
|
||||
// ones are immediately removed as well with another reshape. The
|
||||
// implementation of reduce tends to be slightly more efficient at reducing
|
||||
// entire dimensions compared to reduce window.
|
||||
auto effective_reduce_dims = [&] {
|
||||
if (window_util::HasStride(window) || window_util::HasDilation(window) ||
|
||||
window_util::HasPadding(window)) {
|
||||
@ -5068,8 +5053,7 @@ StatusOr<bool> AlgebraicSimplifierVisitor::SwapConvOperands(
|
||||
|
||||
auto new_dim = swapped_window.add_dimensions();
|
||||
new_dim->set_size(input_size);
|
||||
// If the kernel is not reversed, the activations must be manually
|
||||
// reversed.
|
||||
// If the kernel is not reversed, the activations must be manually reversed.
|
||||
if (!window_dims[spatial_dim].window_reversal()) {
|
||||
reverse_dimensions.push_back(
|
||||
dnums.kernel_spatial_dimensions(spatial_dim));
|
||||
@ -5089,8 +5073,8 @@ StatusOr<bool> AlgebraicSimplifierVisitor::SwapConvOperands(
|
||||
dilated_kernel_size);
|
||||
}
|
||||
|
||||
// Don't transform if a naive convolution implementation would not have
|
||||
// fewer flops.
|
||||
// Don't transform if a naive convolution implementation would not have fewer
|
||||
// flops.
|
||||
if (kernel_product <= swapped_kernel_product) {
|
||||
return false;
|
||||
}
|
||||
@ -5168,11 +5152,11 @@ StatusOr<bool> AlgebraicSimplifierVisitor::SimplifyConvToDot(
|
||||
}
|
||||
}
|
||||
|
||||
// Stride ignores part of the output, which matrix multiplication does not
|
||||
// do, so require no stride. Padding and base (lhs) dilation both implicitly
|
||||
// Stride ignores part of the output, which matrix multiplication does not do,
|
||||
// so require no stride. Padding and base (lhs) dilation both implicitly
|
||||
// extend the data, which matrix multiplication also does not do, so require
|
||||
// no padding and no base (lhs) dilation. Window (rhs) dilation has no
|
||||
// effect for a 1x1 window, so window dilation is no problem.
|
||||
// no padding and no base (lhs) dilation. Window (rhs) dilation has no effect
|
||||
// for a 1x1 window, so window dilation is no problem.
|
||||
if (window_util::HasStride(window) || window_util::HasPadding(window) ||
|
||||
window_util::HasBaseDilation(window)) {
|
||||
return false;
|
||||
@ -5225,9 +5209,8 @@ StatusOr<bool> AlgebraicSimplifierVisitor::SimplifyConvToDot(
|
||||
}
|
||||
}
|
||||
|
||||
// We already checked feature_dimension is most minor, so data in
|
||||
// input_shape and row-major {conv_width,input_channels} are bitwise
|
||||
// identical.
|
||||
// We already checked feature_dimension is most minor, so data in input_shape
|
||||
// and row-major {conv_width,input_channels} are bitwise identical.
|
||||
Shape new_input_shape = ShapeUtil::MakeShapeWithDescendingLayout(
|
||||
input_shape.element_type(), {conv_width, input_channels});
|
||||
simplifier_->UpdateLayout(&new_input_shape);
|
||||
|
@ -97,14 +97,6 @@ class AlgebraicSimplifierOptions {
|
||||
return enable_scalar_multiply_reduction_;
|
||||
}
|
||||
|
||||
// Also the algebraic simplifer to treat floating point values like real
|
||||
// numbers.
|
||||
void set_enable_floats_are_real(bool enable_floats_are_real) {
|
||||
enable_floats_are_real_ = enable_floats_are_real;
|
||||
}
|
||||
|
||||
bool enable_floats_are_real() const { return enable_floats_are_real_; }
|
||||
|
||||
// If enable_window_reduce_replacement is true, the kReduceWindow instruction
|
||||
// can be optimized by replacement with simpler operations.
|
||||
void set_enable_window_reduce_to_reduce_replacement(
|
||||
@ -166,7 +158,6 @@ class AlgebraicSimplifierOptions {
|
||||
bool enable_conv_simplification_{true};
|
||||
bool enable_conv_operand_swap_{true};
|
||||
bool enable_scalar_multiply_reduction_{false};
|
||||
bool enable_floats_are_real_{false};
|
||||
bool enable_window_reduce_to_reduce_replacement_{true};
|
||||
bool enable_reduce_of_reshape_{true};
|
||||
bool replace_transpose_with_bitcast_{true};
|
||||
|
@ -117,22 +117,6 @@ TEST_F(AlgebraicSimplifierTest, FactorFpAddition) {
|
||||
m::ConstantScalar(0.125))));
|
||||
}
|
||||
|
||||
// (Abs(A)) * (Abs(A)) => (A*A)
|
||||
TEST_F(AlgebraicSimplifierTest, SquareOfAbs) {
|
||||
const char* kModuleStr = R"(
|
||||
HloModule m
|
||||
test {
|
||||
p = f32[] parameter(0)
|
||||
a = f32[] abs(p)
|
||||
ROOT z = f32[] multiply(a, a)
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
|
||||
ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
|
||||
EXPECT_THAT(m->entry_computation()->root_instruction(),
|
||||
GmockMatch(m::Multiply(m::Parameter(0), m::Parameter(0))));
|
||||
}
|
||||
|
||||
// (A*C1) * (B*C2) => (A*B)*(C1*C2)
|
||||
TEST_F(AlgebraicSimplifierTest, MultiplyChain) {
|
||||
const char* kModuleStr = R"(
|
||||
|
Loading…
Reference in New Issue
Block a user