[XLA:GPU] Simplify outfeed thunk to accept leaf slices.

- Since MLIR will not have tuples, change outfeed tuple to accept just the leaf
  slices.

PiperOrigin-RevId: 350778961
Change-Id: I8fd9520a314848d985a2f423d9effc146f3e75a0
This commit is contained in:
Rahul Joshi 2021-01-08 09:32:34 -08:00 committed by TensorFlower Gardener
parent adc7dadc74
commit e0886adfed
6 changed files with 65 additions and 58 deletions

View File

@ -32,13 +32,6 @@ namespace gpu {
// to the buffer allocated for the infeed op.
class InfeedThunk : public Thunk {
public:
// A struct that defines a shaped slice, i.e., a BufferAllocation::Slice and
// its shape.
struct ShapedSlice {
BufferAllocation::Slice slice;
Shape shape;
};
// Constructs a InfeedThunk that copies data from the on-device
// infeed queue into the buffers in the given shape tree.
InfeedThunk(ThunkInfo thunk_info, std::vector<ShapedSlice>&& dest_slices);

View File

@ -3175,13 +3175,13 @@ Status IrEmitterUnnested::HandleInfeed(HloInstruction* xla_infeed) {
auto infeed_op = mlir::dyn_cast<mlir::lmhlo::InfeedOp>(input.op);
std::vector<InfeedThunk::ShapedSlice> dest_slices;
std::vector<ShapedSlice> dest_slices;
dest_slices.reserve(infeed_op.outputs().size());
for (mlir::Value output : infeed_op.outputs()) {
TF_ASSIGN_OR_RETURN(auto slice, GetAllocationSliceForMlir(output));
const Shape& shape = TypeToShape(output.getType());
dest_slices.push_back(InfeedThunk::ShapedSlice{slice, shape});
dest_slices.push_back(ShapedSlice{slice, shape});
}
AddThunkToThunkSequence(
@ -4325,8 +4325,8 @@ void IrEmitterUnnested::EmitPrologueForReduction(
}
const HloInstruction* init_value = reduce_hlo->operand(1);
init_ir_value = (*fused_emitter->GetGenerator(init_value))(
IrArray::Index(b_.getInt32Ty()))
init_ir_value = (*fused_emitter->GetGenerator(
init_value))(IrArray::Index(b_.getInt32Ty()))
.ValueOrDie();
} else {
init_ir_value = operand_ir_arrays[1].EmitReadArrayElement(

View File

@ -31,11 +31,11 @@ OutfeedConfig GetOutfeedConfig(const HloInstruction* instr) {
return config;
}
OutfeedThunk::OutfeedThunk(ThunkInfo thunk_info, OutfeedConfig&& config,
ShapeTree<BufferAllocation::Slice> outfeed_slices)
OutfeedThunk::OutfeedThunk(ThunkInfo thunk_info, OutfeedConfig config,
std::vector<ShapedSlice> source_slices)
: Thunk(Kind::kOutfeed, thunk_info),
config_(std::move(config)),
outfeed_slices_(std::move(outfeed_slices)) {}
source_slices_(std::move(source_slices)) {}
Status OutfeedThunk::ExecuteOnStream(const ExecuteParams& params) {
auto& stream = *params.stream;
@ -43,43 +43,44 @@ Status OutfeedThunk::ExecuteOnStream(const ExecuteParams& params) {
VLOG(2) << "Outfeeding from GPU";
auto op_profiler =
params.profiler->MakeScopedInstructionProfiler(profile_index());
OutfeedManager* outfeed_manager = GetOrCreateOutfeedManager();
ShapeTree<std::unique_ptr<OutfeedBuffer>>* outfeed_buffers =
outfeed_manager->BlockingGetNextDestination();
// Nothing to be done for empty tuples.
if (ShapeUtil::IsEmptyTuple(config_.input_shape)) {
return Status::OK();
}
CHECK(ShapeUtil::Compatible(config_.input_shape, outfeed_buffers->shape()))
<< "XLA program outfeed request of shape "
<< config_.input_shape.ToString()
<< " did not match the runtime's outfeed buffer of shape "
<< outfeed_buffers->shape().ToString();
TF_RETURN_IF_ERROR(outfeed_buffers->ForEachMutableElementWithStatus(
[&](const ShapeIndex& index, std::unique_ptr<OutfeedBuffer>* buffer) {
if (!*buffer) { // Tuple pointers.
return Status::OK();
}
auto op_profiler =
params.profiler->MakeScopedInstructionProfiler(profile_index());
OutfeedManager* outfeed_manager = GetOrCreateOutfeedManager();
ShapeTree<std::unique_ptr<OutfeedBuffer>>* output_buffers =
outfeed_manager->BlockingGetNextDestination();
BufferAllocation::Slice slice = outfeed_slices_.element(index);
if (!slice.allocation())
return InternalError("outfeed input missing buffer allocation");
se::DeviceMemoryBase data_address =
buffer_allocations.GetDeviceAddress(slice);
size_t index = 0;
for (auto& output : output_buffers->leaves()) {
// Assert that the shapes are compatible.
const ShapeIndex& shape_index = output.first;
std::unique_ptr<OutfeedBuffer>& buffer = output.second;
const Shape& output_shape =
ShapeUtil::GetSubshape(output_buffers->shape(), shape_index);
TF_RET_CHECK(ShapeUtil::Equal(source_slices_[index].shape, output_shape))
<< "Mismatch between outfeed output buffer shape "
<< ShapeUtil::HumanStringWithLayout(output_shape)
<< " and outfeed source buffer shape "
<< ShapeUtil::HumanStringWithLayout(source_slices_[index].shape);
// TODO(b/111309141): Run this on a separate stream so it doesn't block
// the GPU from doing work during the transfer. This could be handled by
// making StreamAssignment do something intelligent with outfeed thunks.
stream
.ThenMemcpy((*buffer)->destination()->untyped_data(), data_address,
(*buffer)->length())
.ThenDoHostCallback([buffer]() { (*buffer)->Done(); });
return Status::OK();
}));
BufferAllocation::Slice source_slice = source_slices_[index++].slice;
if (!source_slice.allocation())
return InternalError("outfeed source missing buffer allocation");
se::DeviceMemoryBase data_address =
buffer_allocations.GetDeviceAddress(source_slice);
// TODO(b/111309141): Run this on a separate stream so it doesn't block
// the GPU from doing work during the transfer. This could be handled by
// making StreamAssignment do something intelligent with outfeed thunks.
stream
.ThenMemcpy(buffer->destination()->untyped_data(), data_address,
buffer->length())
.ThenDoHostCallback([&buffer]() { buffer->Done(); });
}
Status block_status = stream.BlockHostUntilDone();
if (!block_status.ok()) {

View File

@ -38,8 +38,8 @@ class OutfeedThunk : public Thunk {
public:
// Constructs a OutfeedThunk that copies data to the host-side
// outfeed queue from the buffers in the given shape tree.
OutfeedThunk(ThunkInfo thunk_info, OutfeedConfig&& config,
ShapeTree<BufferAllocation::Slice> outfeed_slices);
OutfeedThunk(ThunkInfo thunk_info, OutfeedConfig config,
std::vector<ShapedSlice> source_slices);
OutfeedThunk(const OutfeedThunk&) = delete;
OutfeedThunk& operator=(const OutfeedThunk&) = delete;
@ -48,7 +48,7 @@ class OutfeedThunk : public Thunk {
private:
const OutfeedConfig config_;
const ShapeTree<BufferAllocation::Slice> outfeed_slices_;
const std::vector<ShapedSlice> source_slices_;
};
} // namespace gpu

View File

@ -148,6 +148,13 @@ using ThunkSequence = std::vector<std::unique_ptr<Thunk>>;
absl::string_view ThunkKindToString(Thunk::Kind);
std::ostream& operator<<(std::ostream& os, Thunk::Kind kind);
// A struct that defines a shaped slice, i.e., a BufferAllocation::Slice and its
// shape.
struct ShapedSlice {
BufferAllocation::Slice slice;
Shape shape;
};
} // namespace gpu
} // namespace xla

View File

@ -43,7 +43,6 @@ limitations under the License.
namespace xla {
namespace gpu {
std::unique_ptr<Thunk> ThunkEmitter::BuildFftThunk(const HloInstruction* inst) {
const HloInstruction* operand = inst->operand(0);
return absl::make_unique<FftThunk>(
@ -119,24 +118,31 @@ std::unique_ptr<Thunk> ThunkEmitter::BuildOutfeedThunk(
const HloInstruction* inst) {
CHECK_EQ(HloOpcode::kOutfeed, inst->opcode());
ShapeTree<BufferAllocation::Slice> slices(inst->operand(0)->shape());
slices.ForEachMutableElement([&](const ShapeIndex& index,
BufferAllocation::Slice* slice) {
auto status_or_slice = MaybeGetAllocationSlice(*inst->operand(0), index);
if (status_or_slice.ok()) {
*slice = status_or_slice.ValueOrDie();
}
});
const HloInstruction* source = inst->operand(0);
std::vector<ShapeUtil::IndexedShape> leaf_shapes =
ShapeUtil::GetLeafShapes(source->shape());
std::vector<ShapedSlice> source_slices;
source_slices.reserve(leaf_shapes.size());
for (ShapeUtil::IndexedShape& indexed_shape : leaf_shapes) {
BufferAllocation::Slice slice =
GetAllocationSlice(*source, indexed_shape.index);
const Shape& shape =
ShapeUtil::GetSubshape(source->shape(), indexed_shape.index);
source_slices.push_back(ShapedSlice{slice, shape});
}
OutfeedConfig config = GetOutfeedConfig(inst);
return absl::make_unique<OutfeedThunk>(context_->GetThunkInfo(inst),
std::move(config), std::move(slices));
std::move(config),
std::move(source_slices));
}
Status ThunkEmitter::HandleCustomCall(HloInstruction* custom_call) {
// A CustomCall on the GPU backend can either be a custom-call to a
// user-supplied kernel, or a call into a library like cudnn.
#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \
(defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM)
if (void* call_target = CustomCallTargetRegistry::Global()->Lookup(