[XLA] Compute ShapedBuffer::on_host_shape from ShapedBuffer::on_device_shape.
This change is in preparation for removing ShapedBuffer::on_host_shape entirely: on all known XLA backends, the host shape can now be computed from the device shape. We therefore have no need to maintain two representations of a buffer's shape. As a transitional measure, start by ignoring the on_host_shape supplied by the user and returning an on_host_shape formed by stripping tiling and memory space information from the on_device_shape(). We can then refactor clients of ShapedBuffer to avoid producing or consuming the on_host_shape in subsequent CLs. PiperOrigin-RevId: 335989587 Change-Id: I309c2e69b001c0777769dafdfc6c4ffe8dfef18e
This commit is contained in:
parent
6c3538bf6b
commit
cc33e1f05b
@ -60,15 +60,24 @@ namespace xla {
|
||||
// with their indices absent from unowned_indices_.
|
||||
class ExecutionInput {
|
||||
public:
|
||||
explicit ExecutionInput(xla::Shape shape, xla::Shape host_shape)
|
||||
explicit ExecutionInput(xla::Shape shape) : buffers_(std::move(shape)) {
|
||||
SetHostShape(ShapeUtil::DeviceShapeToHostShape(buffers_.shape()));
|
||||
}
|
||||
// TODO(b/170310047): remove this overload.
|
||||
ExecutionInput(xla::Shape shape, xla::Shape host_shape)
|
||||
: buffers_(std::move(shape)) {
|
||||
SetHostShape(std::move(host_shape));
|
||||
SetHostShape(ShapeUtil::DeviceShapeToHostShape(buffers_.shape()));
|
||||
}
|
||||
|
||||
explicit ExecutionInput(ShapeTree<MaybeOwningDeviceMemory> buffers,
|
||||
xla::Shape host_shape)
|
||||
explicit ExecutionInput(ShapeTree<MaybeOwningDeviceMemory> buffers)
|
||||
: buffers_(std::move(buffers)) {
|
||||
SetHostShape(std::move(host_shape));
|
||||
SetHostShape(ShapeUtil::DeviceShapeToHostShape(buffers_.shape()));
|
||||
}
|
||||
// TODO(b/170310047): remove this overload.
|
||||
ExecutionInput(ShapeTree<MaybeOwningDeviceMemory> buffers,
|
||||
xla::Shape host_shape)
|
||||
: buffers_(std::move(buffers)) {
|
||||
SetHostShape(ShapeUtil::DeviceShapeToHostShape(buffers_.shape()));
|
||||
}
|
||||
|
||||
ExecutionInput(ExecutionInput&&) = default;
|
||||
|
@ -106,7 +106,8 @@ Status GenericTransferManager::TransferLiteralToDeviceAsync(
|
||||
// The on-host and on-device shape should always be the same for the generic
|
||||
// transfer manager.
|
||||
TF_RET_CHECK(ShapeUtil::Equal(device_buffer.on_device_shape(),
|
||||
device_buffer.on_host_shape()));
|
||||
device_buffer.on_host_shape()))
|
||||
<< device_buffer.ToString();
|
||||
|
||||
TF_RET_CHECK(
|
||||
ShapeUtil::Compatible(literal.shape(), device_buffer.on_host_shape()));
|
||||
|
@ -33,11 +33,16 @@ namespace xla {
|
||||
|
||||
ShapedBuffer::ShapedBuffer(Shape on_host_shape, Shape on_device_shape,
|
||||
const se::Platform* platform, int device_ordinal)
|
||||
: on_host_shape_(std::move(on_host_shape)),
|
||||
on_device_shape_(std::move(on_device_shape)),
|
||||
: ShapedBuffer(on_device_shape, platform, device_ordinal) {}
|
||||
|
||||
ShapedBuffer::ShapedBuffer(Shape on_device_shape, const se::Platform* platform,
|
||||
int device_ordinal)
|
||||
: on_device_shape_(std::move(on_device_shape)),
|
||||
platform_(platform),
|
||||
device_ordinal_(device_ordinal),
|
||||
buffers_(&on_device_shape_) {}
|
||||
buffers_(&on_device_shape_) {
|
||||
on_host_shape_ = ShapeUtil::DeviceShapeToHostShape(on_device_shape_);
|
||||
}
|
||||
|
||||
ShapedBuffer::ShapedBuffer(ShapedBuffer&& s)
|
||||
: on_host_shape_(std::move(s.on_host_shape_)),
|
||||
@ -52,8 +57,8 @@ ShapedBuffer::ShapedBuffer(ShapedBuffer&& s)
|
||||
}
|
||||
|
||||
ShapedBuffer& ShapedBuffer::operator=(ShapedBuffer&& s) {
|
||||
on_host_shape_ = std::move(s.on_host_shape_);
|
||||
on_device_shape_ = std::move(s.on_device_shape_);
|
||||
on_host_shape_ = std::move(s.on_host_shape_);
|
||||
platform_ = s.platform_;
|
||||
device_ordinal_ = s.device_ordinal_;
|
||||
buffers_ = std::move(s.buffers_);
|
||||
@ -68,12 +73,9 @@ ShapedBuffer::~ShapedBuffer() {}
|
||||
|
||||
StatusOr<ShapedBuffer> ShapedBuffer::SubShapedBuffer(
|
||||
const ShapeIndex& index) const {
|
||||
TF_ASSIGN_OR_RETURN(const Shape* host_sub_shape,
|
||||
ShapeUtil::TryGetSubshape(on_host_shape(), index));
|
||||
TF_ASSIGN_OR_RETURN(const Shape* device_sub_shape,
|
||||
ShapeUtil::TryGetSubshape(on_device_shape(), index));
|
||||
ShapedBuffer sub_shaped_buffer(*host_sub_shape, *device_sub_shape, platform_,
|
||||
device_ordinal_);
|
||||
ShapedBuffer sub_shaped_buffer(*device_sub_shape, platform_, device_ordinal_);
|
||||
TF_ASSIGN_OR_RETURN(ShapeTree<se::DeviceMemoryBase> sub_buffers,
|
||||
buffers_.SubShapeTree(index));
|
||||
sub_shaped_buffer.set_buffers(std::move(sub_buffers));
|
||||
@ -120,8 +122,15 @@ ScopedShapedBuffer::ScopedShapedBuffer(Shape on_host_shape,
|
||||
Shape on_device_shape,
|
||||
se::DeviceMemoryAllocator* allocator,
|
||||
int device_ordinal)
|
||||
: ShapedBuffer(std::move(on_host_shape), std::move(on_device_shape),
|
||||
allocator->platform(), device_ordinal),
|
||||
: ShapedBuffer(std::move(on_device_shape), allocator->platform(),
|
||||
device_ordinal),
|
||||
allocator_(allocator) {}
|
||||
|
||||
ScopedShapedBuffer::ScopedShapedBuffer(Shape on_device_shape,
|
||||
se::DeviceMemoryAllocator* allocator,
|
||||
int device_ordinal)
|
||||
: ShapedBuffer(std::move(on_device_shape), allocator->platform(),
|
||||
device_ordinal),
|
||||
allocator_(allocator) {}
|
||||
|
||||
ScopedShapedBuffer::ScopedShapedBuffer(ShapedBuffer shaped_buffer,
|
||||
@ -171,13 +180,11 @@ void ScopedShapedBuffer::Deallocate() {
|
||||
}
|
||||
|
||||
ScopedShapedBuffer ScopedShapedBuffer::TakeSubTree(ShapeIndexView index) {
|
||||
const xla::Shape& sub_on_host_shape =
|
||||
xla::ShapeUtil::GetSubshape(on_host_shape(), {index});
|
||||
const xla::Shape& sub_on_device_shape =
|
||||
xla::ShapeUtil::GetSubshape(on_device_shape(), {index});
|
||||
|
||||
ScopedShapedBuffer output(sub_on_host_shape, sub_on_device_shape,
|
||||
memory_allocator(), device_ordinal());
|
||||
ScopedShapedBuffer output(sub_on_device_shape, memory_allocator(),
|
||||
device_ordinal());
|
||||
auto src_it = buffers().find(index);
|
||||
auto dst_it = output.buffers().begin();
|
||||
while (dst_it != output.buffers().end()) {
|
||||
|
@ -43,6 +43,9 @@ class ShapedBuffer {
|
||||
// both the on-host and on-device shape are required. The on-device shape
|
||||
// determines the number of device allocations (DeviceMemoryBase) held by the
|
||||
// ShapedBuffer.
|
||||
ShapedBuffer(Shape on_device_shape, const se::Platform* platform,
|
||||
int device_ordinal);
|
||||
// TODO(b/170310047): remove this overload.
|
||||
ShapedBuffer(Shape on_host_shape, Shape on_device_shape,
|
||||
const se::Platform* platform, int device_ordinal);
|
||||
|
||||
@ -101,7 +104,7 @@ class ShapedBuffer {
|
||||
CHECK(ShapeUtil::EqualStructure(on_device_shape, on_device_shape_))
|
||||
<< "Structures are not the same. new: " << on_device_shape
|
||||
<< ", old: " << on_device_shape_;
|
||||
on_host_shape_ = on_host_shape;
|
||||
on_host_shape_ = ShapeUtil::DeviceShapeToHostShape(on_device_shape);
|
||||
on_device_shape_ = on_device_shape;
|
||||
buffers_.replace_shape_ptr(&on_device_shape_);
|
||||
}
|
||||
@ -119,7 +122,6 @@ class ShapedBuffer {
|
||||
string ToString() const;
|
||||
|
||||
protected:
|
||||
// The shape of the data when represented on the host.
|
||||
Shape on_host_shape_;
|
||||
|
||||
// The shape of the data on the device.
|
||||
@ -148,6 +150,10 @@ std::ostream& operator<<(std::ostream& out, const ShapedBuffer& buffer);
|
||||
class ScopedShapedBuffer : public ShapedBuffer {
|
||||
public:
|
||||
// Creates a ScopedShapedBuffer with null DeviceMemoryBases at each index.
|
||||
explicit ScopedShapedBuffer(Shape on_device_shape,
|
||||
se::DeviceMemoryAllocator* allocator,
|
||||
int device_ordinal);
|
||||
// TODO(b/170310047): remove this overload.
|
||||
explicit ScopedShapedBuffer(Shape on_host_shape, Shape on_device_shape,
|
||||
se::DeviceMemoryAllocator* allocator,
|
||||
int device_ordinal);
|
||||
|
@ -51,7 +51,11 @@ class TransferManager {
|
||||
// pre-allocated by the host, e.g. TransferLiteralToDevice, without the user
|
||||
// needing to consider device-specific behaviors.
|
||||
virtual Shape HostShapeToDeviceShape(const Shape& host_shape) const {
|
||||
return host_shape;
|
||||
// Strips off any preexisting tiling or memory space information.
|
||||
// TODO(phawkins): fix clients not to including tiling or memory space
|
||||
// information in shapes passed to this function and turn this into an
|
||||
// assertion.
|
||||
return ShapeUtil::DeviceShapeToHostShape(host_shape);
|
||||
}
|
||||
|
||||
// Base class for specifying platform specific transfer metadata that can be
|
||||
|
@ -1624,4 +1624,14 @@ static Shape MergeDimensions(absl::Span<const size_t> segs,
|
||||
return absl::nullopt;
|
||||
}
|
||||
|
||||
Shape ShapeUtil::DeviceShapeToHostShape(Shape s) {
|
||||
ForEachMutableSubshape(&s, [](Shape* subshape, const ShapeIndex& index) {
|
||||
if (subshape->IsArray()) {
|
||||
subshape->mutable_layout()->clear_tiles();
|
||||
subshape->mutable_layout()->set_memory_space(Layout::kDefaultMemorySpace);
|
||||
}
|
||||
});
|
||||
return s;
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
|
@ -783,6 +783,10 @@ class ShapeUtil {
|
||||
static absl::optional<std::vector<int64>> FindTranspose021(const Shape& a,
|
||||
const Shape& b);
|
||||
|
||||
// Strips device-specific information, namely tiling and memory-space
|
||||
// information, from a shape.
|
||||
static Shape DeviceShapeToHostShape(Shape s);
|
||||
|
||||
private:
|
||||
// Validates the shape size is sane. This makes sure it's safe to do
|
||||
// calculations in int64 without overflowing.
|
||||
|
Loading…
Reference in New Issue
Block a user