Merge pull request #41631 from nouiz:upstream_master_tf2_pow3

PiperOrigin-RevId: 322951984
Change-Id: Ia0080a732101f46c43252e80c7864fdb535267e1
This commit is contained in:
TensorFlower Gardener 2020-07-24 01:04:50 -07:00
commit 55c1276013
2 changed files with 37 additions and 0 deletions

View File

@ -3115,6 +3115,17 @@ Status AlgebraicSimplifierVisitor::HandlePower(HloInstruction* power) {
HloOpcode::kMultiply, lhs, lhs));
}
// Pow(A, 3) is used in GELU.
VLOG(10) << "trying transform [pow(A, 3) => A*A*A]: " << power->ToString();
if (IsAll(rhs, 3)) {
HloInstruction* tmp =
computation_->AddInstruction(HloInstruction::CreateBinary(
power->shape(), HloOpcode::kMultiply, lhs, lhs));
return ReplaceWithNewInstruction(
power, HloInstruction::CreateBinary(power->shape(),
HloOpcode::kMultiply, lhs, tmp));
}
VLOG(10) << "trying transform [pow(A, -1) => 1/A]: " << power->ToString();
if (IsAll(rhs, -1)) {
return ReplaceWithNewInstruction(

View File

@ -1591,6 +1591,32 @@ TEST_F(AlgebraicSimplifierTest, Pow2) {
GmockMatch(m::Multiply(m::Parameter(0), m::Parameter(0))));
}
// Test that pow(A, 3) is simplified to A*A*A.
TEST_F(AlgebraicSimplifierTest, Pow3) {
auto m = CreateNewVerifiedModule();
Shape r0f32 = ShapeUtil::MakeShape(F32, {});
HloComputation::Builder builder(TestName());
HloInstruction* param0 = builder.AddInstruction(
HloInstruction::CreateParameter(0, r0f32, "param0"));
HloInstruction* three = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(3)));
builder.AddInstruction(
HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, three));
auto computation = m->AddEntryComputation(builder.Build());
EXPECT_THAT(computation->root_instruction(),
GmockMatch(m::Power(m::Parameter(0), m::Op().Is(three))));
AlgebraicSimplifier simplifier(default_options_);
ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
EXPECT_THAT(
computation->root_instruction(),
GmockMatch(m::Multiply(m::Parameter(0),
m::Multiply(m::Parameter(0), m::Parameter(0)))));
}
// Test that pow(A, -1) is simplified to 1/A.
TEST_F(AlgebraicSimplifierTest, PowNegative1) {
auto m = CreateNewVerifiedModule();