From aa92d1d7a9bd24e157326bd3dc18945ab5ca5b78 Mon Sep 17 00:00:00 2001 From: Rahul Joshi Date: Tue, 29 Dec 2020 10:56:39 -0800 Subject: [PATCH] [XLA:GPU] Simplify Outfeed thunk to eliminate code for handling dynamic tuples. - XLA:GPU does not support dynamic tuples (which result from tuple select), so remove code from outfeed thunk that was required to support them. - Without tuple select, we should be able to find the allocation slice for each Outfeed input. PiperOrigin-RevId: 349443816 Change-Id: I410059ecb844c6f9a8145d1751453ffe1967d072 --- .../compiler/xla/service/gpu/outfeed_thunk.cc | 34 +++---------------- .../compiler/xla/service/gpu/outfeed_thunk.h | 2 +- 2 files changed, 5 insertions(+), 31 deletions(-) diff --git a/tensorflow/compiler/xla/service/gpu/outfeed_thunk.cc b/tensorflow/compiler/xla/service/gpu/outfeed_thunk.cc index 6eef1b9f0b9..40a00748273 100644 --- a/tensorflow/compiler/xla/service/gpu/outfeed_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/outfeed_thunk.cc @@ -66,36 +66,10 @@ Status OutfeedThunk::ExecuteOnStream(const ExecuteParams& params) { } BufferAllocation::Slice slice = outfeed_slices_.element(index); - se::DeviceMemoryBase data_address; - if (slice.allocation()) { - // If we have a static allocation, read it from there. This avoids - // synchronizing the host and device just to read a pointer. - data_address = buffer_allocations.GetDeviceAddress(slice); - } else { - // Otherwise we have to read the tuple pointer first. - CHECK(!index.empty()); - // Copy the parent buffer to the host. - BufferAllocation::Slice tuple_slice = - outfeed_slices_.element(ShapeIndexView(index).ConsumeFront()); - if (!tuple_slice.allocation()) { - return Unimplemented( - "Nested dynamic tuples are not supported on GPU"); - } - se::DeviceMemoryBase tuple_address = - buffer_allocations.GetDeviceAddress(tuple_slice); - CHECK(tuple_slice.size() % sizeof(void*) == 0) - << "Tuple size must be a multiple of pointer size"; - std::vector tuple_element_buffer_addresses(tuple_slice.size() / - sizeof(void*)); - stream.ThenMemcpy(tuple_element_buffer_addresses.data(), - tuple_address, tuple_slice.size()); - TF_RETURN_IF_ERROR(stream.BlockHostUntilDone()); - // The data address is specified by the element of the tuple pointer - // buffer. - data_address = - se::DeviceMemoryBase(tuple_element_buffer_addresses[index.back()], - (*buffer)->length()); - } + if (!slice.allocation()) + return InternalError("outfeed input missing buffer allocation"); + se::DeviceMemoryBase data_address = + buffer_allocations.GetDeviceAddress(slice); // TODO(b/111309141): Run this on a separate stream so it doesn't block // the GPU from doing work during the transfer. This could be handled by diff --git a/tensorflow/compiler/xla/service/gpu/outfeed_thunk.h b/tensorflow/compiler/xla/service/gpu/outfeed_thunk.h index 60c64858ee7..eec336407cd 100644 --- a/tensorflow/compiler/xla/service/gpu/outfeed_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/outfeed_thunk.h @@ -32,7 +32,7 @@ struct OutfeedConfig { OutfeedConfig GetOutfeedConfig(const HloInstruction* instr); // A thunk that outfeeds data. Data must be already resident on the host. This -// thunk performs a host to device copy from the buffer allocated for the +// thunk performs a device to host copy from the buffer allocated for the // outfeed op to the host location. class OutfeedThunk : public Thunk { public: