[XLA] Handle scalar operand in GatherExpander.
PiperOrigin-RevId: 237134843
This commit is contained in:
parent
3850ab8a0d
commit
31395d336e
@ -112,6 +112,14 @@ static StatusOr<HloInstruction*> ExpandIndexVectorIntoOperandSpace(
|
||||
int64 operand_rank) {
|
||||
HloComputation* computation = index_vector->parent();
|
||||
const Shape& index_shape = index_vector->shape();
|
||||
|
||||
if (operand_rank == 0) {
|
||||
// This is Gather from a scalar. So, the index vector in operand space must
|
||||
// be a zero-sized vector.
|
||||
return computation->AddInstruction(HloInstruction::CreateConstant(
|
||||
LiteralUtil::CreateFromDimensions(index_shape.element_type(), {0})));
|
||||
}
|
||||
|
||||
HloInstruction* zero =
|
||||
computation->AddInstruction(HloInstruction::CreateConstant(
|
||||
LiteralUtil::CreateFromDimensions(index_shape.element_type(), {1})));
|
||||
|
@ -598,6 +598,26 @@ ENTRY main {
|
||||
RunTest(hlo_text, &operand, &start_indices);
|
||||
}
|
||||
|
||||
XLA_TEST_F(GatherOperationTest, GatherFromScalar) {
|
||||
const string hlo_text = R"(
|
||||
HloModule GatherFromScalar
|
||||
|
||||
ENTRY main {
|
||||
operand = f32[] parameter(0)
|
||||
indices = s32[0]{0} parameter(1)
|
||||
ROOT gather = f32[] gather(operand, indices),
|
||||
offset_dims={},
|
||||
collapsed_slice_dims={},
|
||||
start_index_map={},
|
||||
index_vector_dim=0,
|
||||
slice_sizes={}
|
||||
}
|
||||
)";
|
||||
Literal operand = LiteralUtil::CreateR0<float>(1);
|
||||
Literal start_indices = LiteralUtil::CreateR1<int32>({});
|
||||
RunTest(hlo_text, &operand, &start_indices);
|
||||
}
|
||||
|
||||
class GatherClientLibraryTest : public ClientLibraryTestBase {};
|
||||
|
||||
// Disabled on interpreter since ExectuteAsyncOnStream is not supported.
|
||||
|
Loading…
Reference in New Issue
Block a user