Merge pull request #41631 from nouiz:upstream_master_tf2_pow3
PiperOrigin-RevId: 322951984 Change-Id: Ia0080a732101f46c43252e80c7864fdb535267e1
This commit is contained in:
commit
55c1276013
@ -3115,6 +3115,17 @@ Status AlgebraicSimplifierVisitor::HandlePower(HloInstruction* power) {
|
|||||||
HloOpcode::kMultiply, lhs, lhs));
|
HloOpcode::kMultiply, lhs, lhs));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Pow(A, 3) is used in GELU.
|
||||||
|
VLOG(10) << "trying transform [pow(A, 3) => A*A*A]: " << power->ToString();
|
||||||
|
if (IsAll(rhs, 3)) {
|
||||||
|
HloInstruction* tmp =
|
||||||
|
computation_->AddInstruction(HloInstruction::CreateBinary(
|
||||||
|
power->shape(), HloOpcode::kMultiply, lhs, lhs));
|
||||||
|
return ReplaceWithNewInstruction(
|
||||||
|
power, HloInstruction::CreateBinary(power->shape(),
|
||||||
|
HloOpcode::kMultiply, lhs, tmp));
|
||||||
|
}
|
||||||
|
|
||||||
VLOG(10) << "trying transform [pow(A, -1) => 1/A]: " << power->ToString();
|
VLOG(10) << "trying transform [pow(A, -1) => 1/A]: " << power->ToString();
|
||||||
if (IsAll(rhs, -1)) {
|
if (IsAll(rhs, -1)) {
|
||||||
return ReplaceWithNewInstruction(
|
return ReplaceWithNewInstruction(
|
||||||
|
@ -1591,6 +1591,32 @@ TEST_F(AlgebraicSimplifierTest, Pow2) {
|
|||||||
GmockMatch(m::Multiply(m::Parameter(0), m::Parameter(0))));
|
GmockMatch(m::Multiply(m::Parameter(0), m::Parameter(0))));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Test that pow(A, 3) is simplified to A*A*A.
|
||||||
|
TEST_F(AlgebraicSimplifierTest, Pow3) {
|
||||||
|
auto m = CreateNewVerifiedModule();
|
||||||
|
Shape r0f32 = ShapeUtil::MakeShape(F32, {});
|
||||||
|
HloComputation::Builder builder(TestName());
|
||||||
|
HloInstruction* param0 = builder.AddInstruction(
|
||||||
|
HloInstruction::CreateParameter(0, r0f32, "param0"));
|
||||||
|
HloInstruction* three = builder.AddInstruction(
|
||||||
|
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(3)));
|
||||||
|
builder.AddInstruction(
|
||||||
|
HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, three));
|
||||||
|
|
||||||
|
auto computation = m->AddEntryComputation(builder.Build());
|
||||||
|
|
||||||
|
EXPECT_THAT(computation->root_instruction(),
|
||||||
|
GmockMatch(m::Power(m::Parameter(0), m::Op().Is(three))));
|
||||||
|
|
||||||
|
AlgebraicSimplifier simplifier(default_options_);
|
||||||
|
ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
|
||||||
|
|
||||||
|
EXPECT_THAT(
|
||||||
|
computation->root_instruction(),
|
||||||
|
GmockMatch(m::Multiply(m::Parameter(0),
|
||||||
|
m::Multiply(m::Parameter(0), m::Parameter(0)))));
|
||||||
|
}
|
||||||
|
|
||||||
// Test that pow(A, -1) is simplified to 1/A.
|
// Test that pow(A, -1) is simplified to 1/A.
|
||||||
TEST_F(AlgebraicSimplifierTest, PowNegative1) {
|
TEST_F(AlgebraicSimplifierTest, PowNegative1) {
|
||||||
auto m = CreateNewVerifiedModule();
|
auto m = CreateNewVerifiedModule();
|
||||||
|
Loading…
x
Reference in New Issue
Block a user