From b2f5d100d1e1d9422fca8656e64c39fdc287e6b1 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 5 Aug 2020 13:02:31 -0700 Subject: [PATCH] [XLA] Convert Abs(a)*Abs(a) to a*a and add an option to allow for numerically unsafe algebraic simplifications PiperOrigin-RevId: 325084126 Change-Id: Id8bf89ba6601d7bb1efc2b167e6e9accf5913114 --- .../xla/service/algebraic_simplifier.cc | 117 ++++++++---------- .../xla/service/algebraic_simplifier.h | 9 -- .../xla/service/algebraic_simplifier_test.cc | 16 --- 3 files changed, 50 insertions(+), 92 deletions(-) diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index d04a428d349..0b588048e4a 100755 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -665,7 +665,7 @@ Status AlgebraicSimplifierVisitor::ScalarMultiplyReduction( HloInstruction* inst; HloInstruction* user; int64 index; - std::tie(inst, user, index) = operands.back(); + std::tie (inst, user, index) = operands.back(); operands.pop_back(); // Skip the op types that are not commutative with multiply. @@ -913,7 +913,7 @@ Status AlgebraicSimplifierVisitor::HandleAdd(HloInstruction* add) { (Match(lhs, m::Multiply(m::Op(&c), m::Op(&a))) && Match(rhs, m::MultiplyAnyOrder(m::Op().Is(c), m::Op(&b))))) && (ShapeUtil::ElementIsIntegral(add->shape()) || - options_.enable_floats_are_real() || IsAllFpConstantPowerOf2(c))) { + IsAllFpConstantPowerOf2(c))) { return ReplaceWithNewInstruction( add, HloInstruction::CreateBinary( add->shape(), HloOpcode::kMultiply, @@ -2667,17 +2667,6 @@ Status AlgebraicSimplifierVisitor::HandleMultiply(HloInstruction* multiply) { return Status::OK(); } - { - HloInstruction* abs_operand; - if (lhs == rhs && Match(lhs, m::Abs(m::Op(&abs_operand))) && - !ShapeUtil::ElementIsComplex(abs_operand->shape())) { - TF_RETURN_IF_ERROR(multiply->ReplaceOperandWith(0, abs_operand)); - TF_RETURN_IF_ERROR(multiply->ReplaceOperandWith(1, abs_operand)); - changed_ = true; - return Status::OK(); - } - } - { HloInstruction *convert_operand, *operand; // Mul(Convert(Pred), operand) => select(pred, operand, 0) @@ -3048,8 +3037,8 @@ Status AlgebraicSimplifierVisitor::HandleBroadcast(HloInstruction* broadcast) { HloInstruction* new_broadcast = computation_->AddInstruction( HloInstruction::CreateBroadcast(user->shape(), operand, {})); // Use HloInstruction::ReplaceAllUsesWith instead of - // HloComputation::ReplaceWithNewInstruction because we are replacing - // an instruction other than the visited instruction. + // HloComputation::ReplaceWithNewInstruction because we are replacing an + // instruction other than the visited instruction. changed_ = true; return user->ReplaceAllUsesWith(new_broadcast); } @@ -3166,11 +3155,9 @@ Status AlgebraicSimplifierVisitor::HandleConvert(HloInstruction* convert) { // Eliminate a convert pair if it is a no-op. The following are a few // example cases that are being handled: - // 1. convert(convert(A, $TYPE1), $TYPE2) is simplified to A if A is of - // $TYPE2 + // 1. convert(convert(A, $TYPE1), $TYPE2) is simplified to A if A is of $TYPE2 // and convert(A, $TYPE1) is an upcast - // 2. convert(convert(A, $TYPE1),$TYPE2) is simplified to A if A is of - // $TYPE2 + // 2. convert(convert(A, $TYPE1),$TYPE2) is simplified to A if A is of $TYPE2 // and convert(A, $TYPE1) is an upcast and is an integral conversion from // unsigned to signed (only signed to unsigned conversion is NOT allowed) // 3. Tuple(convert(A, $TYPE1) , floor(convert(convert(A, $TYPE1), $TYPE2)), @@ -3306,8 +3293,7 @@ Status AlgebraicSimplifierVisitor::HandlePad(HloInstruction* pad) { pad->shape(), nonzero_pad->mutable_shape())); simplifier_->UpdateLayout(nonzero_pad->mutable_shape()); - // Second, construct the slice instruction to perform the negative - // padding. + // Second, construct the slice instruction to perform the negative padding. std::vector start_indices; std::vector end_indices; std::vector strides; @@ -3460,8 +3446,8 @@ AlgebraicSimplifierVisitor::TryToSinkBroadcastAfterOpWithUniqueNonScalarOperand( Shape changed_shape; for (HloInstruction* user_operand : user->operands()) { - // If this is a broadcast operand that is not our original broadcast - // input to this function then we might need to change the input. + // If this is a broadcast operand that is not our original broadcast input + // to this function then we might need to change the input. if (is_compatible_broadcast(user_operand)) { // If this is a broadcast from a scalar value rewrite a broadcast from // the scalar to the new shape enforced from the other broadcast @@ -3632,16 +3618,16 @@ Status AlgebraicSimplifierVisitor::HandleRemainder(HloInstruction* remainder) { // If M < N, then {0, ..., M} % N ==> {0, ..., M}. // // Currently this only covers the case when N is a broadcasted constant - // scalar. We could also cover the case when N is a non-broadcasted - // constant with the same value repeated. + // scalar. We could also cover the case when N is a non-broadcasted constant + // with the same value repeated. HloInstruction* iota; HloInstruction* divisor; if (Match(remainder, m::Remainder(m::Iota(&iota), m::Broadcast(m::ConstantEffectiveScalar(&divisor))))) { // The iota counts {0, ..., iota_upper_bound - 1}. (Actually this is - // conservative; the iota may overflow and count up to a smaller value - // than this. But that's OK for our purposes here.) + // conservative; the iota may overflow and count up to a smaller value than + // this. But that's OK for our purposes here.) int64 iota_upper_bound = iota->shape().dimensions( Cast(iota)->iota_dimension()); absl::optional divisor_val = divisor->literal().GetIntegralAsS64( @@ -3654,8 +3640,8 @@ Status AlgebraicSimplifierVisitor::HandleRemainder(HloInstruction* remainder) { // (X + N) % N = X % N, so long as X + N does not overflow. // // We don't have range tracking in XLA that would let us know whether X + N - // overflows, so for now we only do this simplification when X is an iota. - // We could add other operations where it's easy to see a range, such as + // overflows, so for now we only do this simplification when X is an iota. We + // could add other operations where it's easy to see a range, such as // remainder, convert, etc., though at some point we'd probably want a // range-tracking analysis. HloInstruction* bcast; @@ -3667,9 +3653,9 @@ Status AlgebraicSimplifierVisitor::HandleRemainder(HloInstruction* remainder) { m::Broadcast(m::ConstantEffectiveScalar(&addend))), m::Broadcast(&bcast, m::ConstantEffectiveScalar(&divisor)))) && addend == divisor) { - // The iota counts {0, ...iota_upper_bound - 1}, with the same caveat - // above that iota_upper_bound is conservative, and the true upper bound - // may be smaller. + // The iota counts {0, ...iota_upper_bound - 1}, with the same caveat above + // that iota_upper_bound is conservative, and the true upper bound may be + // smaller. int64 iota_upper_bound = iota->shape().dimensions( Cast(iota)->iota_dimension()); absl::optional divisor_val = divisor->literal().GetIntegralAsS64( @@ -3774,9 +3760,9 @@ Status AlgebraicSimplifierVisitor::HandleReverse(HloInstruction* reverse) { StatusOr AlgebraicSimplifierVisitor::TrySimplifyScalarSlice( HloInstruction* slice) { - // Only try to do this for effective scalars. We could do the same for - // slicing out larger pieces of padding (replacing with a broadcast of the - // padding value), but this is probably not worth it. + // Only try to do this for effective scalars. We could do the same for slicing + // out larger pieces of padding (replacing with a broadcast of the padding + // value), but this is probably not worth it. if (!ShapeUtil::IsEffectiveScalar(slice->shape())) { return false; } @@ -3877,8 +3863,8 @@ StatusOr AlgebraicSimplifierVisitor::TryToReorderSliceAndReshape( return false; } -// Allowing a slice to move through a reverse with any necessary updates to -// the slice config. +// Allowing a slice to move through a reverse with any necessary updates to the +// slice config. StatusOr AlgebraicSimplifierVisitor::TryToReorderSliceAndReverse( HloInstruction* slice) { VLOG(2) << "Entered TryToReorderSliceAndReverse for slice:" @@ -3906,8 +3892,8 @@ StatusOr AlgebraicSimplifierVisitor::TryToReorderSliceAndReverse( << new_limits[rdim]; } // New slice formed from the reverse_operand, but strides and shape of the - // slice output remains the same. New slice's starts and limits are - // updated for ONLY the reversed dimensions as indicated above. + // slice output remains the same. New slice's starts and limits are updated + // for ONLY the reversed dimensions as indicated above. HloInstruction* new_slice = computation_->AddInstruction( HloInstruction::CreateSlice(slice->shape(), reverse_operand, new_starts, new_limits, new_strides)); @@ -3934,8 +3920,7 @@ Status AlgebraicSimplifierVisitor::HandleSlice(HloInstruction* slice) { if (Match(slice, m::Slice(m::Pad(&pad, m::Op(&pad_operand), m::Op())))) { // Is the result of the slice the pad operand. bool slice_undoes_pad = true; - // Can the slice be moved to the pad_operand without any padding being - // read. + // Can the slice be moved to the pad_operand without any padding being read. bool slice_inside_pad = true; // Does this slice slice out pading only. bool slice_in_padding = false; @@ -4070,8 +4055,8 @@ Status AlgebraicSimplifierVisitor::HandleSlice(HloInstruction* slice) { } } - // Do not try to reorder slices and reshapes after layout assignment as it - // may be invalid. + // Do not try to reorder slices and reshapes after layout assignment as it may + // be invalid. if (!options_.is_layout_sensitive()) { TF_ASSIGN_OR_RETURN(replaced, TryToReorderSliceAndReshape(slice)); } @@ -4121,8 +4106,8 @@ Status AlgebraicSimplifierVisitor::HandleDynamicSlice( if (ShapeUtil::IsScalar(dynamic_slice->shape())) { return ReplaceInstruction(dynamic_slice, operand); } - // DynamicSlice where operand has the same size as the output is simply - // equal to operand. + // DynamicSlice where operand has the same size as the output is simply equal + // to operand. if (SameShape(operand, dynamic_slice)) { return ReplaceInstruction(dynamic_slice, operand); } @@ -4453,8 +4438,8 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* hlo) { // Convert Reduce(concat({a,b,...})) to // map(reduce(a),map(reduce(b),...,)) // - // This should make fusion easier or use less memory bandwidth in the - // unfused case. + // This should make fusion easier or use less memory bandwidth in the unfused + // case. if (arg->opcode() == HloOpcode::kConcatenate && absl::c_linear_search(reduce->dimensions(), arg->concatenate_dimension())) { @@ -4473,9 +4458,9 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* hlo) { } HloInstruction *dot, *lhs, *rhs; - // Convert Reduce(Dot(X,Y)) to Dot(X,Y) if any of the dimensions reduced - // were batch dimensions of the dot. The transformation supports reducing - // other dimensions as well. + // Convert Reduce(Dot(X,Y)) to Dot(X,Y) if any of the dimensions reduced were + // batch dimensions of the dot. The transformation supports reducing other + // dimensions as well. if (options_.enable_dot_strength_reduction() && Match(arg, m::Dot(&dot, m::Op(&lhs), m::Op(&rhs)).WithOneUser()) && Match(reduce->to_apply()->root_instruction(), @@ -4547,13 +4532,13 @@ Status AlgebraicSimplifierVisitor::HandleReduceWindow( if (options_.enable_window_reduce_to_reduce_replacement()) { // A reduce window can be expressed as a reduce and a reshape if all // dimensions either have a window size of one or the entire dimension. If - // there is no stride, dilation, or padding, this is as easy as checking - // the size of the output shape and window dimension. + // there is no stride, dilation, or padding, this is as easy as checking the + // size of the output shape and window dimension. // - // The reshape is a bitcast since it adds one-sized dimensions. Often - // these ones are immediately removed as well with another reshape. The - // implementation of reduce tends to be slightly more efficient at - // reducing entire dimensions compared to reduce window. + // The reshape is a bitcast since it adds one-sized dimensions. Often these + // ones are immediately removed as well with another reshape. The + // implementation of reduce tends to be slightly more efficient at reducing + // entire dimensions compared to reduce window. auto effective_reduce_dims = [&] { if (window_util::HasStride(window) || window_util::HasDilation(window) || window_util::HasPadding(window)) { @@ -5068,8 +5053,7 @@ StatusOr AlgebraicSimplifierVisitor::SwapConvOperands( auto new_dim = swapped_window.add_dimensions(); new_dim->set_size(input_size); - // If the kernel is not reversed, the activations must be manually - // reversed. + // If the kernel is not reversed, the activations must be manually reversed. if (!window_dims[spatial_dim].window_reversal()) { reverse_dimensions.push_back( dnums.kernel_spatial_dimensions(spatial_dim)); @@ -5089,8 +5073,8 @@ StatusOr AlgebraicSimplifierVisitor::SwapConvOperands( dilated_kernel_size); } - // Don't transform if a naive convolution implementation would not have - // fewer flops. + // Don't transform if a naive convolution implementation would not have fewer + // flops. if (kernel_product <= swapped_kernel_product) { return false; } @@ -5168,11 +5152,11 @@ StatusOr AlgebraicSimplifierVisitor::SimplifyConvToDot( } } - // Stride ignores part of the output, which matrix multiplication does not - // do, so require no stride. Padding and base (lhs) dilation both implicitly + // Stride ignores part of the output, which matrix multiplication does not do, + // so require no stride. Padding and base (lhs) dilation both implicitly // extend the data, which matrix multiplication also does not do, so require - // no padding and no base (lhs) dilation. Window (rhs) dilation has no - // effect for a 1x1 window, so window dilation is no problem. + // no padding and no base (lhs) dilation. Window (rhs) dilation has no effect + // for a 1x1 window, so window dilation is no problem. if (window_util::HasStride(window) || window_util::HasPadding(window) || window_util::HasBaseDilation(window)) { return false; @@ -5225,9 +5209,8 @@ StatusOr AlgebraicSimplifierVisitor::SimplifyConvToDot( } } - // We already checked feature_dimension is most minor, so data in - // input_shape and row-major {conv_width,input_channels} are bitwise - // identical. + // We already checked feature_dimension is most minor, so data in input_shape + // and row-major {conv_width,input_channels} are bitwise identical. Shape new_input_shape = ShapeUtil::MakeShapeWithDescendingLayout( input_shape.element_type(), {conv_width, input_channels}); simplifier_->UpdateLayout(&new_input_shape); diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.h b/tensorflow/compiler/xla/service/algebraic_simplifier.h index cabecec4eb8..9f2a3404116 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.h +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.h @@ -97,14 +97,6 @@ class AlgebraicSimplifierOptions { return enable_scalar_multiply_reduction_; } - // Also the algebraic simplifer to treat floating point values like real - // numbers. - void set_enable_floats_are_real(bool enable_floats_are_real) { - enable_floats_are_real_ = enable_floats_are_real; - } - - bool enable_floats_are_real() const { return enable_floats_are_real_; } - // If enable_window_reduce_replacement is true, the kReduceWindow instruction // can be optimized by replacement with simpler operations. void set_enable_window_reduce_to_reduce_replacement( @@ -166,7 +158,6 @@ class AlgebraicSimplifierOptions { bool enable_conv_simplification_{true}; bool enable_conv_operand_swap_{true}; bool enable_scalar_multiply_reduction_{false}; - bool enable_floats_are_real_{false}; bool enable_window_reduce_to_reduce_replacement_{true}; bool enable_reduce_of_reshape_{true}; bool replace_transpose_with_bitcast_{true}; diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index fdd9fb04941..90ca44714f7 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -117,22 +117,6 @@ TEST_F(AlgebraicSimplifierTest, FactorFpAddition) { m::ConstantScalar(0.125)))); } -// (Abs(A)) * (Abs(A)) => (A*A) -TEST_F(AlgebraicSimplifierTest, SquareOfAbs) { - const char* kModuleStr = R"( - HloModule m - test { - p = f32[] parameter(0) - a = f32[] abs(p) - ROOT z = f32[] multiply(a, a) - } - )"; - TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); - ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); - EXPECT_THAT(m->entry_computation()->root_instruction(), - GmockMatch(m::Multiply(m::Parameter(0), m::Parameter(0)))); -} - // (A*C1) * (B*C2) => (A*B)*(C1*C2) TEST_F(AlgebraicSimplifierTest, MultiplyChain) { const char* kModuleStr = R"(