[XLA:GPU] Migrate Infeed thunk emission to MLIR

PiperOrigin-RevId: 350169810
Change-Id: I0ce19d7b32b19bf314e9acba1922fcf4700b6545
This commit is contained in:
Rahul Joshi 2021-01-05 10:30:20 -08:00 committed by TensorFlower Gardener
parent 4788116863
commit 0b83955574
3 changed files with 17 additions and 35 deletions

View File

@ -71,6 +71,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/gpu_constants.h" #include "tensorflow/compiler/xla/service/gpu/gpu_constants.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_conv_runner.h" #include "tensorflow/compiler/xla/service/gpu/gpu_conv_runner.h"
#include "tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h" #include "tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h"
#include "tensorflow/compiler/xla/service/gpu/infeed_thunk.h"
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
#include "tensorflow/compiler/xla/service/gpu/ir_emitter_context.h" #include "tensorflow/compiler/xla/service/gpu/ir_emitter_context.h"
#include "tensorflow/compiler/xla/service/gpu/kernel_mapping_scheme.h" #include "tensorflow/compiler/xla/service/gpu/kernel_mapping_scheme.h"
@ -3168,7 +3169,22 @@ Status IrEmitterUnnested::HandleAllToAll(HloInstruction* hlo) {
} }
Status IrEmitterUnnested::HandleInfeed(HloInstruction* xla_infeed) { Status IrEmitterUnnested::HandleInfeed(HloInstruction* xla_infeed) {
return ThunkEmitter(this).HandleInfeed(xla_infeed); TF_ASSIGN_OR_RETURN(auto input, GetMlirEmitterInput(xla_infeed));
auto infeed_op = mlir::dyn_cast<mlir::lmhlo::InfeedOp>(input.op);
std::vector<InfeedThunk::ShapedSlice> dest_slices;
dest_slices.reserve(infeed_op.outputs().size());
for (mlir::Value output : infeed_op.outputs()) {
TF_ASSIGN_OR_RETURN(auto slice, GetAllocationSliceForMlir(output));
const Shape& shape = TypeToShape(output.getType());
dest_slices.push_back(InfeedThunk::ShapedSlice{slice, shape});
}
AddThunkToThunkSequence(
absl::make_unique<InfeedThunk>(input.thunk_info, std::move(dest_slices)));
return Status::OK();
} }
Status IrEmitterUnnested::HandleOutfeed(HloInstruction* outfeed) { Status IrEmitterUnnested::HandleOutfeed(HloInstruction* outfeed) {

View File

@ -115,34 +115,6 @@ std::unique_ptr<Thunk> ThunkEmitter::BuildGemmThunk(
/*implements_whole_instruction=*/true); /*implements_whole_instruction=*/true);
} }
std::unique_ptr<Thunk> ThunkEmitter::BuildInfeedThunk(
const HloInstruction* inst) {
CHECK_EQ(HloOpcode::kInfeed, inst->opcode());
std::vector<ShapeUtil::IndexedShape> leaf_shapes =
ShapeUtil::GetLeafShapes(inst->shape());
// For an infeed HLO, the output is a 2 element tuple where the first element
// of the tuple is all the infeed buffers and the second element is a token.
// The infeed thunk does not need to handle this token output, so just drop
// it.
leaf_shapes.pop_back();
std::vector<InfeedThunk::ShapedSlice> dest_slices;
dest_slices.reserve(leaf_shapes.size());
for (ShapeUtil::IndexedShape& indexed_shape : leaf_shapes) {
BufferAllocation::Slice slice =
GetAllocationSlice(*inst, indexed_shape.index);
const Shape& shape =
ShapeUtil::GetSubshape(inst->shape(), indexed_shape.index);
dest_slices.emplace_back(InfeedThunk::ShapedSlice{slice, shape});
}
return absl::make_unique<InfeedThunk>(context_->GetThunkInfo(inst),
std::move(dest_slices));
}
std::unique_ptr<Thunk> ThunkEmitter::BuildOutfeedThunk( std::unique_ptr<Thunk> ThunkEmitter::BuildOutfeedThunk(
const HloInstruction* inst) { const HloInstruction* inst) {
CHECK_EQ(HloOpcode::kOutfeed, inst->opcode()); CHECK_EQ(HloOpcode::kOutfeed, inst->opcode());
@ -258,11 +230,6 @@ Status ThunkEmitter::HandleTriangularSolve(HloInstruction* hlo) {
return Status::OK(); return Status::OK();
} }
Status ThunkEmitter::HandleInfeed(HloInstruction* infeed) {
AddThunkToThunkSequence(BuildInfeedThunk(infeed));
return Status::OK();
}
Status ThunkEmitter::HandleOutfeed(HloInstruction* outfeed) { Status ThunkEmitter::HandleOutfeed(HloInstruction* outfeed) {
AddThunkToThunkSequence(BuildOutfeedThunk(outfeed)); AddThunkToThunkSequence(BuildOutfeedThunk(outfeed));
return Status::OK(); return Status::OK();

View File

@ -46,7 +46,6 @@ class ThunkEmitter {
Status HandleCustomCall(HloInstruction* custom_call); Status HandleCustomCall(HloInstruction* custom_call);
Status HandleFft(HloInstruction* fft); Status HandleFft(HloInstruction* fft);
Status HandleTriangularSolve(HloInstruction* hlo); Status HandleTriangularSolve(HloInstruction* hlo);
Status HandleInfeed(HloInstruction* xla_infeed);
Status HandleOutfeed(HloInstruction* outfeed); Status HandleOutfeed(HloInstruction* outfeed);
private: private: