[XLA] Correctly expand scatters into scalars

This is a valid edge case, don't crash on it.

PiperOrigin-RevId: 236802786
This commit is contained in:
Benjamin Kramer 2019-03-05 01:00:24 -08:00 committed by TensorFlower Gardener
parent 17a758f7bb
commit 383228b0d2
2 changed files with 34 additions and 0 deletions

View File

@ -134,6 +134,13 @@ static StatusOr<HloInstruction*> ExpandIndexVectorIntoOperandSpace(
int64 operand_rank) {
HloComputation* computation = index_vector->parent();
const Shape& index_shape = index_vector->shape();
// Scatter of a scalar. Return a zero-sized vector of indices.
if (operand_rank == 0) {
return computation->AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::CreateFromDimensions(index_shape.element_type(), {0})));
}
HloInstruction* zero =
computation->AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::CreateFromDimensions(index_shape.element_type(), {1})));

View File

@ -718,5 +718,32 @@ ENTRY main {
RunTest(hlo_text, &operand, &scatter_indices, &updates);
}
XLA_TEST_F(ScatterTest, ScatterIntoScalar) {
const char* hlo_text = R"(
HloModule ScatterIntoScalar
update_s32 {
lhs = s32[] parameter(0)
ROOT rhs = s32[] parameter(1)
}
ENTRY main {
parameter.1 = s32[] parameter(0)
parameter.2 = s32[0]{0} parameter(1)
parameter.3 = s32[] parameter(2)
ROOT scatter = s32[] scatter(parameter.1, parameter.2, parameter.3),
update_window_dims={},
inserted_window_dims={},
scatter_dims_to_operand_dims={},
index_vector_dim=0,
to_apply=update_s32
}
)";
Literal operand = LiteralUtil::CreateR0<int32>(1);
Literal scatter_indices = LiteralUtil::CreateR1<int32>({});
Literal updates = LiteralUtil::CreateR0<int32>(2);
RunTest(hlo_text, &operand, &scatter_indices, &updates);
}
} // namespace
} // namespace xla