[XLA:GPU] Simplify Outfeed thunk to eliminate code for handling dynamic tuples.
- XLA:GPU does not support dynamic tuples (which result from tuple select), so remove code from outfeed thunk that was required to support them. - Without tuple select, we should be able to find the allocation slice for each Outfeed input. PiperOrigin-RevId: 349443816 Change-Id: I410059ecb844c6f9a8145d1751453ffe1967d072
This commit is contained in:
parent
49fc208865
commit
aa92d1d7a9
@ -66,36 +66,10 @@ Status OutfeedThunk::ExecuteOnStream(const ExecuteParams& params) {
|
||||
}
|
||||
|
||||
BufferAllocation::Slice slice = outfeed_slices_.element(index);
|
||||
se::DeviceMemoryBase data_address;
|
||||
if (slice.allocation()) {
|
||||
// If we have a static allocation, read it from there. This avoids
|
||||
// synchronizing the host and device just to read a pointer.
|
||||
data_address = buffer_allocations.GetDeviceAddress(slice);
|
||||
} else {
|
||||
// Otherwise we have to read the tuple pointer first.
|
||||
CHECK(!index.empty());
|
||||
// Copy the parent buffer to the host.
|
||||
BufferAllocation::Slice tuple_slice =
|
||||
outfeed_slices_.element(ShapeIndexView(index).ConsumeFront());
|
||||
if (!tuple_slice.allocation()) {
|
||||
return Unimplemented(
|
||||
"Nested dynamic tuples are not supported on GPU");
|
||||
}
|
||||
se::DeviceMemoryBase tuple_address =
|
||||
buffer_allocations.GetDeviceAddress(tuple_slice);
|
||||
CHECK(tuple_slice.size() % sizeof(void*) == 0)
|
||||
<< "Tuple size must be a multiple of pointer size";
|
||||
std::vector<void*> tuple_element_buffer_addresses(tuple_slice.size() /
|
||||
sizeof(void*));
|
||||
stream.ThenMemcpy(tuple_element_buffer_addresses.data(),
|
||||
tuple_address, tuple_slice.size());
|
||||
TF_RETURN_IF_ERROR(stream.BlockHostUntilDone());
|
||||
// The data address is specified by the element of the tuple pointer
|
||||
// buffer.
|
||||
data_address =
|
||||
se::DeviceMemoryBase(tuple_element_buffer_addresses[index.back()],
|
||||
(*buffer)->length());
|
||||
}
|
||||
if (!slice.allocation())
|
||||
return InternalError("outfeed input missing buffer allocation");
|
||||
se::DeviceMemoryBase data_address =
|
||||
buffer_allocations.GetDeviceAddress(slice);
|
||||
|
||||
// TODO(b/111309141): Run this on a separate stream so it doesn't block
|
||||
// the GPU from doing work during the transfer. This could be handled by
|
||||
|
@ -32,7 +32,7 @@ struct OutfeedConfig {
|
||||
OutfeedConfig GetOutfeedConfig(const HloInstruction* instr);
|
||||
|
||||
// A thunk that outfeeds data. Data must be already resident on the host. This
|
||||
// thunk performs a host to device copy from the buffer allocated for the
|
||||
// thunk performs a device to host copy from the buffer allocated for the
|
||||
// outfeed op to the host location.
|
||||
class OutfeedThunk : public Thunk {
|
||||
public:
|
||||
|
Loading…
x
Reference in New Issue
Block a user