From add27c7db6fedb2e9f32a7f3d17ef213559e70b4 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 9 Mar 2020 09:26:14 -0700 Subject: [PATCH] Add cross host send/recv to PyLocalClient. Not implemented for now. PiperOrigin-RevId: 299858816 Change-Id: Id8c14e9ce491281532ff9795b05e10582db6be00 --- tensorflow/compiler/xla/python/dlpack.cc | 3 +- .../compiler/xla/python/local_client.cc | 95 +++++++++++++++++-- tensorflow/compiler/xla/python/local_client.h | 52 ++++++++++ .../xla/python/shared_device_buffer.cc | 46 ++++++--- .../xla/python/shared_device_buffer.h | 23 +++-- .../xla/python/shared_device_buffer_test.cc | 82 ++++++++-------- 6 files changed, 229 insertions(+), 72 deletions(-) diff --git a/tensorflow/compiler/xla/python/dlpack.cc b/tensorflow/compiler/xla/python/dlpack.cc index b4ae503ba4c..ae6e54ce48f 100644 --- a/tensorflow/compiler/xla/python/dlpack.cc +++ b/tensorflow/compiler/xla/python/dlpack.cc @@ -329,11 +329,12 @@ StatusOr> DLPackManagedTensorToBuffer( if (dlmt->deleter) { on_delete_callback = [dlmt]() { dlmt->deleter(dlmt); }; } + absl::Span> definition_events; auto device_buffer = std::make_shared( /*allocator=*/nullptr, dlmt->dl_tensor.ctx.device_id, std::initializer_list{buffer}, /*children=*/std::vector>{}, - /*definition_event=*/nullptr, std::move(on_delete_callback)); + definition_events, std::move(on_delete_callback)); // We have taken ownership of the array inside the capsule; make sure the // capsule it cannot be used again. diff --git a/tensorflow/compiler/xla/python/local_client.cc b/tensorflow/compiler/xla/python/local_client.cc index e6734396afc..05caf011728 100644 --- a/tensorflow/compiler/xla/python/local_client.cc +++ b/tensorflow/compiler/xla/python/local_client.cc @@ -95,6 +95,7 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/profiler/lib/traceme.h" @@ -182,11 +183,12 @@ StatusOr> PyLocalBuffer::FromHostBuffer( }; se::DeviceMemoryBase buffer(const_cast(data), ShapeUtil::ByteSizeOf(shape)); + absl::Span> definition_events; auto device_buffer = std::make_shared( /*allocator=*/nullptr, local_device->device_ordinal(), std::initializer_list{buffer}, /*children=*/std::vector>{}, - /*definition_event=*/nullptr, std::move(on_delete_callback)); + definition_events, std::move(on_delete_callback)); return absl::make_unique( shape, shape, std::move(device_buffer), std::move(client), std::move(device)); @@ -218,7 +220,7 @@ StatusOr> PyLocalBuffer::FromHostBuffer( std::make_shared(); std::shared_ptr device_buffer = SharedDeviceBuffer::FromScopedShapedBuffer(&scoped_buffer, - definition_event); + {definition_event}); Shape on_device_shape = scoped_buffer.on_device_shape(); auto transfer_h2d = [client, transfer_manager, local_device, device_buffer, @@ -263,7 +265,7 @@ StatusOr> PyLocalBuffer::FromHostBuffer( // Sets the buffer definition event. Note: this has the side effect of // unblocking any host threads that may have been waiting to consume the // buffer. - device_buffer->definition_event()->SetDefinitionEvent( + device_buffer->definition_events()[0]->SetDefinitionEvent( std::move(event), local_device->host_to_device_stream()); if (local_device->synchronous_deallocation()) { @@ -318,7 +320,7 @@ StatusOr> PyLocalBuffer::FromHostBuffer( std::shared_ptr tuple_buffer, SharedDeviceBuffer::MakeTuple( device_buffers, on_host_shape, transfer_manager, allocator, - local_device->device_ordinal(), definition_event)); + local_device->device_ordinal(), {definition_event})); auto buffer = absl::make_unique( std::move(on_host_shape), ShapeUtil::MakeTupleShape(device_shapes), tuple_buffer, std::move(client), std::move(device)); @@ -348,6 +350,80 @@ StatusOr> PyLocalBuffer::FromHostBuffer( return buffer; } +StatusOr>> +MakeCrossHostReceiveBuffersHelper(absl::Span shapes, + std::shared_ptr client, + std::shared_ptr device) { + TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device, + device->GetLocalDeviceState()); + TransferManager* transfer_manager = + client->client()->backend().transfer_manager(); + std::vector> buffers; + buffers.reserve(shapes.size()); + se::Stream* host_to_device_stream = local_device->host_to_device_stream(); + for (const auto& shape : shapes) { + TF_ASSIGN_OR_RETURN( + ScopedShapedBuffer scoped_buffer, + transfer_manager->AllocateScopedShapedBuffer( + shape, client->allocator(), local_device->device_ordinal())); + + if (!transfer_manager->CanShapedBufferBeAccessedNow( + local_device->compute_stream()->parent(), scoped_buffer)) { + return Unimplemented( + "Cross host receive not enabled unless deallocations are deferred"); + } + + absl::InlinedVector, 2> + definition_events; + + if (scoped_buffer.on_device_shape().IsTuple()) { + TF_CHECK_OK(transfer_manager->WriteTupleIndexTablesAsync( + host_to_device_stream, scoped_buffer)); + definition_events = {std::make_shared(), + std::make_shared()}; + TF_ASSIGN_OR_RETURN(EventPool::Handle event, + local_device->event_pool().ThenAllocateAndRecordEvent( + host_to_device_stream)); + definition_events[1]->SetDefinitionEvent(std::move(event), + host_to_device_stream); + } else { + definition_events = {std::make_shared()}; + } + + std::shared_ptr device_buffer = + SharedDeviceBuffer::FromScopedShapedBuffer(&scoped_buffer, + definition_events); + Shape on_device_shape = scoped_buffer.on_device_shape(); + + auto buffer = absl::make_unique( + shape, std::move(on_device_shape), std::move(device_buffer), client, + device); + + buffers.push_back(std::move(buffer)); + } + return buffers; +} + +/*static*/ void PyLocalBuffer::MakeCrossHostReceiveBuffers( + absl::Span shapes, std::shared_ptr client, + std::shared_ptr device, PyLocalCrossHostRecvNotifier&& notifier) { + if (shapes.empty()) { + notifier(InvalidArgument( + "shapes parameter empty in MakeCrossHostReceiveBuffers")); + return; + } + PyLocalClient* client_ptr = client.get(); + auto buffer_or = MakeCrossHostReceiveBuffersHelper(shapes, std::move(client), + std::move(device)); + if (!buffer_or.ok()) { + notifier(buffer_or.status()); + return; + } + + client_ptr->EnqueueCrossHostReceive(buffer_or.ConsumeValueOrDie(), + std::move(notifier)); +} + PyLocalBuffer::PyLocalBuffer(Shape on_host_shape, Shape on_device_shape, std::shared_ptr device_buffer, std::shared_ptr client, @@ -519,12 +595,19 @@ StatusOr> PyLocalBuffer::CopyToDevice( definition_event->SetDefinitionEvent(std::move(event), transfer_stream); std::shared_ptr dst_device_buffer = - SharedDeviceBuffer::FromScopedShapedBuffer(&dst_buffer, definition_event); + SharedDeviceBuffer::FromScopedShapedBuffer(&dst_buffer, + {definition_event}); return absl::make_unique( dst_buffer.on_host_shape(), dst_buffer.on_device_shape(), std::move(dst_device_buffer), client_, dst_device); } +Status PyLocalBuffer::CopyToRemoteDevice( + absl::string_view serialized_descriptor, + std::shared_ptr dst_device) { + return client_->CopyToRemoteDevice(this, serialized_descriptor, dst_device); +} + Status PyLocalBuffer::BlockHostUntilReady() { tensorflow::profiler::TraceMe traceme("PyLocalBuffer::BlockHostUntilReady"); std::shared_ptr device_buffer = DeviceBuffer(); @@ -693,7 +776,7 @@ StatusOr> PyLocalExecutable::ExecuteHelper( std::shared_ptr out_buffer = SharedDeviceBuffer::FromScopedShapedBuffer(&result_buffer, - definition_event); + {definition_event}); if (device_state->synchronous_deallocation()) { device_buffers.push_back(out_buffer); diff --git a/tensorflow/compiler/xla/python/local_client.h b/tensorflow/compiler/xla/python/local_client.h index 726c2d9e8e5..834141de13d 100644 --- a/tensorflow/compiler/xla/python/local_client.h +++ b/tensorflow/compiler/xla/python/local_client.h @@ -35,6 +35,7 @@ limitations under the License. #include "tensorflow/compiler/xla/shape.h" #include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/framework/allocator.h" #include "tensorflow/core/lib/core/status.h" @@ -81,6 +82,19 @@ class Device { const std::string platform_name_; }; +class PyLocalBuffer; +// Helper struct for cross host transfers, returned by the callback from a call +// to PyLocalBuffer::MakeCrossHostReceiveBuffers. +struct PyLocalCrossHostRecvBuffer { + // serialized_descriptor should be transmitted to the sender and passed to a + // call to src_buffer->CopyToRemoteDevice. + std::string serialized_descriptor; + // The buffer that will hold the result of the transfer. + std::unique_ptr buffer; +}; +using PyLocalCrossHostRecvNotifier = + std::function>&&)>; + // Encapsulates the state of Python session with XLA. class PyLocalClient { public: @@ -134,6 +148,19 @@ class PyLocalClient { virtual bool EnqueueD2DTransfersOnSrcStream() const { return true; } protected: + friend class PyLocalBuffer; + virtual void EnqueueCrossHostReceive( + std::vector>&& buffers, + PyLocalCrossHostRecvNotifier&& notifier) const { + notifier(Unimplemented("Cross host receives not implemented.")); + } + + virtual Status CopyToRemoteDevice(PyLocalBuffer* buffer, + absl::string_view serialized_descriptor, + std::shared_ptr device) const { + return Unimplemented("Cross host sends not implemented."); + } + std::string platform_name_; LocalClient* client_; @@ -181,6 +208,19 @@ class PyLocalBuffer { const std::vector buffers, std::shared_ptr client, std::shared_ptr device); + // Asynchronously makes a vector of PyLocalBuffers that can be used to receive + // cross host transfers using `client` on `device'. `shapes` must be the exact + // shapes, with identical layouts, corresponding to the buffers that will be + // sent. When resources for the transfer are available, notifier will be + // called with a vector of PyLocalCrossHostRecvBuffer structs, one for each + // shape in `shapes`. Each struct contains a buffer that will contain the + // received value, and an opaque string that should be transmitted to the + // sending host and used in a call to CopyToRemoteDevice. None of the recv + // buffers will become ready until *all* of the sends have completed. + static void MakeCrossHostReceiveBuffers( + absl::Span shapes, std::shared_ptr client, + std::shared_ptr device, PyLocalCrossHostRecvNotifier&& notifier); + PyLocalBuffer(Shape on_host_shape, Shape on_device_shape, std::shared_ptr device_buffer, std::shared_ptr client, @@ -227,6 +267,18 @@ class PyLocalBuffer { StatusOr> CopyToDevice( std::shared_ptr dst_device); + // Copies the buffer to remote device `dst_device`. This call must be preceded + // by a call to MakeCrossHostReceiveBuffers on the remote host's + // dst_device. MakeCrossHostReceiveBuffers takes an array of shapes to + // construct the destination buffers, and a callback supplies an array + // containing both the destination buffers, and a serialized descriptor for + // each buffer. For each destination buffer there should be a matching call to + // src->CopyToRemoteDevice on a remote host for a src buffer of the + // corresponding shape. serialized_descriptor is the string returned by the + // callback along with the corresponding destination buffer. + Status CopyToRemoteDevice(absl::string_view serialized_descriptor, + std::shared_ptr dst_device); + // Blocks the host until the buffer's value has been computed and is ready for // immediate use on the device. Useful in particular for timing benchmarks. Status BlockHostUntilReady(); diff --git a/tensorflow/compiler/xla/python/shared_device_buffer.cc b/tensorflow/compiler/xla/python/shared_device_buffer.cc index ca6da645024..91f2b434a61 100644 --- a/tensorflow/compiler/xla/python/shared_device_buffer.cc +++ b/tensorflow/compiler/xla/python/shared_device_buffer.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/python/shared_device_buffer.h" +#include #include #include "tensorflow/stream_executor/device_memory.h" @@ -60,7 +61,8 @@ static std::shared_ptr BufferFromScopedShapedBufferIterator( int device_ordinal, se::DeviceMemoryAllocator* allocator, ShapeTree::iterator* iterator, const ShapeTree::iterator& end, - const std::shared_ptr& definition_event) { + absl::Span> + definition_events) { std::vector buffers; buffers.reserve(1); std::vector> children; @@ -78,7 +80,7 @@ static std::shared_ptr BufferFromScopedShapedBufferIterator( for (int i = 0; i < num_children; ++i) { children.push_back(BufferFromScopedShapedBufferIterator( on_host_shape.tuple_shapes(i), on_device_shape.tuple_shapes(i), - device_ordinal, allocator, iterator, end, definition_event)); + device_ordinal, allocator, iterator, end, definition_events)); } } else { // An on-host array may be an on-device tuple. For example, a complex tensor @@ -88,20 +90,21 @@ static std::shared_ptr BufferFromScopedShapedBufferIterator( [&](const Shape&, const ShapeIndex&) { consume_buffer(); }); } return std::make_shared( - absl::Span(buffers), children, definition_event); + absl::Span(buffers), children, definition_events); } /* static */ std::shared_ptr SharedDeviceBuffer::FromScopedShapedBuffer( ScopedShapedBuffer* shaped_buffer, - const std::shared_ptr& definition_event) { + absl::Span> + definition_events) { ShapeTree::iterator iterator = shaped_buffer->buffers().begin(); std::shared_ptr output = BufferFromScopedShapedBufferIterator( shaped_buffer->on_host_shape(), shaped_buffer->on_device_shape(), shaped_buffer->device_ordinal(), shaped_buffer->memory_allocator(), - &iterator, shaped_buffer->buffers().end(), definition_event); + &iterator, shaped_buffer->buffers().end(), definition_events); CHECK(iterator == shaped_buffer->buffers().end()); return output; } @@ -111,7 +114,8 @@ SharedDeviceBuffer::MakeTuple( std::vector> children, const Shape& on_host_shape, TransferManager* transfer_manager, se::DeviceMemoryAllocator* allocator, int device_ordinal, - std::shared_ptr definition_event) { + absl::Span> + definition_events) { CHECK(on_host_shape.IsTuple() && on_host_shape.tuple_shapes_size() == children.size()); TF_ASSIGN_OR_RETURN( @@ -122,7 +126,7 @@ SharedDeviceBuffer::MakeTuple( return std::make_shared( allocator, device_ordinal, std::initializer_list{device_memory.Release()}, - std::move(children), std::move(definition_event), + std::move(children), definition_events, /*on_delete_callback=*/nullptr); } @@ -130,7 +134,8 @@ SharedDeviceBuffer::MakeTuple( SharedDeviceBuffer::MakeArray( Shape on_device_shape, TransferManager* transfer_manager, se::DeviceMemoryAllocator* allocator, int device_ordinal, - std::shared_ptr definition_event) { + absl::Span> + definition_events) { std::vector device_buffers; TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( on_device_shape, [&](const Shape& subshape, const ShapeIndex&) -> Status { @@ -145,7 +150,7 @@ SharedDeviceBuffer::MakeArray( return std::make_shared( absl::Span(device_buffers), /*children=*/std::vector>{}, - std::move(definition_event)); + definition_events); } // Populates a buffer tree from a ShapeTree iterator. @@ -176,25 +181,36 @@ ShapedBuffer SharedDeviceBuffer::AsShapedBuffer(const Shape& on_host_shape, return shaped_buffer; } +namespace { + +using MoveIterator = + absl::Span>::iterator; + +} // namespace + SharedDeviceBuffer::SharedDeviceBuffer( se::DeviceMemoryAllocator* allocator, int device_ordinal, absl::Span device_memory, std::vector> children, - std::shared_ptr definition_event, + absl::Span> definition_events, std::function on_delete_callback) : allocator_(allocator), device_ordinal_(device_ordinal), device_memory_(device_memory.begin(), device_memory.end()), children_(std::move(children)), - definition_event_(std::move(definition_event)), + definition_events_( + std::move_iterator(definition_events.begin()), + std::move_iterator(definition_events.end())), on_delete_callback_(std::move(on_delete_callback)) {} SharedDeviceBuffer::SharedDeviceBuffer( absl::Span device_memory, std::vector> children, - std::shared_ptr definition_event) + absl::Span> definition_events) : children_(std::move(children)), - definition_event_(std::move(definition_event)) { + definition_events_( + std::move_iterator(definition_events.begin()), + std::move_iterator(definition_events.end())) { CHECK(!device_memory.empty()); allocator_ = device_memory.front().allocator(); device_ordinal_ = device_memory.front().device_ordinal(); @@ -222,8 +238,8 @@ SharedDeviceBuffer::~SharedDeviceBuffer() { void GetDeviceBufferDefinitionEvents( const SharedDeviceBuffer& buffer, absl::flat_hash_set* events) { - if (buffer.definition_event()) { - events->insert(buffer.definition_event().get()); + for (const auto& e : buffer.definition_events()) { + events->insert(e.get()); } for (const auto& child : buffer.children()) { GetDeviceBufferDefinitionEvents(*child, events); diff --git a/tensorflow/compiler/xla/python/shared_device_buffer.h b/tensorflow/compiler/xla/python/shared_device_buffer.h index bd4cd8f7079..3aa122c535d 100644 --- a/tensorflow/compiler/xla/python/shared_device_buffer.h +++ b/tensorflow/compiler/xla/python/shared_device_buffer.h @@ -93,20 +93,23 @@ class SharedDeviceBuffer { // buffers of the shaped_buffer. static std::shared_ptr FromScopedShapedBuffer( ScopedShapedBuffer* shaped_buffer, - const std::shared_ptr& definition_event); + absl::Span> + definition_events); // Makes a tuple buffer. Does not initialize the tuple table. static StatusOr> MakeTuple( std::vector> children, const Shape& on_host_shape, TransferManager* transfer_manager, se::DeviceMemoryAllocator* allocator, int device_ordinal, - std::shared_ptr definition_event); + absl::Span> + definition_events); // Makes an uninitialized array buffer. static StatusOr> MakeArray( Shape on_device_shape, TransferManager* transfer_manager, se::DeviceMemoryAllocator* allocator, int device_ordinal, - std::shared_ptr definition_event); + absl::Span> + definition_events); // Builds a ShapedBuffer view onto the buffers of 'tree'. We require but do // not verify that TransferManager::HostShapeToDeviceShape(on_host_shape) == @@ -126,19 +129,22 @@ class SharedDeviceBuffer { const absl::InlinedVector& device_memory() const { return device_memory_; } - const std::shared_ptr definition_event() const { - return definition_event_; + absl::Span> definition_events() + const { + return definition_events_; } SharedDeviceBuffer() = default; SharedDeviceBuffer(se::DeviceMemoryAllocator* allocator, int device_ordinal, absl::Span device_memory, std::vector> children, - std::shared_ptr definition_event, + absl::Span> + definition_events, std::function on_delete_callback); SharedDeviceBuffer(absl::Span device_memory, std::vector> children, - std::shared_ptr definition_event); + absl::Span> + definition_events); ~SharedDeviceBuffer(); private: @@ -155,7 +161,8 @@ class SharedDeviceBuffer { // ready during multistream execution. May be nullptr, which is used in the // single-stream execution case where events are not necessary for buffer // event sequencing. - std::shared_ptr definition_event_; + absl::InlinedVector, 2> + definition_events_; // A callback to call when the SharedDeviceBuffer is about to be destroyed. std::function on_delete_callback_; diff --git a/tensorflow/compiler/xla/python/shared_device_buffer_test.cc b/tensorflow/compiler/xla/python/shared_device_buffer_test.cc index b39767a0d46..05842c52a0c 100644 --- a/tensorflow/compiler/xla/python/shared_device_buffer_test.cc +++ b/tensorflow/compiler/xla/python/shared_device_buffer_test.cc @@ -28,10 +28,10 @@ TEST(SharedDeviceBufferTest, MakeArray) { LocalClient* client = ClientLibrary::LocalClientOrDie(); Shape shape = ShapeUtil::MakeShape(F32, {3, 101, 4}); - TF_ASSERT_OK_AND_ASSIGN( - auto buffer, SharedDeviceBuffer::MakeArray( - shape, client->backend().transfer_manager(), - client->backend().memory_allocator(), 0, nullptr)); + TF_ASSERT_OK_AND_ASSIGN(auto buffer, + SharedDeviceBuffer::MakeArray( + shape, client->backend().transfer_manager(), + client->backend().memory_allocator(), 0, {})); EXPECT_EQ(buffer->children().size(), 0); EXPECT_EQ(buffer->device_ordinal(), 0); EXPECT_EQ(buffer->allocator(), client->backend().memory_allocator()); @@ -45,19 +45,19 @@ TEST(SharedDeviceBufferTest, MakeTuple) { Shape a_shape = ShapeUtil::MakeShape(F32, {3, 101, 4}); Shape b_shape = ShapeUtil::MakeShape(S8, {77}); Shape tuple_shape = ShapeUtil::MakeTupleShape({a_shape, b_shape}); - TF_ASSERT_OK_AND_ASSIGN( - auto a_buffer, SharedDeviceBuffer::MakeArray( - a_shape, client->backend().transfer_manager(), - client->backend().memory_allocator(), 0, nullptr)); - TF_ASSERT_OK_AND_ASSIGN( - auto b_buffer, SharedDeviceBuffer::MakeArray( - b_shape, client->backend().transfer_manager(), - client->backend().memory_allocator(), 0, nullptr)); - TF_ASSERT_OK_AND_ASSIGN( - auto tuple_buffer, SharedDeviceBuffer::MakeTuple( - {a_buffer, b_buffer}, tuple_shape, - client->backend().transfer_manager(), - client->backend().memory_allocator(), 0, nullptr)); + TF_ASSERT_OK_AND_ASSIGN(auto a_buffer, + SharedDeviceBuffer::MakeArray( + a_shape, client->backend().transfer_manager(), + client->backend().memory_allocator(), 0, {})); + TF_ASSERT_OK_AND_ASSIGN(auto b_buffer, + SharedDeviceBuffer::MakeArray( + b_shape, client->backend().transfer_manager(), + client->backend().memory_allocator(), 0, {})); + TF_ASSERT_OK_AND_ASSIGN(auto tuple_buffer, + SharedDeviceBuffer::MakeTuple( + {a_buffer, b_buffer}, tuple_shape, + client->backend().transfer_manager(), + client->backend().memory_allocator(), 0, {})); ASSERT_EQ(tuple_buffer->children().size(), 2); EXPECT_EQ(tuple_buffer->children()[0], a_buffer); EXPECT_EQ(tuple_buffer->children()[1], b_buffer); @@ -75,30 +75,28 @@ TEST(SharedDeviceBufferTest, AsShapedBuffer) { Shape ab_tuple_shape = ShapeUtil::MakeTupleShape({a_shape, b_shape}); Shape c_shape = ShapeUtil::MakeShape(S64, {}); Shape abc_tuple_shape = ShapeUtil::MakeTupleShape({c_shape, ab_tuple_shape}); - TF_ASSERT_OK_AND_ASSIGN( - auto a_buffer, SharedDeviceBuffer::MakeArray( - a_shape, client->backend().transfer_manager(), - client->backend().memory_allocator(), 0, nullptr)); - TF_ASSERT_OK_AND_ASSIGN( - auto b_buffer, SharedDeviceBuffer::MakeArray( - b_shape, client->backend().transfer_manager(), - client->backend().memory_allocator(), 0, nullptr)); - TF_ASSERT_OK_AND_ASSIGN( - auto ab_tuple_buffer, - SharedDeviceBuffer::MakeTuple({a_buffer, b_buffer}, ab_tuple_shape, - client->backend().transfer_manager(), - client->backend().memory_allocator(), 0, - nullptr)); - TF_ASSERT_OK_AND_ASSIGN( - auto c_buffer, SharedDeviceBuffer::MakeArray( - c_shape, client->backend().transfer_manager(), - client->backend().memory_allocator(), 0, nullptr)); - TF_ASSERT_OK_AND_ASSIGN( - auto abc_tuple_buffer, - SharedDeviceBuffer::MakeTuple( - {c_buffer, ab_tuple_buffer}, abc_tuple_shape, - client->backend().transfer_manager(), - client->backend().memory_allocator(), 0, nullptr)); + TF_ASSERT_OK_AND_ASSIGN(auto a_buffer, + SharedDeviceBuffer::MakeArray( + a_shape, client->backend().transfer_manager(), + client->backend().memory_allocator(), 0, {})); + TF_ASSERT_OK_AND_ASSIGN(auto b_buffer, + SharedDeviceBuffer::MakeArray( + b_shape, client->backend().transfer_manager(), + client->backend().memory_allocator(), 0, {})); + TF_ASSERT_OK_AND_ASSIGN(auto ab_tuple_buffer, + SharedDeviceBuffer::MakeTuple( + {a_buffer, b_buffer}, ab_tuple_shape, + client->backend().transfer_manager(), + client->backend().memory_allocator(), 0, {})); + TF_ASSERT_OK_AND_ASSIGN(auto c_buffer, + SharedDeviceBuffer::MakeArray( + c_shape, client->backend().transfer_manager(), + client->backend().memory_allocator(), 0, {})); + TF_ASSERT_OK_AND_ASSIGN(auto abc_tuple_buffer, + SharedDeviceBuffer::MakeTuple( + {c_buffer, ab_tuple_buffer}, abc_tuple_shape, + client->backend().transfer_manager(), + client->backend().memory_allocator(), 0, {})); Shape abc_tuple_device_shape = client->backend().transfer_manager()->HostShapeToDeviceShape( abc_tuple_shape); @@ -140,7 +138,7 @@ TEST(SharedDeviceBufferTest, FromScopedShapedBuffer) { ScopedShapedBuffer shaped_buffer, client->LiteralToShapedBuffer(literal, /*device_ordinal=*/0)); std::shared_ptr device_buffer = - SharedDeviceBuffer::FromScopedShapedBuffer(&shaped_buffer, nullptr); + SharedDeviceBuffer::FromScopedShapedBuffer(&shaped_buffer, {}); ASSERT_EQ(device_buffer->device_memory().size(), 1); ASSERT_EQ(device_buffer->children().size(), 2);