Remove platform field from shaped buffer.

This further simplifies ShapedBuffer object as itself doesn't have any
logic to use the platform field.

Notice that this cl also removed some sanity check in allocation
tracker. we can add that sanity check back if need --- just keep track
of `platform` inside of allocation tracker as a side map.

PiperOrigin-RevId: 339938197
Change-Id: I090e603927ed3fccdb51254f972b3af2e1ec1470
This commit is contained in:
Yunxing Dai 2020-10-30 14:16:38 -07:00 committed by TensorFlower Gardener
parent cb46f059a5
commit 2e02c5e78d
16 changed files with 36 additions and 81 deletions

View File

@ -426,7 +426,7 @@ Status XlaComputationLaunchContext::PopulateOutputs(
ShapedBuffer buffer(
xla::ShapeUtil::MakeTupleShape({nontuple_buffer.on_host_shape()}),
xla::ShapeUtil::MakeTupleShape({nontuple_buffer.on_device_shape()}),
output.platform(), output.device_ordinal());
output.device_ordinal());
buffer.buffers().CopySubtreeFrom(nontuple_buffer.buffers(),
/*source_base_index=*/{},
/*target_base_index=*/{0});

View File

@ -267,9 +267,8 @@ StatusOr<ScopedShapedBuffer> LocalExecutable::RunAsync(
}
static ShapedBuffer MaybeOwningShapeTreeToShapedBuffer(
const ShapeTree<MaybeOwningDeviceMemory>& tree, se::Platform* platform,
int device_ordinal) {
ShapedBuffer result(tree.shape(), platform, device_ordinal);
const ShapeTree<MaybeOwningDeviceMemory>& tree, int device_ordinal) {
ShapedBuffer result(tree.shape(), device_ordinal);
auto it = tree.begin();
auto out_it = result.buffers().begin();
for (; it != tree.end(); ++it, ++out_it) {
@ -299,8 +298,7 @@ StatusOr<ExecutionOutput> LocalExecutable::RunAsync(
shaped_buffer_ptrs.reserve(arguments.size());
for (size_t i = 0; i < arguments.size(); ++i) {
shaped_buffers.push_back(MaybeOwningShapeTreeToShapedBuffer(
arguments[i].Buffers(), backend_->platform(),
stream->parent()->device_ordinal()));
arguments[i].Buffers(), stream->parent()->device_ordinal()));
shaped_buffer_ptrs.push_back(&shaped_buffers.back());
}

View File

@ -653,8 +653,8 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtClient::BufferFromHostBuffer(
// memory that has already been allocated, and a possible Event
// allocation.
ShapedBuffer buffer = device_buffer->AsShapedBuffer(
compact_shape, on_device_shape, local_client->platform());
ShapedBuffer buffer =
device_buffer->AsShapedBuffer(compact_shape, on_device_shape);
// If applicable on the backend, stage the transfer via host memory
// allocated via the host_memory_allocator. On GPU, this is pinned
// memory.
@ -753,8 +753,8 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtClient::BufferFromHostLiteral(
// allocation.
se::Stream* h2d_stream = local_device->host_to_device_stream();
ShapedBuffer buffer = device_buffer->AsShapedBuffer(
compact_shape, on_device_shape, local_client->platform());
ShapedBuffer buffer =
device_buffer->AsShapedBuffer(compact_shape, on_device_shape);
TF_CHECK_OK(transfer_manager->TransferLiteralToDeviceAsync(
h2d_stream, literal, buffer));
@ -1099,8 +1099,8 @@ PjRtBuffer::CopyToHostAsyncInternal(bool discard_cached_copy,
host_shape = on_host_shape_;
}
host_value->value = std::make_shared<Literal>(host_shape);
ShapedBuffer shaped_buffer = device_buffer->AsShapedBuffer(
host_shape, on_device_shape_, client_->client()->platform());
ShapedBuffer shaped_buffer =
device_buffer->AsShapedBuffer(host_shape, on_device_shape_);
client_->client()->backend().transfer_manager()->TransferLiteralFromDevice(
stream, shaped_buffer, host_value->value.get(),
[host_value](Status done_status) {
@ -1152,8 +1152,7 @@ StatusOr<ShapedBuffer> PjRtBuffer::AsShapedBuffer() const {
return InvalidArgument(
"Attempted to fetch value of invalid/deleted buffer.");
}
return device_buffer_->AsShapedBuffer(on_host_shape_, on_device_shape_,
client_->client()->platform());
return device_buffer_->AsShapedBuffer(on_host_shape_, on_device_shape_);
}
PjRtBuffer::ScopedHold PjRtBuffer::GetBufferWithHold(ScopedHold::Type type) {
@ -1188,8 +1187,8 @@ PjRtBuffer::CopyToDeviceHelper(
ScopedHold dst_device_buffer(py_buffer->GetBufferWithUsageHold());
CHECK(dst_device_buffer.ok());
ShapedBuffer dst_buffer = dst_device_buffer->AsShapedBuffer(
on_host_shape_, on_device_shape_, client_->client()->platform());
ShapedBuffer dst_buffer =
dst_device_buffer->AsShapedBuffer(on_host_shape_, on_device_shape_);
// Copy the leaf buffers.
StatusOr<std::shared_ptr<BufferSequencingEvent>> copy_event_or =

View File

@ -117,11 +117,9 @@ TrackedDeviceBuffer::FromScopedShapedBuffer(
/*on_delete_callback=*/nullptr);
}
ShapedBuffer TrackedDeviceBuffer::AsShapedBuffer(const Shape& on_host_shape,
const Shape& on_device_shape,
se::Platform* platform) const {
ShapedBuffer shaped_buffer(on_host_shape, on_device_shape, platform,
device_ordinal_);
ShapedBuffer TrackedDeviceBuffer::AsShapedBuffer(
const Shape& on_host_shape, const Shape& on_device_shape) const {
ShapedBuffer shaped_buffer(on_host_shape, on_device_shape, device_ordinal_);
ShapeTree<se::DeviceMemoryBase>::iterator iterator =
shaped_buffer.buffers().begin();
for (const se::DeviceMemoryBase& buf : device_memory_) {

View File

@ -141,8 +141,7 @@ class TrackedDeviceBuffer {
// not verify that TransferManager::HostShapeToDeviceShape(on_host_shape) ==
// on_device_shape().
ShapedBuffer AsShapedBuffer(const Shape& on_host_shape,
const Shape& on_device_shape,
se::Platform* platform) const;
const Shape& on_device_shape) const;
// Adds the owned device buffers in order to 'iterator'. Used to add the
// buffers to an ExecutionInput. We require but do not verify that 'iterator'

View File

@ -66,16 +66,13 @@ TEST(TrackedDeviceBufferTest, AsShapedBuffer) {
c_buffer->device_memory()[0]};
ShapedBuffer shaped_a = a_buffer->AsShapedBuffer(
a_shape,
client->backend().transfer_manager()->HostShapeToDeviceShape(a_shape),
client->platform());
client->backend().transfer_manager()->HostShapeToDeviceShape(a_shape));
ShapedBuffer shaped_b = b_buffer->AsShapedBuffer(
b_shape,
client->backend().transfer_manager()->HostShapeToDeviceShape(b_shape),
client->platform());
client->backend().transfer_manager()->HostShapeToDeviceShape(b_shape));
ShapedBuffer shaped_c = c_buffer->AsShapedBuffer(
c_shape,
client->backend().transfer_manager()->HostShapeToDeviceShape(c_shape),
client->platform());
client->backend().transfer_manager()->HostShapeToDeviceShape(c_shape));
auto expected_it = expected_buffer_sequence.begin();
for (auto it = shaped_a.buffers().begin(); it != shaped_a.buffers().end();
++it) {

View File

@ -64,15 +64,6 @@ StatusOr<GlobalDataHandle> AllocationTracker::RegisterInternal(
VLOG(2) << "RegisterInternal("
<< "tag: \"" << tag << "\" with " << replicated_buffers.size()
<< " shaped_buffers.";
for (const auto& shaped_buffer : replicated_buffers) {
VLOG(2) << "shaped_buffer:" << shaped_buffer;
if (shaped_buffer.platform() != backend_->platform()) {
return InvalidArgument(
"AllocationTracker for platform %s cannot register buffer from "
"platform %s",
backend_->platform()->Name(), shaped_buffer.platform()->Name());
}
}
int64 handle = next_handle_++;
for (auto& shaped_buffer : replicated_buffers) {
@ -158,7 +149,7 @@ StatusOr<std::vector<GlobalDataHandle>> AllocationTracker::DeconstructTuple(
++i) {
auto element_buffer = ShapedBuffer(
ShapeUtil::GetTupleElementShape(shaped_buffer->on_device_shape(), i),
shaped_buffer->platform(), shaped_buffer->device_ordinal());
shaped_buffer->device_ordinal());
element_buffer.set_buffer(shaped_buffer->buffer(/*index=*/{i}),
/*index=*/{});
std::vector<ShapedBuffer> replicated_buffers;

View File

@ -62,8 +62,7 @@ void ExecutionInput::SetUnownedBuffer(const ShapeIndex& index,
StatusOr<ShapedBuffer> ExecutionInput::ToShapedBuffer(
se::DeviceMemoryAllocator* allocator, int device_ordinal) const {
const Shape& input_shape = shape();
ShapedBuffer shaped_buffer(input_shape, allocator->platform(),
device_ordinal);
ShapedBuffer shaped_buffer(input_shape, device_ordinal);
for (const auto& index_buffer : Buffers()) {
const tensorflow::se::OwningDeviceMemory* mem =
index_buffer.second.AsOwningDeviceMemory();

View File

@ -57,7 +57,6 @@ StatusOr<ExecutionOutput> InterpreterExecutableBase::ExecuteAsyncOnStream(
for (auto& argument : arguments) {
const ShapeTree<MaybeOwningDeviceMemory>& buffers = argument.Buffers();
argument_buffers.push_back(ShapedBuffer(buffers.shape(),
/*platform=*/nullptr,
/*device_ordinal=*/device_ordinal));
auto in_it = buffers.begin();
auto out_it = argument_buffers.back().buffers().begin();

View File

@ -245,17 +245,6 @@ Service::ResolveAndValidateArguments(
CHECK_EQ(options_.number_of_replicas(), replicated_buffers.size());
for (int replica = 0; replica < options_.number_of_replicas(); ++replica) {
const ShapedBuffer* shaped_buffer = replicated_buffers[replica];
int replica_device_ordinal = stream_executors[replica]->device_ordinal();
// Verify allocation is same platform and device as the execution.
if (shaped_buffer->platform() != execute_backend_->platform() ||
shaped_buffer->device_ordinal() != replica_device_ordinal) {
return InvalidArgument(
"argument %lu is on device %s:%d but computation will be executed "
"on device %s",
i, shaped_buffer->platform()->Name(),
shaped_buffer->device_ordinal(),
execute_backend_->device_name(replica_device_ordinal));
}
replicated_arguments[replica].push_back(shaped_buffer);
}
}

View File

@ -31,23 +31,20 @@ limitations under the License.
namespace xla {
ShapedBuffer::ShapedBuffer(Shape on_device_shape, const se::Platform* platform,
int device_ordinal)
ShapedBuffer::ShapedBuffer(Shape on_device_shape, int device_ordinal)
: on_device_shape_(std::move(on_device_shape)),
platform_(platform),
device_ordinal_(device_ordinal),
buffers_(&on_device_shape_) {
on_host_shape_ = ShapeUtil::DeviceShapeToHostShape(on_device_shape_);
}
ShapedBuffer::ShapedBuffer(Shape on_host_shape, Shape on_device_shape,
const se::Platform* platform, int device_ordinal)
: ShapedBuffer(on_device_shape, platform, device_ordinal) {}
int device_ordinal)
: ShapedBuffer(on_device_shape, device_ordinal) {}
ShapedBuffer::ShapedBuffer(ShapedBuffer&& s)
: on_host_shape_(std::move(s.on_host_shape_)),
on_device_shape_(std::move(s.on_device_shape_)),
platform_(s.platform_),
device_ordinal_(s.device_ordinal_),
buffers_(std::move(s.buffers_)) {
// s.buffers_ has a pointer to s.on_device_shape_. When we move s.buffers_
@ -59,7 +56,6 @@ ShapedBuffer::ShapedBuffer(ShapedBuffer&& s)
ShapedBuffer& ShapedBuffer::operator=(ShapedBuffer&& s) {
on_device_shape_ = std::move(s.on_device_shape_);
on_host_shape_ = std::move(s.on_host_shape_);
platform_ = s.platform_;
device_ordinal_ = s.device_ordinal_;
buffers_ = std::move(s.buffers_);
// buffers_ has a pointer to its on_device_shape_. When we move s.buffers_
@ -75,7 +71,7 @@ StatusOr<ShapedBuffer> ShapedBuffer::SubShapedBuffer(
const ShapeIndex& index) const {
TF_ASSIGN_OR_RETURN(const Shape* device_sub_shape,
ShapeUtil::TryGetSubshape(on_device_shape(), index));
ShapedBuffer sub_shaped_buffer(*device_sub_shape, platform_, device_ordinal_);
ShapedBuffer sub_shaped_buffer(*device_sub_shape, device_ordinal_);
TF_ASSIGN_OR_RETURN(ShapeTree<se::DeviceMemoryBase> sub_buffers,
buffers_.SubShapeTree(index));
sub_shaped_buffer.set_buffers(std::move(sub_buffers));
@ -91,7 +87,7 @@ void ShapedBuffer::clear() {
string ShapedBuffer::ToString() const {
string s =
absl::StrCat("ShapedBuffer(", platform_->Name(), ":", device_ordinal(),
absl::StrCat("ShapedBuffer(", device_ordinal(),
"), on-device shape=" +
ShapeUtil::HumanStringWithLayout(on_device_shape()),
":\n");
@ -120,8 +116,7 @@ std::ostream& operator<<(std::ostream& out, const ShapedBuffer& buffer) {
ScopedShapedBuffer::ScopedShapedBuffer(Shape on_device_shape,
se::DeviceMemoryAllocator* allocator,
int device_ordinal)
: ShapedBuffer(std::move(on_device_shape), allocator->platform(),
device_ordinal),
: ShapedBuffer(std::move(on_device_shape), device_ordinal),
allocator_(allocator) {}
ScopedShapedBuffer::ScopedShapedBuffer(Shape on_host_shape,

View File

@ -43,12 +43,10 @@ class ShapedBuffer {
// both the on-host and on-device shape are required. The on-device shape
// determines the number of device allocations (DeviceMemoryBase) held by the
// ShapedBuffer.
ShapedBuffer(Shape on_device_shape, const se::Platform* platform,
int device_ordinal);
ShapedBuffer(Shape on_device_shape, int device_ordinal);
// TODO(b/170310047): remove this overload.
ShapedBuffer(Shape on_host_shape, Shape on_device_shape,
const se::Platform* platform, int device_ordinal);
ShapedBuffer(Shape on_host_shape, Shape on_device_shape, int device_ordinal);
// Movable, but not copyable.
ShapedBuffer(ShapedBuffer&& s);
@ -70,7 +68,6 @@ class ShapedBuffer {
// ShapedBuffer.
const Shape& on_device_shape() const { return on_device_shape_; }
const se::Platform* platform() const { return platform_; }
int device_ordinal() const { return device_ordinal_; }
// Return the root buffer of the shape (shape index {}).
@ -132,9 +129,6 @@ class ShapedBuffer {
// The shape of the data on the device.
Shape on_device_shape_;
// The platform the memory is allocated on.
const se::Platform* platform_;
// The device the memory is allocated on.
int device_ordinal_;

View File

@ -169,7 +169,7 @@ Status TransferManager::TransferArrayToDeviceAsync(
"%d < %d",
dest.size(), GetByteSizeRequirement(on_device_shape));
}
ShapedBuffer shaped_buffer(on_device_shape, stream->parent()->platform(),
ShapedBuffer shaped_buffer(on_device_shape,
stream->parent()->device_ordinal());
shaped_buffer.set_buffer(dest, /*index=*/{});
return TransferLiteralToDevice(stream, literal, shaped_buffer,
@ -193,8 +193,7 @@ void TransferManager::TransferArrayFromDevice(
"%d < %d",
source.size(), GetByteSizeRequirement(shape)));
}
ShapedBuffer shaped_buffer(shape, stream->parent()->platform(),
stream->parent()->device_ordinal());
ShapedBuffer shaped_buffer(shape, stream->parent()->device_ordinal());
shaped_buffer.set_buffer(source, /*index=*/{});
return TransferLiteralFromDevice(stream, shaped_buffer, literal,
std::move(done), transfer_metadata);

View File

@ -585,7 +585,7 @@ void XRTTupleAllocation::InitializeFromShapedBuffer(
xla::StatusOr<xla::ShapedBuffer> XRTTupleAllocation::ToShapedBuffer() {
xla::ShapedBuffer shaped_buffer(on_host_shape(), on_device_shape(),
allocator_->platform(), device_ordinal_);
device_ordinal_);
for (const auto& index_buffer : buffers_) {
if (index_buffer.second == nullptr ||
(index_buffer.second->allocation().is_null() &&

View File

@ -172,7 +172,7 @@ struct InputBuffers {
int device_ordinal) {
CHECK_NE(allocator, nullptr);
xla::ShapedBuffer shaped_buffer(std::move(host_shape), buffers.shape(),
allocator->platform(), device_ordinal);
device_ordinal);
shaped_buffer.set_buffers(buffers.Map<se::DeviceMemoryBase>(
[](xla::MaybeOwningDeviceMemory* buffer) {
CHECK(buffer);

View File

@ -34,10 +34,8 @@ xla::ShapedBuffer FromC(XLA_ShapedBuffer* c_buffer) {
i++;
}
xla::ShapedBuffer xla_shaped_buffer(
xla_on_device_shape,
tensorflow::tpu::TpuPlatformInterface::GetRegisteredPlatform(),
c_buffer->device_ordinal);
xla::ShapedBuffer xla_shaped_buffer(xla_on_device_shape,
c_buffer->device_ordinal);
xla_shaped_buffer.set_buffers(xla_shape_tree);
return xla_shaped_buffer;
}