[XLA] Add a variety of reassociation transformations
New transforms: (A * C0) * C1 => A * (C0 * C1) Extended to support broadcasts: A - C => A + -C (A + C0) + C1 => A + (C0 + C1) PiperOrigin-RevId: 244284014
This commit is contained in:
parent
1062b2b253
commit
00d9ab673c
@ -531,9 +531,17 @@ Status AlgebraicSimplifierVisitor::HandleAdd(HloInstruction* add) {
|
||||
VLOG(10) << "trying transform [(A + C1) + C2 => A + (C1 + C2)]";
|
||||
HloInstruction *a, *c1, *c2;
|
||||
if (Match(add, m::Add(m::Add(m::NonConstant(&a), m::Constant(&c1)),
|
||||
m::Constant(&c2)))) {
|
||||
m::Constant(&c2))) ||
|
||||
Match(add, m::Add(m::Add(m::NonConstant(&a),
|
||||
m::Broadcast(m::ConstantScalar(&c1))),
|
||||
m::Broadcast(m::ConstantScalar(&c2))))) {
|
||||
TF_ASSIGN_OR_RETURN(auto* sum_of_constants,
|
||||
MakeBinaryHlo(HloOpcode::kAdd, c1, c2));
|
||||
if (ShapeUtil::IsScalar(sum_of_constants->shape()) &&
|
||||
!ShapeUtil::IsScalar(add->shape())) {
|
||||
sum_of_constants = computation_->AddInstruction(
|
||||
HloInstruction::CreateBroadcast(add->shape(), sum_of_constants, {}));
|
||||
}
|
||||
return ReplaceWithNewInstruction(
|
||||
add, HloInstruction::CreateBinary(add->shape(), HloOpcode::kAdd, a,
|
||||
sum_of_constants));
|
||||
@ -861,9 +869,17 @@ Status AlgebraicSimplifierVisitor::HandleSubtract(HloInstruction* sub) {
|
||||
|
||||
// Canonicalize subtraction of a constant to addition.
|
||||
VLOG(10) << "trying transform [A - Const => A + (-Const)]";
|
||||
if (Match(sub, m::Subtract(m::NonConstant(&lhs), m::Constant(&rhs)))) {
|
||||
if (Match(sub, m::Subtract(m::NonConstant(&lhs), m::Constant(&rhs))) ||
|
||||
Match(sub, m::Subtract(m::NonConstant(&lhs),
|
||||
m::Broadcast(m::Constant(&rhs))))) {
|
||||
HloInstruction* negative_const = computation_->AddInstruction(
|
||||
HloInstruction::CreateUnary(rhs->shape(), HloOpcode::kNegate, rhs));
|
||||
if (const HloInstruction* broadcast =
|
||||
DynCast<HloBroadcastInstruction>(sub->operand(1))) {
|
||||
negative_const =
|
||||
computation_->AddInstruction(HloInstruction::CreateBroadcast(
|
||||
broadcast->shape(), negative_const, broadcast->dimensions()));
|
||||
}
|
||||
return ReplaceWithNewInstruction(
|
||||
sub, HloInstruction::CreateBinary(sub->shape(), HloOpcode::kAdd, lhs,
|
||||
negative_const));
|
||||
@ -1883,6 +1899,29 @@ Status AlgebraicSimplifierVisitor::HandleMultiply(HloInstruction* multiply) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
VLOG(10) << "trying transform [(A * C1) * C2 => A * (C1 * C2)]";
|
||||
HloInstruction *a, *c1, *c2;
|
||||
if (Match(multiply,
|
||||
m::Multiply(m::Multiply(m::NonConstant(&a), m::Constant(&c1)),
|
||||
m::Constant(&c2))) ||
|
||||
Match(multiply,
|
||||
m::Multiply(m::Multiply(m::NonConstant(&a),
|
||||
m::Broadcast(m::ConstantScalar(&c1))),
|
||||
m::Broadcast(m::ConstantScalar(&c2))))) {
|
||||
TF_ASSIGN_OR_RETURN(auto* product_of_constants,
|
||||
MakeBinaryHlo(HloOpcode::kMultiply, c1, c2));
|
||||
if (ShapeUtil::IsScalar(product_of_constants->shape()) &&
|
||||
!ShapeUtil::IsScalar(multiply->shape())) {
|
||||
product_of_constants =
|
||||
computation_->AddInstruction(HloInstruction::CreateBroadcast(
|
||||
multiply->shape(), product_of_constants, {}));
|
||||
}
|
||||
return ReplaceWithNewInstruction(
|
||||
multiply,
|
||||
HloInstruction::CreateBinary(multiply->shape(), HloOpcode::kMultiply, a,
|
||||
product_of_constants));
|
||||
}
|
||||
|
||||
// exp(A) * exp(B) => exp(A+B)
|
||||
if (Match(multiply, m::Multiply(m::Exp(m::Op(&lhs)), m::Exp(m::Op(&rhs))))) {
|
||||
auto add = computation_->AddInstruction(HloInstruction::CreateBinary(
|
||||
|
@ -295,6 +295,47 @@ TEST_F(AlgebraicSimplifierTest, MulZero) {
|
||||
EXPECT_EQ(computation->root_instruction(), zero);
|
||||
}
|
||||
|
||||
TEST_F(AlgebraicSimplifierTest, MultiplyReassociateMergeConstants) {
|
||||
const char* kModuleStr = R"(
|
||||
HloModule m
|
||||
test {
|
||||
p0 = f32[] parameter(0)
|
||||
c0 = f32[] constant(2.0)
|
||||
c1 = f32[] constant(3.0)
|
||||
multiply0 = f32[] multiply(p0, c0)
|
||||
ROOT multiply1 = f32[] multiply(multiply0, c1)
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
|
||||
ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
|
||||
EXPECT_THAT(m->entry_computation()->root_instruction(),
|
||||
GmockMatch(m::Multiply(m::Parameter(0),
|
||||
m::Multiply(m::ConstantScalar(2.0),
|
||||
m::ConstantScalar(3.0)))));
|
||||
}
|
||||
|
||||
TEST_F(AlgebraicSimplifierTest, MultiplyReassociateMergeBroadcastedConstants) {
|
||||
const char* kModuleStr = R"(
|
||||
HloModule m
|
||||
test {
|
||||
p0 = f32[4] parameter(0)
|
||||
c0 = f32[] constant(2.0)
|
||||
c1 = f32[] constant(3.0)
|
||||
b0 = f32[4] broadcast(c0), dimensions={}
|
||||
b1 = f32[4] broadcast(c1), dimensions={}
|
||||
multiply0 = f32[4] multiply(p0, b0)
|
||||
ROOT multiply1 = f32[4] multiply(multiply0, b1)
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
|
||||
ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
|
||||
EXPECT_THAT(
|
||||
m->entry_computation()->root_instruction(),
|
||||
GmockMatch(m::Multiply(
|
||||
m::Parameter(0), m::Broadcast(m::Multiply(m::ConstantScalar(2.0),
|
||||
m::ConstantScalar(3.0))))));
|
||||
}
|
||||
|
||||
// Test that select(true, a, b) is simplified to a
|
||||
TEST_F(AlgebraicSimplifierTest, SelectTrue) {
|
||||
Shape r0s32 = ShapeUtil::MakeShape(S32, {});
|
||||
@ -446,6 +487,27 @@ TEST_F(AlgebraicSimplifierTest, AddReassociateMergeConstants) {
|
||||
m::Add(m::Op().Is(constant1), m::Op().Is(constant2)))));
|
||||
}
|
||||
|
||||
TEST_F(AlgebraicSimplifierTest, AddReassociateMergeBroadcastedConstants) {
|
||||
const char* kModuleStr = R"(
|
||||
HloModule m
|
||||
test {
|
||||
p0 = f32[4] parameter(0)
|
||||
c0 = f32[] constant(1.0)
|
||||
c1 = f32[] constant(2.0)
|
||||
b0 = f32[4] broadcast(c0), dimensions={}
|
||||
b1 = f32[4] broadcast(c1), dimensions={}
|
||||
add0 = f32[4] add(p0, b0)
|
||||
ROOT add1 = f32[4] add(add0, b1)
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
|
||||
ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
|
||||
EXPECT_THAT(m->entry_computation()->root_instruction(),
|
||||
GmockMatch(m::Add(m::Parameter(0),
|
||||
m::Broadcast(m::Add(m::ConstantScalar(1.0),
|
||||
m::ConstantScalar(2.0))))));
|
||||
}
|
||||
|
||||
TEST_F(AlgebraicSimplifierTest, AddBroadcastZeroR0Operand) {
|
||||
auto m = CreateNewVerifiedModule();
|
||||
Shape r2f32 = ShapeUtil::MakeShape(F32, {3, 2});
|
||||
@ -640,6 +702,25 @@ TEST_F(AlgebraicSimplifierTest, SubConstCanonicalization) {
|
||||
m::Negate(m::Op().Is(constant)))));
|
||||
}
|
||||
|
||||
// Test that A - Broadcast(Const) is canonicalized to A + Broadcast(-Const).
|
||||
TEST_F(AlgebraicSimplifierTest, SubBroadcastConstCanonicalization) {
|
||||
const char* kModuleStr = R"(
|
||||
HloModule m
|
||||
test {
|
||||
p0 = f32[4] parameter(0)
|
||||
c = f32[] constant(0.125)
|
||||
b = f32[4] broadcast(c), dimensions={}
|
||||
ROOT sub = f32[4] subtract(p0, b)
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
|
||||
ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
|
||||
EXPECT_THAT(
|
||||
m->entry_computation()->root_instruction(),
|
||||
GmockMatch(m::Add(m::Parameter(0),
|
||||
m::Broadcast(m::Negate(m::ConstantScalar(0.125))))));
|
||||
}
|
||||
|
||||
// Test that (A/B)/C is simplified to A/(B*C).
|
||||
TEST_F(AlgebraicSimplifierTest, LhsDivOfDiv) {
|
||||
auto m = CreateNewVerifiedModule();
|
||||
|
Loading…
Reference in New Issue
Block a user