[XLA] Add support for sinking broadcasts through ops with multiple broadcasts operands.
Add support for sinking this kind of pattern: p0 = f32[4] parameter(0) p1 = f32[4] parameter(1) b0 = f32[4,2] broadcast(p0), dimensions={0} b1 = f32[4,2] broadcast(p1), dimensions={0} ROOT multiply = f32[4,2] multiply(b1, b0) into: p0 = f32[4] parameter(0) p1 = f32[4] parameter(1) multiply = f32[4] multiply(p1, p0) ROOT out = f32[4,2] broadcast(multiply) PiperOrigin-RevId: 313231737 Change-Id: Ic508b3cf30daaf1a2aa9246886ef63ad49be6a01
This commit is contained in:
parent
444ea7fa7f
commit
48296300d6
@ -3058,6 +3058,20 @@ AlgebraicSimplifierVisitor::TryToSinkBroadcastAfterOpWithUniqueNonScalarOperand(
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
HloInstruction* operand = broadcast->mutable_operand(0);
|
HloInstruction* operand = broadcast->mutable_operand(0);
|
||||||
|
auto is_scalar_broadcast = [](const HloInstruction* instruction) {
|
||||||
|
return instruction->opcode() == HloOpcode::kBroadcast &&
|
||||||
|
ShapeUtil::IsScalar(instruction->operand(0)->shape());
|
||||||
|
};
|
||||||
|
auto is_equal_broadcast = [operand,
|
||||||
|
broadcast](const HloInstruction* instruction) {
|
||||||
|
return instruction->opcode() == HloOpcode::kBroadcast &&
|
||||||
|
ShapeUtil::Equal(operand->shape(),
|
||||||
|
instruction->operand(0)->shape()) &&
|
||||||
|
broadcast->dimensions() == instruction->dimensions();
|
||||||
|
};
|
||||||
|
auto is_compatible_broadcast = [&](const HloInstruction* instruction) {
|
||||||
|
return is_scalar_broadcast(instruction) || is_equal_broadcast(instruction);
|
||||||
|
};
|
||||||
for (HloInstruction* user : broadcast->users()) {
|
for (HloInstruction* user : broadcast->users()) {
|
||||||
if (user->user_count() == 0 && user != computation_->root_instruction()) {
|
if (user->user_count() == 0 && user != computation_->root_instruction()) {
|
||||||
continue;
|
continue;
|
||||||
@ -3076,18 +3090,20 @@ AlgebraicSimplifierVisitor::TryToSinkBroadcastAfterOpWithUniqueNonScalarOperand(
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Find the unique non-scalar operand or continue if there isn't one.
|
// Check if all the operands of the user are compatible broadcasts for
|
||||||
int64 scalar_broadcast_count = 0;
|
// sinking. (They are either scalar broadcasts or broadcasts casting
|
||||||
|
// from/to the same shape/dimensions)
|
||||||
|
int64 compatible_broadcast_count = 0;
|
||||||
int64 broadcast_use_count = 0;
|
int64 broadcast_use_count = 0;
|
||||||
for (HloInstruction* user_operand : user->operands()) {
|
for (HloInstruction* user_operand : user->operands()) {
|
||||||
if (user_operand->opcode() == HloOpcode::kBroadcast &&
|
if (is_compatible_broadcast(user_operand)) {
|
||||||
ShapeUtil::IsScalar(user_operand->operand(0)->shape())) {
|
++compatible_broadcast_count;
|
||||||
++scalar_broadcast_count;
|
|
||||||
} else if (broadcast == user_operand) {
|
} else if (broadcast == user_operand) {
|
||||||
++broadcast_use_count;
|
++broadcast_use_count;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (scalar_broadcast_count + broadcast_use_count != user->operand_count()) {
|
if (compatible_broadcast_count + broadcast_use_count !=
|
||||||
|
user->operand_count()) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
std::vector<HloInstruction*> new_operands;
|
std::vector<HloInstruction*> new_operands;
|
||||||
@ -3095,14 +3111,24 @@ AlgebraicSimplifierVisitor::TryToSinkBroadcastAfterOpWithUniqueNonScalarOperand(
|
|||||||
|
|
||||||
Shape changed_shape;
|
Shape changed_shape;
|
||||||
for (HloInstruction* user_operand : user->operands()) {
|
for (HloInstruction* user_operand : user->operands()) {
|
||||||
if (user_operand->opcode() == HloOpcode::kBroadcast &&
|
// If this is a broadcast operand that is not our original broadcast input
|
||||||
ShapeUtil::IsScalar(user_operand->operand(0)->shape())) {
|
// to this function then we might need to change the input.
|
||||||
changed_shape = ShapeUtil::ChangeElementType(
|
if (is_compatible_broadcast(user_operand)) {
|
||||||
operand->shape(), user_operand->shape().element_type());
|
// If this is a broadcast from a scalar value rewrite a broadcast from
|
||||||
simplifier_->UpdateLayout(&changed_shape);
|
// the scalar to the new shape enforced from the other broadcast
|
||||||
new_operands.push_back(
|
// operands.
|
||||||
computation_->AddInstruction(HloInstruction::CreateBroadcast(
|
if (is_scalar_broadcast(user_operand)) {
|
||||||
changed_shape, user_operand->mutable_operand(0), {})));
|
changed_shape = ShapeUtil::ChangeElementType(
|
||||||
|
operand->shape(), user_operand->shape().element_type());
|
||||||
|
simplifier_->UpdateLayout(&changed_shape);
|
||||||
|
new_operands.push_back(
|
||||||
|
computation_->AddInstruction(HloInstruction::CreateBroadcast(
|
||||||
|
changed_shape, user_operand->mutable_operand(0), {})));
|
||||||
|
} else {
|
||||||
|
// For the non-scalar broadcasts we guarantee that the shape of the
|
||||||
|
// operand of the broadcast needs to be already a compatible shape.
|
||||||
|
new_operands.push_back(user_operand->mutable_operand(0));
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
CHECK_EQ(broadcast, user_operand);
|
CHECK_EQ(broadcast, user_operand);
|
||||||
new_operands.push_back(operand);
|
new_operands.push_back(operand);
|
||||||
|
@ -338,6 +338,79 @@ TEST_F(AlgebraicSimplifierTest, MultiplyReassociateMergeBroadcastedConstants) {
|
|||||||
m::ConstantScalar(3.0))))));
|
m::ConstantScalar(3.0))))));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(AlgebraicSimplifierTest, ElementwiseSinkMultipleBroadcastsScalar) {
|
||||||
|
const char* kModuleStr = R"(
|
||||||
|
HloModule m
|
||||||
|
test {
|
||||||
|
p0 = f32[] parameter(0)
|
||||||
|
p1 = f32[] parameter(1)
|
||||||
|
b0 = f32[4] broadcast(p0), dimensions={}
|
||||||
|
b1 = f32[4] broadcast(p1), dimensions={}
|
||||||
|
ROOT multiply = f32[4] multiply(b1, b0)
|
||||||
|
}
|
||||||
|
)";
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
|
||||||
|
ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
|
||||||
|
EXPECT_THAT(
|
||||||
|
m->entry_computation()->root_instruction(),
|
||||||
|
GmockMatch(m::Broadcast(m::Multiply(m::Broadcast(m::Parameter(1)),
|
||||||
|
m::Broadcast(m::Parameter(0))))));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(AlgebraicSimplifierTest, ElementwiseSinkMultipleBroadcastsConstantMix) {
|
||||||
|
const char* kModuleStr = R"(
|
||||||
|
HloModule m
|
||||||
|
test {
|
||||||
|
p0 = f32[4] parameter(0)
|
||||||
|
c0 = f32[] constant(2.0)
|
||||||
|
b0 = f32[4,2] broadcast(c0), dimensions={}
|
||||||
|
b1 = f32[4,2] broadcast(p0), dimensions={0}
|
||||||
|
ROOT multiply = f32[4,2] multiply(b1, b0)
|
||||||
|
}
|
||||||
|
)";
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
|
||||||
|
ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
|
||||||
|
EXPECT_THAT(m->entry_computation()->root_instruction(),
|
||||||
|
GmockMatch(m::Broadcast(m::Multiply(
|
||||||
|
m::Parameter(0), m::Broadcast(m::ConstantScalar(2.0))))));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(AlgebraicSimplifierTest, ElementwiseSinkMultipleBroadcastsNonScalar) {
|
||||||
|
const char* kModuleStr = R"(
|
||||||
|
HloModule m
|
||||||
|
test {
|
||||||
|
p0 = f32[4] parameter(0)
|
||||||
|
p1 = f32[4] parameter(1)
|
||||||
|
b0 = f32[4,2] broadcast(p0), dimensions={0}
|
||||||
|
b1 = f32[4,2] broadcast(p1), dimensions={0}
|
||||||
|
ROOT multiply = f32[4,2] multiply(b1, b0)
|
||||||
|
}
|
||||||
|
)";
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
|
||||||
|
ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
|
||||||
|
EXPECT_THAT(
|
||||||
|
m->entry_computation()->root_instruction(),
|
||||||
|
GmockMatch(m::Broadcast(m::Multiply(m::Parameter(1), m::Parameter(0)))));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(AlgebraicSimplifierTest, ElementwiseNoSinkBroadcastsDifferentDims) {
|
||||||
|
const char* kModuleStr = R"(
|
||||||
|
HloModule m
|
||||||
|
test {
|
||||||
|
p0 = f32[4] parameter(0)
|
||||||
|
p1 = f32[8] parameter(1)
|
||||||
|
b0 = f32[4,8] broadcast(p0), dimensions={0}
|
||||||
|
b1 = f32[4,8] broadcast(p1), dimensions={1}
|
||||||
|
ROOT multiply = f32[4,8] multiply(b1, b0)
|
||||||
|
}
|
||||||
|
)";
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
|
||||||
|
ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
|
||||||
|
EXPECT_THAT(m->entry_computation()->root_instruction(),
|
||||||
|
GmockMatch(m::Multiply(m::Broadcast(m::Parameter(1)),
|
||||||
|
m::Broadcast(m::Parameter(0)))));
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(AlgebraicSimplifierTest,
|
TEST_F(AlgebraicSimplifierTest,
|
||||||
MultiplyReassociateMultiplyOfConstantAndBroadcast) {
|
MultiplyReassociateMultiplyOfConstantAndBroadcast) {
|
||||||
const char* kModuleStr = R"(
|
const char* kModuleStr = R"(
|
||||||
|
Loading…
Reference in New Issue
Block a user