[XLA:GPU] Simplify Outfeed thunk to eliminate code for handling dynamic tuples.

- XLA:GPU does not support dynamic tuples (which result from tuple select), so remove
  code from outfeed thunk that was required to support them.
- Without tuple select, we should be able to find the allocation slice for each Outfeed
  input.

PiperOrigin-RevId: 349443816
Change-Id: I410059ecb844c6f9a8145d1751453ffe1967d072
This commit is contained in:
Rahul Joshi 2020-12-29 10:56:39 -08:00 committed by TensorFlower Gardener
parent 49fc208865
commit aa92d1d7a9
2 changed files with 5 additions and 31 deletions

View File

@ -66,36 +66,10 @@ Status OutfeedThunk::ExecuteOnStream(const ExecuteParams& params) {
}
BufferAllocation::Slice slice = outfeed_slices_.element(index);
se::DeviceMemoryBase data_address;
if (slice.allocation()) {
// If we have a static allocation, read it from there. This avoids
// synchronizing the host and device just to read a pointer.
data_address = buffer_allocations.GetDeviceAddress(slice);
} else {
// Otherwise we have to read the tuple pointer first.
CHECK(!index.empty());
// Copy the parent buffer to the host.
BufferAllocation::Slice tuple_slice =
outfeed_slices_.element(ShapeIndexView(index).ConsumeFront());
if (!tuple_slice.allocation()) {
return Unimplemented(
"Nested dynamic tuples are not supported on GPU");
}
se::DeviceMemoryBase tuple_address =
buffer_allocations.GetDeviceAddress(tuple_slice);
CHECK(tuple_slice.size() % sizeof(void*) == 0)
<< "Tuple size must be a multiple of pointer size";
std::vector<void*> tuple_element_buffer_addresses(tuple_slice.size() /
sizeof(void*));
stream.ThenMemcpy(tuple_element_buffer_addresses.data(),
tuple_address, tuple_slice.size());
TF_RETURN_IF_ERROR(stream.BlockHostUntilDone());
// The data address is specified by the element of the tuple pointer
// buffer.
data_address =
se::DeviceMemoryBase(tuple_element_buffer_addresses[index.back()],
(*buffer)->length());
}
if (!slice.allocation())
return InternalError("outfeed input missing buffer allocation");
se::DeviceMemoryBase data_address =
buffer_allocations.GetDeviceAddress(slice);
// TODO(b/111309141): Run this on a separate stream so it doesn't block
// the GPU from doing work during the transfer. This could be handled by

View File

@ -32,7 +32,7 @@ struct OutfeedConfig {
OutfeedConfig GetOutfeedConfig(const HloInstruction* instr);
// A thunk that outfeeds data. Data must be already resident on the host. This
// thunk performs a host to device copy from the buffer allocated for the
// thunk performs a device to host copy from the buffer allocated for the
// outfeed op to the host location.
class OutfeedThunk : public Thunk {
public: