[XLA:GPU] Migrate Infeed thunk emission to MLIR
PiperOrigin-RevId: 350169810 Change-Id: I0ce19d7b32b19bf314e9acba1922fcf4700b6545
This commit is contained in:
parent
4788116863
commit
0b83955574
@ -71,6 +71,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/service/gpu/gpu_constants.h"
|
#include "tensorflow/compiler/xla/service/gpu/gpu_constants.h"
|
||||||
#include "tensorflow/compiler/xla/service/gpu/gpu_conv_runner.h"
|
#include "tensorflow/compiler/xla/service/gpu/gpu_conv_runner.h"
|
||||||
#include "tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h"
|
#include "tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/gpu/infeed_thunk.h"
|
||||||
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
|
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
|
||||||
#include "tensorflow/compiler/xla/service/gpu/ir_emitter_context.h"
|
#include "tensorflow/compiler/xla/service/gpu/ir_emitter_context.h"
|
||||||
#include "tensorflow/compiler/xla/service/gpu/kernel_mapping_scheme.h"
|
#include "tensorflow/compiler/xla/service/gpu/kernel_mapping_scheme.h"
|
||||||
@ -3168,7 +3169,22 @@ Status IrEmitterUnnested::HandleAllToAll(HloInstruction* hlo) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
Status IrEmitterUnnested::HandleInfeed(HloInstruction* xla_infeed) {
|
Status IrEmitterUnnested::HandleInfeed(HloInstruction* xla_infeed) {
|
||||||
return ThunkEmitter(this).HandleInfeed(xla_infeed);
|
TF_ASSIGN_OR_RETURN(auto input, GetMlirEmitterInput(xla_infeed));
|
||||||
|
|
||||||
|
auto infeed_op = mlir::dyn_cast<mlir::lmhlo::InfeedOp>(input.op);
|
||||||
|
|
||||||
|
std::vector<InfeedThunk::ShapedSlice> dest_slices;
|
||||||
|
dest_slices.reserve(infeed_op.outputs().size());
|
||||||
|
|
||||||
|
for (mlir::Value output : infeed_op.outputs()) {
|
||||||
|
TF_ASSIGN_OR_RETURN(auto slice, GetAllocationSliceForMlir(output));
|
||||||
|
const Shape& shape = TypeToShape(output.getType());
|
||||||
|
dest_slices.push_back(InfeedThunk::ShapedSlice{slice, shape});
|
||||||
|
}
|
||||||
|
|
||||||
|
AddThunkToThunkSequence(
|
||||||
|
absl::make_unique<InfeedThunk>(input.thunk_info, std::move(dest_slices)));
|
||||||
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status IrEmitterUnnested::HandleOutfeed(HloInstruction* outfeed) {
|
Status IrEmitterUnnested::HandleOutfeed(HloInstruction* outfeed) {
|
||||||
|
@ -115,34 +115,6 @@ std::unique_ptr<Thunk> ThunkEmitter::BuildGemmThunk(
|
|||||||
/*implements_whole_instruction=*/true);
|
/*implements_whole_instruction=*/true);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unique_ptr<Thunk> ThunkEmitter::BuildInfeedThunk(
|
|
||||||
const HloInstruction* inst) {
|
|
||||||
CHECK_EQ(HloOpcode::kInfeed, inst->opcode());
|
|
||||||
|
|
||||||
std::vector<ShapeUtil::IndexedShape> leaf_shapes =
|
|
||||||
ShapeUtil::GetLeafShapes(inst->shape());
|
|
||||||
|
|
||||||
// For an infeed HLO, the output is a 2 element tuple where the first element
|
|
||||||
// of the tuple is all the infeed buffers and the second element is a token.
|
|
||||||
// The infeed thunk does not need to handle this token output, so just drop
|
|
||||||
// it.
|
|
||||||
leaf_shapes.pop_back();
|
|
||||||
|
|
||||||
std::vector<InfeedThunk::ShapedSlice> dest_slices;
|
|
||||||
dest_slices.reserve(leaf_shapes.size());
|
|
||||||
|
|
||||||
for (ShapeUtil::IndexedShape& indexed_shape : leaf_shapes) {
|
|
||||||
BufferAllocation::Slice slice =
|
|
||||||
GetAllocationSlice(*inst, indexed_shape.index);
|
|
||||||
const Shape& shape =
|
|
||||||
ShapeUtil::GetSubshape(inst->shape(), indexed_shape.index);
|
|
||||||
dest_slices.emplace_back(InfeedThunk::ShapedSlice{slice, shape});
|
|
||||||
}
|
|
||||||
|
|
||||||
return absl::make_unique<InfeedThunk>(context_->GetThunkInfo(inst),
|
|
||||||
std::move(dest_slices));
|
|
||||||
}
|
|
||||||
|
|
||||||
std::unique_ptr<Thunk> ThunkEmitter::BuildOutfeedThunk(
|
std::unique_ptr<Thunk> ThunkEmitter::BuildOutfeedThunk(
|
||||||
const HloInstruction* inst) {
|
const HloInstruction* inst) {
|
||||||
CHECK_EQ(HloOpcode::kOutfeed, inst->opcode());
|
CHECK_EQ(HloOpcode::kOutfeed, inst->opcode());
|
||||||
@ -258,11 +230,6 @@ Status ThunkEmitter::HandleTriangularSolve(HloInstruction* hlo) {
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status ThunkEmitter::HandleInfeed(HloInstruction* infeed) {
|
|
||||||
AddThunkToThunkSequence(BuildInfeedThunk(infeed));
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
Status ThunkEmitter::HandleOutfeed(HloInstruction* outfeed) {
|
Status ThunkEmitter::HandleOutfeed(HloInstruction* outfeed) {
|
||||||
AddThunkToThunkSequence(BuildOutfeedThunk(outfeed));
|
AddThunkToThunkSequence(BuildOutfeedThunk(outfeed));
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
|
@ -46,7 +46,6 @@ class ThunkEmitter {
|
|||||||
Status HandleCustomCall(HloInstruction* custom_call);
|
Status HandleCustomCall(HloInstruction* custom_call);
|
||||||
Status HandleFft(HloInstruction* fft);
|
Status HandleFft(HloInstruction* fft);
|
||||||
Status HandleTriangularSolve(HloInstruction* hlo);
|
Status HandleTriangularSolve(HloInstruction* hlo);
|
||||||
Status HandleInfeed(HloInstruction* xla_infeed);
|
|
||||||
Status HandleOutfeed(HloInstruction* outfeed);
|
Status HandleOutfeed(HloInstruction* outfeed);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
Loading…
Reference in New Issue
Block a user