[XLA:GPU] Simplify outfeed thunk to accept leaf slices.
- Since MLIR will not have tuples, change outfeed tuple to accept just the leaf slices. PiperOrigin-RevId: 350778961 Change-Id: I8fd9520a314848d985a2f423d9effc146f3e75a0
This commit is contained in:
parent
adc7dadc74
commit
e0886adfed
@ -32,13 +32,6 @@ namespace gpu {
|
||||
// to the buffer allocated for the infeed op.
|
||||
class InfeedThunk : public Thunk {
|
||||
public:
|
||||
// A struct that defines a shaped slice, i.e., a BufferAllocation::Slice and
|
||||
// its shape.
|
||||
struct ShapedSlice {
|
||||
BufferAllocation::Slice slice;
|
||||
Shape shape;
|
||||
};
|
||||
|
||||
// Constructs a InfeedThunk that copies data from the on-device
|
||||
// infeed queue into the buffers in the given shape tree.
|
||||
InfeedThunk(ThunkInfo thunk_info, std::vector<ShapedSlice>&& dest_slices);
|
||||
|
@ -3175,13 +3175,13 @@ Status IrEmitterUnnested::HandleInfeed(HloInstruction* xla_infeed) {
|
||||
|
||||
auto infeed_op = mlir::dyn_cast<mlir::lmhlo::InfeedOp>(input.op);
|
||||
|
||||
std::vector<InfeedThunk::ShapedSlice> dest_slices;
|
||||
std::vector<ShapedSlice> dest_slices;
|
||||
dest_slices.reserve(infeed_op.outputs().size());
|
||||
|
||||
for (mlir::Value output : infeed_op.outputs()) {
|
||||
TF_ASSIGN_OR_RETURN(auto slice, GetAllocationSliceForMlir(output));
|
||||
const Shape& shape = TypeToShape(output.getType());
|
||||
dest_slices.push_back(InfeedThunk::ShapedSlice{slice, shape});
|
||||
dest_slices.push_back(ShapedSlice{slice, shape});
|
||||
}
|
||||
|
||||
AddThunkToThunkSequence(
|
||||
@ -4325,8 +4325,8 @@ void IrEmitterUnnested::EmitPrologueForReduction(
|
||||
}
|
||||
const HloInstruction* init_value = reduce_hlo->operand(1);
|
||||
|
||||
init_ir_value = (*fused_emitter->GetGenerator(init_value))(
|
||||
IrArray::Index(b_.getInt32Ty()))
|
||||
init_ir_value = (*fused_emitter->GetGenerator(
|
||||
init_value))(IrArray::Index(b_.getInt32Ty()))
|
||||
.ValueOrDie();
|
||||
} else {
|
||||
init_ir_value = operand_ir_arrays[1].EmitReadArrayElement(
|
||||
|
@ -31,11 +31,11 @@ OutfeedConfig GetOutfeedConfig(const HloInstruction* instr) {
|
||||
return config;
|
||||
}
|
||||
|
||||
OutfeedThunk::OutfeedThunk(ThunkInfo thunk_info, OutfeedConfig&& config,
|
||||
ShapeTree<BufferAllocation::Slice> outfeed_slices)
|
||||
OutfeedThunk::OutfeedThunk(ThunkInfo thunk_info, OutfeedConfig config,
|
||||
std::vector<ShapedSlice> source_slices)
|
||||
: Thunk(Kind::kOutfeed, thunk_info),
|
||||
config_(std::move(config)),
|
||||
outfeed_slices_(std::move(outfeed_slices)) {}
|
||||
source_slices_(std::move(source_slices)) {}
|
||||
|
||||
Status OutfeedThunk::ExecuteOnStream(const ExecuteParams& params) {
|
||||
auto& stream = *params.stream;
|
||||
@ -43,43 +43,44 @@ Status OutfeedThunk::ExecuteOnStream(const ExecuteParams& params) {
|
||||
|
||||
VLOG(2) << "Outfeeding from GPU";
|
||||
|
||||
auto op_profiler =
|
||||
params.profiler->MakeScopedInstructionProfiler(profile_index());
|
||||
OutfeedManager* outfeed_manager = GetOrCreateOutfeedManager();
|
||||
ShapeTree<std::unique_ptr<OutfeedBuffer>>* outfeed_buffers =
|
||||
outfeed_manager->BlockingGetNextDestination();
|
||||
|
||||
// Nothing to be done for empty tuples.
|
||||
if (ShapeUtil::IsEmptyTuple(config_.input_shape)) {
|
||||
return Status::OK();
|
||||
}
|
||||
CHECK(ShapeUtil::Compatible(config_.input_shape, outfeed_buffers->shape()))
|
||||
<< "XLA program outfeed request of shape "
|
||||
<< config_.input_shape.ToString()
|
||||
<< " did not match the runtime's outfeed buffer of shape "
|
||||
<< outfeed_buffers->shape().ToString();
|
||||
|
||||
TF_RETURN_IF_ERROR(outfeed_buffers->ForEachMutableElementWithStatus(
|
||||
[&](const ShapeIndex& index, std::unique_ptr<OutfeedBuffer>* buffer) {
|
||||
if (!*buffer) { // Tuple pointers.
|
||||
return Status::OK();
|
||||
}
|
||||
auto op_profiler =
|
||||
params.profiler->MakeScopedInstructionProfiler(profile_index());
|
||||
OutfeedManager* outfeed_manager = GetOrCreateOutfeedManager();
|
||||
ShapeTree<std::unique_ptr<OutfeedBuffer>>* output_buffers =
|
||||
outfeed_manager->BlockingGetNextDestination();
|
||||
|
||||
BufferAllocation::Slice slice = outfeed_slices_.element(index);
|
||||
if (!slice.allocation())
|
||||
return InternalError("outfeed input missing buffer allocation");
|
||||
se::DeviceMemoryBase data_address =
|
||||
buffer_allocations.GetDeviceAddress(slice);
|
||||
size_t index = 0;
|
||||
for (auto& output : output_buffers->leaves()) {
|
||||
// Assert that the shapes are compatible.
|
||||
const ShapeIndex& shape_index = output.first;
|
||||
std::unique_ptr<OutfeedBuffer>& buffer = output.second;
|
||||
const Shape& output_shape =
|
||||
ShapeUtil::GetSubshape(output_buffers->shape(), shape_index);
|
||||
TF_RET_CHECK(ShapeUtil::Equal(source_slices_[index].shape, output_shape))
|
||||
<< "Mismatch between outfeed output buffer shape "
|
||||
<< ShapeUtil::HumanStringWithLayout(output_shape)
|
||||
<< " and outfeed source buffer shape "
|
||||
<< ShapeUtil::HumanStringWithLayout(source_slices_[index].shape);
|
||||
|
||||
// TODO(b/111309141): Run this on a separate stream so it doesn't block
|
||||
// the GPU from doing work during the transfer. This could be handled by
|
||||
// making StreamAssignment do something intelligent with outfeed thunks.
|
||||
stream
|
||||
.ThenMemcpy((*buffer)->destination()->untyped_data(), data_address,
|
||||
(*buffer)->length())
|
||||
.ThenDoHostCallback([buffer]() { (*buffer)->Done(); });
|
||||
return Status::OK();
|
||||
}));
|
||||
BufferAllocation::Slice source_slice = source_slices_[index++].slice;
|
||||
if (!source_slice.allocation())
|
||||
return InternalError("outfeed source missing buffer allocation");
|
||||
se::DeviceMemoryBase data_address =
|
||||
buffer_allocations.GetDeviceAddress(source_slice);
|
||||
|
||||
// TODO(b/111309141): Run this on a separate stream so it doesn't block
|
||||
// the GPU from doing work during the transfer. This could be handled by
|
||||
// making StreamAssignment do something intelligent with outfeed thunks.
|
||||
stream
|
||||
.ThenMemcpy(buffer->destination()->untyped_data(), data_address,
|
||||
buffer->length())
|
||||
.ThenDoHostCallback([&buffer]() { buffer->Done(); });
|
||||
}
|
||||
|
||||
Status block_status = stream.BlockHostUntilDone();
|
||||
if (!block_status.ok()) {
|
||||
|
@ -38,8 +38,8 @@ class OutfeedThunk : public Thunk {
|
||||
public:
|
||||
// Constructs a OutfeedThunk that copies data to the host-side
|
||||
// outfeed queue from the buffers in the given shape tree.
|
||||
OutfeedThunk(ThunkInfo thunk_info, OutfeedConfig&& config,
|
||||
ShapeTree<BufferAllocation::Slice> outfeed_slices);
|
||||
OutfeedThunk(ThunkInfo thunk_info, OutfeedConfig config,
|
||||
std::vector<ShapedSlice> source_slices);
|
||||
|
||||
OutfeedThunk(const OutfeedThunk&) = delete;
|
||||
OutfeedThunk& operator=(const OutfeedThunk&) = delete;
|
||||
@ -48,7 +48,7 @@ class OutfeedThunk : public Thunk {
|
||||
|
||||
private:
|
||||
const OutfeedConfig config_;
|
||||
const ShapeTree<BufferAllocation::Slice> outfeed_slices_;
|
||||
const std::vector<ShapedSlice> source_slices_;
|
||||
};
|
||||
|
||||
} // namespace gpu
|
||||
|
@ -148,6 +148,13 @@ using ThunkSequence = std::vector<std::unique_ptr<Thunk>>;
|
||||
absl::string_view ThunkKindToString(Thunk::Kind);
|
||||
std::ostream& operator<<(std::ostream& os, Thunk::Kind kind);
|
||||
|
||||
// A struct that defines a shaped slice, i.e., a BufferAllocation::Slice and its
|
||||
// shape.
|
||||
struct ShapedSlice {
|
||||
BufferAllocation::Slice slice;
|
||||
Shape shape;
|
||||
};
|
||||
|
||||
} // namespace gpu
|
||||
} // namespace xla
|
||||
|
||||
|
@ -43,7 +43,6 @@ limitations under the License.
|
||||
namespace xla {
|
||||
namespace gpu {
|
||||
|
||||
|
||||
std::unique_ptr<Thunk> ThunkEmitter::BuildFftThunk(const HloInstruction* inst) {
|
||||
const HloInstruction* operand = inst->operand(0);
|
||||
return absl::make_unique<FftThunk>(
|
||||
@ -119,24 +118,31 @@ std::unique_ptr<Thunk> ThunkEmitter::BuildOutfeedThunk(
|
||||
const HloInstruction* inst) {
|
||||
CHECK_EQ(HloOpcode::kOutfeed, inst->opcode());
|
||||
|
||||
ShapeTree<BufferAllocation::Slice> slices(inst->operand(0)->shape());
|
||||
slices.ForEachMutableElement([&](const ShapeIndex& index,
|
||||
BufferAllocation::Slice* slice) {
|
||||
auto status_or_slice = MaybeGetAllocationSlice(*inst->operand(0), index);
|
||||
if (status_or_slice.ok()) {
|
||||
*slice = status_or_slice.ValueOrDie();
|
||||
}
|
||||
});
|
||||
const HloInstruction* source = inst->operand(0);
|
||||
std::vector<ShapeUtil::IndexedShape> leaf_shapes =
|
||||
ShapeUtil::GetLeafShapes(source->shape());
|
||||
|
||||
std::vector<ShapedSlice> source_slices;
|
||||
source_slices.reserve(leaf_shapes.size());
|
||||
|
||||
for (ShapeUtil::IndexedShape& indexed_shape : leaf_shapes) {
|
||||
BufferAllocation::Slice slice =
|
||||
GetAllocationSlice(*source, indexed_shape.index);
|
||||
const Shape& shape =
|
||||
ShapeUtil::GetSubshape(source->shape(), indexed_shape.index);
|
||||
source_slices.push_back(ShapedSlice{slice, shape});
|
||||
}
|
||||
|
||||
OutfeedConfig config = GetOutfeedConfig(inst);
|
||||
return absl::make_unique<OutfeedThunk>(context_->GetThunkInfo(inst),
|
||||
std::move(config), std::move(slices));
|
||||
std::move(config),
|
||||
std::move(source_slices));
|
||||
}
|
||||
|
||||
Status ThunkEmitter::HandleCustomCall(HloInstruction* custom_call) {
|
||||
// A CustomCall on the GPU backend can either be a custom-call to a
|
||||
// user-supplied kernel, or a call into a library like cudnn.
|
||||
|
||||
|
||||
#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \
|
||||
(defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM)
|
||||
if (void* call_target = CustomCallTargetRegistry::Global()->Lookup(
|
||||
|
Loading…
Reference in New Issue
Block a user