diff --git a/tensorflow/compiler/xla/g3doc/operation_semantics.md b/tensorflow/compiler/xla/g3doc/operation_semantics.md index 355e5fec73a..a1772a0d1e1 100644 --- a/tensorflow/compiler/xla/g3doc/operation_semantics.md +++ b/tensorflow/compiler/xla/g3doc/operation_semantics.md @@ -1384,7 +1384,7 @@ For a more intuitive description, see the "Informal Description" section below. : : : the caller. : | `unique_indices` | `bool` | Whether the indices are | : : : guaranteed to be unique by : -: : : the caller : +: : : the caller. : For convenience, we label dimensions in the output array not in `offset_dims` as `batch_dims`. @@ -1455,7 +1455,8 @@ then the semantics is implementation defined. If `unique_indices` is set to true then XLA can assume that all element scattered to are unique. So XLA could use non-atomic -operation. If they are not, then the semantics is implementation +operations. If `unique_indices` is set to true and the indices being +scattered to are not unique then the semantics is implementation defined. ### Informal Description and Examples diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h index 9c69545cb9f..0f52aecda23 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h @@ -230,7 +230,7 @@ class IrEmitterUnnested : public IrEmitter, // is false, we will use an atomic update. Using false for unique_indices // is safe only when it is guaranteed that there are no duplicate // indices. - // When using unique_indices=true, it is the caller responsibility to + // When using unique_indices=true, it is the caller's responsibility to // ensure there is no overlap. Status EmitScatter(Thunk* thunk, HloInstruction* scatter, const llvm_ir::ElementGenerator& scatter_indices_gen, diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_atomic_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_atomic_test.cc index 6b18c4c6371..a54c0e5ae44 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gpu_atomic_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_atomic_test.cc @@ -53,6 +53,33 @@ CHECK: store atomic{{.*}}unordered, align 4 )"); } +TEST_F(GpuAtomicTest, TestStoreNoAtomic) { + const char* hlo_string = R"( + HloModule TensorFlowScatterV1 + + update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + ROOT rhs = s32[] parameter(1) + } + + ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2] parameter(1) + updates = s32[2,3] parameter(2) + ROOT scatter = s32[3,3] scatter(operand, indices, updates), + to_apply=update_s32, + update_window_dims={1}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1, unique_indices=true + } +)"; + + CompileAndVerifyIr(hlo_string, R"( +CHECK-NOT: store atomic{{.*}}unordered, align 4 +)"); +} + } // namespace } // namespace gpu } // namespace xla