[XLA] Add a variety of reassociation transformations

New transforms:
(A * C0) * C1 => A * (C0 * C1)

Extended to support broadcasts:
A - C => A + -C
(A + C0) + C1 => A + (C0 + C1)

PiperOrigin-RevId: 244284014
This commit is contained in:
David Majnemer 2019-04-18 16:43:37 -07:00 committed by TensorFlower Gardener
parent 1062b2b253
commit 00d9ab673c
2 changed files with 122 additions and 2 deletions

View File

@ -531,9 +531,17 @@ Status AlgebraicSimplifierVisitor::HandleAdd(HloInstruction* add) {
VLOG(10) << "trying transform [(A + C1) + C2 => A + (C1 + C2)]";
HloInstruction *a, *c1, *c2;
if (Match(add, m::Add(m::Add(m::NonConstant(&a), m::Constant(&c1)),
m::Constant(&c2)))) {
m::Constant(&c2))) ||
Match(add, m::Add(m::Add(m::NonConstant(&a),
m::Broadcast(m::ConstantScalar(&c1))),
m::Broadcast(m::ConstantScalar(&c2))))) {
TF_ASSIGN_OR_RETURN(auto* sum_of_constants,
MakeBinaryHlo(HloOpcode::kAdd, c1, c2));
if (ShapeUtil::IsScalar(sum_of_constants->shape()) &&
!ShapeUtil::IsScalar(add->shape())) {
sum_of_constants = computation_->AddInstruction(
HloInstruction::CreateBroadcast(add->shape(), sum_of_constants, {}));
}
return ReplaceWithNewInstruction(
add, HloInstruction::CreateBinary(add->shape(), HloOpcode::kAdd, a,
sum_of_constants));
@ -861,9 +869,17 @@ Status AlgebraicSimplifierVisitor::HandleSubtract(HloInstruction* sub) {
// Canonicalize subtraction of a constant to addition.
VLOG(10) << "trying transform [A - Const => A + (-Const)]";
if (Match(sub, m::Subtract(m::NonConstant(&lhs), m::Constant(&rhs)))) {
if (Match(sub, m::Subtract(m::NonConstant(&lhs), m::Constant(&rhs))) ||
Match(sub, m::Subtract(m::NonConstant(&lhs),
m::Broadcast(m::Constant(&rhs))))) {
HloInstruction* negative_const = computation_->AddInstruction(
HloInstruction::CreateUnary(rhs->shape(), HloOpcode::kNegate, rhs));
if (const HloInstruction* broadcast =
DynCast<HloBroadcastInstruction>(sub->operand(1))) {
negative_const =
computation_->AddInstruction(HloInstruction::CreateBroadcast(
broadcast->shape(), negative_const, broadcast->dimensions()));
}
return ReplaceWithNewInstruction(
sub, HloInstruction::CreateBinary(sub->shape(), HloOpcode::kAdd, lhs,
negative_const));
@ -1883,6 +1899,29 @@ Status AlgebraicSimplifierVisitor::HandleMultiply(HloInstruction* multiply) {
return Status::OK();
}
VLOG(10) << "trying transform [(A * C1) * C2 => A * (C1 * C2)]";
HloInstruction *a, *c1, *c2;
if (Match(multiply,
m::Multiply(m::Multiply(m::NonConstant(&a), m::Constant(&c1)),
m::Constant(&c2))) ||
Match(multiply,
m::Multiply(m::Multiply(m::NonConstant(&a),
m::Broadcast(m::ConstantScalar(&c1))),
m::Broadcast(m::ConstantScalar(&c2))))) {
TF_ASSIGN_OR_RETURN(auto* product_of_constants,
MakeBinaryHlo(HloOpcode::kMultiply, c1, c2));
if (ShapeUtil::IsScalar(product_of_constants->shape()) &&
!ShapeUtil::IsScalar(multiply->shape())) {
product_of_constants =
computation_->AddInstruction(HloInstruction::CreateBroadcast(
multiply->shape(), product_of_constants, {}));
}
return ReplaceWithNewInstruction(
multiply,
HloInstruction::CreateBinary(multiply->shape(), HloOpcode::kMultiply, a,
product_of_constants));
}
// exp(A) * exp(B) => exp(A+B)
if (Match(multiply, m::Multiply(m::Exp(m::Op(&lhs)), m::Exp(m::Op(&rhs))))) {
auto add = computation_->AddInstruction(HloInstruction::CreateBinary(

View File

@ -295,6 +295,47 @@ TEST_F(AlgebraicSimplifierTest, MulZero) {
EXPECT_EQ(computation->root_instruction(), zero);
}
TEST_F(AlgebraicSimplifierTest, MultiplyReassociateMergeConstants) {
const char* kModuleStr = R"(
HloModule m
test {
p0 = f32[] parameter(0)
c0 = f32[] constant(2.0)
c1 = f32[] constant(3.0)
multiply0 = f32[] multiply(p0, c0)
ROOT multiply1 = f32[] multiply(multiply0, c1)
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
EXPECT_THAT(m->entry_computation()->root_instruction(),
GmockMatch(m::Multiply(m::Parameter(0),
m::Multiply(m::ConstantScalar(2.0),
m::ConstantScalar(3.0)))));
}
TEST_F(AlgebraicSimplifierTest, MultiplyReassociateMergeBroadcastedConstants) {
const char* kModuleStr = R"(
HloModule m
test {
p0 = f32[4] parameter(0)
c0 = f32[] constant(2.0)
c1 = f32[] constant(3.0)
b0 = f32[4] broadcast(c0), dimensions={}
b1 = f32[4] broadcast(c1), dimensions={}
multiply0 = f32[4] multiply(p0, b0)
ROOT multiply1 = f32[4] multiply(multiply0, b1)
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
EXPECT_THAT(
m->entry_computation()->root_instruction(),
GmockMatch(m::Multiply(
m::Parameter(0), m::Broadcast(m::Multiply(m::ConstantScalar(2.0),
m::ConstantScalar(3.0))))));
}
// Test that select(true, a, b) is simplified to a
TEST_F(AlgebraicSimplifierTest, SelectTrue) {
Shape r0s32 = ShapeUtil::MakeShape(S32, {});
@ -446,6 +487,27 @@ TEST_F(AlgebraicSimplifierTest, AddReassociateMergeConstants) {
m::Add(m::Op().Is(constant1), m::Op().Is(constant2)))));
}
TEST_F(AlgebraicSimplifierTest, AddReassociateMergeBroadcastedConstants) {
const char* kModuleStr = R"(
HloModule m
test {
p0 = f32[4] parameter(0)
c0 = f32[] constant(1.0)
c1 = f32[] constant(2.0)
b0 = f32[4] broadcast(c0), dimensions={}
b1 = f32[4] broadcast(c1), dimensions={}
add0 = f32[4] add(p0, b0)
ROOT add1 = f32[4] add(add0, b1)
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
EXPECT_THAT(m->entry_computation()->root_instruction(),
GmockMatch(m::Add(m::Parameter(0),
m::Broadcast(m::Add(m::ConstantScalar(1.0),
m::ConstantScalar(2.0))))));
}
TEST_F(AlgebraicSimplifierTest, AddBroadcastZeroR0Operand) {
auto m = CreateNewVerifiedModule();
Shape r2f32 = ShapeUtil::MakeShape(F32, {3, 2});
@ -640,6 +702,25 @@ TEST_F(AlgebraicSimplifierTest, SubConstCanonicalization) {
m::Negate(m::Op().Is(constant)))));
}
// Test that A - Broadcast(Const) is canonicalized to A + Broadcast(-Const).
TEST_F(AlgebraicSimplifierTest, SubBroadcastConstCanonicalization) {
const char* kModuleStr = R"(
HloModule m
test {
p0 = f32[4] parameter(0)
c = f32[] constant(0.125)
b = f32[4] broadcast(c), dimensions={}
ROOT sub = f32[4] subtract(p0, b)
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
EXPECT_THAT(
m->entry_computation()->root_instruction(),
GmockMatch(m::Add(m::Parameter(0),
m::Broadcast(m::Negate(m::ConstantScalar(0.125))))));
}
// Test that (A/B)/C is simplified to A/(B*C).
TEST_F(AlgebraicSimplifierTest, LhsDivOfDiv) {
auto m = CreateNewVerifiedModule();