Add cross host send/recv to PyLocalClient. Not implemented for now.
PiperOrigin-RevId: 299858816 Change-Id: Id8c14e9ce491281532ff9795b05e10582db6be00
This commit is contained in:
parent
60207a00d4
commit
add27c7db6
@ -329,11 +329,12 @@ StatusOr<std::unique_ptr<PyLocalBuffer>> DLPackManagedTensorToBuffer(
|
||||
if (dlmt->deleter) {
|
||||
on_delete_callback = [dlmt]() { dlmt->deleter(dlmt); };
|
||||
}
|
||||
absl::Span<const std::shared_ptr<BufferDefinitionEvent>> definition_events;
|
||||
auto device_buffer = std::make_shared<SharedDeviceBuffer>(
|
||||
/*allocator=*/nullptr, dlmt->dl_tensor.ctx.device_id,
|
||||
std::initializer_list<se::DeviceMemoryBase>{buffer},
|
||||
/*children=*/std::vector<std::shared_ptr<SharedDeviceBuffer>>{},
|
||||
/*definition_event=*/nullptr, std::move(on_delete_callback));
|
||||
definition_events, std::move(on_delete_callback));
|
||||
|
||||
// We have taken ownership of the array inside the capsule; make sure the
|
||||
// capsule it cannot be used again.
|
||||
|
@ -95,6 +95,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/profiler/lib/traceme.h"
|
||||
|
||||
@ -182,11 +183,12 @@ StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalBuffer::FromHostBuffer(
|
||||
};
|
||||
se::DeviceMemoryBase buffer(const_cast<void*>(data),
|
||||
ShapeUtil::ByteSizeOf(shape));
|
||||
absl::Span<const std::shared_ptr<BufferDefinitionEvent>> definition_events;
|
||||
auto device_buffer = std::make_shared<SharedDeviceBuffer>(
|
||||
/*allocator=*/nullptr, local_device->device_ordinal(),
|
||||
std::initializer_list<se::DeviceMemoryBase>{buffer},
|
||||
/*children=*/std::vector<std::shared_ptr<SharedDeviceBuffer>>{},
|
||||
/*definition_event=*/nullptr, std::move(on_delete_callback));
|
||||
definition_events, std::move(on_delete_callback));
|
||||
return absl::make_unique<PyLocalBuffer>(
|
||||
shape, shape, std::move(device_buffer), std::move(client),
|
||||
std::move(device));
|
||||
@ -218,7 +220,7 @@ StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalBuffer::FromHostBuffer(
|
||||
std::make_shared<BufferDefinitionEvent>();
|
||||
std::shared_ptr<SharedDeviceBuffer> device_buffer =
|
||||
SharedDeviceBuffer::FromScopedShapedBuffer(&scoped_buffer,
|
||||
definition_event);
|
||||
{definition_event});
|
||||
Shape on_device_shape = scoped_buffer.on_device_shape();
|
||||
|
||||
auto transfer_h2d = [client, transfer_manager, local_device, device_buffer,
|
||||
@ -263,7 +265,7 @@ StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalBuffer::FromHostBuffer(
|
||||
// Sets the buffer definition event. Note: this has the side effect of
|
||||
// unblocking any host threads that may have been waiting to consume the
|
||||
// buffer.
|
||||
device_buffer->definition_event()->SetDefinitionEvent(
|
||||
device_buffer->definition_events()[0]->SetDefinitionEvent(
|
||||
std::move(event), local_device->host_to_device_stream());
|
||||
|
||||
if (local_device->synchronous_deallocation()) {
|
||||
@ -318,7 +320,7 @@ StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalBuffer::FromHostBuffer(
|
||||
std::shared_ptr<SharedDeviceBuffer> tuple_buffer,
|
||||
SharedDeviceBuffer::MakeTuple(
|
||||
device_buffers, on_host_shape, transfer_manager, allocator,
|
||||
local_device->device_ordinal(), definition_event));
|
||||
local_device->device_ordinal(), {definition_event}));
|
||||
auto buffer = absl::make_unique<PyLocalBuffer>(
|
||||
std::move(on_host_shape), ShapeUtil::MakeTupleShape(device_shapes),
|
||||
tuple_buffer, std::move(client), std::move(device));
|
||||
@ -348,6 +350,80 @@ StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalBuffer::FromHostBuffer(
|
||||
return buffer;
|
||||
}
|
||||
|
||||
StatusOr<std::vector<std::unique_ptr<PyLocalBuffer>>>
|
||||
MakeCrossHostReceiveBuffersHelper(absl::Span<const Shape> shapes,
|
||||
std::shared_ptr<PyLocalClient> client,
|
||||
std::shared_ptr<Device> device) {
|
||||
TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device,
|
||||
device->GetLocalDeviceState());
|
||||
TransferManager* transfer_manager =
|
||||
client->client()->backend().transfer_manager();
|
||||
std::vector<std::unique_ptr<PyLocalBuffer>> buffers;
|
||||
buffers.reserve(shapes.size());
|
||||
se::Stream* host_to_device_stream = local_device->host_to_device_stream();
|
||||
for (const auto& shape : shapes) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
ScopedShapedBuffer scoped_buffer,
|
||||
transfer_manager->AllocateScopedShapedBuffer(
|
||||
shape, client->allocator(), local_device->device_ordinal()));
|
||||
|
||||
if (!transfer_manager->CanShapedBufferBeAccessedNow(
|
||||
local_device->compute_stream()->parent(), scoped_buffer)) {
|
||||
return Unimplemented(
|
||||
"Cross host receive not enabled unless deallocations are deferred");
|
||||
}
|
||||
|
||||
absl::InlinedVector<std::shared_ptr<BufferDefinitionEvent>, 2>
|
||||
definition_events;
|
||||
|
||||
if (scoped_buffer.on_device_shape().IsTuple()) {
|
||||
TF_CHECK_OK(transfer_manager->WriteTupleIndexTablesAsync(
|
||||
host_to_device_stream, scoped_buffer));
|
||||
definition_events = {std::make_shared<BufferDefinitionEvent>(),
|
||||
std::make_shared<BufferDefinitionEvent>()};
|
||||
TF_ASSIGN_OR_RETURN(EventPool::Handle event,
|
||||
local_device->event_pool().ThenAllocateAndRecordEvent(
|
||||
host_to_device_stream));
|
||||
definition_events[1]->SetDefinitionEvent(std::move(event),
|
||||
host_to_device_stream);
|
||||
} else {
|
||||
definition_events = {std::make_shared<BufferDefinitionEvent>()};
|
||||
}
|
||||
|
||||
std::shared_ptr<SharedDeviceBuffer> device_buffer =
|
||||
SharedDeviceBuffer::FromScopedShapedBuffer(&scoped_buffer,
|
||||
definition_events);
|
||||
Shape on_device_shape = scoped_buffer.on_device_shape();
|
||||
|
||||
auto buffer = absl::make_unique<PyLocalBuffer>(
|
||||
shape, std::move(on_device_shape), std::move(device_buffer), client,
|
||||
device);
|
||||
|
||||
buffers.push_back(std::move(buffer));
|
||||
}
|
||||
return buffers;
|
||||
}
|
||||
|
||||
/*static*/ void PyLocalBuffer::MakeCrossHostReceiveBuffers(
|
||||
absl::Span<const Shape> shapes, std::shared_ptr<PyLocalClient> client,
|
||||
std::shared_ptr<Device> device, PyLocalCrossHostRecvNotifier&& notifier) {
|
||||
if (shapes.empty()) {
|
||||
notifier(InvalidArgument(
|
||||
"shapes parameter empty in MakeCrossHostReceiveBuffers"));
|
||||
return;
|
||||
}
|
||||
PyLocalClient* client_ptr = client.get();
|
||||
auto buffer_or = MakeCrossHostReceiveBuffersHelper(shapes, std::move(client),
|
||||
std::move(device));
|
||||
if (!buffer_or.ok()) {
|
||||
notifier(buffer_or.status());
|
||||
return;
|
||||
}
|
||||
|
||||
client_ptr->EnqueueCrossHostReceive(buffer_or.ConsumeValueOrDie(),
|
||||
std::move(notifier));
|
||||
}
|
||||
|
||||
PyLocalBuffer::PyLocalBuffer(Shape on_host_shape, Shape on_device_shape,
|
||||
std::shared_ptr<SharedDeviceBuffer> device_buffer,
|
||||
std::shared_ptr<PyLocalClient> client,
|
||||
@ -519,12 +595,19 @@ StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalBuffer::CopyToDevice(
|
||||
definition_event->SetDefinitionEvent(std::move(event), transfer_stream);
|
||||
|
||||
std::shared_ptr<SharedDeviceBuffer> dst_device_buffer =
|
||||
SharedDeviceBuffer::FromScopedShapedBuffer(&dst_buffer, definition_event);
|
||||
SharedDeviceBuffer::FromScopedShapedBuffer(&dst_buffer,
|
||||
{definition_event});
|
||||
return absl::make_unique<PyLocalBuffer>(
|
||||
dst_buffer.on_host_shape(), dst_buffer.on_device_shape(),
|
||||
std::move(dst_device_buffer), client_, dst_device);
|
||||
}
|
||||
|
||||
Status PyLocalBuffer::CopyToRemoteDevice(
|
||||
absl::string_view serialized_descriptor,
|
||||
std::shared_ptr<Device> dst_device) {
|
||||
return client_->CopyToRemoteDevice(this, serialized_descriptor, dst_device);
|
||||
}
|
||||
|
||||
Status PyLocalBuffer::BlockHostUntilReady() {
|
||||
tensorflow::profiler::TraceMe traceme("PyLocalBuffer::BlockHostUntilReady");
|
||||
std::shared_ptr<SharedDeviceBuffer> device_buffer = DeviceBuffer();
|
||||
@ -693,7 +776,7 @@ StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalExecutable::ExecuteHelper(
|
||||
|
||||
std::shared_ptr<SharedDeviceBuffer> out_buffer =
|
||||
SharedDeviceBuffer::FromScopedShapedBuffer(&result_buffer,
|
||||
definition_event);
|
||||
{definition_event});
|
||||
|
||||
if (device_state->synchronous_deallocation()) {
|
||||
device_buffers.push_back(out_buffer);
|
||||
|
@ -35,6 +35,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/shape.h"
|
||||
#include "tensorflow/compiler/xla/status.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/core/framework/allocator.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
|
||||
@ -81,6 +82,19 @@ class Device {
|
||||
const std::string platform_name_;
|
||||
};
|
||||
|
||||
class PyLocalBuffer;
|
||||
// Helper struct for cross host transfers, returned by the callback from a call
|
||||
// to PyLocalBuffer::MakeCrossHostReceiveBuffers.
|
||||
struct PyLocalCrossHostRecvBuffer {
|
||||
// serialized_descriptor should be transmitted to the sender and passed to a
|
||||
// call to src_buffer->CopyToRemoteDevice.
|
||||
std::string serialized_descriptor;
|
||||
// The buffer that will hold the result of the transfer.
|
||||
std::unique_ptr<PyLocalBuffer> buffer;
|
||||
};
|
||||
using PyLocalCrossHostRecvNotifier =
|
||||
std::function<void(StatusOr<std::vector<PyLocalCrossHostRecvBuffer>>&&)>;
|
||||
|
||||
// Encapsulates the state of Python session with XLA.
|
||||
class PyLocalClient {
|
||||
public:
|
||||
@ -134,6 +148,19 @@ class PyLocalClient {
|
||||
virtual bool EnqueueD2DTransfersOnSrcStream() const { return true; }
|
||||
|
||||
protected:
|
||||
friend class PyLocalBuffer;
|
||||
virtual void EnqueueCrossHostReceive(
|
||||
std::vector<std::unique_ptr<PyLocalBuffer>>&& buffers,
|
||||
PyLocalCrossHostRecvNotifier&& notifier) const {
|
||||
notifier(Unimplemented("Cross host receives not implemented."));
|
||||
}
|
||||
|
||||
virtual Status CopyToRemoteDevice(PyLocalBuffer* buffer,
|
||||
absl::string_view serialized_descriptor,
|
||||
std::shared_ptr<Device> device) const {
|
||||
return Unimplemented("Cross host sends not implemented.");
|
||||
}
|
||||
|
||||
std::string platform_name_;
|
||||
LocalClient* client_;
|
||||
|
||||
@ -181,6 +208,19 @@ class PyLocalBuffer {
|
||||
const std::vector<PyLocalBuffer*> buffers,
|
||||
std::shared_ptr<PyLocalClient> client, std::shared_ptr<Device> device);
|
||||
|
||||
// Asynchronously makes a vector of PyLocalBuffers that can be used to receive
|
||||
// cross host transfers using `client` on `device'. `shapes` must be the exact
|
||||
// shapes, with identical layouts, corresponding to the buffers that will be
|
||||
// sent. When resources for the transfer are available, notifier will be
|
||||
// called with a vector of PyLocalCrossHostRecvBuffer structs, one for each
|
||||
// shape in `shapes`. Each struct contains a buffer that will contain the
|
||||
// received value, and an opaque string that should be transmitted to the
|
||||
// sending host and used in a call to CopyToRemoteDevice. None of the recv
|
||||
// buffers will become ready until *all* of the sends have completed.
|
||||
static void MakeCrossHostReceiveBuffers(
|
||||
absl::Span<const Shape> shapes, std::shared_ptr<PyLocalClient> client,
|
||||
std::shared_ptr<Device> device, PyLocalCrossHostRecvNotifier&& notifier);
|
||||
|
||||
PyLocalBuffer(Shape on_host_shape, Shape on_device_shape,
|
||||
std::shared_ptr<SharedDeviceBuffer> device_buffer,
|
||||
std::shared_ptr<PyLocalClient> client,
|
||||
@ -227,6 +267,18 @@ class PyLocalBuffer {
|
||||
StatusOr<std::unique_ptr<PyLocalBuffer>> CopyToDevice(
|
||||
std::shared_ptr<Device> dst_device);
|
||||
|
||||
// Copies the buffer to remote device `dst_device`. This call must be preceded
|
||||
// by a call to MakeCrossHostReceiveBuffers on the remote host's
|
||||
// dst_device. MakeCrossHostReceiveBuffers takes an array of shapes to
|
||||
// construct the destination buffers, and a callback supplies an array
|
||||
// containing both the destination buffers, and a serialized descriptor for
|
||||
// each buffer. For each destination buffer there should be a matching call to
|
||||
// src->CopyToRemoteDevice on a remote host for a src buffer of the
|
||||
// corresponding shape. serialized_descriptor is the string returned by the
|
||||
// callback along with the corresponding destination buffer.
|
||||
Status CopyToRemoteDevice(absl::string_view serialized_descriptor,
|
||||
std::shared_ptr<Device> dst_device);
|
||||
|
||||
// Blocks the host until the buffer's value has been computed and is ready for
|
||||
// immediate use on the device. Useful in particular for timing benchmarks.
|
||||
Status BlockHostUntilReady();
|
||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/xla/python/shared_device_buffer.h"
|
||||
|
||||
#include <iterator>
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/stream_executor/device_memory.h"
|
||||
@ -60,7 +61,8 @@ static std::shared_ptr<SharedDeviceBuffer> BufferFromScopedShapedBufferIterator(
|
||||
int device_ordinal, se::DeviceMemoryAllocator* allocator,
|
||||
ShapeTree<se::DeviceMemoryBase>::iterator* iterator,
|
||||
const ShapeTree<se::DeviceMemoryBase>::iterator& end,
|
||||
const std::shared_ptr<BufferDefinitionEvent>& definition_event) {
|
||||
absl::Span<const std::shared_ptr<BufferDefinitionEvent>>
|
||||
definition_events) {
|
||||
std::vector<se::OwningDeviceMemory> buffers;
|
||||
buffers.reserve(1);
|
||||
std::vector<std::shared_ptr<SharedDeviceBuffer>> children;
|
||||
@ -78,7 +80,7 @@ static std::shared_ptr<SharedDeviceBuffer> BufferFromScopedShapedBufferIterator(
|
||||
for (int i = 0; i < num_children; ++i) {
|
||||
children.push_back(BufferFromScopedShapedBufferIterator(
|
||||
on_host_shape.tuple_shapes(i), on_device_shape.tuple_shapes(i),
|
||||
device_ordinal, allocator, iterator, end, definition_event));
|
||||
device_ordinal, allocator, iterator, end, definition_events));
|
||||
}
|
||||
} else {
|
||||
// An on-host array may be an on-device tuple. For example, a complex tensor
|
||||
@ -88,20 +90,21 @@ static std::shared_ptr<SharedDeviceBuffer> BufferFromScopedShapedBufferIterator(
|
||||
[&](const Shape&, const ShapeIndex&) { consume_buffer(); });
|
||||
}
|
||||
return std::make_shared<SharedDeviceBuffer>(
|
||||
absl::Span<se::OwningDeviceMemory>(buffers), children, definition_event);
|
||||
absl::Span<se::OwningDeviceMemory>(buffers), children, definition_events);
|
||||
}
|
||||
|
||||
/* static */ std::shared_ptr<SharedDeviceBuffer>
|
||||
SharedDeviceBuffer::FromScopedShapedBuffer(
|
||||
ScopedShapedBuffer* shaped_buffer,
|
||||
const std::shared_ptr<BufferDefinitionEvent>& definition_event) {
|
||||
absl::Span<const std::shared_ptr<BufferDefinitionEvent>>
|
||||
definition_events) {
|
||||
ShapeTree<se::DeviceMemoryBase>::iterator iterator =
|
||||
shaped_buffer->buffers().begin();
|
||||
std::shared_ptr<SharedDeviceBuffer> output =
|
||||
BufferFromScopedShapedBufferIterator(
|
||||
shaped_buffer->on_host_shape(), shaped_buffer->on_device_shape(),
|
||||
shaped_buffer->device_ordinal(), shaped_buffer->memory_allocator(),
|
||||
&iterator, shaped_buffer->buffers().end(), definition_event);
|
||||
&iterator, shaped_buffer->buffers().end(), definition_events);
|
||||
CHECK(iterator == shaped_buffer->buffers().end());
|
||||
return output;
|
||||
}
|
||||
@ -111,7 +114,8 @@ SharedDeviceBuffer::MakeTuple(
|
||||
std::vector<std::shared_ptr<SharedDeviceBuffer>> children,
|
||||
const Shape& on_host_shape, TransferManager* transfer_manager,
|
||||
se::DeviceMemoryAllocator* allocator, int device_ordinal,
|
||||
std::shared_ptr<BufferDefinitionEvent> definition_event) {
|
||||
absl::Span<const std::shared_ptr<BufferDefinitionEvent>>
|
||||
definition_events) {
|
||||
CHECK(on_host_shape.IsTuple() &&
|
||||
on_host_shape.tuple_shapes_size() == children.size());
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
@ -122,7 +126,7 @@ SharedDeviceBuffer::MakeTuple(
|
||||
return std::make_shared<SharedDeviceBuffer>(
|
||||
allocator, device_ordinal,
|
||||
std::initializer_list<se::DeviceMemoryBase>{device_memory.Release()},
|
||||
std::move(children), std::move(definition_event),
|
||||
std::move(children), definition_events,
|
||||
/*on_delete_callback=*/nullptr);
|
||||
}
|
||||
|
||||
@ -130,7 +134,8 @@ SharedDeviceBuffer::MakeTuple(
|
||||
SharedDeviceBuffer::MakeArray(
|
||||
Shape on_device_shape, TransferManager* transfer_manager,
|
||||
se::DeviceMemoryAllocator* allocator, int device_ordinal,
|
||||
std::shared_ptr<BufferDefinitionEvent> definition_event) {
|
||||
absl::Span<const std::shared_ptr<BufferDefinitionEvent>>
|
||||
definition_events) {
|
||||
std::vector<se::OwningDeviceMemory> device_buffers;
|
||||
TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus(
|
||||
on_device_shape, [&](const Shape& subshape, const ShapeIndex&) -> Status {
|
||||
@ -145,7 +150,7 @@ SharedDeviceBuffer::MakeArray(
|
||||
return std::make_shared<SharedDeviceBuffer>(
|
||||
absl::Span<se::OwningDeviceMemory>(device_buffers),
|
||||
/*children=*/std::vector<std::shared_ptr<SharedDeviceBuffer>>{},
|
||||
std::move(definition_event));
|
||||
definition_events);
|
||||
}
|
||||
|
||||
// Populates a buffer tree from a ShapeTree iterator.
|
||||
@ -176,25 +181,36 @@ ShapedBuffer SharedDeviceBuffer::AsShapedBuffer(const Shape& on_host_shape,
|
||||
return shaped_buffer;
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
using MoveIterator =
|
||||
absl::Span<const std::shared_ptr<BufferDefinitionEvent>>::iterator;
|
||||
|
||||
} // namespace
|
||||
|
||||
SharedDeviceBuffer::SharedDeviceBuffer(
|
||||
se::DeviceMemoryAllocator* allocator, int device_ordinal,
|
||||
absl::Span<se::DeviceMemoryBase const> device_memory,
|
||||
std::vector<std::shared_ptr<SharedDeviceBuffer>> children,
|
||||
std::shared_ptr<BufferDefinitionEvent> definition_event,
|
||||
absl::Span<const std::shared_ptr<BufferDefinitionEvent>> definition_events,
|
||||
std::function<void()> on_delete_callback)
|
||||
: allocator_(allocator),
|
||||
device_ordinal_(device_ordinal),
|
||||
device_memory_(device_memory.begin(), device_memory.end()),
|
||||
children_(std::move(children)),
|
||||
definition_event_(std::move(definition_event)),
|
||||
definition_events_(
|
||||
std::move_iterator<MoveIterator>(definition_events.begin()),
|
||||
std::move_iterator<MoveIterator>(definition_events.end())),
|
||||
on_delete_callback_(std::move(on_delete_callback)) {}
|
||||
|
||||
SharedDeviceBuffer::SharedDeviceBuffer(
|
||||
absl::Span<se::OwningDeviceMemory> device_memory,
|
||||
std::vector<std::shared_ptr<SharedDeviceBuffer>> children,
|
||||
std::shared_ptr<BufferDefinitionEvent> definition_event)
|
||||
absl::Span<const std::shared_ptr<BufferDefinitionEvent>> definition_events)
|
||||
: children_(std::move(children)),
|
||||
definition_event_(std::move(definition_event)) {
|
||||
definition_events_(
|
||||
std::move_iterator<MoveIterator>(definition_events.begin()),
|
||||
std::move_iterator<MoveIterator>(definition_events.end())) {
|
||||
CHECK(!device_memory.empty());
|
||||
allocator_ = device_memory.front().allocator();
|
||||
device_ordinal_ = device_memory.front().device_ordinal();
|
||||
@ -222,8 +238,8 @@ SharedDeviceBuffer::~SharedDeviceBuffer() {
|
||||
void GetDeviceBufferDefinitionEvents(
|
||||
const SharedDeviceBuffer& buffer,
|
||||
absl::flat_hash_set<BufferDefinitionEvent*>* events) {
|
||||
if (buffer.definition_event()) {
|
||||
events->insert(buffer.definition_event().get());
|
||||
for (const auto& e : buffer.definition_events()) {
|
||||
events->insert(e.get());
|
||||
}
|
||||
for (const auto& child : buffer.children()) {
|
||||
GetDeviceBufferDefinitionEvents(*child, events);
|
||||
|
@ -93,20 +93,23 @@ class SharedDeviceBuffer {
|
||||
// buffers of the shaped_buffer.
|
||||
static std::shared_ptr<SharedDeviceBuffer> FromScopedShapedBuffer(
|
||||
ScopedShapedBuffer* shaped_buffer,
|
||||
const std::shared_ptr<BufferDefinitionEvent>& definition_event);
|
||||
absl::Span<const std::shared_ptr<BufferDefinitionEvent>>
|
||||
definition_events);
|
||||
|
||||
// Makes a tuple buffer. Does not initialize the tuple table.
|
||||
static StatusOr<std::shared_ptr<SharedDeviceBuffer>> MakeTuple(
|
||||
std::vector<std::shared_ptr<SharedDeviceBuffer>> children,
|
||||
const Shape& on_host_shape, TransferManager* transfer_manager,
|
||||
se::DeviceMemoryAllocator* allocator, int device_ordinal,
|
||||
std::shared_ptr<BufferDefinitionEvent> definition_event);
|
||||
absl::Span<const std::shared_ptr<BufferDefinitionEvent>>
|
||||
definition_events);
|
||||
|
||||
// Makes an uninitialized array buffer.
|
||||
static StatusOr<std::shared_ptr<SharedDeviceBuffer>> MakeArray(
|
||||
Shape on_device_shape, TransferManager* transfer_manager,
|
||||
se::DeviceMemoryAllocator* allocator, int device_ordinal,
|
||||
std::shared_ptr<BufferDefinitionEvent> definition_event);
|
||||
absl::Span<const std::shared_ptr<BufferDefinitionEvent>>
|
||||
definition_events);
|
||||
|
||||
// Builds a ShapedBuffer view onto the buffers of 'tree'. We require but do
|
||||
// not verify that TransferManager::HostShapeToDeviceShape(on_host_shape) ==
|
||||
@ -126,19 +129,22 @@ class SharedDeviceBuffer {
|
||||
const absl::InlinedVector<se::DeviceMemoryBase, 1>& device_memory() const {
|
||||
return device_memory_;
|
||||
}
|
||||
const std::shared_ptr<BufferDefinitionEvent> definition_event() const {
|
||||
return definition_event_;
|
||||
absl::Span<const std::shared_ptr<BufferDefinitionEvent>> definition_events()
|
||||
const {
|
||||
return definition_events_;
|
||||
}
|
||||
|
||||
SharedDeviceBuffer() = default;
|
||||
SharedDeviceBuffer(se::DeviceMemoryAllocator* allocator, int device_ordinal,
|
||||
absl::Span<se::DeviceMemoryBase const> device_memory,
|
||||
std::vector<std::shared_ptr<SharedDeviceBuffer>> children,
|
||||
std::shared_ptr<BufferDefinitionEvent> definition_event,
|
||||
absl::Span<const std::shared_ptr<BufferDefinitionEvent>>
|
||||
definition_events,
|
||||
std::function<void()> on_delete_callback);
|
||||
SharedDeviceBuffer(absl::Span<se::OwningDeviceMemory> device_memory,
|
||||
std::vector<std::shared_ptr<SharedDeviceBuffer>> children,
|
||||
std::shared_ptr<BufferDefinitionEvent> definition_event);
|
||||
absl::Span<const std::shared_ptr<BufferDefinitionEvent>>
|
||||
definition_events);
|
||||
~SharedDeviceBuffer();
|
||||
|
||||
private:
|
||||
@ -155,7 +161,8 @@ class SharedDeviceBuffer {
|
||||
// ready during multistream execution. May be nullptr, which is used in the
|
||||
// single-stream execution case where events are not necessary for buffer
|
||||
// event sequencing.
|
||||
std::shared_ptr<BufferDefinitionEvent> definition_event_;
|
||||
absl::InlinedVector<std::shared_ptr<BufferDefinitionEvent>, 2>
|
||||
definition_events_;
|
||||
|
||||
// A callback to call when the SharedDeviceBuffer is about to be destroyed.
|
||||
std::function<void()> on_delete_callback_;
|
||||
|
@ -28,10 +28,10 @@ TEST(SharedDeviceBufferTest, MakeArray) {
|
||||
LocalClient* client = ClientLibrary::LocalClientOrDie();
|
||||
|
||||
Shape shape = ShapeUtil::MakeShape(F32, {3, 101, 4});
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
auto buffer, SharedDeviceBuffer::MakeArray(
|
||||
shape, client->backend().transfer_manager(),
|
||||
client->backend().memory_allocator(), 0, nullptr));
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto buffer,
|
||||
SharedDeviceBuffer::MakeArray(
|
||||
shape, client->backend().transfer_manager(),
|
||||
client->backend().memory_allocator(), 0, {}));
|
||||
EXPECT_EQ(buffer->children().size(), 0);
|
||||
EXPECT_EQ(buffer->device_ordinal(), 0);
|
||||
EXPECT_EQ(buffer->allocator(), client->backend().memory_allocator());
|
||||
@ -45,19 +45,19 @@ TEST(SharedDeviceBufferTest, MakeTuple) {
|
||||
Shape a_shape = ShapeUtil::MakeShape(F32, {3, 101, 4});
|
||||
Shape b_shape = ShapeUtil::MakeShape(S8, {77});
|
||||
Shape tuple_shape = ShapeUtil::MakeTupleShape({a_shape, b_shape});
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
auto a_buffer, SharedDeviceBuffer::MakeArray(
|
||||
a_shape, client->backend().transfer_manager(),
|
||||
client->backend().memory_allocator(), 0, nullptr));
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
auto b_buffer, SharedDeviceBuffer::MakeArray(
|
||||
b_shape, client->backend().transfer_manager(),
|
||||
client->backend().memory_allocator(), 0, nullptr));
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
auto tuple_buffer, SharedDeviceBuffer::MakeTuple(
|
||||
{a_buffer, b_buffer}, tuple_shape,
|
||||
client->backend().transfer_manager(),
|
||||
client->backend().memory_allocator(), 0, nullptr));
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto a_buffer,
|
||||
SharedDeviceBuffer::MakeArray(
|
||||
a_shape, client->backend().transfer_manager(),
|
||||
client->backend().memory_allocator(), 0, {}));
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto b_buffer,
|
||||
SharedDeviceBuffer::MakeArray(
|
||||
b_shape, client->backend().transfer_manager(),
|
||||
client->backend().memory_allocator(), 0, {}));
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto tuple_buffer,
|
||||
SharedDeviceBuffer::MakeTuple(
|
||||
{a_buffer, b_buffer}, tuple_shape,
|
||||
client->backend().transfer_manager(),
|
||||
client->backend().memory_allocator(), 0, {}));
|
||||
ASSERT_EQ(tuple_buffer->children().size(), 2);
|
||||
EXPECT_EQ(tuple_buffer->children()[0], a_buffer);
|
||||
EXPECT_EQ(tuple_buffer->children()[1], b_buffer);
|
||||
@ -75,30 +75,28 @@ TEST(SharedDeviceBufferTest, AsShapedBuffer) {
|
||||
Shape ab_tuple_shape = ShapeUtil::MakeTupleShape({a_shape, b_shape});
|
||||
Shape c_shape = ShapeUtil::MakeShape(S64, {});
|
||||
Shape abc_tuple_shape = ShapeUtil::MakeTupleShape({c_shape, ab_tuple_shape});
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
auto a_buffer, SharedDeviceBuffer::MakeArray(
|
||||
a_shape, client->backend().transfer_manager(),
|
||||
client->backend().memory_allocator(), 0, nullptr));
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
auto b_buffer, SharedDeviceBuffer::MakeArray(
|
||||
b_shape, client->backend().transfer_manager(),
|
||||
client->backend().memory_allocator(), 0, nullptr));
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
auto ab_tuple_buffer,
|
||||
SharedDeviceBuffer::MakeTuple({a_buffer, b_buffer}, ab_tuple_shape,
|
||||
client->backend().transfer_manager(),
|
||||
client->backend().memory_allocator(), 0,
|
||||
nullptr));
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
auto c_buffer, SharedDeviceBuffer::MakeArray(
|
||||
c_shape, client->backend().transfer_manager(),
|
||||
client->backend().memory_allocator(), 0, nullptr));
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
auto abc_tuple_buffer,
|
||||
SharedDeviceBuffer::MakeTuple(
|
||||
{c_buffer, ab_tuple_buffer}, abc_tuple_shape,
|
||||
client->backend().transfer_manager(),
|
||||
client->backend().memory_allocator(), 0, nullptr));
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto a_buffer,
|
||||
SharedDeviceBuffer::MakeArray(
|
||||
a_shape, client->backend().transfer_manager(),
|
||||
client->backend().memory_allocator(), 0, {}));
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto b_buffer,
|
||||
SharedDeviceBuffer::MakeArray(
|
||||
b_shape, client->backend().transfer_manager(),
|
||||
client->backend().memory_allocator(), 0, {}));
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto ab_tuple_buffer,
|
||||
SharedDeviceBuffer::MakeTuple(
|
||||
{a_buffer, b_buffer}, ab_tuple_shape,
|
||||
client->backend().transfer_manager(),
|
||||
client->backend().memory_allocator(), 0, {}));
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto c_buffer,
|
||||
SharedDeviceBuffer::MakeArray(
|
||||
c_shape, client->backend().transfer_manager(),
|
||||
client->backend().memory_allocator(), 0, {}));
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto abc_tuple_buffer,
|
||||
SharedDeviceBuffer::MakeTuple(
|
||||
{c_buffer, ab_tuple_buffer}, abc_tuple_shape,
|
||||
client->backend().transfer_manager(),
|
||||
client->backend().memory_allocator(), 0, {}));
|
||||
Shape abc_tuple_device_shape =
|
||||
client->backend().transfer_manager()->HostShapeToDeviceShape(
|
||||
abc_tuple_shape);
|
||||
@ -140,7 +138,7 @@ TEST(SharedDeviceBufferTest, FromScopedShapedBuffer) {
|
||||
ScopedShapedBuffer shaped_buffer,
|
||||
client->LiteralToShapedBuffer(literal, /*device_ordinal=*/0));
|
||||
std::shared_ptr<SharedDeviceBuffer> device_buffer =
|
||||
SharedDeviceBuffer::FromScopedShapedBuffer(&shaped_buffer, nullptr);
|
||||
SharedDeviceBuffer::FromScopedShapedBuffer(&shaped_buffer, {});
|
||||
|
||||
ASSERT_EQ(device_buffer->device_memory().size(), 1);
|
||||
ASSERT_EQ(device_buffer->children().size(), 2);
|
||||
|
Loading…
Reference in New Issue
Block a user