From e0886adfed9865f766e6c14b40a87e0ab19c5141 Mon Sep 17 00:00:00 2001 From: Rahul Joshi Date: Fri, 8 Jan 2021 09:32:34 -0800 Subject: [PATCH] [XLA:GPU] Simplify outfeed thunk to accept leaf slices. - Since MLIR will not have tuples, change outfeed tuple to accept just the leaf slices. PiperOrigin-RevId: 350778961 Change-Id: I8fd9520a314848d985a2f423d9effc146f3e75a0 --- .../compiler/xla/service/gpu/infeed_thunk.h | 7 -- .../xla/service/gpu/ir_emitter_unnested.cc | 8 +-- .../compiler/xla/service/gpu/outfeed_thunk.cc | 67 ++++++++++--------- .../compiler/xla/service/gpu/outfeed_thunk.h | 6 +- tensorflow/compiler/xla/service/gpu/thunk.h | 7 ++ .../compiler/xla/service/gpu/thunk_emitter.cc | 28 +++++--- 6 files changed, 65 insertions(+), 58 deletions(-) diff --git a/tensorflow/compiler/xla/service/gpu/infeed_thunk.h b/tensorflow/compiler/xla/service/gpu/infeed_thunk.h index afac883df67..6994bd5e54a 100644 --- a/tensorflow/compiler/xla/service/gpu/infeed_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/infeed_thunk.h @@ -32,13 +32,6 @@ namespace gpu { // to the buffer allocated for the infeed op. class InfeedThunk : public Thunk { public: - // A struct that defines a shaped slice, i.e., a BufferAllocation::Slice and - // its shape. - struct ShapedSlice { - BufferAllocation::Slice slice; - Shape shape; - }; - // Constructs a InfeedThunk that copies data from the on-device // infeed queue into the buffers in the given shape tree. InfeedThunk(ThunkInfo thunk_info, std::vector&& dest_slices); diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index b26b241046f..5591e1c0170 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -3175,13 +3175,13 @@ Status IrEmitterUnnested::HandleInfeed(HloInstruction* xla_infeed) { auto infeed_op = mlir::dyn_cast(input.op); - std::vector dest_slices; + std::vector dest_slices; dest_slices.reserve(infeed_op.outputs().size()); for (mlir::Value output : infeed_op.outputs()) { TF_ASSIGN_OR_RETURN(auto slice, GetAllocationSliceForMlir(output)); const Shape& shape = TypeToShape(output.getType()); - dest_slices.push_back(InfeedThunk::ShapedSlice{slice, shape}); + dest_slices.push_back(ShapedSlice{slice, shape}); } AddThunkToThunkSequence( @@ -4325,8 +4325,8 @@ void IrEmitterUnnested::EmitPrologueForReduction( } const HloInstruction* init_value = reduce_hlo->operand(1); - init_ir_value = (*fused_emitter->GetGenerator(init_value))( - IrArray::Index(b_.getInt32Ty())) + init_ir_value = (*fused_emitter->GetGenerator( + init_value))(IrArray::Index(b_.getInt32Ty())) .ValueOrDie(); } else { init_ir_value = operand_ir_arrays[1].EmitReadArrayElement( diff --git a/tensorflow/compiler/xla/service/gpu/outfeed_thunk.cc b/tensorflow/compiler/xla/service/gpu/outfeed_thunk.cc index 40a00748273..dd1743b557e 100644 --- a/tensorflow/compiler/xla/service/gpu/outfeed_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/outfeed_thunk.cc @@ -31,11 +31,11 @@ OutfeedConfig GetOutfeedConfig(const HloInstruction* instr) { return config; } -OutfeedThunk::OutfeedThunk(ThunkInfo thunk_info, OutfeedConfig&& config, - ShapeTree outfeed_slices) +OutfeedThunk::OutfeedThunk(ThunkInfo thunk_info, OutfeedConfig config, + std::vector source_slices) : Thunk(Kind::kOutfeed, thunk_info), config_(std::move(config)), - outfeed_slices_(std::move(outfeed_slices)) {} + source_slices_(std::move(source_slices)) {} Status OutfeedThunk::ExecuteOnStream(const ExecuteParams& params) { auto& stream = *params.stream; @@ -43,43 +43,44 @@ Status OutfeedThunk::ExecuteOnStream(const ExecuteParams& params) { VLOG(2) << "Outfeeding from GPU"; - auto op_profiler = - params.profiler->MakeScopedInstructionProfiler(profile_index()); - OutfeedManager* outfeed_manager = GetOrCreateOutfeedManager(); - ShapeTree>* outfeed_buffers = - outfeed_manager->BlockingGetNextDestination(); - // Nothing to be done for empty tuples. if (ShapeUtil::IsEmptyTuple(config_.input_shape)) { return Status::OK(); } - CHECK(ShapeUtil::Compatible(config_.input_shape, outfeed_buffers->shape())) - << "XLA program outfeed request of shape " - << config_.input_shape.ToString() - << " did not match the runtime's outfeed buffer of shape " - << outfeed_buffers->shape().ToString(); - TF_RETURN_IF_ERROR(outfeed_buffers->ForEachMutableElementWithStatus( - [&](const ShapeIndex& index, std::unique_ptr* buffer) { - if (!*buffer) { // Tuple pointers. - return Status::OK(); - } + auto op_profiler = + params.profiler->MakeScopedInstructionProfiler(profile_index()); + OutfeedManager* outfeed_manager = GetOrCreateOutfeedManager(); + ShapeTree>* output_buffers = + outfeed_manager->BlockingGetNextDestination(); - BufferAllocation::Slice slice = outfeed_slices_.element(index); - if (!slice.allocation()) - return InternalError("outfeed input missing buffer allocation"); - se::DeviceMemoryBase data_address = - buffer_allocations.GetDeviceAddress(slice); + size_t index = 0; + for (auto& output : output_buffers->leaves()) { + // Assert that the shapes are compatible. + const ShapeIndex& shape_index = output.first; + std::unique_ptr& buffer = output.second; + const Shape& output_shape = + ShapeUtil::GetSubshape(output_buffers->shape(), shape_index); + TF_RET_CHECK(ShapeUtil::Equal(source_slices_[index].shape, output_shape)) + << "Mismatch between outfeed output buffer shape " + << ShapeUtil::HumanStringWithLayout(output_shape) + << " and outfeed source buffer shape " + << ShapeUtil::HumanStringWithLayout(source_slices_[index].shape); - // TODO(b/111309141): Run this on a separate stream so it doesn't block - // the GPU from doing work during the transfer. This could be handled by - // making StreamAssignment do something intelligent with outfeed thunks. - stream - .ThenMemcpy((*buffer)->destination()->untyped_data(), data_address, - (*buffer)->length()) - .ThenDoHostCallback([buffer]() { (*buffer)->Done(); }); - return Status::OK(); - })); + BufferAllocation::Slice source_slice = source_slices_[index++].slice; + if (!source_slice.allocation()) + return InternalError("outfeed source missing buffer allocation"); + se::DeviceMemoryBase data_address = + buffer_allocations.GetDeviceAddress(source_slice); + + // TODO(b/111309141): Run this on a separate stream so it doesn't block + // the GPU from doing work during the transfer. This could be handled by + // making StreamAssignment do something intelligent with outfeed thunks. + stream + .ThenMemcpy(buffer->destination()->untyped_data(), data_address, + buffer->length()) + .ThenDoHostCallback([&buffer]() { buffer->Done(); }); + } Status block_status = stream.BlockHostUntilDone(); if (!block_status.ok()) { diff --git a/tensorflow/compiler/xla/service/gpu/outfeed_thunk.h b/tensorflow/compiler/xla/service/gpu/outfeed_thunk.h index eec336407cd..0acfb9bc9a4 100644 --- a/tensorflow/compiler/xla/service/gpu/outfeed_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/outfeed_thunk.h @@ -38,8 +38,8 @@ class OutfeedThunk : public Thunk { public: // Constructs a OutfeedThunk that copies data to the host-side // outfeed queue from the buffers in the given shape tree. - OutfeedThunk(ThunkInfo thunk_info, OutfeedConfig&& config, - ShapeTree outfeed_slices); + OutfeedThunk(ThunkInfo thunk_info, OutfeedConfig config, + std::vector source_slices); OutfeedThunk(const OutfeedThunk&) = delete; OutfeedThunk& operator=(const OutfeedThunk&) = delete; @@ -48,7 +48,7 @@ class OutfeedThunk : public Thunk { private: const OutfeedConfig config_; - const ShapeTree outfeed_slices_; + const std::vector source_slices_; }; } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/thunk.h b/tensorflow/compiler/xla/service/gpu/thunk.h index ed79f1c45f6..9166ee1878a 100644 --- a/tensorflow/compiler/xla/service/gpu/thunk.h +++ b/tensorflow/compiler/xla/service/gpu/thunk.h @@ -148,6 +148,13 @@ using ThunkSequence = std::vector>; absl::string_view ThunkKindToString(Thunk::Kind); std::ostream& operator<<(std::ostream& os, Thunk::Kind kind); +// A struct that defines a shaped slice, i.e., a BufferAllocation::Slice and its +// shape. +struct ShapedSlice { + BufferAllocation::Slice slice; + Shape shape; +}; + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/thunk_emitter.cc b/tensorflow/compiler/xla/service/gpu/thunk_emitter.cc index 7fbfeff8b89..6fb667c5d9c 100644 --- a/tensorflow/compiler/xla/service/gpu/thunk_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/thunk_emitter.cc @@ -43,7 +43,6 @@ limitations under the License. namespace xla { namespace gpu { - std::unique_ptr ThunkEmitter::BuildFftThunk(const HloInstruction* inst) { const HloInstruction* operand = inst->operand(0); return absl::make_unique( @@ -119,24 +118,31 @@ std::unique_ptr ThunkEmitter::BuildOutfeedThunk( const HloInstruction* inst) { CHECK_EQ(HloOpcode::kOutfeed, inst->opcode()); - ShapeTree slices(inst->operand(0)->shape()); - slices.ForEachMutableElement([&](const ShapeIndex& index, - BufferAllocation::Slice* slice) { - auto status_or_slice = MaybeGetAllocationSlice(*inst->operand(0), index); - if (status_or_slice.ok()) { - *slice = status_or_slice.ValueOrDie(); - } - }); + const HloInstruction* source = inst->operand(0); + std::vector leaf_shapes = + ShapeUtil::GetLeafShapes(source->shape()); + + std::vector source_slices; + source_slices.reserve(leaf_shapes.size()); + + for (ShapeUtil::IndexedShape& indexed_shape : leaf_shapes) { + BufferAllocation::Slice slice = + GetAllocationSlice(*source, indexed_shape.index); + const Shape& shape = + ShapeUtil::GetSubshape(source->shape(), indexed_shape.index); + source_slices.push_back(ShapedSlice{slice, shape}); + } + OutfeedConfig config = GetOutfeedConfig(inst); return absl::make_unique(context_->GetThunkInfo(inst), - std::move(config), std::move(slices)); + std::move(config), + std::move(source_slices)); } Status ThunkEmitter::HandleCustomCall(HloInstruction* custom_call) { // A CustomCall on the GPU backend can either be a custom-call to a // user-supplied kernel, or a call into a library like cudnn. - #if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \ (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM) if (void* call_target = CustomCallTargetRegistry::Global()->Lookup(