[XLA:GPU] Simplify Outfeed thunk to eliminate code for handling dynamic tuples.
- XLA:GPU does not support dynamic tuples (which result from tuple select), so remove code from outfeed thunk that was required to support them. - Without tuple select, we should be able to find the allocation slice for each Outfeed input. PiperOrigin-RevId: 349443816 Change-Id: I410059ecb844c6f9a8145d1751453ffe1967d072
This commit is contained in:
parent
49fc208865
commit
aa92d1d7a9
@ -66,36 +66,10 @@ Status OutfeedThunk::ExecuteOnStream(const ExecuteParams& params) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
BufferAllocation::Slice slice = outfeed_slices_.element(index);
|
BufferAllocation::Slice slice = outfeed_slices_.element(index);
|
||||||
se::DeviceMemoryBase data_address;
|
if (!slice.allocation())
|
||||||
if (slice.allocation()) {
|
return InternalError("outfeed input missing buffer allocation");
|
||||||
// If we have a static allocation, read it from there. This avoids
|
se::DeviceMemoryBase data_address =
|
||||||
// synchronizing the host and device just to read a pointer.
|
buffer_allocations.GetDeviceAddress(slice);
|
||||||
data_address = buffer_allocations.GetDeviceAddress(slice);
|
|
||||||
} else {
|
|
||||||
// Otherwise we have to read the tuple pointer first.
|
|
||||||
CHECK(!index.empty());
|
|
||||||
// Copy the parent buffer to the host.
|
|
||||||
BufferAllocation::Slice tuple_slice =
|
|
||||||
outfeed_slices_.element(ShapeIndexView(index).ConsumeFront());
|
|
||||||
if (!tuple_slice.allocation()) {
|
|
||||||
return Unimplemented(
|
|
||||||
"Nested dynamic tuples are not supported on GPU");
|
|
||||||
}
|
|
||||||
se::DeviceMemoryBase tuple_address =
|
|
||||||
buffer_allocations.GetDeviceAddress(tuple_slice);
|
|
||||||
CHECK(tuple_slice.size() % sizeof(void*) == 0)
|
|
||||||
<< "Tuple size must be a multiple of pointer size";
|
|
||||||
std::vector<void*> tuple_element_buffer_addresses(tuple_slice.size() /
|
|
||||||
sizeof(void*));
|
|
||||||
stream.ThenMemcpy(tuple_element_buffer_addresses.data(),
|
|
||||||
tuple_address, tuple_slice.size());
|
|
||||||
TF_RETURN_IF_ERROR(stream.BlockHostUntilDone());
|
|
||||||
// The data address is specified by the element of the tuple pointer
|
|
||||||
// buffer.
|
|
||||||
data_address =
|
|
||||||
se::DeviceMemoryBase(tuple_element_buffer_addresses[index.back()],
|
|
||||||
(*buffer)->length());
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO(b/111309141): Run this on a separate stream so it doesn't block
|
// TODO(b/111309141): Run this on a separate stream so it doesn't block
|
||||||
// the GPU from doing work during the transfer. This could be handled by
|
// the GPU from doing work during the transfer. This could be handled by
|
||||||
|
@ -32,7 +32,7 @@ struct OutfeedConfig {
|
|||||||
OutfeedConfig GetOutfeedConfig(const HloInstruction* instr);
|
OutfeedConfig GetOutfeedConfig(const HloInstruction* instr);
|
||||||
|
|
||||||
// A thunk that outfeeds data. Data must be already resident on the host. This
|
// A thunk that outfeeds data. Data must be already resident on the host. This
|
||||||
// thunk performs a host to device copy from the buffer allocated for the
|
// thunk performs a device to host copy from the buffer allocated for the
|
||||||
// outfeed op to the host location.
|
// outfeed op to the host location.
|
||||||
class OutfeedThunk : public Thunk {
|
class OutfeedThunk : public Thunk {
|
||||||
public:
|
public:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user