Add cross host send/recv to PyLocalClient. Not implemented for now.

PiperOrigin-RevId: 299858816
Change-Id: Id8c14e9ce491281532ff9795b05e10582db6be00
This commit is contained in:
A. Unique TensorFlower 2020-03-09 09:26:14 -07:00 committed by TensorFlower Gardener
parent 60207a00d4
commit add27c7db6
6 changed files with 229 additions and 72 deletions

View File

@ -329,11 +329,12 @@ StatusOr<std::unique_ptr<PyLocalBuffer>> DLPackManagedTensorToBuffer(
if (dlmt->deleter) {
on_delete_callback = [dlmt]() { dlmt->deleter(dlmt); };
}
absl::Span<const std::shared_ptr<BufferDefinitionEvent>> definition_events;
auto device_buffer = std::make_shared<SharedDeviceBuffer>(
/*allocator=*/nullptr, dlmt->dl_tensor.ctx.device_id,
std::initializer_list<se::DeviceMemoryBase>{buffer},
/*children=*/std::vector<std::shared_ptr<SharedDeviceBuffer>>{},
/*definition_event=*/nullptr, std::move(on_delete_callback));
definition_events, std::move(on_delete_callback));
// We have taken ownership of the array inside the capsule; make sure the
// capsule it cannot be used again.

View File

