[XLA] Fix crash in algebraic simplifier when sinking a reshape-to-scalar after an elementwise operation.
The test case in the change previously failed with: algebraic_simplifier.cc:1091] Check failed: user->operand(reshape_or_broadcast_operand_index) == reshape_or_broadcast PiperOrigin-RevId: 161121259
This commit is contained in:
parent
b87774c9d8
commit
87d86dbbf4
@ -1055,6 +1055,9 @@ StatusOr<bool> AlgebraicSimplifierVisitor::
|
|||||||
TryToSinkReshapeOrBroadcastAfterOpWithUniqueNonScalarOperand(
|
TryToSinkReshapeOrBroadcastAfterOpWithUniqueNonScalarOperand(
|
||||||
HloInstruction* reshape_or_broadcast) {
|
HloInstruction* reshape_or_broadcast) {
|
||||||
bool changed = false;
|
bool changed = false;
|
||||||
|
if (ShapeUtil::IsScalar(reshape_or_broadcast->shape())) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
HloInstruction* operand = reshape_or_broadcast->mutable_operand(0);
|
HloInstruction* operand = reshape_or_broadcast->mutable_operand(0);
|
||||||
for (HloInstruction* user : reshape_or_broadcast->users()) {
|
for (HloInstruction* user : reshape_or_broadcast->users()) {
|
||||||
if (user->user_count() == 0 && user != computation_->root_instruction()) {
|
if (user->user_count() == 0 && user != computation_->root_instruction()) {
|
||||||
|
@ -983,6 +983,34 @@ TEST_F(AlgebraicSimplifierTest, ReshapeAfterEffectiveUnary) {
|
|||||||
op::Reshape(op::Maximum(param, zero)));
|
op::Reshape(op::Maximum(param, zero)));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Regression test for a bug in the reshape sinking transformation, where
|
||||||
|
// moving a reshape to a scalar led to a crash.
|
||||||
|
TEST_F(AlgebraicSimplifierTest, ReshapeToScalarNotHoistedAfterEffectiveUnary) {
|
||||||
|
HloComputation::Builder builder(TestName());
|
||||||
|
HloInstruction* param =
|
||||||
|
builder.AddInstruction(HloInstruction::CreateParameter(
|
||||||
|
0, ShapeUtil::MakeShape(F32, {1, 1}), "param"));
|
||||||
|
HloInstruction* reshape = builder.AddInstruction(
|
||||||
|
HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {}), param));
|
||||||
|
HloInstruction* zero = builder.AddInstruction(
|
||||||
|
HloInstruction::CreateConstant(Literal::CreateR1<float>({1., 2., 3.})));
|
||||||
|
builder.AddInstruction(HloInstruction::CreateBinary(
|
||||||
|
ShapeUtil::MakeShape(F32, {}), HloOpcode::kMaximum, reshape, zero));
|
||||||
|
auto module = CreateNewModule();
|
||||||
|
auto computation = module->AddEntryComputation(builder.Build());
|
||||||
|
|
||||||
|
EXPECT_THAT(computation->root_instruction(),
|
||||||
|
op::Maximum(op::Reshape(param), zero));
|
||||||
|
|
||||||
|
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
|
||||||
|
bitcasting_callback());
|
||||||
|
|
||||||
|
simplifier.Run(module.get()).ValueOrDie();
|
||||||
|
|
||||||
|
EXPECT_THAT(computation->root_instruction(),
|
||||||
|
op::Maximum(op::Reshape(param), zero));
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(AlgebraicSimplifierTest, TransposeEqualsBitcast1) {
|
TEST_F(AlgebraicSimplifierTest, TransposeEqualsBitcast1) {
|
||||||
HloComputation::Builder builder(TestName());
|
HloComputation::Builder builder(TestName());
|
||||||
HloInstruction* param =
|
HloInstruction* param =
|
||||||
|
Loading…
Reference in New Issue
Block a user