[XLA] Fix crash in algebraic simplifier when sinking a reshape-to-scalar after an elementwise operation.
The test case in the change previously failed with: algebraic_simplifier.cc:1091] Check failed: user->operand(reshape_or_broadcast_operand_index) == reshape_or_broadcast PiperOrigin-RevId: 161121259
This commit is contained in:
parent
b87774c9d8
commit
87d86dbbf4
@ -1055,6 +1055,9 @@ StatusOr<bool> AlgebraicSimplifierVisitor::
|
||||
TryToSinkReshapeOrBroadcastAfterOpWithUniqueNonScalarOperand(
|
||||
HloInstruction* reshape_or_broadcast) {
|
||||
bool changed = false;
|
||||
if (ShapeUtil::IsScalar(reshape_or_broadcast->shape())) {
|
||||
return false;
|
||||
}
|
||||
HloInstruction* operand = reshape_or_broadcast->mutable_operand(0);
|
||||
for (HloInstruction* user : reshape_or_broadcast->users()) {
|
||||
if (user->user_count() == 0 && user != computation_->root_instruction()) {
|
||||
|
@ -983,6 +983,34 @@ TEST_F(AlgebraicSimplifierTest, ReshapeAfterEffectiveUnary) {
|
||||
op::Reshape(op::Maximum(param, zero)));
|
||||
}
|
||||
|
||||
// Regression test for a bug in the reshape sinking transformation, where
|
||||
// moving a reshape to a scalar led to a crash.
|
||||
TEST_F(AlgebraicSimplifierTest, ReshapeToScalarNotHoistedAfterEffectiveUnary) {
|
||||
HloComputation::Builder builder(TestName());
|
||||
HloInstruction* param =
|
||||
builder.AddInstruction(HloInstruction::CreateParameter(
|
||||
0, ShapeUtil::MakeShape(F32, {1, 1}), "param"));
|
||||
HloInstruction* reshape = builder.AddInstruction(
|
||||
HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {}), param));
|
||||
HloInstruction* zero = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR1<float>({1., 2., 3.})));
|
||||
builder.AddInstruction(HloInstruction::CreateBinary(
|
||||
ShapeUtil::MakeShape(F32, {}), HloOpcode::kMaximum, reshape, zero));
|
||||
auto module = CreateNewModule();
|
||||
auto computation = module->AddEntryComputation(builder.Build());
|
||||
|
||||
EXPECT_THAT(computation->root_instruction(),
|
||||
op::Maximum(op::Reshape(param), zero));
|
||||
|
||||
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
|
||||
bitcasting_callback());
|
||||
|
||||
simplifier.Run(module.get()).ValueOrDie();
|
||||
|
||||
EXPECT_THAT(computation->root_instruction(),
|
||||
op::Maximum(op::Reshape(param), zero));
|
||||
}
|
||||
|
||||
TEST_F(AlgebraicSimplifierTest, TransposeEqualsBitcast1) {
|
||||
HloComputation::Builder builder(TestName());
|
||||
HloInstruction* param =
|
||||
|
Loading…
Reference in New Issue
Block a user