Revert of

[XLA] Add support for sinking broadcasts through ops with multiple broadcasts operands.

as it is causing some internal failures. Investigation in progress.

PiperOrigin-RevId: 315293975
Change-Id: If65d7aaf53f29cac52072bc14b06e3b5a8c5fc49
This commit is contained in:
A. Unique TensorFlower 2020-06-08 09:49:22 -07:00 committed by TensorFlower Gardener
parent bf1b3d7e70
commit fcfdbcf14a
2 changed files with 14 additions and 113 deletions

View File

@ -3058,20 +3058,6 @@ AlgebraicSimplifierVisitor::TryToSinkBroadcastAfterOpWithUniqueNonScalarOperand(
return false; return false;
} }
HloInstruction* operand = broadcast->mutable_operand(0); HloInstruction* operand = broadcast->mutable_operand(0);
auto is_scalar_broadcast = [](const HloInstruction* instruction) {
return instruction->opcode() == HloOpcode::kBroadcast &&
ShapeUtil::IsScalar(instruction->operand(0)->shape());
};
auto is_equal_broadcast = [operand,
broadcast](const HloInstruction* instruction) {
return instruction->opcode() == HloOpcode::kBroadcast &&
ShapeUtil::Equal(operand->shape(),
instruction->operand(0)->shape()) &&
broadcast->dimensions() == instruction->dimensions();
};
auto is_compatible_broadcast = [&](const HloInstruction* instruction) {
return is_scalar_broadcast(instruction) || is_equal_broadcast(instruction);
};
for (HloInstruction* user : broadcast->users()) { for (HloInstruction* user : broadcast->users()) {
if (user->user_count() == 0 && user != computation_->root_instruction()) { if (user->user_count() == 0 && user != computation_->root_instruction()) {
continue; continue;
@ -3090,20 +3076,18 @@ AlgebraicSimplifierVisitor::TryToSinkBroadcastAfterOpWithUniqueNonScalarOperand(
continue; continue;
} }
// Check if all the operands of the user are compatible broadcasts for // Find the unique non-scalar operand or continue if there isn't one.
// sinking. (They are either scalar broadcasts or broadcasts casting int64 scalar_broadcast_count = 0;
// from/to the same shape/dimensions)
int64 compatible_broadcast_count = 0;
int64 broadcast_use_count = 0; int64 broadcast_use_count = 0;
for (HloInstruction* user_operand : user->operands()) { for (HloInstruction* user_operand : user->operands()) {
if (is_compatible_broadcast(user_operand)) { if (user_operand->opcode() == HloOpcode::kBroadcast &&
++compatible_broadcast_count; ShapeUtil::IsScalar(user_operand->operand(0)->shape())) {
++scalar_broadcast_count;
} else if (broadcast == user_operand) { } else if (broadcast == user_operand) {
++broadcast_use_count; ++broadcast_use_count;
} }
} }
if (compatible_broadcast_count + broadcast_use_count != if (scalar_broadcast_count + broadcast_use_count != user->operand_count()) {
user->operand_count()) {
continue; continue;
} }
std::vector<HloInstruction*> new_operands; std::vector<HloInstruction*> new_operands;
@ -3111,24 +3095,14 @@ AlgebraicSimplifierVisitor::TryToSinkBroadcastAfterOpWithUniqueNonScalarOperand(
Shape changed_shape; Shape changed_shape;
for (HloInstruction* user_operand : user->operands()) { for (HloInstruction* user_operand : user->operands()) {
// If this is a broadcast operand that is not our original broadcast input if (user_operand->opcode() == HloOpcode::kBroadcast &&
// to this function then we might need to change the input. ShapeUtil::IsScalar(user_operand->operand(0)->shape())) {
if (is_compatible_broadcast(user_operand)) { changed_shape = ShapeUtil::ChangeElementType(
// If this is a broadcast from a scalar value rewrite a broadcast from operand->shape(), user_operand->shape().element_type());
// the scalar to the new shape enforced from the other broadcast simplifier_->UpdateLayout(&changed_shape);
// operands. new_operands.push_back(
if (is_scalar_broadcast(user_operand)) { computation_->AddInstruction(HloInstruction::CreateBroadcast(
changed_shape = ShapeUtil::ChangeElementType( changed_shape, user_operand->mutable_operand(0), {})));
operand->shape(), user_operand->shape().element_type());
simplifier_->UpdateLayout(&changed_shape);
new_operands.push_back(
computation_->AddInstruction(HloInstruction::CreateBroadcast(
changed_shape, user_operand->mutable_operand(0), {})));
} else {
// For the non-scalar broadcasts we guarantee that the shape of the
// operand of the broadcast needs to be already a compatible shape.
new_operands.push_back(user_operand->mutable_operand(0));
}
} else { } else {
CHECK_EQ(broadcast, user_operand); CHECK_EQ(broadcast, user_operand);
new_operands.push_back(operand); new_operands.push_back(operand);

View File

@ -338,79 +338,6 @@ TEST_F(AlgebraicSimplifierTest, MultiplyReassociateMergeBroadcastedConstants) {
m::ConstantScalar(3.0)))))); m::ConstantScalar(3.0))))));
} }
TEST_F(AlgebraicSimplifierTest, ElementwiseSinkMultipleBroadcastsScalar) {
const char* kModuleStr = R"(
HloModule m
test {
p0 = f32[] parameter(0)
p1 = f32[] parameter(1)
b0 = f32[4] broadcast(p0), dimensions={}
b1 = f32[4] broadcast(p1), dimensions={}
ROOT multiply = f32[4] multiply(b1, b0)
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
EXPECT_THAT(
m->entry_computation()->root_instruction(),
GmockMatch(m::Broadcast(m::Multiply(m::Broadcast(m::Parameter(1)),
m::Broadcast(m::Parameter(0))))));
}
TEST_F(AlgebraicSimplifierTest, ElementwiseSinkMultipleBroadcastsConstantMix) {
const char* kModuleStr = R"(
HloModule m
test {
p0 = f32[4] parameter(0)
c0 = f32[] constant(2.0)
b0 = f32[4,2] broadcast(c0), dimensions={}
b1 = f32[4,2] broadcast(p0), dimensions={0}
ROOT multiply = f32[4,2] multiply(b1, b0)
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
EXPECT_THAT(m->entry_computation()->root_instruction(),
GmockMatch(m::Broadcast(m::Multiply(
m::Parameter(0), m::Broadcast(m::ConstantScalar(2.0))))));
}
TEST_F(AlgebraicSimplifierTest, ElementwiseSinkMultipleBroadcastsNonScalar) {
const char* kModuleStr = R"(
HloModule m
test {
p0 = f32[4] parameter(0)
p1 = f32[4] parameter(1)
b0 = f32[4,2] broadcast(p0), dimensions={0}
b1 = f32[4,2] broadcast(p1), dimensions={0}
ROOT multiply = f32[4,2] multiply(b1, b0)
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
EXPECT_THAT(
m->entry_computation()->root_instruction(),
GmockMatch(m::Broadcast(m::Multiply(m::Parameter(1), m::Parameter(0)))));
}
TEST_F(AlgebraicSimplifierTest, ElementwiseNoSinkBroadcastsDifferentDims) {
const char* kModuleStr = R"(
HloModule m
test {
p0 = f32[4] parameter(0)
p1 = f32[8] parameter(1)
b0 = f32[4,8] broadcast(p0), dimensions={0}
b1 = f32[4,8] broadcast(p1), dimensions={1}
ROOT multiply = f32[4,8] multiply(b1, b0)
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
EXPECT_THAT(m->entry_computation()->root_instruction(),
GmockMatch(m::Multiply(m::Broadcast(m::Parameter(1)),
m::Broadcast(m::Parameter(0)))));
}
TEST_F(AlgebraicSimplifierTest, TEST_F(AlgebraicSimplifierTest,
MultiplyReassociateMultiplyOfConstantAndBroadcast) { MultiplyReassociateMultiplyOfConstantAndBroadcast) {
const char* kModuleStr = R"( const char* kModuleStr = R"(