diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc index 250398ca3eb..7fea245f69a 100644 --- a/tensorflow/compiler/xla/client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_builder.cc @@ -1746,11 +1746,13 @@ XlaOp XlaBuilder::Scatter(const XlaOp& input, const XlaOp& scatter_indices, const XlaOp& updates, const XlaComputation& update_computation, const ScatterDimensionNumbers& dimension_numbers, - bool indices_are_sorted) { + bool indices_are_sorted, bool unique_indices) { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; instr.set_indices_are_sorted(indices_are_sorted); + instr.set_unique_indices(unique_indices); + TF_ASSIGN_OR_RETURN(const Shape& input_shape, GetShape(input)); TF_ASSIGN_OR_RETURN(const Shape& scatter_indices_shape, GetShape(scatter_indices)); @@ -3390,10 +3392,10 @@ XlaOp Gather(const XlaOp input, const XlaOp start_indices, XlaOp Scatter(const XlaOp input, const XlaOp scatter_indices, const XlaOp updates, const XlaComputation& update_computation, const ScatterDimensionNumbers& dimension_numbers, - bool indices_are_sorted) { + bool indices_are_sorted, bool unique_indices) { return input.builder()->Scatter(input, scatter_indices, updates, update_computation, dimension_numbers, - indices_are_sorted); + indices_are_sorted, unique_indices); } void Send(const XlaOp operand, const ChannelHandle& handle) { diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h index 5c28e8b5150..187cd261833 100644 --- a/tensorflow/compiler/xla/client/xla_builder.h +++ b/tensorflow/compiler/xla/client/xla_builder.h @@ -592,7 +592,7 @@ class XlaBuilder { XlaOp Scatter(const XlaOp& input, const XlaOp& scatter_indices, const XlaOp& updates, const XlaComputation& update_computation, const ScatterDimensionNumbers& dimension_numbers, - bool indices_are_sorted = false); + bool indices_are_sorted = false, bool unique_indices = false); void Send(const XlaOp& operand, const ChannelHandle& handle); XlaOp SendWithToken(const XlaOp& operand, const XlaOp& token, @@ -1010,7 +1010,7 @@ class XlaBuilder { friend XlaOp Scatter(XlaOp input, XlaOp scatter_indices, XlaOp updates, const XlaComputation& update_computation, const ScatterDimensionNumbers& dimension_numbers, - bool indices_are_sorted); + bool indices_are_sorted, bool unique_indices); friend void Send(XlaOp operand, const ChannelHandle& handle); friend XlaOp Recv(XlaBuilder* builder, const Shape& shape, const ChannelHandle& handle); @@ -1869,7 +1869,7 @@ XlaOp Gather(XlaOp input, XlaOp start_indices, XlaOp Scatter(XlaOp input, XlaOp scatter_indices, XlaOp updates, const XlaComputation& update_computation, const ScatterDimensionNumbers& dimension_numbers, - bool indices_are_sorted = false); + bool indices_are_sorted = false, bool unique_indices = false); // Enqueues a Send node onto the computation for device-to-device // communication. This operation sends the given operand to diff --git a/tensorflow/compiler/xla/g3doc/operation_semantics.md b/tensorflow/compiler/xla/g3doc/operation_semantics.md index 1f2790e98bb..bc8d4d5a1cb 100644 --- a/tensorflow/compiler/xla/g3doc/operation_semantics.md +++ b/tensorflow/compiler/xla/g3doc/operation_semantics.md @@ -1382,6 +1382,9 @@ For a more intuitive description, see the "Informal Description" section below. | `indices_are_sorted` | `bool` | Whether the indices are | : : : guaranteed to be sorted by : : : : the caller. : +| `unique_indices` | `bool` | Whether the indices are | +: : : guaranteed to be unique by : +: : : the caller. : For convenience, we label dimensions in the output array not in `offset_dims` as `batch_dims`. @@ -1450,6 +1453,11 @@ If `indices_are_sorted` is set to true then XLA can assume that `start_indices` are sorted (in ascending `start_index_map` order) by the user. If they are not then the semantics is implementation defined. +If `unique_indices` is set to true then XLA can assume that all element +scattered to are unique. So XLA could use non-atomic operations. If +`unique_indices` is set to true and the indices being scattered to are not +unique then the semantics is implementation defined. + ### Informal Description and Examples Informally, every index `Out` in the output array corresponds to an element `E` diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py index 63a9ea37692..7abd2f7429d 100644 --- a/tensorflow/compiler/xla/python/xla_client.py +++ b/tensorflow/compiler/xla/python/xla_client.py @@ -1544,11 +1544,12 @@ class ComputationBuilder(object): updates, update_computation, dimension_numbers, - indices_are_sorted=False): + indices_are_sorted=False, + unique_indices=False): """Enqueues a Scatter operation onto the computation.""" return ops.Scatter(a, scatter_indices, updates, update_computation.computation, dimension_numbers, - indices_are_sorted) + indices_are_sorted, unique_indices) def Fft(self, operand, fft_type, fft_lengths): """Enqueues a FFT operation onto the computation.""" diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index 3295a6caff8..1c617a07372 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -965,8 +965,15 @@ Status IrEmitterUnnested::EmitScatter( updates->shape().element_type(), module_)); TF_ASSIGN_OR_RETURN(llvm::Value* const input_ir_value, updates_gen(index)); Store(input_ir_value, input_address); - return EmitAtomicOperationForNestedComputation( - *scatter->to_apply(), output_address, input_address); + + if (!scatter->unique_indices()) { + return EmitAtomicOperationForNestedComputation( + *scatter->to_apply(), output_address, input_address); + } else { + return EmitCallToNestedComputation(*scatter->to_apply(), + {output_address, input_address}, + output_address); + } }; // Launch a kernel that reads every element in the updates tensor. We could diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h index 3db707b7d24..dbd6ce78bb3 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h @@ -188,7 +188,12 @@ class IrEmitterUnnested : public IrEmitter, // Emits code for an in-place scatter, modifying `thunk`s launch dimensions in // the process. `scatter` may be fused, scatter indices are taken from // `scatter_indices_gen`, updates from`updates_gen`. The output buffer is - // expected to have the operand values in it already. + // expected to have the operand values in it already. If unique_indices + // is false, we will use an atomic update. Using false for unique_indices + // is safe only when it is guaranteed that there are no duplicate + // indices. + // When using unique_indices=true, it is the caller's responsibility to + // ensure there is no overlap. Status EmitScatter(Thunk* thunk, HloInstruction* scatter, const llvm_ir::ElementGenerator& scatter_indices_gen, const llvm_ir::ElementGenerator& updates_gen); diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_atomic_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_atomic_test.cc index 6b18c4c6371..a54c0e5ae44 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gpu_atomic_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_atomic_test.cc @@ -53,6 +53,33 @@ CHECK: store atomic{{.*}}unordered, align 4 )"); } +TEST_F(GpuAtomicTest, TestStoreNoAtomic) { + const char* hlo_string = R"( + HloModule TensorFlowScatterV1 + + update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + ROOT rhs = s32[] parameter(1) + } + + ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2] parameter(1) + updates = s32[2,3] parameter(2) + ROOT scatter = s32[3,3] scatter(operand, indices, updates), + to_apply=update_s32, + update_window_dims={1}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1, unique_indices=true + } +)"; + + CompileAndVerifyIr(hlo_string, R"( +CHECK-NOT: store atomic{{.*}}unordered, align 4 +)"); +} + } // namespace } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto index 61e562c7eda..286562d0226 100644 --- a/tensorflow/compiler/xla/service/hlo.proto +++ b/tensorflow/compiler/xla/service/hlo.proto @@ -35,7 +35,7 @@ import "tensorflow/compiler/xla/xla_data.proto"; option cc_enable_arenas = true; // Serialization of HloInstruction. -// Next ID: 69 +// Next ID: 70 message HloInstructionProto { reserved 10; reserved "parameter_name"; @@ -237,6 +237,10 @@ message HloInstructionProto { // Frontend attributes to pass to the XLA backend. xla.FrontendAttributes frontend_attributes = 68; + + // Specifies if all elements updated are guaranteed to be unique by + // the caller. + bool unique_indices = 69; } // Serialization of HloComputation. diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 668beadfc72..54e35fbeaa0 100755 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -563,9 +563,10 @@ StatusOr> HloInstruction::CreateFromProto( auto scatter_dimension_numbers = absl::make_unique( proto.scatter_dimension_numbers()); - instruction = CreateScatter(shape, operands(0), operands(1), operands(2), - computations(0), *scatter_dimension_numbers, - proto.indices_are_sorted()); + instruction = + CreateScatter(shape, operands(0), operands(1), operands(2), + computations(0), *scatter_dimension_numbers, + proto.indices_are_sorted(), proto.unique_indices()); break; } case HloOpcode::kIota: @@ -1392,11 +1393,11 @@ bool HloInstruction::HasSideEffect() const { const Shape& shape, HloInstruction* operand, HloInstruction* scatter_indices, HloInstruction* updates, HloComputation* update_computation, - const ScatterDimensionNumbers& scatter_dim_numbers, - bool indices_are_sorted) { + const ScatterDimensionNumbers& scatter_dim_numbers, bool indices_are_sorted, + bool unique_indices) { return absl::make_unique( shape, operand, scatter_indices, updates, update_computation, - scatter_dim_numbers, indices_are_sorted); + scatter_dim_numbers, indices_are_sorted, unique_indices); } /* static */ std::unique_ptr HloInstruction::CreateDomain( diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index 32b9e82ee6e..bbd7f232bb6 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -801,7 +801,7 @@ class HloInstruction { HloInstruction* scatter_indices, HloInstruction* updates, HloComputation* update_computation, const ScatterDimensionNumbers& scatter_dim_numbers, - bool indices_are_sorted); + bool indices_are_sorted, bool unique_indices); // Creates a kDomain instruction which delimits an HLO domain which have // the provided user and operand side metadata. @@ -1629,6 +1629,9 @@ class HloInstruction { LOG(FATAL) << "Unimplemented method."; } + // Returns the unique_indices field. + virtual bool unique_indices() const { LOG(FATAL) << "Unimplemented method."; } + // Returns data on the dimension numbers used for a convolution operation, // which may be a kConvolution instruction or a kCustomCall that implements a // convolution. diff --git a/tensorflow/compiler/xla/service/hlo_instruction_test.cc b/tensorflow/compiler/xla/service/hlo_instruction_test.cc index 0a50ed04af7..633ce875de0 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction_test.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction_test.cc @@ -1529,7 +1529,8 @@ TEST_F(HloInstructionTest, StringifyScatter) { /*inserted_window_dims=*/{}, /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, /*index_vector_dim=*/2), - /*indices_are_sorted=*/false)); + /*indices_are_sorted=*/false, + /*unique_indices=*/false)); module->AddEntryComputation(builder.Build()); EXPECT_EQ( diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc index 183967941bf..82f3b245590 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.cc +++ b/tensorflow/compiler/xla/service/hlo_instructions.cc @@ -2493,9 +2493,11 @@ HloScatterInstruction::HloScatterInstruction( const Shape& shape, HloInstruction* operand, HloInstruction* scatter_indices, HloInstruction* updates, HloComputation* update_computation, - const ScatterDimensionNumbers& scatter_dim_numbers, bool indices_are_sorted) + const ScatterDimensionNumbers& scatter_dim_numbers, bool indices_are_sorted, + bool unique_indices) : HloInstruction(HloOpcode::kScatter, shape), - indices_are_sorted_(indices_are_sorted) { + indices_are_sorted_(indices_are_sorted), + unique_indices_(unique_indices) { AppendOperand(operand); AppendOperand(scatter_indices); AppendOperand(updates); @@ -2550,6 +2552,7 @@ HloInstructionProto HloScatterInstruction::ToProto() const { HloInstructionProto proto = HloInstruction::ToProto(); *proto.mutable_scatter_dimension_numbers() = scatter_dimension_numbers(); proto.set_indices_are_sorted(indices_are_sorted()); + proto.set_unique_indices(unique_indices()); return proto; } @@ -2560,6 +2563,9 @@ std::vector HloScatterInstruction::ExtraAttributesToStringImpl( if (indices_are_sorted()) { attrs.push_back("indices_are_sorted=true"); } + if (unique_indices()) { + attrs.push_back("unique_indices=true"); + } return attrs; } @@ -2572,7 +2578,8 @@ bool HloScatterInstruction::IdenticalSlowPath( scatter_dimension_numbers(), casted_other.scatter_dimension_numbers()) && eq_computations(to_apply(), casted_other.to_apply()) && - indices_are_sorted() == casted_other.indices_are_sorted(); + indices_are_sorted() == casted_other.indices_are_sorted() && + unique_indices() == casted_other.unique_indices(); } std::unique_ptr HloScatterInstruction::CloneWithNewOperandsImpl( @@ -2581,7 +2588,7 @@ std::unique_ptr HloScatterInstruction::CloneWithNewOperandsImpl( CHECK_EQ(new_operands.size(), 3); return absl::make_unique( shape, new_operands[0], new_operands[1], new_operands[2], to_apply(), - scatter_dimension_numbers(), indices_are_sorted()); + scatter_dimension_numbers(), indices_are_sorted(), unique_indices()); } HloIotaInstruction::HloIotaInstruction(const Shape& shape, int64 iota_dimension) diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h index 0de050108b7..59f2392866b 100755 --- a/tensorflow/compiler/xla/service/hlo_instructions.h +++ b/tensorflow/compiler/xla/service/hlo_instructions.h @@ -1453,7 +1453,7 @@ class HloScatterInstruction : public HloInstruction { HloInstruction* scatter_indices, HloInstruction* updates, HloComputation* update_computation, const ScatterDimensionNumbers& scatter_dim_numbers, - bool indices_are_sorted); + bool indices_are_sorted, bool unique_indices); const ScatterDimensionNumbers& scatter_dimension_numbers() const { CHECK(scatter_dimension_numbers_ != nullptr); return *scatter_dimension_numbers_; @@ -1462,6 +1462,7 @@ class HloScatterInstruction : public HloInstruction { void set_indices_are_sorted(bool indices_are_sorted) { indices_are_sorted_ = indices_are_sorted; } + bool unique_indices() const override { return unique_indices_; } // Returns a serialized representation of this instruction. HloInstructionProto ToProto() const override; @@ -1489,6 +1490,7 @@ class HloScatterInstruction : public HloInstruction { std::unique_ptr scatter_dimension_numbers_; bool indices_are_sorted_; + bool unique_indices_; }; class HloIotaInstruction : public HloInstruction { diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc index c96bfb15187..6716c4f1744 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.cc +++ b/tensorflow/compiler/xla/service/hlo_parser.cc @@ -1726,6 +1726,9 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, optional indices_are_sorted = false; attrs["indices_are_sorted"] = {/*required=*/false, AttrTy::kBool, &indices_are_sorted}; + optional unique_indices = false; + attrs["unique_indices"] = {/*required=*/false, AttrTy::kBool, + &unique_indices}; if (!ParseOperands(&operands, /*expected_size=*/3) || !ParseAttributes(attrs)) { @@ -1742,7 +1745,7 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, instruction = builder->AddInstruction(HloInstruction::CreateScatter( shape, /*operand=*/operands[0], /*scatter_indices=*/operands[1], /*updates=*/operands[2], *update_computation, dim_numbers, - indices_are_sorted.value())); + indices_are_sorted.value(), unique_indices.value())); break; } case HloOpcode::kDomain: { diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc index c913784cd13..a2c8c61bee5 100644 --- a/tensorflow/compiler/xla/service/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc @@ -934,6 +934,25 @@ ENTRY %Scatter (input_tensor: f32[50,49,48,47,46], scatter_indices: s64[10,9,8,7 ROOT %scatter = f32[50,49,48,47,46]{4,3,2,1,0} scatter(f32[50,49,48,47,46]{4,3,2,1,0} %input_tensor, s64[10,9,8,7,5]{4,3,2,1,0} %scatter_indices, f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} %updates), update_window_dims={4,5,6,7,8}, inserted_window_dims={}, scatter_dims_to_operand_dims={0,1,2,3,4}, index_vector_dim=4, indices_are_sorted=true, to_apply=%add_F32.v3 } +)" +}, +{ +"UniqueIndicesScatter", +R"(HloModule StringifyUniqueIndicesScatter + +%add_F32.v3 (lhs: f32[], rhs: f32[]) -> f32[] { + %lhs = f32[] parameter(0) + %rhs = f32[] parameter(1) + ROOT %add = f32[] add(f32[] %lhs, f32[] %rhs) +} + +ENTRY %Scatter (input_tensor: f32[50,49,48,47,46], scatter_indices: s64[10,9,8,7,5], updates: f32[10,9,8,7,30,29,28,27,26]) -> f32[50,49,48,47,46] { + %input_tensor = f32[50,49,48,47,46]{4,3,2,1,0} parameter(0) + %scatter_indices = s64[10,9,8,7,5]{4,3,2,1,0} parameter(1) + %updates = f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} parameter(2) + ROOT %scatter = f32[50,49,48,47,46]{4,3,2,1,0} scatter(f32[50,49,48,47,46]{4,3,2,1,0} %input_tensor, s64[10,9,8,7,5]{4,3,2,1,0} %scatter_indices, f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} %updates), update_window_dims={4,5,6,7,8}, inserted_window_dims={}, scatter_dims_to_operand_dims={0,1,2,3,4}, index_vector_dim=4, unique_indices=true, to_apply=%add_F32.v3 +} + )" }, { diff --git a/tensorflow/compiler/xla/tests/scatter_test.cc b/tensorflow/compiler/xla/tests/scatter_test.cc index 05073de9c90..0fdd176e8ef 100644 --- a/tensorflow/compiler/xla/tests/scatter_test.cc +++ b/tensorflow/compiler/xla/tests/scatter_test.cc @@ -225,6 +225,36 @@ ENTRY main { RunTest(hlo_text, &operand, &scatter_indices, &updates); } +XLA_TEST_F(ScatterTest, TensorFlowScatter_Add_UniqueIndices) { + const string hlo_text = R"( +HloModule TensorFlowScatter_Add + +add_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + rhs = s32[] parameter(1) + ROOT add = s32[] add(s32[] lhs, s32[] rhs) +} + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2] parameter(1) + updates = s32[2,3] parameter(2) + ROOT scatter = s32[3,3] scatter(operand, indices, updates), + to_apply=add_s32, + update_window_dims={1}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1, + unique_indices=true +} +)"; + Literal operand = + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + Literal scatter_indices = LiteralUtil::CreateR1({0, 2}); + Literal updates = LiteralUtil::CreateR2({{10, 20, 30}, {70, 80, 90}}); + RunTest(hlo_text, &operand, &scatter_indices, &updates); +} + XLA_TEST_F(ScatterTest, TensorFlowScatter_Mul) { const string hlo_text = R"( HloModule TensorFlowScatter_Mul