@ -95,6 +95,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/profiler/lib/traceme.h"
@ -182,11 +183,12 @@ StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalBuffer::FromHostBuffer(
};
se::DeviceMemoryBase buffer(const_cast<void*>(data),
ShapeUtil::ByteSizeOf(shape));
absl::Span<const std::shared_ptr<BufferDefinitionEvent>> definition_events;
auto device_buffer = std::make_shared<SharedDeviceBuffer>(
/*allocator=*/nullptr, local_device->device_ordinal(),
std::initializer_list<se::DeviceMemoryBase>{buffer},
/*children=*/std::vector<std::shared_ptr<SharedDeviceBuffer>>{},
/*definition_event=*/nullptr, std::move(on_delete_callback));
definition_events, std::move(on_delete_callback));
return absl::make_unique<PyLocalBuffer>(
shape, shape, std::move(device_buffer), std::move(client),
std::move(device));
@ -218,7 +220,7 @@ StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalBuffer::FromHostBuffer(
std::make_shared<BufferDefinitionEvent>();
std::shared_ptr<SharedDeviceBuffer> device_buffer =
SharedDeviceBuffer::FromScopedShapedBuffer(&scoped_buffer,
definition_event);
{definition_event});
Shape on_device_shape = scoped_buffer.on_device_shape();
auto transfer_h2d = [client, transfer_manager, local_device, device_buffer,
@ -263,7 +265,7 @@ StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalBuffer::FromHostBuffer(
// Sets the buffer definition event. Note: this has the side effect of
// unblocking any host threads that may have been waiting to consume the
// buffer.
device_buffer->definition_event()->SetDefinitionEvent(
device_buffer->definition_events()[0]->SetDefinitionEvent(
std::move(event), local_device->host_to_device_stream());
if (local_device->synchronous_deallocation()) {
@ -318,7 +320,7 @@ StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalBuffer::FromHostBuffer(
std::shared_ptr<SharedDeviceBuffer> tuple_buffer,
SharedDeviceBuffer::MakeTuple(
device_buffers, on_host_shape, transfer_manager, allocator,
local_device->device_ordinal(), definition_event));
local_device->device_ordinal(), {definition_event}));
auto buffer = absl::make_unique<PyLocalBuffer>(
std::move(on_host_shape), ShapeUtil::MakeTupleShape(device_shapes),
tuple_buffer, std::move(client), std::move(device));
@ -348,6 +350,80 @@ StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalBuffer::FromHostBuffer(
return buffer;
}
StatusOr<std::vector<std::unique_ptr<PyLocalBuffer>>>
MakeCrossHostReceiveBuffersHelper(absl::Span<const Shape> shapes,
std::shared_ptr<PyLocalClient> client,
std::shared_ptr<Device> device) {
TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device,
device->GetLocalDeviceState());
TransferManager* transfer_manager =
client->client()->backend().transfer_manager();
std::vector<std::unique_ptr<PyLocalBuffer>> buffers;
buffers.reserve(shapes.size());
se::Stream* host_to_device_stream = local_device->host_to_device_stream();
for (const auto& shape : shapes) {
TF_ASSIGN_OR_RETURN(
ScopedShapedBuffer scoped_buffer,
transfer_manager->AllocateScopedShapedBuffer(
shape, client->allocator(), local_device->device_ordinal()));
if (!transfer_manager->CanShapedBufferBeAccessedNow(
local_device->compute_stream()->parent(), scoped_buffer)) {
return Unimplemented(
"Cross host receive not enabled unless deallocations are deferred");
}
absl::InlinedVector<std::shared_ptr<BufferDefinitionEvent>, 2>
definition_events;
if (scoped_buffer.on_device_shape().IsTuple()) {
TF_CHECK_OK(transfer_manager->WriteTupleIndexTablesAsync(
host_to_device_stream, scoped_buffer));
definition_events = {std::make_shared<BufferDefinitionEvent>(),
std::make_shared<BufferDefinitionEvent>()};
TF_ASSIGN_OR_RETURN(EventPool::Handle event,
local_device->event_pool().ThenAllocateAndRecordEvent(
host_to_device_stream));
definition_events[1]->SetDefinitionEvent(std::move(event),
host_to_device_stream);
} else {
definition_events = {std::make_shared<BufferDefinitionEvent>()};
}
std::shared_ptr<SharedDeviceBuffer> device_buffer =
SharedDeviceBuffer::FromScopedShapedBuffer(&scoped_buffer,
definition_events);
Shape on_device_shape = scoped_buffer.on_device_shape();
auto buffer = absl::make_unique<PyLocalBuffer>(
shape, std::move(on_device_shape), std::move(device_buffer), client,
device);
buffers.push_back(std::move(buffer));
}
return buffers;
}
/*static*/ void PyLocalBuffer::MakeCrossHostReceiveBuffers(
absl::Span<const Shape> shapes, std::shared_ptr<PyLocalClient> client,
std::shared_ptr<Device> device, PyLocalCrossHostRecvNotifier&& notifier) {
if (shapes.empty()) {
notifier(InvalidArgument(
"shapes parameter empty in MakeCrossHostReceiveBuffers"));
return;
}
PyLocalClient* client_ptr = client.get();
auto buffer_or = MakeCrossHostReceiveBuffersHelper(shapes, std::move(client),
std::move(device));
if (!buffer_or.ok()) {
notifier(buffer_or.status());
return;
}
client_ptr->EnqueueCrossHostReceive(buffer_or.ConsumeValueOrDie(),
std::move(notifier));
}
PyLocalBuffer::PyLocalBuffer(Shape on_host_shape, Shape on_device_shape,
std::shared_ptr<SharedDeviceBuffer> device_buffer,
std::shared_ptr<PyLocalClient> client,
@ -519,12 +595,19 @@ StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalBuffer::CopyToDevice(
definition_event->SetDefinitionEvent(std::move(event), transfer_stream);
std::shared_ptr<SharedDeviceBuffer> dst_device_buffer =
SharedDeviceBuffer::FromScopedShapedBuffer(&dst_buffer, definition_event);
SharedDeviceBuffer::FromScopedShapedBuffer(&dst_buffer,
{definition_event});
return absl::make_unique<PyLocalBuffer>(
dst_buffer.on_host_shape(), dst_buffer.on_device_shape(),
std::move(dst_device_buffer), client_, dst_device);
}
Status PyLocalBuffer::CopyToRemoteDevice(
absl::string_view serialized_descriptor,
std::shared_ptr<Device> dst_device) {
return client_->CopyToRemoteDevice(this, serialized_descriptor, dst_device);
}
Status PyLocalBuffer::BlockHostUntilReady() {
tensorflow::profiler::TraceMe traceme("PyLocalBuffer::BlockHostUntilReady");
std::shared_ptr<SharedDeviceBuffer> device_buffer = DeviceBuffer();
@ -693,7 +776,7 @@ StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalExecutable::ExecuteHelper(
std::shared_ptr<SharedDeviceBuffer> out_buffer =
SharedDeviceBuffer::FromScopedShapedBuffer(&result_buffer,
definition_event);
{definition_event});
if (device_state->synchronous_deallocation()) {
device_buffers.push_back(out_buffer);

View File

@ -35,6 +35,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/shape.h"
#include "tensorflow/compiler/xla/status.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/lib/core/status.h"
@ -81,6 +82,19 @@ class Device {
const std::string platform_name_;
};
class PyLocalBuffer;
// Helper struct for cross host transfers, returned by the callback from a call
// to PyLocalBuffer::MakeCrossHostReceiveBuffers.
struct PyLocalCrossHostRecvBuffer {
// serialized_descriptor should be transmitted to the sender and passed to a
// call to src_buffer->CopyToRemoteDevice.
std::string serialized_descriptor;
// The buffer that will hold the result of the transfer.
std::unique_ptr<PyLocalBuffer> buffer;
};
using PyLocalCrossHostRecvNotifier =
std::function<void(StatusOr<std::vector<PyLocalCrossHostRecvBuffer>>&&)>;
// Encapsulates the state of Python session with XLA.
class PyLocalClient {
public:
@ -134,6 +148,19 @@ class PyLocalClient {
virtual bool EnqueueD2DTransfersOnSrcStream() const { return true; }
protected:
friend class PyLocalBuffer;
virtual void EnqueueCrossHostReceive(
std::vector<std::unique_ptr<PyLocalBuffer>>&& buffers,
PyLocalCrossHostRecvNotifier&& notifier) const {
notifier(Unimplemented("Cross host receives not implemented."));
}
virtual Status CopyToRemoteDevice(PyLocalBuffer* buffer,
absl::string_view serialized_descriptor,
std::shared_ptr<Device> device) const {
return Unimplemented("Cross host sends not implemented.");
}
std::string platform_name_;
LocalClient* client_;
@ -181,6 +208,19 @@ class PyLocalBuffer {
const std::vector<PyLocalBuffer*> buffers,
std::shared_ptr<PyLocalClient> client, std::shared_ptr<Device> device);
// Asynchronously makes a vector of PyLocalBuffers that can be used to receive
// cross host transfers using `client` on `device'. `shapes` must be the exact
// shapes, with identical layouts, corresponding to the buffers that will be
// sent. When resources for the transfer are available, notifier will be
// called with a vector of PyLocalCrossHostRecvBuffer structs, one for each
// shape in `shapes`. Each struct contains a buffer that will contain the
// received value, and an opaque string that should be transmitted to the
// sending host and used in a call to CopyToRemoteDevice. None of the recv
// buffers will become ready until *all* of the sends have completed.
static void MakeCrossHostReceiveBuffers(
absl::Span<const Shape> shapes, std::shared_ptr<PyLocalClient> client,
std::shared_ptr<Device> device, PyLocalCrossHostRecvNotifier&& notifier);
PyLocalBuffer(Shape on_host_shape, Shape on_device_shape,
std::shared_ptr<SharedDeviceBuffer> device_buffer,
std::shared_ptr<PyLocalClient> client,
@ -227,6 +267,18 @@ class PyLocalBuffer {
StatusOr<std::unique_ptr<PyLocalBuffer>> CopyToDevice(
std::shared_ptr<Device> dst_device);
// Copies the buffer to remote device `dst_device`. This call must be preceded
// by a call to MakeCrossHostReceiveBuffers on the remote host's
// dst_device. MakeCrossHostReceiveBuffers takes an array of shapes to
// construct the destination buffers, and a callback supplies an array
// containing both the destination buffers, and a serialized descriptor for
// each buffer. For each destination buffer there should be a matching call to
// src->CopyToRemoteDevice on a remote host for a src buffer of the
// corresponding shape. serialized_descriptor is the string returned by the
// callback along with the corresponding destination buffer.
Status CopyToRemoteDevice(absl::string_view serialized_descriptor,
std::shared_ptr<Device> dst_device);
// Blocks the host until the buffer's value has been computed and is ready for
// immediate use on the device. Useful in particular for timing benchmarks.
Status BlockHostUntilReady();

View File

@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/python/shared_device_buffer.h"
#include <iterator>
#include <memory>
#include "tensorflow/stream_executor/device_memory.h"
@ -60,7 +61,8 @@ static std::shared_ptr<SharedDeviceBuffer> BufferFromScopedShapedBufferIterator(
int device_ordinal, se::DeviceMemoryAllocator* allocator,
ShapeTree<se::DeviceMemoryBase>::iterator* iterator,
const ShapeTree<se::DeviceMemoryBase>::iterator& end,
const std::shared_ptr<BufferDefinitionEvent>& definition_event) {
absl::Span<const std::shared_ptr<BufferDefinitionEvent>>
definition_events) {
std::vector<se::OwningDeviceMemory> buffers;
buffers.reserve(1);
std::vector<std::shared_ptr<SharedDeviceBuffer>> children;
@ -78,7 +80,7 @@ static std::shared_ptr<SharedDeviceBuffer> BufferFromScopedShapedBufferIterator(
for (int i = 0; i < num_children; ++i) {
children.push_back(BufferFromScopedShapedBufferIterator(
on_host_shape.tuple_shapes(i), on_device_shape.tuple_shapes(i),
device_ordinal, allocator, iterator, end, definition_event));
device_ordinal, allocator, iterator, end, definition_events));
}
} else {
// An on-host array may be an on-device tuple. For example, a complex tensor
@ -88,20 +90,21 @@ static std::shared_ptr<SharedDeviceBuffer> BufferFromScopedShapedBufferIterator(
[&](const Shape&, const ShapeIndex&) { consume_buffer(); });
}
return std::make_shared<SharedDeviceBuffer>(
absl::Span<se::OwningDeviceMemory>(buffers), children, definition_event);
absl::Span<se::OwningDeviceMemory>(buffers), children, definition_events);
}
/* static */ std::shared_ptr<SharedDeviceBuffer>
SharedDeviceBuffer::FromScopedShapedBuffer(
ScopedShapedBuffer* shaped_buffer,
const std::shared_ptr<BufferDefinitionEvent>& definition_event) {
absl::Span<const std::shared_ptr<BufferDefinitionEvent>>
definition_events) {
ShapeTree<se::DeviceMemoryBase>::iterator iterator =
shaped_buffer->buffers().begin();
std::shared_ptr<SharedDeviceBuffer> output =
BufferFromScopedShapedBufferIterator(
shaped_buffer->on_host_shape(), shaped_buffer->on_device_shape(),
shaped_buffer->device_ordinal(), shaped_buffer->memory_allocator(),
&iterator, shaped_buffer->buffers().end(), definition_event);
&iterator, shaped_buffer->buffers().end(), definition_events);
CHECK(iterator == shaped_buffer->buffers().end());
return output;
}
@ -111,7 +114,8 @@ SharedDeviceBuffer::MakeTuple(
std::vector<std::shared_ptr<SharedDeviceBuffer>> children,
const Shape& on_host_shape, TransferManager* transfer_manager,
se::DeviceMemoryAllocator* allocator, int device_ordinal,
std::shared_ptr<BufferDefinitionEvent> definition_event) {
absl::Span<const std::shared_ptr<BufferDefinitionEvent>>
definition_events) {
CHECK(on_host_shape.IsTuple() &&
on_host_shape.tuple_shapes_size() == children.size());
TF_ASSIGN_OR_RETURN(
@ -122,7 +126,7 @@ SharedDeviceBuffer::MakeTuple(
return std::make_shared<SharedDeviceBuffer>(
allocator, device_ordinal,
std::initializer_list<se::DeviceMemoryBase>{device_memory.Release()},
std::move(children), std::move(definition_event),
std::move(children), definition_events,
/*on_delete_callback=*/nullptr);
}
@ -130,7 +134,8 @@ SharedDeviceBuffer::MakeTuple(
SharedDeviceBuffer::MakeArray(
Shape on_device_shape, TransferManager* transfer_manager,
se::DeviceMemoryAllocator* allocator, int device_ordinal,
std::shared_ptr<BufferDefinitionEvent> definition_event) {
absl::Span<const std::shared_ptr<BufferDefinitionEvent>>
definition_events) {
std::vector<se::OwningDeviceMemory> device_buffers;
TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus(
on_device_shape, [&](const Shape& subshape, const ShapeIndex&) -> Status {
@ -145,7 +150,7 @@ SharedDeviceBuffer::MakeArray(
return std::make_shared<SharedDeviceBuffer>(
absl::Span<se::OwningDeviceMemory>(device_buffers),
/*children=*/std::vector<std::shared_ptr<SharedDeviceBuffer>>{},
std::move(definition_event));
definition_events);
}
// Populates a buffer tree from a ShapeTree iterator.
@ -176,25 +181,36 @@ ShapedBuffer SharedDeviceBuffer::AsShapedBuffer(const Shape& on_host_shape,
return shaped_buffer;
}
namespace {
using MoveIterator =
absl::Span<const std::shared_ptr<BufferDefinitionEvent>>::iterator;
} // namespace
SharedDeviceBuffer::SharedDeviceBuffer(
se::DeviceMemoryAllocator* allocator, int device_ordinal,
absl::Span<se::DeviceMemoryBase const> device_memory,
std::vector<std::shared_ptr<SharedDeviceBuffer>> children,
std::shared_ptr<BufferDefinitionEvent> definition_event,
absl::Span<const std::shared_ptr<BufferDefinitionEvent>> definition_events,
std::function<void()> on_delete_callback)
: allocator_(allocator),
device_ordinal_(device_ordinal),
device_memory_(device_memory.begin(), device_memory.end()),
children_(std::move(children)),
definition_event_(std::move(definition_event)),
definition_events_(
std::move_iterator<MoveIterator>(definition_events.begin()),
std::move_iterator<MoveIterator>(definition_events.end())),
on_delete_callback_(std::move(on_delete_callback)) {}
SharedDeviceBuffer::SharedDeviceBuffer(
absl::Span<se::OwningDeviceMemory> device_memory,
std::vector<std::shared_ptr<SharedDeviceBuffer>> children,
std::shared_ptr<BufferDefinitionEvent> definition_event)
absl::Span<const std::shared_ptr<BufferDefinitionEvent>> definition_events)
: children_(std::move(children)),
definition_event_(std::move(definition_event)) {
definition_events_(
std::move_iterator<MoveIterator>(definition_events.begin()),
std::move_iterator<MoveIterator>(definition_events.end())) {
CHECK(!device_memory.empty());
allocator_ = device_memory.front().allocator();
device_ordinal_ = device_memory.front().device_ordinal();
@ -222,8 +238,8 @@ SharedDeviceBuffer::~SharedDeviceBuffer() {
void GetDeviceBufferDefinitionEvents(
const SharedDeviceBuffer& buffer,
absl::flat_hash_set<BufferDefinitionEvent*>* events) {
if (buffer.definition_event()) {
events->insert(buffer.definition_event().get());
for (const auto& e : buffer.definition_events()) {
events->insert(e.get());
}
for (const auto& child : buffer.children()) {
GetDeviceBufferDefinitionEvents(*child, events);

View File

@ -93,20 +93,23 @@ class SharedDeviceBuffer {
// buffers of the shaped_buffer.
static std::shared_ptr<SharedDeviceBuffer> FromScopedShapedBuffer(
ScopedShapedBuffer* shaped_buffer,
const std::shared_ptr<BufferDefinitionEvent>& definition_event);
absl::Span<const std::shared_ptr<BufferDefinitionEvent>>
definition_events);
// Makes a tuple buffer. Does not initialize the tuple table.
static StatusOr<std::shared_ptr<SharedDeviceBuffer>> MakeTuple(
std::vector<std::shared_ptr<SharedDeviceBuffer>> children,
const Shape& on_host_shape, TransferManager* transfer_manager,
se::DeviceMemoryAllocator* allocator, int device_ordinal,
std::shared_ptr<BufferDefinitionEvent> definition_event);
absl::Span<const std::shared_ptr<BufferDefinitionEvent>>
definition_events);
// Makes an uninitialized array buffer.
static StatusOr<std::shared_ptr<SharedDeviceBuffer>> MakeArray(
Shape on_device_shape, TransferManager* transfer_manager,
se::DeviceMemoryAllocator* allocator, int device_ordinal,
std::shared_ptr<BufferDefinitionEvent> definition_event);
absl::Span<const std::shared_ptr<BufferDefinitionEvent>>
definition_events);
// Builds a ShapedBuffer view onto the buffers of 'tree'. We require but do
// not verify that TransferManager::HostShapeToDeviceShape(on_host_shape) ==
@ -126,19 +129,22 @@ class SharedDeviceBuffer {
const absl::InlinedVector<se::DeviceMemoryBase, 1>& device_memory() const {
return device_memory_;
}
const std::shared_ptr<BufferDefinitionEvent> definition_event() const {
return definition_event_;
absl::Span<const std::shared_ptr<BufferDefinitionEvent>> definition_events()
const {
return definition_events_;
}
SharedDeviceBuffer() = default;
SharedDeviceBuffer(se::DeviceMemoryAllocator* allocator, int device_ordinal,
absl::Span<se::DeviceMemoryBase const> device_memory,
std::vector<std::shared_ptr<SharedDeviceBuffer>> children,
std::shared_ptr<BufferDefinitionEvent> definition_event,
absl::Span<const std::shared_ptr<BufferDefinitionEvent>>
definition_events,
std::function<void()> on_delete_callback);
SharedDeviceBuffer(absl::Span<se::OwningDeviceMemory> device_memory,
std::vector<std::shared_ptr<SharedDeviceBuffer>> children,
std::shared_ptr<BufferDefinitionEvent> definition_event);
absl::Span<const std::shared_ptr<BufferDefinitionEvent>>
definition_events);
~SharedDeviceBuffer();
private:
@ -155,7 +161,8 @@ class SharedDeviceBuffer {
// ready during multistream execution. May be nullptr, which is used in the
// single-stream execution case where events are not necessary for buffer
// event sequencing.
std::shared_ptr<BufferDefinitionEvent> definition_event_;
absl::InlinedVector<std::shared_ptr<BufferDefinitionEvent>, 2>
definition_events_;
// A callback to call when the SharedDeviceBuffer is about to be destroyed.
std::function<void()> on_delete_callback_;

View File

@ -28,10 +28,10 @@ TEST(SharedDeviceBufferTest, MakeArray) {
LocalClient* client = ClientLibrary::LocalClientOrDie();
Shape shape = ShapeUtil::MakeShape(F32, {3, 101, 4});
TF_ASSERT_OK_AND_ASSIGN(
auto buffer, SharedDeviceBuffer::MakeArray(
shape, client->backend().transfer_manager(),
client->backend().memory_allocator(), 0, nullptr));
TF_ASSERT_OK_AND_ASSIGN(auto buffer,
SharedDeviceBuffer::MakeArray(
shape, client->backend().transfer_manager(),
client->backend().memory_allocator(), 0, {}));
EXPECT_EQ(buffer->children().size(), 0);
EXPECT_EQ(buffer->device_ordinal(), 0);
EXPECT_EQ(buffer->allocator(), client->backend().memory_allocator());
@ -45,19 +45,19 @@ TEST(SharedDeviceBufferTest, MakeTuple) {
Shape a_shape = ShapeUtil::MakeShape(F32, {3, 101, 4});
Shape b_shape = ShapeUtil::MakeShape(S8, {77});
Shape tuple_shape = ShapeUtil::MakeTupleShape({a_shape, b_shape});
TF_ASSERT_OK_AND_ASSIGN(
auto a_buffer, SharedDeviceBuffer::MakeArray(
a_shape, client->backend().transfer_manager(),
client->backend().memory_allocator(), 0, nullptr));
TF_ASSERT_OK_AND_ASSIGN(
auto b_buffer, SharedDeviceBuffer::MakeArray(
b_shape, client->backend().transfer_manager(),
client->backend().memory_allocator(), 0, nullptr));
TF_ASSERT_OK_AND_ASSIGN(
auto tuple_buffer, SharedDeviceBuffer::MakeTuple(
{a_buffer, b_buffer}, tuple_shape,
client->backend().transfer_manager(),
client->backend().memory_allocator(), 0, nullptr));
TF_ASSERT_OK_AND_ASSIGN(auto a_buffer,
SharedDeviceBuffer::MakeArray(
a_shape, client->backend().transfer_manager(),
client->backend().memory_allocator(), 0, {}));
TF_ASSERT_OK_AND_ASSIGN(auto b_buffer,
SharedDeviceBuffer::MakeArray(
b_shape, client->backend().transfer_manager(),
client->backend().memory_allocator(), 0, {}));
TF_ASSERT_OK_AND_ASSIGN(auto tuple_buffer,
SharedDeviceBuffer::MakeTuple(
{a_buffer, b_buffer}, tuple_shape,
client->backend().transfer_manager(),
client->backend().memory_allocator(), 0, {}));
ASSERT_EQ(tuple_buffer->children().size(), 2);
EXPECT_EQ(tuple_buffer->children()[0], a_buffer);
EXPECT_EQ(tuple_buffer->children()[1], b_buffer);
@ -75,30 +75,28 @@ TEST(SharedDeviceBufferTest, AsShapedBuffer) {
Shape ab_tuple_shape = ShapeUtil::MakeTupleShape({a_shape, b_shape});
Shape c_shape = ShapeUtil::MakeShape(S64, {});
Shape abc_tuple_shape = ShapeUtil::MakeTupleShape({c_shape, ab_tuple_shape});
TF_ASSERT_OK_AND_ASSIGN(
auto a_buffer, SharedDeviceBuffer::MakeArray(
a_shape, client->backend().transfer_manager(),
client->backend().memory_allocator(), 0, nullptr));
TF_ASSERT_OK_AND_ASSIGN(
auto b_buffer, SharedDeviceBuffer::MakeArray(
b_shape, client->backend().transfer_manager(),
client->backend().memory_allocator(), 0, nullptr));
TF_ASSERT_OK_AND_ASSIGN(
auto ab_tuple_buffer,
SharedDeviceBuffer::MakeTuple({a_buffer, b_buffer}, ab_tuple_shape,
client->backend().transfer_manager(),
client->backend().memory_allocator(), 0,
nullptr));
TF_ASSERT_OK_AND_ASSIGN(
auto c_buffer, SharedDeviceBuffer::MakeArray(
c_shape, client->backend().transfer_manager(),
client->backend().memory_allocator(), 0, nullptr));
TF_ASSERT_OK_AND_ASSIGN(
auto abc_tuple_buffer,
SharedDeviceBuffer::MakeTuple(
{c_buffer, ab_tuple_buffer}, abc_tuple_shape,
client->backend().transfer_manager(),
client->backend().memory_allocator(), 0, nullptr));
TF_ASSERT_OK_AND_ASSIGN(auto a_buffer,
SharedDeviceBuffer::MakeArray(
a_shape, client->backend().transfer_manager(),
client->backend().memory_allocator(), 0, {}));
TF_ASSERT_OK_AND_ASSIGN(auto b_buffer,
SharedDeviceBuffer::MakeArray(
b_shape, client->backend().transfer_manager(),
client->backend().memory_allocator(), 0, {}));
TF_ASSERT_OK_AND_ASSIGN(auto ab_tuple_buffer,
SharedDeviceBuffer::MakeTuple(
{a_buffer, b_buffer}, ab_tuple_shape,
client->backend().transfer_manager(),
client->backend().memory_allocator(), 0, {}));
TF_ASSERT_OK_AND_ASSIGN(auto c_buffer,
SharedDeviceBuffer::MakeArray(
c_shape, client->backend().transfer_manager(),
client->backend().memory_allocator(), 0, {}));
TF_ASSERT_OK_AND_ASSIGN(auto abc_tuple_buffer,
SharedDeviceBuffer::MakeTuple(
{c_buffer, ab_tuple_buffer}, abc_tuple_shape,
client->backend().transfer_manager(),
client->backend().memory_allocator(), 0, {}));
Shape abc_tuple_device_shape =
client->backend().transfer_manager()->HostShapeToDeviceShape(
abc_tuple_shape);
@ -140,7 +138,7 @@ TEST(SharedDeviceBufferTest, FromScopedShapedBuffer) {
ScopedShapedBuffer shaped_buffer,
client->LiteralToShapedBuffer(literal, /*device_ordinal=*/0));
std::shared_ptr<SharedDeviceBuffer> device_buffer =
SharedDeviceBuffer::FromScopedShapedBuffer(&shaped_buffer, nullptr);
SharedDeviceBuffer::FromScopedShapedBuffer(&shaped_buffer, {});
ASSERT_EQ(device_buffer->device_memory().size(), 1);
ASSERT_EQ(device_buffer->children().size(), 2);