[XLA] Compute ShapedBuffer::on_host_shape from ShapedBuffer::on_device_shape.

This change is in preparation for removing ShapedBuffer::on_host_shape entirely: on all known XLA backends, the host shape can now be computed from the device shape. We therefore have no need to maintain two representations of a buffer's shape.

As a transitional measure, start by ignoring the on_host_shape supplied by the user and returning an on_host_shape formed by stripping tiling and memory space information from the on_device_shape(). We can then refactor clients of ShapedBuffer to avoid producing or consuming the on_host_shape in subsequent CLs.

PiperOrigin-RevId: 335989587
Change-Id: I309c2e69b001c0777769dafdfc6c4ffe8dfef18e
This commit is contained in:
Peter Hawkins 2020-10-07 17:45:57 -07:00 committed by TensorFlower Gardener
parent 6c3538bf6b
commit cc33e1f05b
7 changed files with 64 additions and 23 deletions

View File

@ -60,15 +60,24 @@ namespace xla {
// with their indices absent from unowned_indices_.
class ExecutionInput {
public:
explicit ExecutionInput(xla::Shape shape, xla::Shape host_shape)
explicit ExecutionInput(xla::Shape shape) : buffers_(std::move(shape)) {
SetHostShape(ShapeUtil::DeviceShapeToHostShape(buffers_.shape()));
}
// TODO(b/170310047): remove this overload.
ExecutionInput(xla::Shape shape, xla::Shape host_shape)
: buffers_(std::move(shape)) {
SetHostShape(std::move(host_shape));
SetHostShape(ShapeUtil::DeviceShapeToHostShape(buffers_.shape()));
}
explicit ExecutionInput(ShapeTree<MaybeOwningDeviceMemory> buffers,
xla::Shape host_shape)
explicit ExecutionInput(ShapeTree<MaybeOwningDeviceMemory> buffers)
: buffers_(std::move(buffers)) {
SetHostShape(std::move(host_shape));
SetHostShape(ShapeUtil::DeviceShapeToHostShape(buffers_.shape()));
}
// TODO(b/170310047): remove this overload.
ExecutionInput(ShapeTree<MaybeOwningDeviceMemory> buffers,
xla::Shape host_shape)
: buffers_(std::move(buffers)) {
SetHostShape(ShapeUtil::DeviceShapeToHostShape(buffers_.shape()));
}
ExecutionInput(ExecutionInput&&) = default;

View File

@ -106,7 +106,8 @@ Status GenericTransferManager::TransferLiteralToDeviceAsync(
// The on-host and on-device shape should always be the same for the generic
// transfer manager.
TF_RET_CHECK(ShapeUtil::Equal(device_buffer.on_device_shape(),
device_buffer.on_host_shape()));
device_buffer.on_host_shape()))
<< device_buffer.ToString();
TF_RET_CHECK(
ShapeUtil::Compatible(literal.shape(), device_buffer.on_host_shape()));

View File

@ -33,11 +33,16 @@ namespace xla {
ShapedBuffer::ShapedBuffer(Shape on_host_shape, Shape on_device_shape,
const se::Platform* platform, int device_ordinal)
: on_host_shape_(std::move(on_host_shape)),
on_device_shape_(std::move(on_device_shape)),
: ShapedBuffer(on_device_shape, platform, device_ordinal) {}
ShapedBuffer::ShapedBuffer(Shape on_device_shape, const se::Platform* platform,
int device_ordinal)
: on_device_shape_(std::move(on_device_shape)),
platform_(platform),
device_ordinal_(device_ordinal),
buffers_(&on_device_shape_) {}
buffers_(&on_device_shape_) {
on_host_shape_ = ShapeUtil::DeviceShapeToHostShape(on_device_shape_);
}
ShapedBuffer::ShapedBuffer(ShapedBuffer&& s)
: on_host_shape_(std::move(s.on_host_shape_)),
@ -52,8 +57,8 @@ ShapedBuffer::ShapedBuffer(ShapedBuffer&& s)
}
ShapedBuffer& ShapedBuffer::operator=(ShapedBuffer&& s) {
on_host_shape_ = std::move(s.on_host_shape_);
on_device_shape_ = std::move(s.on_device_shape_);
on_host_shape_ = std::move(s.on_host_shape_);
platform_ = s.platform_;
device_ordinal_ = s.device_ordinal_;
buffers_ = std::move(s.buffers_);
@ -68,12 +73,9 @@ ShapedBuffer::~ShapedBuffer() {}
StatusOr<ShapedBuffer> ShapedBuffer::SubShapedBuffer(
const ShapeIndex& index) const {
TF_ASSIGN_OR_RETURN(const Shape* host_sub_shape,
ShapeUtil::TryGetSubshape(on_host_shape(), index));
TF_ASSIGN_OR_RETURN(const Shape* device_sub_shape,
ShapeUtil::TryGetSubshape(on_device_shape(), index));
ShapedBuffer sub_shaped_buffer(*host_sub_shape, *device_sub_shape, platform_,
device_ordinal_);
ShapedBuffer sub_shaped_buffer(*device_sub_shape, platform_, device_ordinal_);
TF_ASSIGN_OR_RETURN(ShapeTree<se::DeviceMemoryBase> sub_buffers,
buffers_.SubShapeTree(index));
sub_shaped_buffer.set_buffers(std::move(sub_buffers));
@ -120,8 +122,15 @@ ScopedShapedBuffer::ScopedShapedBuffer(Shape on_host_shape,
Shape on_device_shape,
se::DeviceMemoryAllocator* allocator,
int device_ordinal)
: ShapedBuffer(std::move(on_host_shape), std::move(on_device_shape),
allocator->platform(), device_ordinal),
: ShapedBuffer(std::move(on_device_shape), allocator->platform(),
device_ordinal),
allocator_(allocator) {}
ScopedShapedBuffer::ScopedShapedBuffer(Shape on_device_shape,
se::DeviceMemoryAllocator* allocator,
int device_ordinal)
: ShapedBuffer(std::move(on_device_shape), allocator->platform(),
device_ordinal),
allocator_(allocator) {}
ScopedShapedBuffer::ScopedShapedBuffer(ShapedBuffer shaped_buffer,
@ -171,13 +180,11 @@ void ScopedShapedBuffer::Deallocate() {
}
ScopedShapedBuffer ScopedShapedBuffer::TakeSubTree(ShapeIndexView index) {
const xla::Shape& sub_on_host_shape =
xla::ShapeUtil::GetSubshape(on_host_shape(), {index});
const xla::Shape& sub_on_device_shape =
xla::ShapeUtil::GetSubshape(on_device_shape(), {index});
ScopedShapedBuffer output(sub_on_host_shape, sub_on_device_shape,
memory_allocator(), device_ordinal());
ScopedShapedBuffer output(sub_on_device_shape, memory_allocator(),
device_ordinal());
auto src_it = buffers().find(index);
auto dst_it = output.buffers().begin();
while (dst_it != output.buffers().end()) {

View File

@ -43,6 +43,9 @@ class ShapedBuffer {
// both the on-host and on-device shape are required. The on-device shape
// determines the number of device allocations (DeviceMemoryBase) held by the
// ShapedBuffer.
ShapedBuffer(Shape on_device_shape, const se::Platform* platform,
int device_ordinal);
// TODO(b/170310047): remove this overload.
ShapedBuffer(Shape on_host_shape, Shape on_device_shape,
const se::Platform* platform, int device_ordinal);
@ -101,7 +104,7 @@ class ShapedBuffer {
CHECK(ShapeUtil::EqualStructure(on_device_shape, on_device_shape_))
<< "Structures are not the same. new: " << on_device_shape
<< ", old: " << on_device_shape_;
on_host_shape_ = on_host_shape;
on_host_shape_ = ShapeUtil::DeviceShapeToHostShape(on_device_shape);
on_device_shape_ = on_device_shape;
buffers_.replace_shape_ptr(&on_device_shape_);
}
@ -119,7 +122,6 @@ class ShapedBuffer {
string ToString() const;
protected:
// The shape of the data when represented on the host.
Shape on_host_shape_;
// The shape of the data on the device.
@ -148,6 +150,10 @@ std::ostream& operator<<(std::ostream& out, const ShapedBuffer& buffer);
class ScopedShapedBuffer : public ShapedBuffer {
public:
// Creates a ScopedShapedBuffer with null DeviceMemoryBases at each index.
explicit ScopedShapedBuffer(Shape on_device_shape,
se::DeviceMemoryAllocator* allocator,
int device_ordinal);
// TODO(b/170310047): remove this overload.
explicit ScopedShapedBuffer(Shape on_host_shape, Shape on_device_shape,
se::DeviceMemoryAllocator* allocator,
int device_ordinal);

View File

@ -51,7 +51,11 @@ class TransferManager {
// pre-allocated by the host, e.g. TransferLiteralToDevice, without the user
// needing to consider device-specific behaviors.
virtual Shape HostShapeToDeviceShape(const Shape& host_shape) const {
return host_shape;
// Strips off any preexisting tiling or memory space information.
// TODO(phawkins): fix clients not to including tiling or memory space
// information in shapes passed to this function and turn this into an
// assertion.
return ShapeUtil::DeviceShapeToHostShape(host_shape);
}
// Base class for specifying platform specific transfer metadata that can be

View File

@ -1624,4 +1624,14 @@ static Shape MergeDimensions(absl::Span<const size_t> segs,
return absl::nullopt;
}
Shape ShapeUtil::DeviceShapeToHostShape(Shape s) {
ForEachMutableSubshape(&s, [](Shape* subshape, const ShapeIndex& index) {
if (subshape->IsArray()) {
subshape->mutable_layout()->clear_tiles();
subshape->mutable_layout()->set_memory_space(Layout::kDefaultMemorySpace);
}
});
return s;
}
} // namespace xla

View File

@ -783,6 +783,10 @@ class ShapeUtil {
static absl::optional<std::vector<int64>> FindTranspose021(const Shape& a,
const Shape& b);
// Strips device-specific information, namely tiling and memory-space
// information, from a shape.
static Shape DeviceShapeToHostShape(Shape s);
private:
// Validates the shape size is sane. This makes sure it's safe to do
// calculations in int64 without overflowing.