[XLA:GPU] Simplify outfeed thunk to accept leaf slices.

- Since MLIR will not have tuples, change outfeed tuple to accept just the leaf
  slices.

PiperOrigin-RevId: 350778961
Change-Id: I8fd9520a314848d985a2f423d9effc146f3e75a0
This commit is contained in:
Rahul Joshi 2021-01-08 09:32:34 -08:00 committed by TensorFlower Gardener
parent adc7dadc74
commit e0886adfed
6 changed files with 65 additions and 58 deletions

View File

@ -32,13 +32,6 @@ namespace gpu {
// to the buffer allocated for the infeed op. // to the buffer allocated for the infeed op.
class InfeedThunk : public Thunk { class InfeedThunk : public Thunk {
public: public:
// A struct that defines a shaped slice, i.e., a BufferAllocation::Slice and
// its shape.
struct ShapedSlice {
BufferAllocation::Slice slice;
Shape shape;
};
// Constructs a InfeedThunk that copies data from the on-device // Constructs a InfeedThunk that copies data from the on-device
// infeed queue into the buffers in the given shape tree. // infeed queue into the buffers in the given shape tree.
InfeedThunk(ThunkInfo thunk_info, std::vector<ShapedSlice>&& dest_slices); InfeedThunk(ThunkInfo thunk_info, std::vector<ShapedSlice>&& dest_slices);

View File

@ -3175,13 +3175,13 @@ Status IrEmitterUnnested::HandleInfeed(HloInstruction* xla_infeed) {
auto infeed_op = mlir::dyn_cast<mlir::lmhlo::InfeedOp>(input.op); auto infeed_op = mlir::dyn_cast<mlir::lmhlo::InfeedOp>(input.op);
std::vector<InfeedThunk::ShapedSlice> dest_slices; std::vector<ShapedSlice> dest_slices;
dest_slices.reserve(infeed_op.outputs().size()); dest_slices.reserve(infeed_op.outputs().size());
for (mlir::Value output : infeed_op.outputs()) { for (mlir::Value output : infeed_op.outputs()) {
TF_ASSIGN_OR_RETURN(auto slice, GetAllocationSliceForMlir(output)); TF_ASSIGN_OR_RETURN(auto slice, GetAllocationSliceForMlir(output));
const Shape& shape = TypeToShape(output.getType()); const Shape& shape = TypeToShape(output.getType());
dest_slices.push_back(InfeedThunk::ShapedSlice{slice, shape}); dest_slices.push_back(ShapedSlice{slice, shape});
} }
AddThunkToThunkSequence( AddThunkToThunkSequence(
@ -4325,8 +4325,8 @@ void IrEmitterUnnested::EmitPrologueForReduction(
} }
const HloInstruction* init_value = reduce_hlo->operand(1); const HloInstruction* init_value = reduce_hlo->operand(1);
init_ir_value = (*fused_emitter->GetGenerator(init_value))( init_ir_value = (*fused_emitter->GetGenerator(
IrArray::Index(b_.getInt32Ty())) init_value))(IrArray::Index(b_.getInt32Ty()))
.ValueOrDie(); .ValueOrDie();
} else { } else {
init_ir_value = operand_ir_arrays[1].EmitReadArrayElement( init_ir_value = operand_ir_arrays[1].EmitReadArrayElement(

View File

@ -31,11 +31,11 @@ OutfeedConfig GetOutfeedConfig(const HloInstruction* instr) {
return config; return config;
} }
OutfeedThunk::OutfeedThunk(ThunkInfo thunk_info, OutfeedConfig&& config, OutfeedThunk::OutfeedThunk(ThunkInfo thunk_info, OutfeedConfig config,
ShapeTree<BufferAllocation::Slice> outfeed_slices) std::vector<ShapedSlice> source_slices)
: Thunk(Kind::kOutfeed, thunk_info), : Thunk(Kind::kOutfeed, thunk_info),
config_(std::move(config)), config_(std::move(config)),
outfeed_slices_(std::move(outfeed_slices)) {} source_slices_(std::move(source_slices)) {}
Status OutfeedThunk::ExecuteOnStream(const ExecuteParams& params) { Status OutfeedThunk::ExecuteOnStream(const ExecuteParams& params) {
auto& stream = *params.stream; auto& stream = *params.stream;
@ -43,43 +43,44 @@ Status OutfeedThunk::ExecuteOnStream(const ExecuteParams& params) {
VLOG(2) << "Outfeeding from GPU"; VLOG(2) << "Outfeeding from GPU";
auto op_profiler =
params.profiler->MakeScopedInstructionProfiler(profile_index());
OutfeedManager* outfeed_manager = GetOrCreateOutfeedManager();
ShapeTree<std::unique_ptr<OutfeedBuffer>>* outfeed_buffers =
outfeed_manager->BlockingGetNextDestination();
// Nothing to be done for empty tuples. // Nothing to be done for empty tuples.
if (ShapeUtil::IsEmptyTuple(config_.input_shape)) { if (ShapeUtil::IsEmptyTuple(config_.input_shape)) {
return Status::OK(); return Status::OK();
} }
CHECK(ShapeUtil::Compatible(config_.input_shape, outfeed_buffers->shape()))
<< "XLA program outfeed request of shape "
<< config_.input_shape.ToString()
<< " did not match the runtime's outfeed buffer of shape "
<< outfeed_buffers->shape().ToString();
TF_RETURN_IF_ERROR(outfeed_buffers->ForEachMutableElementWithStatus( auto op_profiler =
[&](const ShapeIndex& index, std::unique_ptr<OutfeedBuffer>* buffer) { params.profiler->MakeScopedInstructionProfiler(profile_index());
if (!*buffer) { // Tuple pointers. OutfeedManager* outfeed_manager = GetOrCreateOutfeedManager();
return Status::OK(); ShapeTree<std::unique_ptr<OutfeedBuffer>>* output_buffers =
} outfeed_manager->BlockingGetNextDestination();
BufferAllocation::Slice slice = outfeed_slices_.element(index); size_t index = 0;
if (!slice.allocation()) for (auto& output : output_buffers->leaves()) {
return InternalError("outfeed input missing buffer allocation"); // Assert that the shapes are compatible.
const ShapeIndex& shape_index = output.first;
std::unique_ptr<OutfeedBuffer>& buffer = output.second;
const Shape& output_shape =
ShapeUtil::GetSubshape(output_buffers->shape(), shape_index);
TF_RET_CHECK(ShapeUtil::Equal(source_slices_[index].shape, output_shape))
<< "Mismatch between outfeed output buffer shape "
<< ShapeUtil::HumanStringWithLayout(output_shape)
<< " and outfeed source buffer shape "
<< ShapeUtil::HumanStringWithLayout(source_slices_[index].shape);
BufferAllocation::Slice source_slice = source_slices_[index++].slice;
if (!source_slice.allocation())
return InternalError("outfeed source missing buffer allocation");
se::DeviceMemoryBase data_address = se::DeviceMemoryBase data_address =
buffer_allocations.GetDeviceAddress(slice); buffer_allocations.GetDeviceAddress(source_slice);
// TODO(b/111309141): Run this on a separate stream so it doesn't block // TODO(b/111309141): Run this on a separate stream so it doesn't block
// the GPU from doing work during the transfer. This could be handled by // the GPU from doing work during the transfer. This could be handled by
// making StreamAssignment do something intelligent with outfeed thunks. // making StreamAssignment do something intelligent with outfeed thunks.
stream stream
.ThenMemcpy((*buffer)->destination()->untyped_data(), data_address, .ThenMemcpy(buffer->destination()->untyped_data(), data_address,
(*buffer)->length()) buffer->length())
.ThenDoHostCallback([buffer]() { (*buffer)->Done(); }); .ThenDoHostCallback([&buffer]() { buffer->Done(); });
return Status::OK(); }
}));
Status block_status = stream.BlockHostUntilDone(); Status block_status = stream.BlockHostUntilDone();
if (!block_status.ok()) { if (!block_status.ok()) {

View File

@ -38,8 +38,8 @@ class OutfeedThunk : public Thunk {
public: public:
// Constructs a OutfeedThunk that copies data to the host-side // Constructs a OutfeedThunk that copies data to the host-side
// outfeed queue from the buffers in the given shape tree. // outfeed queue from the buffers in the given shape tree.
OutfeedThunk(ThunkInfo thunk_info, OutfeedConfig&& config, OutfeedThunk(ThunkInfo thunk_info, OutfeedConfig config,
ShapeTree<BufferAllocation::Slice> outfeed_slices); std::vector<ShapedSlice> source_slices);
OutfeedThunk(const OutfeedThunk&) = delete; OutfeedThunk(const OutfeedThunk&) = delete;
OutfeedThunk& operator=(const OutfeedThunk&) = delete; OutfeedThunk& operator=(const OutfeedThunk&) = delete;
@ -48,7 +48,7 @@ class OutfeedThunk : public Thunk {
private: private:
const OutfeedConfig config_; const OutfeedConfig config_;
const ShapeTree<BufferAllocation::Slice> outfeed_slices_; const std::vector<ShapedSlice> source_slices_;
}; };
} // namespace gpu } // namespace gpu

View File

@ -148,6 +148,13 @@ using ThunkSequence = std::vector<std::unique_ptr<Thunk>>;
absl::string_view ThunkKindToString(Thunk::Kind); absl::string_view ThunkKindToString(Thunk::Kind);
std::ostream& operator<<(std::ostream& os, Thunk::Kind kind); std::ostream& operator<<(std::ostream& os, Thunk::Kind kind);
// A struct that defines a shaped slice, i.e., a BufferAllocation::Slice and its
// shape.
struct ShapedSlice {
BufferAllocation::Slice slice;
Shape shape;
};
} // namespace gpu } // namespace gpu
} // namespace xla } // namespace xla

View File

@ -43,7 +43,6 @@ limitations under the License.
namespace xla { namespace xla {
namespace gpu { namespace gpu {
std::unique_ptr<Thunk> ThunkEmitter::BuildFftThunk(const HloInstruction* inst) { std::unique_ptr<Thunk> ThunkEmitter::BuildFftThunk(const HloInstruction* inst) {
const HloInstruction* operand = inst->operand(0); const HloInstruction* operand = inst->operand(0);
return absl::make_unique<FftThunk>( return absl::make_unique<FftThunk>(
@ -119,24 +118,31 @@ std::unique_ptr<Thunk> ThunkEmitter::BuildOutfeedThunk(
const HloInstruction* inst) { const HloInstruction* inst) {
CHECK_EQ(HloOpcode::kOutfeed, inst->opcode()); CHECK_EQ(HloOpcode::kOutfeed, inst->opcode());
ShapeTree<BufferAllocation::Slice> slices(inst->operand(0)->shape()); const HloInstruction* source = inst->operand(0);
slices.ForEachMutableElement([&](const ShapeIndex& index, std::vector<ShapeUtil::IndexedShape> leaf_shapes =
BufferAllocation::Slice* slice) { ShapeUtil::GetLeafShapes(source->shape());
auto status_or_slice = MaybeGetAllocationSlice(*inst->operand(0), index);
if (status_or_slice.ok()) { std::vector<ShapedSlice> source_slices;
*slice = status_or_slice.ValueOrDie(); source_slices.reserve(leaf_shapes.size());
for (ShapeUtil::IndexedShape& indexed_shape : leaf_shapes) {
BufferAllocation::Slice slice =
GetAllocationSlice(*source, indexed_shape.index);
const Shape& shape =
ShapeUtil::GetSubshape(source->shape(), indexed_shape.index);
source_slices.push_back(ShapedSlice{slice, shape});
} }
});
OutfeedConfig config = GetOutfeedConfig(inst); OutfeedConfig config = GetOutfeedConfig(inst);
return absl::make_unique<OutfeedThunk>(context_->GetThunkInfo(inst), return absl::make_unique<OutfeedThunk>(context_->GetThunkInfo(inst),
std::move(config), std::move(slices)); std::move(config),
std::move(source_slices));
} }
Status ThunkEmitter::HandleCustomCall(HloInstruction* custom_call) { Status ThunkEmitter::HandleCustomCall(HloInstruction* custom_call) {
// A CustomCall on the GPU backend can either be a custom-call to a // A CustomCall on the GPU backend can either be a custom-call to a
// user-supplied kernel, or a call into a library like cudnn. // user-supplied kernel, or a call into a library like cudnn.
#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \ #if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \
(defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM) (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM)
if (void* call_target = CustomCallTargetRegistry::Global()->Lookup( if (void* call_target = CustomCallTargetRegistry::Global()->Lookup(