[XLA] Handle scalar operand in GatherExpander.
PiperOrigin-RevId: 237134843
This commit is contained in:
parent
3850ab8a0d
commit
31395d336e
@ -112,6 +112,14 @@ static StatusOr<HloInstruction*> ExpandIndexVectorIntoOperandSpace(
|
|||||||
int64 operand_rank) {
|
int64 operand_rank) {
|
||||||
HloComputation* computation = index_vector->parent();
|
HloComputation* computation = index_vector->parent();
|
||||||
const Shape& index_shape = index_vector->shape();
|
const Shape& index_shape = index_vector->shape();
|
||||||
|
|
||||||
|
if (operand_rank == 0) {
|
||||||
|
// This is Gather from a scalar. So, the index vector in operand space must
|
||||||
|
// be a zero-sized vector.
|
||||||
|
return computation->AddInstruction(HloInstruction::CreateConstant(
|
||||||
|
LiteralUtil::CreateFromDimensions(index_shape.element_type(), {0})));
|
||||||
|
}
|
||||||
|
|
||||||
HloInstruction* zero =
|
HloInstruction* zero =
|
||||||
computation->AddInstruction(HloInstruction::CreateConstant(
|
computation->AddInstruction(HloInstruction::CreateConstant(
|
||||||
LiteralUtil::CreateFromDimensions(index_shape.element_type(), {1})));
|
LiteralUtil::CreateFromDimensions(index_shape.element_type(), {1})));
|
||||||
|
@ -598,6 +598,26 @@ ENTRY main {
|
|||||||
RunTest(hlo_text, &operand, &start_indices);
|
RunTest(hlo_text, &operand, &start_indices);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
XLA_TEST_F(GatherOperationTest, GatherFromScalar) {
|
||||||
|
const string hlo_text = R"(
|
||||||
|
HloModule GatherFromScalar
|
||||||
|
|
||||||
|
ENTRY main {
|
||||||
|
operand = f32[] parameter(0)
|
||||||
|
indices = s32[0]{0} parameter(1)
|
||||||
|
ROOT gather = f32[] gather(operand, indices),
|
||||||
|
offset_dims={},
|
||||||
|
collapsed_slice_dims={},
|
||||||
|
start_index_map={},
|
||||||
|
index_vector_dim=0,
|
||||||
|
slice_sizes={}
|
||||||
|
}
|
||||||
|
)";
|
||||||
|
Literal operand = LiteralUtil::CreateR0<float>(1);
|
||||||
|
Literal start_indices = LiteralUtil::CreateR1<int32>({});
|
||||||
|
RunTest(hlo_text, &operand, &start_indices);
|
||||||
|
}
|
||||||
|
|
||||||
class GatherClientLibraryTest : public ClientLibraryTestBase {};
|
class GatherClientLibraryTest : public ClientLibraryTestBase {};
|
||||||
|
|
||||||
// Disabled on interpreter since ExectuteAsyncOnStream is not supported.
|
// Disabled on interpreter since ExectuteAsyncOnStream is not supported.
|
||||||
|
Loading…
Reference in New Issue
Block a user