diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index 6bbde42bad9..4e7bd85e557 100755 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -913,7 +913,7 @@ Status AlgebraicSimplifierVisitor::HandleAdd(HloInstruction* add) { (Match(lhs, m::Multiply(m::Op(&c), m::Op(&a))) && Match(rhs, m::MultiplyAnyOrder(m::Op().Is(c), m::Op(&b))))) && (ShapeUtil::ElementIsIntegral(add->shape()) || - IsAllFpConstantPowerOf2(c))) { + options_.enable_floats_are_real() || IsAllFpConstantPowerOf2(c))) { return ReplaceWithNewInstruction( add, HloInstruction::CreateBinary( add->shape(), HloOpcode::kMultiply, @@ -2710,6 +2710,17 @@ Status AlgebraicSimplifierVisitor::HandleMultiply(HloInstruction* multiply) { return Status::OK(); } + { + HloInstruction* abs_operand; + if (lhs == rhs && Match(lhs, m::Abs(m::Op(&abs_operand))) && + !ShapeUtil::ElementIsComplex(abs_operand->shape())) { + TF_RETURN_IF_ERROR(multiply->ReplaceOperandWith(0, abs_operand)); + TF_RETURN_IF_ERROR(multiply->ReplaceOperandWith(1, abs_operand)); + changed_ = true; + return Status::OK(); + } + } + { HloInstruction *convert_operand, *operand; // Mul(Convert(Pred), operand) => select(pred, operand, 0) diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.h b/tensorflow/compiler/xla/service/algebraic_simplifier.h index 9f2a3404116..cabecec4eb8 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.h +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.h @@ -97,6 +97,14 @@ class AlgebraicSimplifierOptions { return enable_scalar_multiply_reduction_; } + // Also the algebraic simplifer to treat floating point values like real + // numbers. + void set_enable_floats_are_real(bool enable_floats_are_real) { + enable_floats_are_real_ = enable_floats_are_real; + } + + bool enable_floats_are_real() const { return enable_floats_are_real_; } + // If enable_window_reduce_replacement is true, the kReduceWindow instruction // can be optimized by replacement with simpler operations. void set_enable_window_reduce_to_reduce_replacement( @@ -158,6 +166,7 @@ class AlgebraicSimplifierOptions { bool enable_conv_simplification_{true}; bool enable_conv_operand_swap_{true}; bool enable_scalar_multiply_reduction_{false}; + bool enable_floats_are_real_{false}; bool enable_window_reduce_to_reduce_replacement_{true}; bool enable_reduce_of_reshape_{true}; bool replace_transpose_with_bitcast_{true}; diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index c08fbd13d9d..c4f3ea4087b 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -117,6 +117,22 @@ TEST_F(AlgebraicSimplifierTest, FactorFpAddition) { m::ConstantScalar(0.125)))); } +// (Abs(A)) * (Abs(A)) => (A*A) +TEST_F(AlgebraicSimplifierTest, SquareOfAbs) { + const char* kModuleStr = R"( + HloModule m + test { + p = f32[] parameter(0) + a = f32[] abs(p) + ROOT z = f32[] multiply(a, a) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::Multiply(m::Parameter(0), m::Parameter(0)))); +} + // (A*C1) * (B*C2) => (A*B)*(C1*C2) TEST_F(AlgebraicSimplifierTest, MultiplyChain) { const char* kModuleStr = R"(