From 87d86dbbf4eadf1f7be414d1192cd5ef72db6026 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Thu, 6 Jul 2017 13:20:29 -0700 Subject: [PATCH] [XLA] Fix crash in algebraic simplifier when sinking a reshape-to-scalar after an elementwise operation. The test case in the change previously failed with: algebraic_simplifier.cc:1091] Check failed: user->operand(reshape_or_broadcast_operand_index) == reshape_or_broadcast PiperOrigin-RevId: 161121259 --- .../xla/service/algebraic_simplifier.cc | 3 ++ .../xla/service/algebraic_simplifier_test.cc | 28 +++++++++++++++++++ 2 files changed, 31 insertions(+) diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index bc6a879b8a4..b351861425d 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -1055,6 +1055,9 @@ StatusOr AlgebraicSimplifierVisitor:: TryToSinkReshapeOrBroadcastAfterOpWithUniqueNonScalarOperand( HloInstruction* reshape_or_broadcast) { bool changed = false; + if (ShapeUtil::IsScalar(reshape_or_broadcast->shape())) { + return false; + } HloInstruction* operand = reshape_or_broadcast->mutable_operand(0); for (HloInstruction* user : reshape_or_broadcast->users()) { if (user->user_count() == 0 && user != computation_->root_instruction()) { diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index aceef6cd5d0..ff119c00986 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -983,6 +983,34 @@ TEST_F(AlgebraicSimplifierTest, ReshapeAfterEffectiveUnary) { op::Reshape(op::Maximum(param, zero))); } +// Regression test for a bug in the reshape sinking transformation, where +// moving a reshape to a scalar led to a crash. +TEST_F(AlgebraicSimplifierTest, ReshapeToScalarNotHoistedAfterEffectiveUnary) { + HloComputation::Builder builder(TestName()); + HloInstruction* param = + builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {1, 1}), "param")); + HloInstruction* reshape = builder.AddInstruction( + HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {}), param)); + HloInstruction* zero = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR1({1., 2., 3.}))); + builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(F32, {}), HloOpcode::kMaximum, reshape, zero)); + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_THAT(computation->root_instruction(), + op::Maximum(op::Reshape(param), zero)); + + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + bitcasting_callback()); + + simplifier.Run(module.get()).ValueOrDie(); + + EXPECT_THAT(computation->root_instruction(), + op::Maximum(op::Reshape(param), zero)); +} + TEST_F(AlgebraicSimplifierTest, TransposeEqualsBitcast1) { HloComputation::Builder builder(TestName()); HloInstruction* param =