[XLA] Convert Abs(a)*Abs(a) to a*a and add an option to allow for numerically unsafe algebraic simplifications

PiperOrigin-RevId: 325084126
Change-Id: Id8bf89ba6601d7bb1efc2b167e6e9accf5913114
This commit is contained in:
A. Unique TensorFlower 2020-08-05 13:02:31 -07:00 committed by TensorFlower Gardener
parent 8846105326
commit b2f5d100d1
3 changed files with 50 additions and 92 deletions

View File

@ -665,7 +665,7 @@ Status AlgebraicSimplifierVisitor::ScalarMultiplyReduction(
HloInstruction* inst;
HloInstruction* user;
int64 index;
std::tie(inst, user, index) = operands.back();
std::tie (inst, user, index) = operands.back();
operands.pop_back();
// Skip the op types that are not commutative with multiply.
@ -913,7 +913,7 @@ Status AlgebraicSimplifierVisitor::HandleAdd(HloInstruction* add) {
(Match(lhs, m::Multiply(m::Op(&c), m::Op(&a))) &&
Match(rhs, m::MultiplyAnyOrder(m::Op().Is(c), m::Op(&b))))) &&
(ShapeUtil::ElementIsIntegral(add->shape()) ||
options_.enable_floats_are_real() || IsAllFpConstantPowerOf2(c))) {
IsAllFpConstantPowerOf2(c))) {
return ReplaceWithNewInstruction(
add, HloInstruction::CreateBinary(
add->shape(), HloOpcode::kMultiply,
@ -2667,17 +2667,6 @@ Status AlgebraicSimplifierVisitor::HandleMultiply(HloInstruction* multiply) {
return Status::OK();
}
{
HloInstruction* abs_operand;
if (lhs == rhs && Match(lhs, m::Abs(m::Op(&abs_operand))) &&
!ShapeUtil::ElementIsComplex(abs_operand->shape())) {
TF_RETURN_IF_ERROR(multiply->ReplaceOperandWith(0, abs_operand));
TF_RETURN_IF_ERROR(multiply->ReplaceOperandWith(1, abs_operand));
changed_ = true;
return Status::OK();
}
}
{
HloInstruction *convert_operand, *operand;
// Mul(Convert(Pred), operand) => select(pred, operand, 0)
@ -3048,8 +3037,8 @@ Status AlgebraicSimplifierVisitor::HandleBroadcast(HloInstruction* broadcast) {
HloInstruction* new_broadcast = computation_->AddInstruction(
HloInstruction::CreateBroadcast(user->shape(), operand, {}));
// Use HloInstruction::ReplaceAllUsesWith instead of
// HloComputation::ReplaceWithNewInstruction because we are replacing
// an instruction other than the visited instruction.
// HloComputation::ReplaceWithNewInstruction because we are replacing an
// instruction other than the visited instruction.
changed_ = true;
return user->ReplaceAllUsesWith(new_broadcast);
}
@ -3166,11 +3155,9 @@ Status AlgebraicSimplifierVisitor::HandleConvert(HloInstruction* convert) {
// Eliminate a convert pair if it is a no-op. The following are a few
// example cases that are being handled:
// 1. convert(convert(A, $TYPE1), $TYPE2) is simplified to A if A is of
// $TYPE2
// 1. convert(convert(A, $TYPE1), $TYPE2) is simplified to A if A is of $TYPE2
// and convert(A, $TYPE1) is an upcast
// 2. convert(convert(A, $TYPE1),$TYPE2) is simplified to A if A is of
// $TYPE2
// 2. convert(convert(A, $TYPE1),$TYPE2) is simplified to A if A is of $TYPE2
// and convert(A, $TYPE1) is an upcast and is an integral conversion from
// unsigned to signed (only signed to unsigned conversion is NOT allowed)
// 3. Tuple(convert(A, $TYPE1) , floor(convert(convert(A, $TYPE1), $TYPE2)),
@ -3306,8 +3293,7 @@ Status AlgebraicSimplifierVisitor::HandlePad(HloInstruction* pad) {
pad->shape(), nonzero_pad->mutable_shape()));
simplifier_->UpdateLayout(nonzero_pad->mutable_shape());
// Second, construct the slice instruction to perform the negative
// padding.
// Second, construct the slice instruction to perform the negative padding.
std::vector<int64> start_indices;
std::vector<int64> end_indices;
std::vector<int64> strides;
@ -3460,8 +3446,8 @@ AlgebraicSimplifierVisitor::TryToSinkBroadcastAfterOpWithUniqueNonScalarOperand(
Shape changed_shape;
for (HloInstruction* user_operand : user->operands()) {
// If this is a broadcast operand that is not our original broadcast
// input to this function then we might need to change the input.
// If this is a broadcast operand that is not our original broadcast input
// to this function then we might need to change the input.
if (is_compatible_broadcast(user_operand)) {
// If this is a broadcast from a scalar value rewrite a broadcast from
// the scalar to the new shape enforced from the other broadcast
@ -3632,16 +3618,16 @@ Status AlgebraicSimplifierVisitor::HandleRemainder(HloInstruction* remainder) {
// If M < N, then {0, ..., M} % N ==> {0, ..., M}.
//
// Currently this only covers the case when N is a broadcasted constant
// scalar. We could also cover the case when N is a non-broadcasted
// constant with the same value repeated.
// scalar. We could also cover the case when N is a non-broadcasted constant
// with the same value repeated.
HloInstruction* iota;
HloInstruction* divisor;
if (Match(remainder,
m::Remainder(m::Iota(&iota),
m::Broadcast(m::ConstantEffectiveScalar(&divisor))))) {
// The iota counts {0, ..., iota_upper_bound - 1}. (Actually this is
// conservative; the iota may overflow and count up to a smaller value
// than this. But that's OK for our purposes here.)
// conservative; the iota may overflow and count up to a smaller value than
// this. But that's OK for our purposes here.)
int64 iota_upper_bound = iota->shape().dimensions(
Cast<HloIotaInstruction>(iota)->iota_dimension());
absl::optional<int64> divisor_val = divisor->literal().GetIntegralAsS64(
@ -3654,8 +3640,8 @@ Status AlgebraicSimplifierVisitor::HandleRemainder(HloInstruction* remainder) {
// (X + N) % N = X % N, so long as X + N does not overflow.
//
// We don't have range tracking in XLA that would let us know whether X + N
// overflows, so for now we only do this simplification when X is an iota.
// We could add other operations where it's easy to see a range, such as
// overflows, so for now we only do this simplification when X is an iota. We
// could add other operations where it's easy to see a range, such as
// remainder, convert, etc., though at some point we'd probably want a
// range-tracking analysis.
HloInstruction* bcast;
@ -3667,9 +3653,9 @@ Status AlgebraicSimplifierVisitor::HandleRemainder(HloInstruction* remainder) {
m::Broadcast(m::ConstantEffectiveScalar(&addend))),
m::Broadcast(&bcast, m::ConstantEffectiveScalar(&divisor)))) &&
addend == divisor) {
// The iota counts {0, ...iota_upper_bound - 1}, with the same caveat
// above that iota_upper_bound is conservative, and the true upper bound
// may be smaller.
// The iota counts {0, ...iota_upper_bound - 1}, with the same caveat above
// that iota_upper_bound is conservative, and the true upper bound may be
// smaller.
int64 iota_upper_bound = iota->shape().dimensions(
Cast<HloIotaInstruction>(iota)->iota_dimension());
absl::optional<int64> divisor_val = divisor->literal().GetIntegralAsS64(
@ -3774,9 +3760,9 @@ Status AlgebraicSimplifierVisitor::HandleReverse(HloInstruction* reverse) {
StatusOr<bool> AlgebraicSimplifierVisitor::TrySimplifyScalarSlice(
HloInstruction* slice) {
// Only try to do this for effective scalars. We could do the same for
// slicing out larger pieces of padding (replacing with a broadcast of the
// padding value), but this is probably not worth it.
// Only try to do this for effective scalars. We could do the same for slicing
// out larger pieces of padding (replacing with a broadcast of the padding
// value), but this is probably not worth it.
if (!ShapeUtil::IsEffectiveScalar(slice->shape())) {
return false;
}
@ -3877,8 +3863,8 @@ StatusOr<bool> AlgebraicSimplifierVisitor::TryToReorderSliceAndReshape(
return false;
}
// Allowing a slice to move through a reverse with any necessary updates to
// the slice config.
// Allowing a slice to move through a reverse with any necessary updates to the
// slice config.
StatusOr<bool> AlgebraicSimplifierVisitor::TryToReorderSliceAndReverse(
HloInstruction* slice) {
VLOG(2) << "Entered TryToReorderSliceAndReverse for slice:"
@ -3906,8 +3892,8 @@ StatusOr<bool> AlgebraicSimplifierVisitor::TryToReorderSliceAndReverse(
<< new_limits[rdim];
}
// New slice formed from the reverse_operand, but strides and shape of the
// slice output remains the same. New slice's starts and limits are
// updated for ONLY the reversed dimensions as indicated above.
// slice output remains the same. New slice's starts and limits are updated
// for ONLY the reversed dimensions as indicated above.
HloInstruction* new_slice = computation_->AddInstruction(
HloInstruction::CreateSlice(slice->shape(), reverse_operand, new_starts,
new_limits, new_strides));
@ -3934,8 +3920,7 @@ Status AlgebraicSimplifierVisitor::HandleSlice(HloInstruction* slice) {
if (Match(slice, m::Slice(m::Pad(&pad, m::Op(&pad_operand), m::Op())))) {
// Is the result of the slice the pad operand.
bool slice_undoes_pad = true;
// Can the slice be moved to the pad_operand without any padding being
// read.
// Can the slice be moved to the pad_operand without any padding being read.
bool slice_inside_pad = true;
// Does this slice slice out pading only.
bool slice_in_padding = false;
@ -4070,8 +4055,8 @@ Status AlgebraicSimplifierVisitor::HandleSlice(HloInstruction* slice) {
}
}
// Do not try to reorder slices and reshapes after layout assignment as it
// may be invalid.
// Do not try to reorder slices and reshapes after layout assignment as it may
// be invalid.
if (!options_.is_layout_sensitive()) {
TF_ASSIGN_OR_RETURN(replaced, TryToReorderSliceAndReshape(slice));
}
@ -4121,8 +4106,8 @@ Status AlgebraicSimplifierVisitor::HandleDynamicSlice(
if (ShapeUtil::IsScalar(dynamic_slice->shape())) {
return ReplaceInstruction(dynamic_slice, operand);
}
// DynamicSlice where operand has the same size as the output is simply
// equal to operand.
// DynamicSlice where operand has the same size as the output is simply equal
// to operand.
if (SameShape(operand, dynamic_slice)) {
return ReplaceInstruction(dynamic_slice, operand);
}
@ -4453,8 +4438,8 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* hlo) {
// Convert Reduce(concat({a,b,...})) to
// map(reduce(a),map(reduce(b),...,))
//
// This should make fusion easier or use less memory bandwidth in the
// unfused case.
// This should make fusion easier or use less memory bandwidth in the unfused
// case.
if (arg->opcode() == HloOpcode::kConcatenate &&
absl::c_linear_search(reduce->dimensions(),
arg->concatenate_dimension())) {
@ -4473,9 +4458,9 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* hlo) {
}
HloInstruction *dot, *lhs, *rhs;
// Convert Reduce(Dot(X,Y)) to Dot(X,Y) if any of the dimensions reduced
// were batch dimensions of the dot. The transformation supports reducing
// other dimensions as well.
// Convert Reduce(Dot(X,Y)) to Dot(X,Y) if any of the dimensions reduced were
// batch dimensions of the dot. The transformation supports reducing other
// dimensions as well.
if (options_.enable_dot_strength_reduction() &&
Match(arg, m::Dot(&dot, m::Op(&lhs), m::Op(&rhs)).WithOneUser()) &&
Match(reduce->to_apply()->root_instruction(),
@ -4547,13 +4532,13 @@ Status AlgebraicSimplifierVisitor::HandleReduceWindow(
if (options_.enable_window_reduce_to_reduce_replacement()) {
// A reduce window can be expressed as a reduce and a reshape if all
// dimensions either have a window size of one or the entire dimension. If
// there is no stride, dilation, or padding, this is as easy as checking
// the size of the output shape and window dimension.
// there is no stride, dilation, or padding, this is as easy as checking the
// size of the output shape and window dimension.
//
// The reshape is a bitcast since it adds one-sized dimensions. Often
// these ones are immediately removed as well with another reshape. The
// implementation of reduce tends to be slightly more efficient at
// reducing entire dimensions compared to reduce window.
// The reshape is a bitcast since it adds one-sized dimensions. Often these
// ones are immediately removed as well with another reshape. The
// implementation of reduce tends to be slightly more efficient at reducing
// entire dimensions compared to reduce window.
auto effective_reduce_dims = [&] {
if (window_util::HasStride(window) || window_util::HasDilation(window) ||
window_util::HasPadding(window)) {
@ -5068,8 +5053,7 @@ StatusOr<bool> AlgebraicSimplifierVisitor::SwapConvOperands(
auto new_dim = swapped_window.add_dimensions();
new_dim->set_size(input_size);
// If the kernel is not reversed, the activations must be manually
// reversed.
// If the kernel is not reversed, the activations must be manually reversed.
if (!window_dims[spatial_dim].window_reversal()) {
reverse_dimensions.push_back(
dnums.kernel_spatial_dimensions(spatial_dim));
@ -5089,8 +5073,8 @@ StatusOr<bool> AlgebraicSimplifierVisitor::SwapConvOperands(
dilated_kernel_size);
}
// Don't transform if a naive convolution implementation would not have
// fewer flops.
// Don't transform if a naive convolution implementation would not have fewer
// flops.
if (kernel_product <= swapped_kernel_product) {
return false;
}
@ -5168,11 +5152,11 @@ StatusOr<bool> AlgebraicSimplifierVisitor::SimplifyConvToDot(
}
}
// Stride ignores part of the output, which matrix multiplication does not
// do, so require no stride. Padding and base (lhs) dilation both implicitly
// Stride ignores part of the output, which matrix multiplication does not do,
// so require no stride. Padding and base (lhs) dilation both implicitly
// extend the data, which matrix multiplication also does not do, so require
// no padding and no base (lhs) dilation. Window (rhs) dilation has no
// effect for a 1x1 window, so window dilation is no problem.
// no padding and no base (lhs) dilation. Window (rhs) dilation has no effect
// for a 1x1 window, so window dilation is no problem.
if (window_util::HasStride(window) || window_util::HasPadding(window) ||
window_util::HasBaseDilation(window)) {
return false;
@ -5225,9 +5209,8 @@ StatusOr<bool> AlgebraicSimplifierVisitor::SimplifyConvToDot(
}
}
// We already checked feature_dimension is most minor, so data in
// input_shape and row-major {conv_width,input_channels} are bitwise
// identical.
// We already checked feature_dimension is most minor, so data in input_shape
// and row-major {conv_width,input_channels} are bitwise identical.
Shape new_input_shape = ShapeUtil::MakeShapeWithDescendingLayout(
input_shape.element_type(), {conv_width, input_channels});
simplifier_->UpdateLayout(&new_input_shape);

View File

@ -97,14 +97,6 @@ class AlgebraicSimplifierOptions {
return enable_scalar_multiply_reduction_;
}
// Also the algebraic simplifer to treat floating point values like real
// numbers.
void set_enable_floats_are_real(bool enable_floats_are_real) {
enable_floats_are_real_ = enable_floats_are_real;
}
bool enable_floats_are_real() const { return enable_floats_are_real_; }
// If enable_window_reduce_replacement is true, the kReduceWindow instruction
// can be optimized by replacement with simpler operations.
void set_enable_window_reduce_to_reduce_replacement(
@ -166,7 +158,6 @@ class AlgebraicSimplifierOptions {
bool enable_conv_simplification_{true};
bool enable_conv_operand_swap_{true};
bool enable_scalar_multiply_reduction_{false};
bool enable_floats_are_real_{false};
bool enable_window_reduce_to_reduce_replacement_{true};
bool enable_reduce_of_reshape_{true};
bool replace_transpose_with_bitcast_{true};

View File

@ -117,22 +117,6 @@ TEST_F(AlgebraicSimplifierTest, FactorFpAddition) {
m::ConstantScalar(0.125))));
}
// (Abs(A)) * (Abs(A)) => (A*A)
TEST_F(AlgebraicSimplifierTest, SquareOfAbs) {
const char* kModuleStr = R"(
HloModule m
test {
p = f32[] parameter(0)
a = f32[] abs(p)
ROOT z = f32[] multiply(a, a)
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
EXPECT_THAT(m->entry_computation()->root_instruction(),
GmockMatch(m::Multiply(m::Parameter(0), m::Parameter(0))));
}
// (A*C1) * (B*C2) => (A*B)*(C1*C2)
TEST_F(AlgebraicSimplifierTest, MultiplyChain) {
const char* kModuleStr = R"(