From cc33e1f05b21cd86ce98375acfd5cab368aef5ba Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Wed, 7 Oct 2020 17:45:57 -0700 Subject: [PATCH] [XLA] Compute ShapedBuffer::on_host_shape from ShapedBuffer::on_device_shape. This change is in preparation for removing ShapedBuffer::on_host_shape entirely: on all known XLA backends, the host shape can now be computed from the device shape. We therefore have no need to maintain two representations of a buffer's shape. As a transitional measure, start by ignoring the on_host_shape supplied by the user and returning an on_host_shape formed by stripping tiling and memory space information from the on_device_shape(). We can then refactor clients of ShapedBuffer to avoid producing or consuming the on_host_shape in subsequent CLs. PiperOrigin-RevId: 335989587 Change-Id: I309c2e69b001c0777769dafdfc6c4ffe8dfef18e --- tensorflow/compiler/xla/service/executable.h | 19 +++++++--- .../xla/service/generic_transfer_manager.cc | 3 +- .../compiler/xla/service/shaped_buffer.cc | 35 +++++++++++-------- .../compiler/xla/service/shaped_buffer.h | 10 ++++-- .../compiler/xla/service/transfer_manager.h | 6 +++- tensorflow/compiler/xla/shape_util.cc | 10 ++++++ tensorflow/compiler/xla/shape_util.h | 4 +++ 7 files changed, 64 insertions(+), 23 deletions(-) diff --git a/tensorflow/compiler/xla/service/executable.h b/tensorflow/compiler/xla/service/executable.h index 2e3ddedfb8c..9216e5de85d 100644 --- a/tensorflow/compiler/xla/service/executable.h +++ b/tensorflow/compiler/xla/service/executable.h @@ -60,15 +60,24 @@ namespace xla { // with their indices absent from unowned_indices_. class ExecutionInput { public: - explicit ExecutionInput(xla::Shape shape, xla::Shape host_shape) + explicit ExecutionInput(xla::Shape shape) : buffers_(std::move(shape)) { + SetHostShape(ShapeUtil::DeviceShapeToHostShape(buffers_.shape())); + } + // TODO(b/170310047): remove this overload. + ExecutionInput(xla::Shape shape, xla::Shape host_shape) : buffers_(std::move(shape)) { - SetHostShape(std::move(host_shape)); + SetHostShape(ShapeUtil::DeviceShapeToHostShape(buffers_.shape())); } - explicit ExecutionInput(ShapeTree buffers, - xla::Shape host_shape) + explicit ExecutionInput(ShapeTree buffers) : buffers_(std::move(buffers)) { - SetHostShape(std::move(host_shape)); + SetHostShape(ShapeUtil::DeviceShapeToHostShape(buffers_.shape())); + } + // TODO(b/170310047): remove this overload. + ExecutionInput(ShapeTree buffers, + xla::Shape host_shape) + : buffers_(std::move(buffers)) { + SetHostShape(ShapeUtil::DeviceShapeToHostShape(buffers_.shape())); } ExecutionInput(ExecutionInput&&) = default; diff --git a/tensorflow/compiler/xla/service/generic_transfer_manager.cc b/tensorflow/compiler/xla/service/generic_transfer_manager.cc index c09757fe1af..b451c3ab1e7 100644 --- a/tensorflow/compiler/xla/service/generic_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/generic_transfer_manager.cc @@ -106,7 +106,8 @@ Status GenericTransferManager::TransferLiteralToDeviceAsync( // The on-host and on-device shape should always be the same for the generic // transfer manager. TF_RET_CHECK(ShapeUtil::Equal(device_buffer.on_device_shape(), - device_buffer.on_host_shape())); + device_buffer.on_host_shape())) + << device_buffer.ToString(); TF_RET_CHECK( ShapeUtil::Compatible(literal.shape(), device_buffer.on_host_shape())); diff --git a/tensorflow/compiler/xla/service/shaped_buffer.cc b/tensorflow/compiler/xla/service/shaped_buffer.cc index 473a9ca7456..cd7a4022234 100644 --- a/tensorflow/compiler/xla/service/shaped_buffer.cc +++ b/tensorflow/compiler/xla/service/shaped_buffer.cc @@ -33,11 +33,16 @@ namespace xla { ShapedBuffer::ShapedBuffer(Shape on_host_shape, Shape on_device_shape, const se::Platform* platform, int device_ordinal) - : on_host_shape_(std::move(on_host_shape)), - on_device_shape_(std::move(on_device_shape)), + : ShapedBuffer(on_device_shape, platform, device_ordinal) {} + +ShapedBuffer::ShapedBuffer(Shape on_device_shape, const se::Platform* platform, + int device_ordinal) + : on_device_shape_(std::move(on_device_shape)), platform_(platform), device_ordinal_(device_ordinal), - buffers_(&on_device_shape_) {} + buffers_(&on_device_shape_) { + on_host_shape_ = ShapeUtil::DeviceShapeToHostShape(on_device_shape_); +} ShapedBuffer::ShapedBuffer(ShapedBuffer&& s) : on_host_shape_(std::move(s.on_host_shape_)), @@ -52,8 +57,8 @@ ShapedBuffer::ShapedBuffer(ShapedBuffer&& s) } ShapedBuffer& ShapedBuffer::operator=(ShapedBuffer&& s) { - on_host_shape_ = std::move(s.on_host_shape_); on_device_shape_ = std::move(s.on_device_shape_); + on_host_shape_ = std::move(s.on_host_shape_); platform_ = s.platform_; device_ordinal_ = s.device_ordinal_; buffers_ = std::move(s.buffers_); @@ -68,12 +73,9 @@ ShapedBuffer::~ShapedBuffer() {} StatusOr ShapedBuffer::SubShapedBuffer( const ShapeIndex& index) const { - TF_ASSIGN_OR_RETURN(const Shape* host_sub_shape, - ShapeUtil::TryGetSubshape(on_host_shape(), index)); TF_ASSIGN_OR_RETURN(const Shape* device_sub_shape, ShapeUtil::TryGetSubshape(on_device_shape(), index)); - ShapedBuffer sub_shaped_buffer(*host_sub_shape, *device_sub_shape, platform_, - device_ordinal_); + ShapedBuffer sub_shaped_buffer(*device_sub_shape, platform_, device_ordinal_); TF_ASSIGN_OR_RETURN(ShapeTree sub_buffers, buffers_.SubShapeTree(index)); sub_shaped_buffer.set_buffers(std::move(sub_buffers)); @@ -120,8 +122,15 @@ ScopedShapedBuffer::ScopedShapedBuffer(Shape on_host_shape, Shape on_device_shape, se::DeviceMemoryAllocator* allocator, int device_ordinal) - : ShapedBuffer(std::move(on_host_shape), std::move(on_device_shape), - allocator->platform(), device_ordinal), + : ShapedBuffer(std::move(on_device_shape), allocator->platform(), + device_ordinal), + allocator_(allocator) {} + +ScopedShapedBuffer::ScopedShapedBuffer(Shape on_device_shape, + se::DeviceMemoryAllocator* allocator, + int device_ordinal) + : ShapedBuffer(std::move(on_device_shape), allocator->platform(), + device_ordinal), allocator_(allocator) {} ScopedShapedBuffer::ScopedShapedBuffer(ShapedBuffer shaped_buffer, @@ -171,13 +180,11 @@ void ScopedShapedBuffer::Deallocate() { } ScopedShapedBuffer ScopedShapedBuffer::TakeSubTree(ShapeIndexView index) { - const xla::Shape& sub_on_host_shape = - xla::ShapeUtil::GetSubshape(on_host_shape(), {index}); const xla::Shape& sub_on_device_shape = xla::ShapeUtil::GetSubshape(on_device_shape(), {index}); - ScopedShapedBuffer output(sub_on_host_shape, sub_on_device_shape, - memory_allocator(), device_ordinal()); + ScopedShapedBuffer output(sub_on_device_shape, memory_allocator(), + device_ordinal()); auto src_it = buffers().find(index); auto dst_it = output.buffers().begin(); while (dst_it != output.buffers().end()) { diff --git a/tensorflow/compiler/xla/service/shaped_buffer.h b/tensorflow/compiler/xla/service/shaped_buffer.h index 995b0ece7cd..f24efa4eed2 100644 --- a/tensorflow/compiler/xla/service/shaped_buffer.h +++ b/tensorflow/compiler/xla/service/shaped_buffer.h @@ -43,6 +43,9 @@ class ShapedBuffer { // both the on-host and on-device shape are required. The on-device shape // determines the number of device allocations (DeviceMemoryBase) held by the // ShapedBuffer. + ShapedBuffer(Shape on_device_shape, const se::Platform* platform, + int device_ordinal); + // TODO(b/170310047): remove this overload. ShapedBuffer(Shape on_host_shape, Shape on_device_shape, const se::Platform* platform, int device_ordinal); @@ -101,7 +104,7 @@ class ShapedBuffer { CHECK(ShapeUtil::EqualStructure(on_device_shape, on_device_shape_)) << "Structures are not the same. new: " << on_device_shape << ", old: " << on_device_shape_; - on_host_shape_ = on_host_shape; + on_host_shape_ = ShapeUtil::DeviceShapeToHostShape(on_device_shape); on_device_shape_ = on_device_shape; buffers_.replace_shape_ptr(&on_device_shape_); } @@ -119,7 +122,6 @@ class ShapedBuffer { string ToString() const; protected: - // The shape of the data when represented on the host. Shape on_host_shape_; // The shape of the data on the device. @@ -148,6 +150,10 @@ std::ostream& operator<<(std::ostream& out, const ShapedBuffer& buffer); class ScopedShapedBuffer : public ShapedBuffer { public: // Creates a ScopedShapedBuffer with null DeviceMemoryBases at each index. + explicit ScopedShapedBuffer(Shape on_device_shape, + se::DeviceMemoryAllocator* allocator, + int device_ordinal); + // TODO(b/170310047): remove this overload. explicit ScopedShapedBuffer(Shape on_host_shape, Shape on_device_shape, se::DeviceMemoryAllocator* allocator, int device_ordinal); diff --git a/tensorflow/compiler/xla/service/transfer_manager.h b/tensorflow/compiler/xla/service/transfer_manager.h index c0670d26eee..8b7100d9b2f 100644 --- a/tensorflow/compiler/xla/service/transfer_manager.h +++ b/tensorflow/compiler/xla/service/transfer_manager.h @@ -51,7 +51,11 @@ class TransferManager { // pre-allocated by the host, e.g. TransferLiteralToDevice, without the user // needing to consider device-specific behaviors. virtual Shape HostShapeToDeviceShape(const Shape& host_shape) const { - return host_shape; + // Strips off any preexisting tiling or memory space information. + // TODO(phawkins): fix clients not to including tiling or memory space + // information in shapes passed to this function and turn this into an + // assertion. + return ShapeUtil::DeviceShapeToHostShape(host_shape); } // Base class for specifying platform specific transfer metadata that can be diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc index c9f37bdc430..238879ebdc0 100644 --- a/tensorflow/compiler/xla/shape_util.cc +++ b/tensorflow/compiler/xla/shape_util.cc @@ -1624,4 +1624,14 @@ static Shape MergeDimensions(absl::Span segs, return absl::nullopt; } +Shape ShapeUtil::DeviceShapeToHostShape(Shape s) { + ForEachMutableSubshape(&s, [](Shape* subshape, const ShapeIndex& index) { + if (subshape->IsArray()) { + subshape->mutable_layout()->clear_tiles(); + subshape->mutable_layout()->set_memory_space(Layout::kDefaultMemorySpace); + } + }); + return s; +} + } // namespace xla diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h index 3f69a8b0aca..5a5695d32ee 100644 --- a/tensorflow/compiler/xla/shape_util.h +++ b/tensorflow/compiler/xla/shape_util.h @@ -783,6 +783,10 @@ class ShapeUtil { static absl::optional> FindTranspose021(const Shape& a, const Shape& b); + // Strips device-specific information, namely tiling and memory-space + // information, from a shape. + static Shape DeviceShapeToHostShape(Shape s); + private: // Validates the shape size is sane. This makes sure it's safe to do // calculations in int64 without overflowing.