diff --git a/tensorflow/compiler/xla/service/gather_expander.cc b/tensorflow/compiler/xla/service/gather_expander.cc index a58ac39dffa..1838f65e6ea 100644 --- a/tensorflow/compiler/xla/service/gather_expander.cc +++ b/tensorflow/compiler/xla/service/gather_expander.cc @@ -112,6 +112,14 @@ static StatusOr ExpandIndexVectorIntoOperandSpace( int64 operand_rank) { HloComputation* computation = index_vector->parent(); const Shape& index_shape = index_vector->shape(); + + if (operand_rank == 0) { + // This is Gather from a scalar. So, the index vector in operand space must + // be a zero-sized vector. + return computation->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateFromDimensions(index_shape.element_type(), {0}))); + } + HloInstruction* zero = computation->AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateFromDimensions(index_shape.element_type(), {1}))); diff --git a/tensorflow/compiler/xla/tests/gather_operation_test.cc b/tensorflow/compiler/xla/tests/gather_operation_test.cc index d65b67a535d..16a1371ec8d 100644 --- a/tensorflow/compiler/xla/tests/gather_operation_test.cc +++ b/tensorflow/compiler/xla/tests/gather_operation_test.cc @@ -598,6 +598,26 @@ ENTRY main { RunTest(hlo_text, &operand, &start_indices); } +XLA_TEST_F(GatherOperationTest, GatherFromScalar) { + const string hlo_text = R"( +HloModule GatherFromScalar + +ENTRY main { + operand = f32[] parameter(0) + indices = s32[0]{0} parameter(1) + ROOT gather = f32[] gather(operand, indices), + offset_dims={}, + collapsed_slice_dims={}, + start_index_map={}, + index_vector_dim=0, + slice_sizes={} +} +)"; + Literal operand = LiteralUtil::CreateR0(1); + Literal start_indices = LiteralUtil::CreateR1({}); + RunTest(hlo_text, &operand, &start_indices); +} + class GatherClientLibraryTest : public ClientLibraryTestBase {}; // Disabled on interpreter since ExectuteAsyncOnStream is not supported.