diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index b0c221dc954..51bee21df4e 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -202,6 +202,31 @@ StatusOr GetAllocationSliceForMlir( "StaticMemRefCastOp(ViewOp(arg))"); } +StatusOr> GetMlirBufferSlices( + mlir::Operation* op, mlir::OperandRange operands, + absl::Span allocations) { + const auto buffer_is_written = [op](mlir::Value operand) { + llvm::SmallVector effects; + mlir::cast(op).getEffectsOnValue(operand, + effects); + return absl::c_any_of( + effects, [](const mlir::MemoryEffects::EffectInstance& instance) { + return mlir::isa(instance.getEffect()); + }); + }; + + std::vector slices; + for (mlir::Value operand : operands) { + slices.emplace_back(); + auto& slice = slices.back(); + TF_ASSIGN_OR_RETURN(slice.buffer_slice, + GetAllocationSliceForMlir(operand, allocations)); + slice.written = buffer_is_written(operand); + slice.shape = TypeToShape(operand.getType()); + } + return slices; +} + } // namespace IrEmitterUnnested::IrEmitterUnnested(const HloModuleConfig& hlo_module_config, @@ -1371,47 +1396,30 @@ Status IrEmitterUnnested::HandleSort(HloInstruction* sort) { return EmitSortFromMlir(result); } -Status IrEmitterUnnested::EmitSortFromMlir(MlirEmitterInput input) { +Status IrEmitterUnnested::EmitSortFromMlir(MlirEmitterInput mlir_input) { absl::Span allocations( ir_emitter_context_->buffer_assignment().Allocations()); - auto sort_op = mlir::cast(input.op); + auto sort_op = mlir::cast(mlir_input.op); std::string name = mlir::GetNameFromLoc(sort_op.getLoc()); - - int operand_count = sort_op.operands().size(); - std::vector operand_shapes(operand_count); - std::vector slices; - std::vector output_shapes(sort_op.output().size()); - - for (int i = 0; i < operand_count; i++) { - operand_shapes[i] = TypeToShape(sort_op.operands()[i].getType()); - } - - // Craft n + 1 slices, where the first n are output parameters, and the last - // is the on-device tuple storage. We don't need n operands because sorting - // kernels are always in-place. - for (int i = 0; i < operand_count; i++) { - output_shapes[i] = TypeToShape(sort_op.output()[i].getType()); - MlirBufferSlice slice; - TF_ASSIGN_OR_RETURN( - slice.buffer_slice, - GetAllocationSliceForMlir(sort_op.output()[i], allocations)); - slice.written = true; - slice.shape = operand_shapes[i]; - slices.push_back(slice); - } - slices.push_back(input.extra_slice); + TF_ASSIGN_OR_RETURN( + std::vector operands, + GetMlirBufferSlices(sort_op, sort_op.operands(), allocations)); + TF_ASSIGN_OR_RETURN( + std::vector outputs, + GetMlirBufferSlices(sort_op, sort_op.output(), allocations)); + outputs.push_back(mlir_input.extra_slice); std::vector> thunks; - Shape keys_shape = operand_shapes[0]; + Shape keys_shape = operands[0].shape; int64 dimension_to_sort = sort_op.dimension(); - for (int64 i = 0; i < operand_count; ++i) { + for (int64 i = 0; i < operands.size(); ++i) { // We assume that the layout of all involved operands and outputs is the // same. TF_RET_CHECK( - LayoutUtil::LayoutsInShapesEqual(keys_shape, operand_shapes[i])); + LayoutUtil::LayoutsInShapesEqual(keys_shape, operands[i].shape)); TF_RET_CHECK( - LayoutUtil::LayoutsInShapesEqual(keys_shape, output_shapes[i])); + LayoutUtil::LayoutsInShapesEqual(keys_shape, outputs[i].shape)); // If possible, we share buffers. If that is not possible, we need to copy // the values, because the emitter does the sorting in-place. @@ -1429,7 +1437,7 @@ Status IrEmitterUnnested::EmitSortFromMlir(MlirEmitterInput input) { Thunk::ThunkInfo(), /*source_address=*/source_address, /*destination_buffer=*/destination_buffer, - /*mem_size=*/ShapeUtil::ByteSizeOf(operand_shapes[i]))); + /*mem_size=*/ShapeUtil::ByteSizeOf(operands[i].shape))); } } @@ -1499,10 +1507,10 @@ Status IrEmitterUnnested::EmitSortFromMlir(MlirEmitterInput input) { // we have not enough threads, or not enough shared memory. Also it does not // give a speedup if the tile size is < 128. int64 total_shared_memory_needed = 0; - for (int64 i = 0; i < operand_count; ++i) { + for (int64 i = 0; i < operands.size(); ++i) { total_shared_memory_needed += kTileSize * - ShapeUtil::ByteSizeOfPrimitiveType(operand_shapes[i].element_type()); + ShapeUtil::ByteSizeOfPrimitiveType(operands[i].shape.element_type()); } bool no_tiling = kTileSize < 128 || @@ -1533,15 +1541,15 @@ Status IrEmitterUnnested::EmitSortFromMlir(MlirEmitterInput input) { absl::StrAppendFormat(out, "0x%x", xor_mask); })); thunks.push_back( - BuildKernelThunkForMlir(name, Thunk::ThunkInfo(), slices, &ir_arrays)); + BuildKernelThunkForMlir(name, Thunk::ThunkInfo(), outputs, &ir_arrays)); LaunchDimensions launch_dimensions = xor_masks.size() > 1 ? tiled_launch_dimensions : standard_launch_dimensions; UpdateLaunchDimensions(launch_dimensions, thunks.back().get(), ir_emitter_context_->llvm_module()); std::vector values_arrays; - values_arrays.reserve(operand_count); - for (int64 i = 0; i < operand_count; ++i) { + values_arrays.reserve(operands.size()); + for (int64 i = 0; i < operands.size(); ++i) { values_arrays.push_back(ir_arrays[i]); } TF_ASSIGN_OR_RETURN( @@ -1583,14 +1591,14 @@ Status IrEmitterUnnested::EmitSortFromMlir(MlirEmitterInput input) { VLOG(2) << absl::StreamFormat( "%s requires %d thunks (including any D2D copies)", name, thunks.size()); - AddThunkToThunkSequence( - absl::make_unique(input.thunk_info, std::move(thunks))); - if (operand_count > 1) { + AddThunkToThunkSequence(absl::make_unique( + mlir_input.thunk_info, std::move(thunks))); + if (operands.size() > 1) { // Emit the tuple as part of the last stage of sorting. // We are currently in the block sorted.in_bounds.after. b_.SetInsertPoint(b_.GetInsertBlock()->getTerminator()); llvm_ir::EmitTuple( - ir_arrays[operand_count], + ir_arrays.back(), absl::MakeSpan(ir_arrays).subspan(0, ir_arrays.size() - 1), &b_); } return Status::OK(); diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h index a317aac16ec..5cc5e206167 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h @@ -160,7 +160,7 @@ class IrEmitterUnnested : public IrEmitter, Status HandleScatter(HloInstruction* scatter) override; Status HandleSelect(HloInstruction* select) override; Status HandleSort(HloInstruction* sort) override; - Status EmitSortFromMlir(MlirEmitterInput input); + Status EmitSortFromMlir(MlirEmitterInput mlir_input); Status HandleTriangularSolve(HloInstruction* hlo) override; Status HandleTupleSelect(HloInstruction* tuple_select) override; Status HandleAllReduce(HloInstruction* crs) override;