[XLA] Correctly expand scatters into scalars
This is a valid edge case, don't crash on it. PiperOrigin-RevId: 236802786
This commit is contained in:
parent
17a758f7bb
commit
383228b0d2
@ -134,6 +134,13 @@ static StatusOr<HloInstruction*> ExpandIndexVectorIntoOperandSpace(
|
||||
int64 operand_rank) {
|
||||
HloComputation* computation = index_vector->parent();
|
||||
const Shape& index_shape = index_vector->shape();
|
||||
|
||||
// Scatter of a scalar. Return a zero-sized vector of indices.
|
||||
if (operand_rank == 0) {
|
||||
return computation->AddInstruction(HloInstruction::CreateConstant(
|
||||
LiteralUtil::CreateFromDimensions(index_shape.element_type(), {0})));
|
||||
}
|
||||
|
||||
HloInstruction* zero =
|
||||
computation->AddInstruction(HloInstruction::CreateConstant(
|
||||
LiteralUtil::CreateFromDimensions(index_shape.element_type(), {1})));
|
||||
|
@ -718,5 +718,32 @@ ENTRY main {
|
||||
RunTest(hlo_text, &operand, &scatter_indices, &updates);
|
||||
}
|
||||
|
||||
XLA_TEST_F(ScatterTest, ScatterIntoScalar) {
|
||||
const char* hlo_text = R"(
|
||||
HloModule ScatterIntoScalar
|
||||
|
||||
update_s32 {
|
||||
lhs = s32[] parameter(0)
|
||||
ROOT rhs = s32[] parameter(1)
|
||||
}
|
||||
|
||||
ENTRY main {
|
||||
parameter.1 = s32[] parameter(0)
|
||||
parameter.2 = s32[0]{0} parameter(1)
|
||||
parameter.3 = s32[] parameter(2)
|
||||
ROOT scatter = s32[] scatter(parameter.1, parameter.2, parameter.3),
|
||||
update_window_dims={},
|
||||
inserted_window_dims={},
|
||||
scatter_dims_to_operand_dims={},
|
||||
index_vector_dim=0,
|
||||
to_apply=update_s32
|
||||
}
|
||||
)";
|
||||
Literal operand = LiteralUtil::CreateR0<int32>(1);
|
||||
Literal scatter_indices = LiteralUtil::CreateR1<int32>({});
|
||||
Literal updates = LiteralUtil::CreateR0<int32>(2);
|
||||
RunTest(hlo_text, &operand, &scatter_indices, &updates);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
|
Loading…
Reference in New Issue
Block a user