From 383228b0d27f227eb6cf27deb8e3d02a4f09b45b Mon Sep 17 00:00:00 2001 From: Benjamin Kramer Date: Tue, 5 Mar 2019 01:00:24 -0800 Subject: [PATCH] [XLA] Correctly expand scatters into scalars This is a valid edge case, don't crash on it. PiperOrigin-RevId: 236802786 --- .../compiler/xla/service/scatter_expander.cc | 7 +++++ tensorflow/compiler/xla/tests/scatter_test.cc | 27 +++++++++++++++++++ 2 files changed, 34 insertions(+) diff --git a/tensorflow/compiler/xla/service/scatter_expander.cc b/tensorflow/compiler/xla/service/scatter_expander.cc index acad871c4d4..fdb5cd91fd0 100644 --- a/tensorflow/compiler/xla/service/scatter_expander.cc +++ b/tensorflow/compiler/xla/service/scatter_expander.cc @@ -134,6 +134,13 @@ static StatusOr ExpandIndexVectorIntoOperandSpace( int64 operand_rank) { HloComputation* computation = index_vector->parent(); const Shape& index_shape = index_vector->shape(); + + // Scatter of a scalar. Return a zero-sized vector of indices. + if (operand_rank == 0) { + return computation->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateFromDimensions(index_shape.element_type(), {0}))); + } + HloInstruction* zero = computation->AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateFromDimensions(index_shape.element_type(), {1}))); diff --git a/tensorflow/compiler/xla/tests/scatter_test.cc b/tensorflow/compiler/xla/tests/scatter_test.cc index 32de0fdf78f..86d9999b4a4 100644 --- a/tensorflow/compiler/xla/tests/scatter_test.cc +++ b/tensorflow/compiler/xla/tests/scatter_test.cc @@ -718,5 +718,32 @@ ENTRY main { RunTest(hlo_text, &operand, &scatter_indices, &updates); } +XLA_TEST_F(ScatterTest, ScatterIntoScalar) { + const char* hlo_text = R"( +HloModule ScatterIntoScalar + +update_s32 { + lhs = s32[] parameter(0) + ROOT rhs = s32[] parameter(1) +} + +ENTRY main { + parameter.1 = s32[] parameter(0) + parameter.2 = s32[0]{0} parameter(1) + parameter.3 = s32[] parameter(2) + ROOT scatter = s32[] scatter(parameter.1, parameter.2, parameter.3), + update_window_dims={}, + inserted_window_dims={}, + scatter_dims_to_operand_dims={}, + index_vector_dim=0, + to_apply=update_s32 +} +)"; + Literal operand = LiteralUtil::CreateR0(1); + Literal scatter_indices = LiteralUtil::CreateR1({}); + Literal updates = LiteralUtil::CreateR0(2); + RunTest(hlo_text, &operand, &scatter_indices, &updates); +} + } // namespace } // namespace xla