[XLA] Convert Abs(a)*Abs(a) to a*a and add an option to allow for numerically unsafe algebraic simplifications

PiperOrigin-RevId: 325084126
Change-Id: Id8bf89ba6601d7bb1efc2b167e6e9accf5913114
This commit is contained in:
A. Unique TensorFlower 2020-08-05 13:02:31 -07:00 committed by TensorFlower Gardener
parent 8846105326
commit b2f5d100d1
3 changed files with 50 additions and 92 deletions

View File

@ -665,7 +665,7 @@ Status AlgebraicSimplifierVisitor::ScalarMultiplyReduction(
HloInstruction* inst; HloInstruction* inst;
HloInstruction* user; HloInstruction* user;
int64 index; int64 index;
std::tie(inst, user, index) = operands.back(); std::tie (inst, user, index) = operands.back();
operands.pop_back(); operands.pop_back();
// Skip the op types that are not commutative with multiply. // Skip the op types that are not commutative with multiply.
@ -913,7 +913,7 @@ Status AlgebraicSimplifierVisitor::HandleAdd(HloInstruction* add) {
(Match(lhs, m::Multiply(m::Op(&c), m::Op(&a))) && (Match(lhs, m::Multiply(m::Op(&c), m::Op(&a))) &&
Match(rhs, m::MultiplyAnyOrder(m::Op().Is(c), m::Op(&b))))) && Match(rhs, m::MultiplyAnyOrder(m::Op().Is(c), m::Op(&b))))) &&
(ShapeUtil::ElementIsIntegral(add->shape()) || (ShapeUtil::ElementIsIntegral(add->shape()) ||
options_.enable_floats_are_real() || IsAllFpConstantPowerOf2(c))) { IsAllFpConstantPowerOf2(c))) {
return ReplaceWithNewInstruction( return ReplaceWithNewInstruction(
add, HloInstruction::CreateBinary( add, HloInstruction::CreateBinary(
add->shape(), HloOpcode::kMultiply, add->shape(), HloOpcode::kMultiply,
@ -2667,17 +2667,6 @@ Status AlgebraicSimplifierVisitor::HandleMultiply(HloInstruction* multiply) {
return Status::OK(); return Status::OK();
} }
{
HloInstruction* abs_operand;
if (lhs == rhs && Match(lhs, m::Abs(m::Op(&abs_operand))) &&
!ShapeUtil::ElementIsComplex(abs_operand->shape())) {
TF_RETURN_IF_ERROR(multiply->ReplaceOperandWith(0, abs_operand));
TF_RETURN_IF_ERROR(multiply->ReplaceOperandWith(1, abs_operand));
changed_ = true;
return Status::OK();
}
}
{ {
HloInstruction *convert_operand, *operand; HloInstruction *convert_operand, *operand;
// Mul(Convert(Pred), operand) => select(pred, operand, 0) // Mul(Convert(Pred), operand) => select(pred, operand, 0)
@ -3048,8 +3037,8 @@ Status AlgebraicSimplifierVisitor::HandleBroadcast(HloInstruction* broadcast) {
HloInstruction* new_broadcast = computation_->AddInstruction( HloInstruction* new_broadcast = computation_->AddInstruction(
HloInstruction::CreateBroadcast(user->shape(), operand, {})); HloInstruction::CreateBroadcast(user->shape(), operand, {}));
// Use HloInstruction::ReplaceAllUsesWith instead of // Use HloInstruction::ReplaceAllUsesWith instead of
// HloComputation::ReplaceWithNewInstruction because we are replacing // HloComputation::ReplaceWithNewInstruction because we are replacing an
// an instruction other than the visited instruction. // instruction other than the visited instruction.
changed_ = true; changed_ = true;
return user->ReplaceAllUsesWith(new_broadcast); return user->ReplaceAllUsesWith(new_broadcast);
} }
@ -3166,11 +3155,9 @@ Status AlgebraicSimplifierVisitor::HandleConvert(HloInstruction* convert) {
// Eliminate a convert pair if it is a no-op. The following are a few // Eliminate a convert pair if it is a no-op. The following are a few
// example cases that are being handled: // example cases that are being handled:
// 1. convert(convert(A, $TYPE1), $TYPE2) is simplified to A if A is of // 1. convert(convert(A, $TYPE1), $TYPE2) is simplified to A if A is of $TYPE2
// $TYPE2
// and convert(A, $TYPE1) is an upcast // and convert(A, $TYPE1) is an upcast
// 2. convert(convert(A, $TYPE1),$TYPE2) is simplified to A if A is of // 2. convert(convert(A, $TYPE1),$TYPE2) is simplified to A if A is of $TYPE2
// $TYPE2
// and convert(A, $TYPE1) is an upcast and is an integral conversion from // and convert(A, $TYPE1) is an upcast and is an integral conversion from
// unsigned to signed (only signed to unsigned conversion is NOT allowed) // unsigned to signed (only signed to unsigned conversion is NOT allowed)
// 3. Tuple(convert(A, $TYPE1) , floor(convert(convert(A, $TYPE1), $TYPE2)), // 3. Tuple(convert(A, $TYPE1) , floor(convert(convert(A, $TYPE1), $TYPE2)),
@ -3306,8 +3293,7 @@ Status AlgebraicSimplifierVisitor::HandlePad(HloInstruction* pad) {
pad->shape(), nonzero_pad->mutable_shape())); pad->shape(), nonzero_pad->mutable_shape()));
simplifier_->UpdateLayout(nonzero_pad->mutable_shape()); simplifier_->UpdateLayout(nonzero_pad->mutable_shape());
// Second, construct the slice instruction to perform the negative // Second, construct the slice instruction to perform the negative padding.
// padding.
std::vector<int64> start_indices; std::vector<int64> start_indices;
std::vector<int64> end_indices; std::vector<int64> end_indices;
std::vector<int64> strides; std::vector<int64> strides;
@ -3460,8 +3446,8 @@ AlgebraicSimplifierVisitor::TryToSinkBroadcastAfterOpWithUniqueNonScalarOperand(
Shape changed_shape; Shape changed_shape;
for (HloInstruction* user_operand : user->operands()) { for (HloInstruction* user_operand : user->operands()) {
// If this is a broadcast operand that is not our original broadcast // If this is a broadcast operand that is not our original broadcast input
// input to this function then we might need to change the input. // to this function then we might need to change the input.
if (is_compatible_broadcast(user_operand)) { if (is_compatible_broadcast(user_operand)) {
// If this is a broadcast from a scalar value rewrite a broadcast from // If this is a broadcast from a scalar value rewrite a broadcast from
// the scalar to the new shape enforced from the other broadcast // the scalar to the new shape enforced from the other broadcast
@ -3632,16 +3618,16 @@ Status AlgebraicSimplifierVisitor::HandleRemainder(HloInstruction* remainder) {
// If M < N, then {0, ..., M} % N ==> {0, ..., M}. // If M < N, then {0, ..., M} % N ==> {0, ..., M}.
// //
// Currently this only covers the case when N is a broadcasted constant // Currently this only covers the case when N is a broadcasted constant
// scalar. We could also cover the case when N is a non-broadcasted // scalar. We could also cover the case when N is a non-broadcasted constant
// constant with the same value repeated. // with the same value repeated.
HloInstruction* iota; HloInstruction* iota;
HloInstruction* divisor; HloInstruction* divisor;
if (Match(remainder, if (Match(remainder,
m::Remainder(m::Iota(&iota), m::Remainder(m::Iota(&iota),
m::Broadcast(m::ConstantEffectiveScalar(&divisor))))) { m::Broadcast(m::ConstantEffectiveScalar(&divisor))))) {
// The iota counts {0, ..., iota_upper_bound - 1}. (Actually this is // The iota counts {0, ..., iota_upper_bound - 1}. (Actually this is
// conservative; the iota may overflow and count up to a smaller value // conservative; the iota may overflow and count up to a smaller value than
// than this. But that's OK for our purposes here.) // this. But that's OK for our purposes here.)
int64 iota_upper_bound = iota->shape().dimensions( int64 iota_upper_bound = iota->shape().dimensions(
Cast<HloIotaInstruction>(iota)->iota_dimension()); Cast<HloIotaInstruction>(iota)->iota_dimension());
absl::optional<int64> divisor_val = divisor->literal().GetIntegralAsS64( absl::optional<int64> divisor_val = divisor->literal().GetIntegralAsS64(
@ -3654,8 +3640,8 @@ Status AlgebraicSimplifierVisitor::HandleRemainder(HloInstruction* remainder) {
// (X + N) % N = X % N, so long as X + N does not overflow. // (X + N) % N = X % N, so long as X + N does not overflow.
// //
// We don't have range tracking in XLA that would let us know whether X + N // We don't have range tracking in XLA that would let us know whether X + N
// overflows, so for now we only do this simplification when X is an iota. // overflows, so for now we only do this simplification when X is an iota. We
// We could add other operations where it's easy to see a range, such as // could add other operations where it's easy to see a range, such as
// remainder, convert, etc., though at some point we'd probably want a // remainder, convert, etc., though at some point we'd probably want a
// range-tracking analysis. // range-tracking analysis.
HloInstruction* bcast; HloInstruction* bcast;
@ -3667,9 +3653,9 @@ Status AlgebraicSimplifierVisitor::HandleRemainder(HloInstruction* remainder) {
m::Broadcast(m::ConstantEffectiveScalar(&addend))), m::Broadcast(m::ConstantEffectiveScalar(&addend))),
m::Broadcast(&bcast, m::ConstantEffectiveScalar(&divisor)))) && m::Broadcast(&bcast, m::ConstantEffectiveScalar(&divisor)))) &&
addend == divisor) { addend == divisor) {
// The iota counts {0, ...iota_upper_bound - 1}, with the same caveat // The iota counts {0, ...iota_upper_bound - 1}, with the same caveat above
// above that iota_upper_bound is conservative, and the true upper bound // that iota_upper_bound is conservative, and the true upper bound may be
// may be smaller. // smaller.
int64 iota_upper_bound = iota->shape().dimensions( int64 iota_upper_bound = iota->shape().dimensions(
Cast<HloIotaInstruction>(iota)->iota_dimension()); Cast<HloIotaInstruction>(iota)->iota_dimension());
absl::optional<int64> divisor_val = divisor->literal().GetIntegralAsS64( absl::optional<int64> divisor_val = divisor->literal().GetIntegralAsS64(
@ -3774,9 +3760,9 @@ Status AlgebraicSimplifierVisitor::HandleReverse(HloInstruction* reverse) {
StatusOr<bool> AlgebraicSimplifierVisitor::TrySimplifyScalarSlice( StatusOr<bool> AlgebraicSimplifierVisitor::TrySimplifyScalarSlice(
HloInstruction* slice) { HloInstruction* slice) {
// Only try to do this for effective scalars. We could do the same for // Only try to do this for effective scalars. We could do the same for slicing
// slicing out larger pieces of padding (replacing with a broadcast of the // out larger pieces of padding (replacing with a broadcast of the padding
// padding value), but this is probably not worth it. // value), but this is probably not worth it.
if (!ShapeUtil::IsEffectiveScalar(slice->shape())) { if (!ShapeUtil::IsEffectiveScalar(slice->shape())) {
return false; return false;
} }
@ -3877,8 +3863,8 @@ StatusOr<bool> AlgebraicSimplifierVisitor::TryToReorderSliceAndReshape(
return false; return false;
} }
// Allowing a slice to move through a reverse with any necessary updates to // Allowing a slice to move through a reverse with any necessary updates to the
// the slice config. // slice config.
StatusOr<bool> AlgebraicSimplifierVisitor::TryToReorderSliceAndReverse( StatusOr<bool> AlgebraicSimplifierVisitor::TryToReorderSliceAndReverse(
HloInstruction* slice) { HloInstruction* slice) {
VLOG(2) << "Entered TryToReorderSliceAndReverse for slice:" VLOG(2) << "Entered TryToReorderSliceAndReverse for slice:"
@ -3906,8 +3892,8 @@ StatusOr<bool> AlgebraicSimplifierVisitor::TryToReorderSliceAndReverse(
<< new_limits[rdim]; << new_limits[rdim];
} }
// New slice formed from the reverse_operand, but strides and shape of the // New slice formed from the reverse_operand, but strides and shape of the
// slice output remains the same. New slice's starts and limits are // slice output remains the same. New slice's starts and limits are updated
// updated for ONLY the reversed dimensions as indicated above. // for ONLY the reversed dimensions as indicated above.
HloInstruction* new_slice = computation_->AddInstruction( HloInstruction* new_slice = computation_->AddInstruction(
HloInstruction::CreateSlice(slice->shape(), reverse_operand, new_starts, HloInstruction::CreateSlice(slice->shape(), reverse_operand, new_starts,
new_limits, new_strides)); new_limits, new_strides));
@ -3934,8 +3920,7 @@ Status AlgebraicSimplifierVisitor::HandleSlice(HloInstruction* slice) {
if (Match(slice, m::Slice(m::Pad(&pad, m::Op(&pad_operand), m::Op())))) { if (Match(slice, m::Slice(m::Pad(&pad, m::Op(&pad_operand), m::Op())))) {
// Is the result of the slice the pad operand. // Is the result of the slice the pad operand.
bool slice_undoes_pad = true; bool slice_undoes_pad = true;
// Can the slice be moved to the pad_operand without any padding being // Can the slice be moved to the pad_operand without any padding being read.
// read.
bool slice_inside_pad = true; bool slice_inside_pad = true;
// Does this slice slice out pading only. // Does this slice slice out pading only.
bool slice_in_padding = false; bool slice_in_padding = false;
@ -4070,8 +4055,8 @@ Status AlgebraicSimplifierVisitor::HandleSlice(HloInstruction* slice) {
} }
} }
// Do not try to reorder slices and reshapes after layout assignment as it // Do not try to reorder slices and reshapes after layout assignment as it may
// may be invalid. // be invalid.
if (!options_.is_layout_sensitive()) { if (!options_.is_layout_sensitive()) {
TF_ASSIGN_OR_RETURN(replaced, TryToReorderSliceAndReshape(slice)); TF_ASSIGN_OR_RETURN(replaced, TryToReorderSliceAndReshape(slice));
} }
@ -4121,8 +4106,8 @@ Status AlgebraicSimplifierVisitor::HandleDynamicSlice(
if (ShapeUtil::IsScalar(dynamic_slice->shape())) { if (ShapeUtil::IsScalar(dynamic_slice->shape())) {
return ReplaceInstruction(dynamic_slice, operand); return ReplaceInstruction(dynamic_slice, operand);
} }
// DynamicSlice where operand has the same size as the output is simply // DynamicSlice where operand has the same size as the output is simply equal
// equal to operand. // to operand.
if (SameShape(operand, dynamic_slice)) { if (SameShape(operand, dynamic_slice)) {
return ReplaceInstruction(dynamic_slice, operand); return ReplaceInstruction(dynamic_slice, operand);
} }
@ -4453,8 +4438,8 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* hlo) {
// Convert Reduce(concat({a,b,...})) to // Convert Reduce(concat({a,b,...})) to
// map(reduce(a),map(reduce(b),...,)) // map(reduce(a),map(reduce(b),...,))
// //
// This should make fusion easier or use less memory bandwidth in the // This should make fusion easier or use less memory bandwidth in the unfused
// unfused case. // case.
if (arg->opcode() == HloOpcode::kConcatenate && if (arg->opcode() == HloOpcode::kConcatenate &&
absl::c_linear_search(reduce->dimensions(), absl::c_linear_search(reduce->dimensions(),
arg->concatenate_dimension())) { arg->concatenate_dimension())) {
@ -4473,9 +4458,9 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* hlo) {
} }
HloInstruction *dot, *lhs, *rhs; HloInstruction *dot, *lhs, *rhs;
// Convert Reduce(Dot(X,Y)) to Dot(X,Y) if any of the dimensions reduced // Convert Reduce(Dot(X,Y)) to Dot(X,Y) if any of the dimensions reduced were
// were batch dimensions of the dot. The transformation supports reducing // batch dimensions of the dot. The transformation supports reducing other
// other dimensions as well. // dimensions as well.
if (options_.enable_dot_strength_reduction() && if (options_.enable_dot_strength_reduction() &&
Match(arg, m::Dot(&dot, m::Op(&lhs), m::Op(&rhs)).WithOneUser()) && Match(arg, m::Dot(&dot, m::Op(&lhs), m::Op(&rhs)).WithOneUser()) &&
Match(reduce->to_apply()->root_instruction(), Match(reduce->to_apply()->root_instruction(),
@ -4547,13 +4532,13 @@ Status AlgebraicSimplifierVisitor::HandleReduceWindow(
if (options_.enable_window_reduce_to_reduce_replacement()) { if (options_.enable_window_reduce_to_reduce_replacement()) {
// A reduce window can be expressed as a reduce and a reshape if all // A reduce window can be expressed as a reduce and a reshape if all
// dimensions either have a window size of one or the entire dimension. If // dimensions either have a window size of one or the entire dimension. If
// there is no stride, dilation, or padding, this is as easy as checking // there is no stride, dilation, or padding, this is as easy as checking the
// the size of the output shape and window dimension. // size of the output shape and window dimension.
// //
// The reshape is a bitcast since it adds one-sized dimensions. Often // The reshape is a bitcast since it adds one-sized dimensions. Often these
// these ones are immediately removed as well with another reshape. The // ones are immediately removed as well with another reshape. The
// implementation of reduce tends to be slightly more efficient at // implementation of reduce tends to be slightly more efficient at reducing
// reducing entire dimensions compared to reduce window. // entire dimensions compared to reduce window.
auto effective_reduce_dims = [&] { auto effective_reduce_dims = [&] {
if (window_util::HasStride(window) || window_util::HasDilation(window) || if (window_util::HasStride(window) || window_util::HasDilation(window) ||
window_util::HasPadding(window)) { window_util::HasPadding(window)) {
@ -5068,8 +5053,7 @@ StatusOr<bool> AlgebraicSimplifierVisitor::SwapConvOperands(
auto new_dim = swapped_window.add_dimensions(); auto new_dim = swapped_window.add_dimensions();
new_dim->set_size(input_size); new_dim->set_size(input_size);
// If the kernel is not reversed, the activations must be manually // If the kernel is not reversed, the activations must be manually reversed.
// reversed.
if (!window_dims[spatial_dim].window_reversal()) { if (!window_dims[spatial_dim].window_reversal()) {
reverse_dimensions.push_back( reverse_dimensions.push_back(
dnums.kernel_spatial_dimensions(spatial_dim)); dnums.kernel_spatial_dimensions(spatial_dim));
@ -5089,8 +5073,8 @@ StatusOr<bool> AlgebraicSimplifierVisitor::SwapConvOperands(
dilated_kernel_size); dilated_kernel_size);
} }
// Don't transform if a naive convolution implementation would not have // Don't transform if a naive convolution implementation would not have fewer
// fewer flops. // flops.
if (kernel_product <= swapped_kernel_product) { if (kernel_product <= swapped_kernel_product) {
return false; return false;
} }
@ -5168,11 +5152,11 @@ StatusOr<bool> AlgebraicSimplifierVisitor::SimplifyConvToDot(
} }
} }
// Stride ignores part of the output, which matrix multiplication does not // Stride ignores part of the output, which matrix multiplication does not do,
// do, so require no stride. Padding and base (lhs) dilation both implicitly // so require no stride. Padding and base (lhs) dilation both implicitly
// extend the data, which matrix multiplication also does not do, so require // extend the data, which matrix multiplication also does not do, so require
// no padding and no base (lhs) dilation. Window (rhs) dilation has no // no padding and no base (lhs) dilation. Window (rhs) dilation has no effect
// effect for a 1x1 window, so window dilation is no problem. // for a 1x1 window, so window dilation is no problem.
if (window_util::HasStride(window) || window_util::HasPadding(window) || if (window_util::HasStride(window) || window_util::HasPadding(window) ||
window_util::HasBaseDilation(window)) { window_util::HasBaseDilation(window)) {
return false; return false;
@ -5225,9 +5209,8 @@ StatusOr<bool> AlgebraicSimplifierVisitor::SimplifyConvToDot(
} }
} }
// We already checked feature_dimension is most minor, so data in // We already checked feature_dimension is most minor, so data in input_shape
// input_shape and row-major {conv_width,input_channels} are bitwise // and row-major {conv_width,input_channels} are bitwise identical.
// identical.
Shape new_input_shape = ShapeUtil::MakeShapeWithDescendingLayout( Shape new_input_shape = ShapeUtil::MakeShapeWithDescendingLayout(
input_shape.element_type(), {conv_width, input_channels}); input_shape.element_type(), {conv_width, input_channels});
simplifier_->UpdateLayout(&new_input_shape); simplifier_->UpdateLayout(&new_input_shape);

View File

@ -97,14 +97,6 @@ class AlgebraicSimplifierOptions {
return enable_scalar_multiply_reduction_; return enable_scalar_multiply_reduction_;
} }
// Also the algebraic simplifer to treat floating point values like real
// numbers.
void set_enable_floats_are_real(bool enable_floats_are_real) {
enable_floats_are_real_ = enable_floats_are_real;
}
bool enable_floats_are_real() const { return enable_floats_are_real_; }
// If enable_window_reduce_replacement is true, the kReduceWindow instruction // If enable_window_reduce_replacement is true, the kReduceWindow instruction
// can be optimized by replacement with simpler operations. // can be optimized by replacement with simpler operations.
void set_enable_window_reduce_to_reduce_replacement( void set_enable_window_reduce_to_reduce_replacement(
@ -166,7 +158,6 @@ class AlgebraicSimplifierOptions {
bool enable_conv_simplification_{true}; bool enable_conv_simplification_{true};
bool enable_conv_operand_swap_{true}; bool enable_conv_operand_swap_{true};
bool enable_scalar_multiply_reduction_{false}; bool enable_scalar_multiply_reduction_{false};
bool enable_floats_are_real_{false};
bool enable_window_reduce_to_reduce_replacement_{true}; bool enable_window_reduce_to_reduce_replacement_{true};
bool enable_reduce_of_reshape_{true}; bool enable_reduce_of_reshape_{true};
bool replace_transpose_with_bitcast_{true}; bool replace_transpose_with_bitcast_{true};

View File

@ -117,22 +117,6 @@ TEST_F(AlgebraicSimplifierTest, FactorFpAddition) {
m::ConstantScalar(0.125)))); m::ConstantScalar(0.125))));
} }
// (Abs(A)) * (Abs(A)) => (A*A)
TEST_F(AlgebraicSimplifierTest, SquareOfAbs) {
const char* kModuleStr = R"(
HloModule m
test {
p = f32[] parameter(0)
a = f32[] abs(p)
ROOT z = f32[] multiply(a, a)
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
EXPECT_THAT(m->entry_computation()->root_instruction(),
GmockMatch(m::Multiply(m::Parameter(0), m::Parameter(0))));
}
// (A*C1) * (B*C2) => (A*B)*(C1*C2) // (A*C1) * (B*C2) => (A*B)*(C1*C2)
TEST_F(AlgebraicSimplifierTest, MultiplyChain) { TEST_F(AlgebraicSimplifierTest, MultiplyChain) {
const char* kModuleStr = R"( const char* kModuleStr = R"(