From 0b83955574fd77b4960e270f4b1d6373d2bec504 Mon Sep 17 00:00:00 2001 From: Rahul Joshi Date: Tue, 5 Jan 2021 10:30:20 -0800 Subject: [PATCH] [XLA:GPU] Migrate Infeed thunk emission to MLIR PiperOrigin-RevId: 350169810 Change-Id: I0ce19d7b32b19bf314e9acba1922fcf4700b6545 --- .../xla/service/gpu/ir_emitter_unnested.cc | 18 +++++++++- .../compiler/xla/service/gpu/thunk_emitter.cc | 33 ------------------- .../compiler/xla/service/gpu/thunk_emitter.h | 1 - 3 files changed, 17 insertions(+), 35 deletions(-) diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index 21967794dd6..c860e565baf 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -71,6 +71,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/gpu_constants.h" #include "tensorflow/compiler/xla/service/gpu/gpu_conv_runner.h" #include "tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h" +#include "tensorflow/compiler/xla/service/gpu/infeed_thunk.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/gpu/ir_emitter_context.h" #include "tensorflow/compiler/xla/service/gpu/kernel_mapping_scheme.h" @@ -3168,7 +3169,22 @@ Status IrEmitterUnnested::HandleAllToAll(HloInstruction* hlo) { } Status IrEmitterUnnested::HandleInfeed(HloInstruction* xla_infeed) { - return ThunkEmitter(this).HandleInfeed(xla_infeed); + TF_ASSIGN_OR_RETURN(auto input, GetMlirEmitterInput(xla_infeed)); + + auto infeed_op = mlir::dyn_cast(input.op); + + std::vector dest_slices; + dest_slices.reserve(infeed_op.outputs().size()); + + for (mlir::Value output : infeed_op.outputs()) { + TF_ASSIGN_OR_RETURN(auto slice, GetAllocationSliceForMlir(output)); + const Shape& shape = TypeToShape(output.getType()); + dest_slices.push_back(InfeedThunk::ShapedSlice{slice, shape}); + } + + AddThunkToThunkSequence( + absl::make_unique(input.thunk_info, std::move(dest_slices))); + return Status::OK(); } Status IrEmitterUnnested::HandleOutfeed(HloInstruction* outfeed) { diff --git a/tensorflow/compiler/xla/service/gpu/thunk_emitter.cc b/tensorflow/compiler/xla/service/gpu/thunk_emitter.cc index e6d6e0d3be8..7fbfeff8b89 100644 --- a/tensorflow/compiler/xla/service/gpu/thunk_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/thunk_emitter.cc @@ -115,34 +115,6 @@ std::unique_ptr ThunkEmitter::BuildGemmThunk( /*implements_whole_instruction=*/true); } -std::unique_ptr ThunkEmitter::BuildInfeedThunk( - const HloInstruction* inst) { - CHECK_EQ(HloOpcode::kInfeed, inst->opcode()); - - std::vector leaf_shapes = - ShapeUtil::GetLeafShapes(inst->shape()); - - // For an infeed HLO, the output is a 2 element tuple where the first element - // of the tuple is all the infeed buffers and the second element is a token. - // The infeed thunk does not need to handle this token output, so just drop - // it. - leaf_shapes.pop_back(); - - std::vector dest_slices; - dest_slices.reserve(leaf_shapes.size()); - - for (ShapeUtil::IndexedShape& indexed_shape : leaf_shapes) { - BufferAllocation::Slice slice = - GetAllocationSlice(*inst, indexed_shape.index); - const Shape& shape = - ShapeUtil::GetSubshape(inst->shape(), indexed_shape.index); - dest_slices.emplace_back(InfeedThunk::ShapedSlice{slice, shape}); - } - - return absl::make_unique(context_->GetThunkInfo(inst), - std::move(dest_slices)); -} - std::unique_ptr ThunkEmitter::BuildOutfeedThunk( const HloInstruction* inst) { CHECK_EQ(HloOpcode::kOutfeed, inst->opcode()); @@ -258,11 +230,6 @@ Status ThunkEmitter::HandleTriangularSolve(HloInstruction* hlo) { return Status::OK(); } -Status ThunkEmitter::HandleInfeed(HloInstruction* infeed) { - AddThunkToThunkSequence(BuildInfeedThunk(infeed)); - return Status::OK(); -} - Status ThunkEmitter::HandleOutfeed(HloInstruction* outfeed) { AddThunkToThunkSequence(BuildOutfeedThunk(outfeed)); return Status::OK(); diff --git a/tensorflow/compiler/xla/service/gpu/thunk_emitter.h b/tensorflow/compiler/xla/service/gpu/thunk_emitter.h index 16b11a4d5e2..13056819bfe 100644 --- a/tensorflow/compiler/xla/service/gpu/thunk_emitter.h +++ b/tensorflow/compiler/xla/service/gpu/thunk_emitter.h @@ -46,7 +46,6 @@ class ThunkEmitter { Status HandleCustomCall(HloInstruction* custom_call); Status HandleFft(HloInstruction* fft); Status HandleTriangularSolve(HloInstruction* hlo); - Status HandleInfeed(HloInstruction* xla_infeed); Status HandleOutfeed(HloInstruction* outfeed); private: