[XLA] Fix crash in algebraic simplifier when sinking a reshape-to-scalar after an elementwise operation.

The test case in the change previously failed with:
algebraic_simplifier.cc:1091] Check failed: user->operand(reshape_or_broadcast_operand_index) == reshape_or_broadcast

PiperOrigin-RevId: 161121259
This commit is contained in:
Peter Hawkins 2017-07-06 13:20:29 -07:00 committed by TensorFlower Gardener
parent b87774c9d8
commit 87d86dbbf4
2 changed files with 31 additions and 0 deletions

View File

@ -1055,6 +1055,9 @@ StatusOr<bool> AlgebraicSimplifierVisitor::
TryToSinkReshapeOrBroadcastAfterOpWithUniqueNonScalarOperand(
HloInstruction* reshape_or_broadcast) {
bool changed = false;
if (ShapeUtil::IsScalar(reshape_or_broadcast->shape())) {
return false;
}
HloInstruction* operand = reshape_or_broadcast->mutable_operand(0);
for (HloInstruction* user : reshape_or_broadcast->users()) {
if (user->user_count() == 0 && user != computation_->root_instruction()) {

View File

@ -983,6 +983,34 @@ TEST_F(AlgebraicSimplifierTest, ReshapeAfterEffectiveUnary) {
op::Reshape(op::Maximum(param, zero)));
}
// Regression test for a bug in the reshape sinking transformation, where
// moving a reshape to a scalar led to a crash.
TEST_F(AlgebraicSimplifierTest, ReshapeToScalarNotHoistedAfterEffectiveUnary) {
HloComputation::Builder builder(TestName());
HloInstruction* param =
builder.AddInstruction(HloInstruction::CreateParameter(
0, ShapeUtil::MakeShape(F32, {1, 1}), "param"));
HloInstruction* reshape = builder.AddInstruction(
HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {}), param));
HloInstruction* zero = builder.AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR1<float>({1., 2., 3.})));
builder.AddInstruction(HloInstruction::CreateBinary(
ShapeUtil::MakeShape(F32, {}), HloOpcode::kMaximum, reshape, zero));
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_THAT(computation->root_instruction(),
op::Maximum(op::Reshape(param), zero));
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
bitcasting_callback());
simplifier.Run(module.get()).ValueOrDie();
EXPECT_THAT(computation->root_instruction(),
op::Maximum(op::Reshape(param), zero));
}
TEST_F(AlgebraicSimplifierTest, TransposeEqualsBitcast1) {
HloComputation::Builder builder(TestName());
HloInstruction* param =