[XLA:GPU] Simplify outfeed thunk to accept leaf slices.
- Since MLIR will not have tuples, change outfeed tuple to accept just the leaf slices. PiperOrigin-RevId: 350778961 Change-Id: I8fd9520a314848d985a2f423d9effc146f3e75a0
This commit is contained in:
parent
adc7dadc74
commit
e0886adfed
@ -32,13 +32,6 @@ namespace gpu {
|
|||||||
// to the buffer allocated for the infeed op.
|
// to the buffer allocated for the infeed op.
|
||||||
class InfeedThunk : public Thunk {
|
class InfeedThunk : public Thunk {
|
||||||
public:
|
public:
|
||||||
// A struct that defines a shaped slice, i.e., a BufferAllocation::Slice and
|
|
||||||
// its shape.
|
|
||||||
struct ShapedSlice {
|
|
||||||
BufferAllocation::Slice slice;
|
|
||||||
Shape shape;
|
|
||||||
};
|
|
||||||
|
|
||||||
// Constructs a InfeedThunk that copies data from the on-device
|
// Constructs a InfeedThunk that copies data from the on-device
|
||||||
// infeed queue into the buffers in the given shape tree.
|
// infeed queue into the buffers in the given shape tree.
|
||||||
InfeedThunk(ThunkInfo thunk_info, std::vector<ShapedSlice>&& dest_slices);
|
InfeedThunk(ThunkInfo thunk_info, std::vector<ShapedSlice>&& dest_slices);
|
||||||
|
@ -3175,13 +3175,13 @@ Status IrEmitterUnnested::HandleInfeed(HloInstruction* xla_infeed) {
|
|||||||
|
|
||||||
auto infeed_op = mlir::dyn_cast<mlir::lmhlo::InfeedOp>(input.op);
|
auto infeed_op = mlir::dyn_cast<mlir::lmhlo::InfeedOp>(input.op);
|
||||||
|
|
||||||
std::vector<InfeedThunk::ShapedSlice> dest_slices;
|
std::vector<ShapedSlice> dest_slices;
|
||||||
dest_slices.reserve(infeed_op.outputs().size());
|
dest_slices.reserve(infeed_op.outputs().size());
|
||||||
|
|
||||||
for (mlir::Value output : infeed_op.outputs()) {
|
for (mlir::Value output : infeed_op.outputs()) {
|
||||||
TF_ASSIGN_OR_RETURN(auto slice, GetAllocationSliceForMlir(output));
|
TF_ASSIGN_OR_RETURN(auto slice, GetAllocationSliceForMlir(output));
|
||||||
const Shape& shape = TypeToShape(output.getType());
|
const Shape& shape = TypeToShape(output.getType());
|
||||||
dest_slices.push_back(InfeedThunk::ShapedSlice{slice, shape});
|
dest_slices.push_back(ShapedSlice{slice, shape});
|
||||||
}
|
}
|
||||||
|
|
||||||
AddThunkToThunkSequence(
|
AddThunkToThunkSequence(
|
||||||
@ -4325,8 +4325,8 @@ void IrEmitterUnnested::EmitPrologueForReduction(
|
|||||||
}
|
}
|
||||||
const HloInstruction* init_value = reduce_hlo->operand(1);
|
const HloInstruction* init_value = reduce_hlo->operand(1);
|
||||||
|
|
||||||
init_ir_value = (*fused_emitter->GetGenerator(init_value))(
|
init_ir_value = (*fused_emitter->GetGenerator(
|
||||||
IrArray::Index(b_.getInt32Ty()))
|
init_value))(IrArray::Index(b_.getInt32Ty()))
|
||||||
.ValueOrDie();
|
.ValueOrDie();
|
||||||
} else {
|
} else {
|
||||||
init_ir_value = operand_ir_arrays[1].EmitReadArrayElement(
|
init_ir_value = operand_ir_arrays[1].EmitReadArrayElement(
|
||||||
|
@ -31,11 +31,11 @@ OutfeedConfig GetOutfeedConfig(const HloInstruction* instr) {
|
|||||||
return config;
|
return config;
|
||||||
}
|
}
|
||||||
|
|
||||||
OutfeedThunk::OutfeedThunk(ThunkInfo thunk_info, OutfeedConfig&& config,
|
OutfeedThunk::OutfeedThunk(ThunkInfo thunk_info, OutfeedConfig config,
|
||||||
ShapeTree<BufferAllocation::Slice> outfeed_slices)
|
std::vector<ShapedSlice> source_slices)
|
||||||
: Thunk(Kind::kOutfeed, thunk_info),
|
: Thunk(Kind::kOutfeed, thunk_info),
|
||||||
config_(std::move(config)),
|
config_(std::move(config)),
|
||||||
outfeed_slices_(std::move(outfeed_slices)) {}
|
source_slices_(std::move(source_slices)) {}
|
||||||
|
|
||||||
Status OutfeedThunk::ExecuteOnStream(const ExecuteParams& params) {
|
Status OutfeedThunk::ExecuteOnStream(const ExecuteParams& params) {
|
||||||
auto& stream = *params.stream;
|
auto& stream = *params.stream;
|
||||||
@ -43,43 +43,44 @@ Status OutfeedThunk::ExecuteOnStream(const ExecuteParams& params) {
|
|||||||
|
|
||||||
VLOG(2) << "Outfeeding from GPU";
|
VLOG(2) << "Outfeeding from GPU";
|
||||||
|
|
||||||
auto op_profiler =
|
|
||||||
params.profiler->MakeScopedInstructionProfiler(profile_index());
|
|
||||||
OutfeedManager* outfeed_manager = GetOrCreateOutfeedManager();
|
|
||||||
ShapeTree<std::unique_ptr<OutfeedBuffer>>* outfeed_buffers =
|
|
||||||
outfeed_manager->BlockingGetNextDestination();
|
|
||||||
|
|
||||||
// Nothing to be done for empty tuples.
|
// Nothing to be done for empty tuples.
|
||||||
if (ShapeUtil::IsEmptyTuple(config_.input_shape)) {
|
if (ShapeUtil::IsEmptyTuple(config_.input_shape)) {
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
CHECK(ShapeUtil::Compatible(config_.input_shape, outfeed_buffers->shape()))
|
|
||||||
<< "XLA program outfeed request of shape "
|
|
||||||
<< config_.input_shape.ToString()
|
|
||||||
<< " did not match the runtime's outfeed buffer of shape "
|
|
||||||
<< outfeed_buffers->shape().ToString();
|
|
||||||
|
|
||||||
TF_RETURN_IF_ERROR(outfeed_buffers->ForEachMutableElementWithStatus(
|
auto op_profiler =
|
||||||
[&](const ShapeIndex& index, std::unique_ptr<OutfeedBuffer>* buffer) {
|
params.profiler->MakeScopedInstructionProfiler(profile_index());
|
||||||
if (!*buffer) { // Tuple pointers.
|
OutfeedManager* outfeed_manager = GetOrCreateOutfeedManager();
|
||||||
return Status::OK();
|
ShapeTree<std::unique_ptr<OutfeedBuffer>>* output_buffers =
|
||||||
}
|
outfeed_manager->BlockingGetNextDestination();
|
||||||
|
|
||||||
BufferAllocation::Slice slice = outfeed_slices_.element(index);
|
size_t index = 0;
|
||||||
if (!slice.allocation())
|
for (auto& output : output_buffers->leaves()) {
|
||||||
return InternalError("outfeed input missing buffer allocation");
|
// Assert that the shapes are compatible.
|
||||||
|
const ShapeIndex& shape_index = output.first;
|
||||||
|
std::unique_ptr<OutfeedBuffer>& buffer = output.second;
|
||||||
|
const Shape& output_shape =
|
||||||
|
ShapeUtil::GetSubshape(output_buffers->shape(), shape_index);
|
||||||
|
TF_RET_CHECK(ShapeUtil::Equal(source_slices_[index].shape, output_shape))
|
||||||
|
<< "Mismatch between outfeed output buffer shape "
|
||||||
|
<< ShapeUtil::HumanStringWithLayout(output_shape)
|
||||||
|
<< " and outfeed source buffer shape "
|
||||||
|
<< ShapeUtil::HumanStringWithLayout(source_slices_[index].shape);
|
||||||
|
|
||||||
|
BufferAllocation::Slice source_slice = source_slices_[index++].slice;
|
||||||
|
if (!source_slice.allocation())
|
||||||
|
return InternalError("outfeed source missing buffer allocation");
|
||||||
se::DeviceMemoryBase data_address =
|
se::DeviceMemoryBase data_address =
|
||||||
buffer_allocations.GetDeviceAddress(slice);
|
buffer_allocations.GetDeviceAddress(source_slice);
|
||||||
|
|
||||||
// TODO(b/111309141): Run this on a separate stream so it doesn't block
|
// TODO(b/111309141): Run this on a separate stream so it doesn't block
|
||||||
// the GPU from doing work during the transfer. This could be handled by
|
// the GPU from doing work during the transfer. This could be handled by
|
||||||
// making StreamAssignment do something intelligent with outfeed thunks.
|
// making StreamAssignment do something intelligent with outfeed thunks.
|
||||||
stream
|
stream
|
||||||
.ThenMemcpy((*buffer)->destination()->untyped_data(), data_address,
|
.ThenMemcpy(buffer->destination()->untyped_data(), data_address,
|
||||||
(*buffer)->length())
|
buffer->length())
|
||||||
.ThenDoHostCallback([buffer]() { (*buffer)->Done(); });
|
.ThenDoHostCallback([&buffer]() { buffer->Done(); });
|
||||||
return Status::OK();
|
}
|
||||||
}));
|
|
||||||
|
|
||||||
Status block_status = stream.BlockHostUntilDone();
|
Status block_status = stream.BlockHostUntilDone();
|
||||||
if (!block_status.ok()) {
|
if (!block_status.ok()) {
|
||||||
|
@ -38,8 +38,8 @@ class OutfeedThunk : public Thunk {
|
|||||||
public:
|
public:
|
||||||
// Constructs a OutfeedThunk that copies data to the host-side
|
// Constructs a OutfeedThunk that copies data to the host-side
|
||||||
// outfeed queue from the buffers in the given shape tree.
|
// outfeed queue from the buffers in the given shape tree.
|
||||||
OutfeedThunk(ThunkInfo thunk_info, OutfeedConfig&& config,
|
OutfeedThunk(ThunkInfo thunk_info, OutfeedConfig config,
|
||||||
ShapeTree<BufferAllocation::Slice> outfeed_slices);
|
std::vector<ShapedSlice> source_slices);
|
||||||
|
|
||||||
OutfeedThunk(const OutfeedThunk&) = delete;
|
OutfeedThunk(const OutfeedThunk&) = delete;
|
||||||
OutfeedThunk& operator=(const OutfeedThunk&) = delete;
|
OutfeedThunk& operator=(const OutfeedThunk&) = delete;
|
||||||
@ -48,7 +48,7 @@ class OutfeedThunk : public Thunk {
|
|||||||
|
|
||||||
private:
|
private:
|
||||||
const OutfeedConfig config_;
|
const OutfeedConfig config_;
|
||||||
const ShapeTree<BufferAllocation::Slice> outfeed_slices_;
|
const std::vector<ShapedSlice> source_slices_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace gpu
|
} // namespace gpu
|
||||||
|
@ -148,6 +148,13 @@ using ThunkSequence = std::vector<std::unique_ptr<Thunk>>;
|
|||||||
absl::string_view ThunkKindToString(Thunk::Kind);
|
absl::string_view ThunkKindToString(Thunk::Kind);
|
||||||
std::ostream& operator<<(std::ostream& os, Thunk::Kind kind);
|
std::ostream& operator<<(std::ostream& os, Thunk::Kind kind);
|
||||||
|
|
||||||
|
// A struct that defines a shaped slice, i.e., a BufferAllocation::Slice and its
|
||||||
|
// shape.
|
||||||
|
struct ShapedSlice {
|
||||||
|
BufferAllocation::Slice slice;
|
||||||
|
Shape shape;
|
||||||
|
};
|
||||||
|
|
||||||
} // namespace gpu
|
} // namespace gpu
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
|
||||||
|
@ -43,7 +43,6 @@ limitations under the License.
|
|||||||
namespace xla {
|
namespace xla {
|
||||||
namespace gpu {
|
namespace gpu {
|
||||||
|
|
||||||
|
|
||||||
std::unique_ptr<Thunk> ThunkEmitter::BuildFftThunk(const HloInstruction* inst) {
|
std::unique_ptr<Thunk> ThunkEmitter::BuildFftThunk(const HloInstruction* inst) {
|
||||||
const HloInstruction* operand = inst->operand(0);
|
const HloInstruction* operand = inst->operand(0);
|
||||||
return absl::make_unique<FftThunk>(
|
return absl::make_unique<FftThunk>(
|
||||||
@ -119,24 +118,31 @@ std::unique_ptr<Thunk> ThunkEmitter::BuildOutfeedThunk(
|
|||||||
const HloInstruction* inst) {
|
const HloInstruction* inst) {
|
||||||
CHECK_EQ(HloOpcode::kOutfeed, inst->opcode());
|
CHECK_EQ(HloOpcode::kOutfeed, inst->opcode());
|
||||||
|
|
||||||
ShapeTree<BufferAllocation::Slice> slices(inst->operand(0)->shape());
|
const HloInstruction* source = inst->operand(0);
|
||||||
slices.ForEachMutableElement([&](const ShapeIndex& index,
|
std::vector<ShapeUtil::IndexedShape> leaf_shapes =
|
||||||
BufferAllocation::Slice* slice) {
|
ShapeUtil::GetLeafShapes(source->shape());
|
||||||
auto status_or_slice = MaybeGetAllocationSlice(*inst->operand(0), index);
|
|
||||||
if (status_or_slice.ok()) {
|
std::vector<ShapedSlice> source_slices;
|
||||||
*slice = status_or_slice.ValueOrDie();
|
source_slices.reserve(leaf_shapes.size());
|
||||||
|
|
||||||
|
for (ShapeUtil::IndexedShape& indexed_shape : leaf_shapes) {
|
||||||
|
BufferAllocation::Slice slice =
|
||||||
|
GetAllocationSlice(*source, indexed_shape.index);
|
||||||
|
const Shape& shape =
|
||||||
|
ShapeUtil::GetSubshape(source->shape(), indexed_shape.index);
|
||||||
|
source_slices.push_back(ShapedSlice{slice, shape});
|
||||||
}
|
}
|
||||||
});
|
|
||||||
OutfeedConfig config = GetOutfeedConfig(inst);
|
OutfeedConfig config = GetOutfeedConfig(inst);
|
||||||
return absl::make_unique<OutfeedThunk>(context_->GetThunkInfo(inst),
|
return absl::make_unique<OutfeedThunk>(context_->GetThunkInfo(inst),
|
||||||
std::move(config), std::move(slices));
|
std::move(config),
|
||||||
|
std::move(source_slices));
|
||||||
}
|
}
|
||||||
|
|
||||||
Status ThunkEmitter::HandleCustomCall(HloInstruction* custom_call) {
|
Status ThunkEmitter::HandleCustomCall(HloInstruction* custom_call) {
|
||||||
// A CustomCall on the GPU backend can either be a custom-call to a
|
// A CustomCall on the GPU backend can either be a custom-call to a
|
||||||
// user-supplied kernel, or a call into a library like cudnn.
|
// user-supplied kernel, or a call into a library like cudnn.
|
||||||
|
|
||||||
|
|
||||||
#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \
|
#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \
|
||||||
(defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM)
|
(defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM)
|
||||||
if (void* call_target = CustomCallTargetRegistry::Global()->Lookup(
|
if (void* call_target = CustomCallTargetRegistry::Global()->Lookup(
|
||||||
|
Loading…
Reference in New Issue
Block a user