[XLA] Handle scalar operand in GatherExpander.

PiperOrigin-RevId: 237134843
This commit is contained in:
A. Unique TensorFlower 2019-03-06 15:41:19 -08:00 committed by TensorFlower Gardener
parent 3850ab8a0d
commit 31395d336e
2 changed files with 28 additions and 0 deletions

View File

@ -112,6 +112,14 @@ static StatusOr<HloInstruction*> ExpandIndexVectorIntoOperandSpace(
int64 operand_rank) {
HloComputation* computation = index_vector->parent();
const Shape& index_shape = index_vector->shape();
if (operand_rank == 0) {
// This is Gather from a scalar. So, the index vector in operand space must
// be a zero-sized vector.
return computation->AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::CreateFromDimensions(index_shape.element_type(), {0})));
}
HloInstruction* zero =
computation->AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::CreateFromDimensions(index_shape.element_type(), {1})));

View File

@ -598,6 +598,26 @@ ENTRY main {
RunTest(hlo_text, &operand, &start_indices);
}
XLA_TEST_F(GatherOperationTest, GatherFromScalar) {
const string hlo_text = R"(
HloModule GatherFromScalar
ENTRY main {
operand = f32[] parameter(0)
indices = s32[0]{0} parameter(1)
ROOT gather = f32[] gather(operand, indices),
offset_dims={},
collapsed_slice_dims={},
start_index_map={},
index_vector_dim=0,
slice_sizes={}
}
)";
Literal operand = LiteralUtil::CreateR0<float>(1);
Literal start_indices = LiteralUtil::CreateR1<int32>({});
RunTest(hlo_text, &operand, &start_indices);
}
class GatherClientLibraryTest : public ClientLibraryTestBase {};
// Disabled on interpreter since ExectuteAsyncOnStream is not supported.