Merge pull request #41631 from nouiz:upstream_master_tf2_pow3
PiperOrigin-RevId: 322951984 Change-Id: Ia0080a732101f46c43252e80c7864fdb535267e1
This commit is contained in:
commit
55c1276013
@ -3115,6 +3115,17 @@ Status AlgebraicSimplifierVisitor::HandlePower(HloInstruction* power) {
|
||||
HloOpcode::kMultiply, lhs, lhs));
|
||||
}
|
||||
|
||||
// Pow(A, 3) is used in GELU.
|
||||
VLOG(10) << "trying transform [pow(A, 3) => A*A*A]: " << power->ToString();
|
||||
if (IsAll(rhs, 3)) {
|
||||
HloInstruction* tmp =
|
||||
computation_->AddInstruction(HloInstruction::CreateBinary(
|
||||
power->shape(), HloOpcode::kMultiply, lhs, lhs));
|
||||
return ReplaceWithNewInstruction(
|
||||
power, HloInstruction::CreateBinary(power->shape(),
|
||||
HloOpcode::kMultiply, lhs, tmp));
|
||||
}
|
||||
|
||||
VLOG(10) << "trying transform [pow(A, -1) => 1/A]: " << power->ToString();
|
||||
if (IsAll(rhs, -1)) {
|
||||
return ReplaceWithNewInstruction(
|
||||
|
@ -1591,6 +1591,32 @@ TEST_F(AlgebraicSimplifierTest, Pow2) {
|
||||
GmockMatch(m::Multiply(m::Parameter(0), m::Parameter(0))));
|
||||
}
|
||||
|
||||
// Test that pow(A, 3) is simplified to A*A*A.
|
||||
TEST_F(AlgebraicSimplifierTest, Pow3) {
|
||||
auto m = CreateNewVerifiedModule();
|
||||
Shape r0f32 = ShapeUtil::MakeShape(F32, {});
|
||||
HloComputation::Builder builder(TestName());
|
||||
HloInstruction* param0 = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, r0f32, "param0"));
|
||||
HloInstruction* three = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(3)));
|
||||
builder.AddInstruction(
|
||||
HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, three));
|
||||
|
||||
auto computation = m->AddEntryComputation(builder.Build());
|
||||
|
||||
EXPECT_THAT(computation->root_instruction(),
|
||||
GmockMatch(m::Power(m::Parameter(0), m::Op().Is(three))));
|
||||
|
||||
AlgebraicSimplifier simplifier(default_options_);
|
||||
ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
|
||||
|
||||
EXPECT_THAT(
|
||||
computation->root_instruction(),
|
||||
GmockMatch(m::Multiply(m::Parameter(0),
|
||||
m::Multiply(m::Parameter(0), m::Parameter(0)))));
|
||||
}
|
||||
|
||||
// Test that pow(A, -1) is simplified to 1/A.
|
||||
TEST_F(AlgebraicSimplifierTest, PowNegative1) {
|
||||
auto m = CreateNewVerifiedModule();
|
||||
|
Loading…
Reference in New Issue
Block a user