Add cross host send/recv to PyLocalClient. Not implemented for now.
PiperOrigin-RevId: 299858816 Change-Id: Id8c14e9ce491281532ff9795b05e10582db6be00
This commit is contained in:
parent
60207a00d4
commit
add27c7db6
@ -329,11 +329,12 @@ StatusOr<std::unique_ptr<PyLocalBuffer>> DLPackManagedTensorToBuffer(
|
|||||||
if (dlmt->deleter) {
|
if (dlmt->deleter) {
|
||||||
on_delete_callback = [dlmt]() { dlmt->deleter(dlmt); };
|
on_delete_callback = [dlmt]() { dlmt->deleter(dlmt); };
|
||||||
}
|
}
|
||||||
|
absl::Span<const std::shared_ptr<BufferDefinitionEvent>> definition_events;
|
||||||
auto device_buffer = std::make_shared<SharedDeviceBuffer>(
|
auto device_buffer = std::make_shared<SharedDeviceBuffer>(
|
||||||
/*allocator=*/nullptr, dlmt->dl_tensor.ctx.device_id,
|
/*allocator=*/nullptr, dlmt->dl_tensor.ctx.device_id,
|
||||||
std::initializer_list<se::DeviceMemoryBase>{buffer},
|
std::initializer_list<se::DeviceMemoryBase>{buffer},
|
||||||
/*children=*/std::vector<std::shared_ptr<SharedDeviceBuffer>>{},
|
/*children=*/std::vector<std::shared_ptr<SharedDeviceBuffer>>{},
|
||||||
/*definition_event=*/nullptr, std::move(on_delete_callback));
|
definition_events, std::move(on_delete_callback));
|
||||||
|
|
||||||
// We have taken ownership of the array inside the capsule; make sure the
|
// We have taken ownership of the array inside the capsule; make sure the
|
||||||
// capsule it cannot be used again.
|
// capsule it cannot be used again.
|
||||||
|
@ -95,6 +95,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/shape_util.h"
|
#include "tensorflow/compiler/xla/shape_util.h"
|
||||||
#include "tensorflow/compiler/xla/util.h"
|
#include "tensorflow/compiler/xla/util.h"
|
||||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||||
|
#include "tensorflow/core/platform/errors.h"
|
||||||
#include "tensorflow/core/platform/types.h"
|
#include "tensorflow/core/platform/types.h"
|
||||||
#include "tensorflow/core/profiler/lib/traceme.h"
|
#include "tensorflow/core/profiler/lib/traceme.h"
|
||||||
|
|
||||||
@ -182,11 +183,12 @@ StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalBuffer::FromHostBuffer(
|
|||||||
};
|
};
|
||||||
se::DeviceMemoryBase buffer(const_cast<void*>(data),
|
se::DeviceMemoryBase buffer(const_cast<void*>(data),
|
||||||
ShapeUtil::ByteSizeOf(shape));
|
ShapeUtil::ByteSizeOf(shape));
|
||||||
|
absl::Span<const std::shared_ptr<BufferDefinitionEvent>> definition_events;
|
||||||
auto device_buffer = std::make_shared<SharedDeviceBuffer>(
|
auto device_buffer = std::make_shared<SharedDeviceBuffer>(
|
||||||
/*allocator=*/nullptr, local_device->device_ordinal(),
|
/*allocator=*/nullptr, local_device->device_ordinal(),
|
||||||
std::initializer_list<se::DeviceMemoryBase>{buffer},
|
std::initializer_list<se::DeviceMemoryBase>{buffer},
|
||||||
/*children=*/std::vector<std::shared_ptr<SharedDeviceBuffer>>{},
|
/*children=*/std::vector<std::shared_ptr<SharedDeviceBuffer>>{},
|
||||||
/*definition_event=*/nullptr, std::move(on_delete_callback));
|
definition_events, std::move(on_delete_callback));
|
||||||
return absl::make_unique<PyLocalBuffer>(
|
return absl::make_unique<PyLocalBuffer>(
|
||||||
shape, shape, std::move(device_buffer), std::move(client),
|
shape, shape, std::move(device_buffer), std::move(client),
|
||||||
std::move(device));
|
std::move(device));
|
||||||
@ -218,7 +220,7 @@ StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalBuffer::FromHostBuffer(
|
|||||||
std::make_shared<BufferDefinitionEvent>();
|
std::make_shared<BufferDefinitionEvent>();
|
||||||
std::shared_ptr<SharedDeviceBuffer> device_buffer =
|
std::shared_ptr<SharedDeviceBuffer> device_buffer =
|
||||||
SharedDeviceBuffer::FromScopedShapedBuffer(&scoped_buffer,
|
SharedDeviceBuffer::FromScopedShapedBuffer(&scoped_buffer,
|
||||||
definition_event);
|
{definition_event});
|
||||||
Shape on_device_shape = scoped_buffer.on_device_shape();
|
Shape on_device_shape = scoped_buffer.on_device_shape();
|
||||||
|
|
||||||
auto transfer_h2d = [client, transfer_manager, local_device, device_buffer,
|
auto transfer_h2d = [client, transfer_manager, local_device, device_buffer,
|
||||||
@ -263,7 +265,7 @@ StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalBuffer::FromHostBuffer(
|
|||||||
// Sets the buffer definition event. Note: this has the side effect of
|
// Sets the buffer definition event. Note: this has the side effect of
|
||||||
// unblocking any host threads that may have been waiting to consume the
|
// unblocking any host threads that may have been waiting to consume the
|
||||||
// buffer.
|
// buffer.
|
||||||
device_buffer->definition_event()->SetDefinitionEvent(
|
device_buffer->definition_events()[0]->SetDefinitionEvent(
|
||||||
std::move(event), local_device->host_to_device_stream());
|
std::move(event), local_device->host_to_device_stream());
|
||||||
|
|
||||||
if (local_device->synchronous_deallocation()) {
|
if (local_device->synchronous_deallocation()) {
|
||||||
@ -318,7 +320,7 @@ StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalBuffer::FromHostBuffer(
|
|||||||
std::shared_ptr<SharedDeviceBuffer> tuple_buffer,
|
std::shared_ptr<SharedDeviceBuffer> tuple_buffer,
|
||||||
SharedDeviceBuffer::MakeTuple(
|
SharedDeviceBuffer::MakeTuple(
|
||||||
device_buffers, on_host_shape, transfer_manager, allocator,
|
device_buffers, on_host_shape, transfer_manager, allocator,
|
||||||
local_device->device_ordinal(), definition_event));
|
local_device->device_ordinal(), {definition_event}));
|
||||||
auto buffer = absl::make_unique<PyLocalBuffer>(
|
auto buffer = absl::make_unique<PyLocalBuffer>(
|
||||||
std::move(on_host_shape), ShapeUtil::MakeTupleShape(device_shapes),
|
std::move(on_host_shape), ShapeUtil::MakeTupleShape(device_shapes),
|
||||||
tuple_buffer, std::move(client), std::move(device));
|
tuple_buffer, std::move(client), std::move(device));
|
||||||
@ -348,6 +350,80 @@ StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalBuffer::FromHostBuffer(
|
|||||||
return buffer;
|
return buffer;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
StatusOr<std::vector<std::unique_ptr<PyLocalBuffer>>>
|
||||||
|
MakeCrossHostReceiveBuffersHelper(absl::Span<const Shape> shapes,
|
||||||
|
std::shared_ptr<PyLocalClient> client,
|
||||||
|
std::shared_ptr<Device> device) {
|
||||||
|
TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device,
|
||||||
|
device->GetLocalDeviceState());
|
||||||
|
TransferManager* transfer_manager =
|
||||||
|
client->client()->backend().transfer_manager();
|
||||||
|
std::vector<std::unique_ptr<PyLocalBuffer>> buffers;
|
||||||
|
buffers.reserve(shapes.size());
|
||||||
|
se::Stream* host_to_device_stream = local_device->host_to_device_stream();
|
||||||
|
for (const auto& shape : shapes) {
|
||||||
|
TF_ASSIGN_OR_RETURN(
|
||||||
|
ScopedShapedBuffer scoped_buffer,
|
||||||
|
transfer_manager->AllocateScopedShapedBuffer(
|
||||||
|
shape, client->allocator(), local_device->device_ordinal()));
|
||||||
|
|
||||||
|
if (!transfer_manager->CanShapedBufferBeAccessedNow(
|
||||||
|
local_device->compute_stream()->parent(), scoped_buffer)) {
|
||||||
|
return Unimplemented(
|
||||||
|
"Cross host receive not enabled unless deallocations are deferred");
|
||||||
|
}
|
||||||
|
|
||||||
|
absl::InlinedVector<std::shared_ptr<BufferDefinitionEvent>, 2>
|
||||||
|
definition_events;
|
||||||
|
|
||||||
|
if (scoped_buffer.on_device_shape().IsTuple()) {
|
||||||
|
TF_CHECK_OK(transfer_manager->WriteTupleIndexTablesAsync(
|
||||||
|
host_to_device_stream, scoped_buffer));
|
||||||
|
definition_events = {std::make_shared<BufferDefinitionEvent>(),
|
||||||
|
std::make_shared<BufferDefinitionEvent>()};
|
||||||
|
TF_ASSIGN_OR_RETURN(EventPool::Handle event,
|
||||||
|
local_device->event_pool().ThenAllocateAndRecordEvent(
|
||||||
|
host_to_device_stream));
|
||||||
|
definition_events[1]->SetDefinitionEvent(std::move(event),
|
||||||
|
host_to_device_stream);
|
||||||
|
} else {
|
||||||
|
definition_events = {std::make_shared<BufferDefinitionEvent>()};
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<SharedDeviceBuffer> device_buffer =
|
||||||
|
SharedDeviceBuffer::FromScopedShapedBuffer(&scoped_buffer,
|
||||||
|
definition_events);
|
||||||
|
Shape on_device_shape = scoped_buffer.on_device_shape();
|
||||||
|
|
||||||
|
auto buffer = absl::make_unique<PyLocalBuffer>(
|
||||||
|
shape, std::move(on_device_shape), std::move(device_buffer), client,
|
||||||
|
device);
|
||||||
|
|
||||||
|
buffers.push_back(std::move(buffer));
|
||||||
|
}
|
||||||
|
return buffers;
|
||||||
|
}
|
||||||
|
|
||||||
|
/*static*/ void PyLocalBuffer::MakeCrossHostReceiveBuffers(
|
||||||
|
absl::Span<const Shape> shapes, std::shared_ptr<PyLocalClient> client,
|
||||||
|
std::shared_ptr<Device> device, PyLocalCrossHostRecvNotifier&& notifier) {
|
||||||
|
if (shapes.empty()) {
|
||||||
|
notifier(InvalidArgument(
|
||||||
|
"shapes parameter empty in MakeCrossHostReceiveBuffers"));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
PyLocalClient* client_ptr = client.get();
|
||||||
|
auto buffer_or = MakeCrossHostReceiveBuffersHelper(shapes, std::move(client),
|
||||||
|
std::move(device));
|
||||||
|
if (!buffer_or.ok()) {
|
||||||
|
notifier(buffer_or.status());
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
client_ptr->EnqueueCrossHostReceive(buffer_or.ConsumeValueOrDie(),
|
||||||
|
std::move(notifier));
|
||||||
|
}
|
||||||
|
|
||||||
PyLocalBuffer::PyLocalBuffer(Shape on_host_shape, Shape on_device_shape,
|
PyLocalBuffer::PyLocalBuffer(Shape on_host_shape, Shape on_device_shape,
|
||||||
std::shared_ptr<SharedDeviceBuffer> device_buffer,
|
std::shared_ptr<SharedDeviceBuffer> device_buffer,
|
||||||
std::shared_ptr<PyLocalClient> client,
|
std::shared_ptr<PyLocalClient> client,
|
||||||
@ -519,12 +595,19 @@ StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalBuffer::CopyToDevice(
|
|||||||
definition_event->SetDefinitionEvent(std::move(event), transfer_stream);
|
definition_event->SetDefinitionEvent(std::move(event), transfer_stream);
|
||||||
|
|
||||||
std::shared_ptr<SharedDeviceBuffer> dst_device_buffer =
|
std::shared_ptr<SharedDeviceBuffer> dst_device_buffer =
|
||||||
SharedDeviceBuffer::FromScopedShapedBuffer(&dst_buffer, definition_event);
|
SharedDeviceBuffer::FromScopedShapedBuffer(&dst_buffer,
|
||||||
|
{definition_event});
|
||||||
return absl::make_unique<PyLocalBuffer>(
|
return absl::make_unique<PyLocalBuffer>(
|
||||||
dst_buffer.on_host_shape(), dst_buffer.on_device_shape(),
|
dst_buffer.on_host_shape(), dst_buffer.on_device_shape(),
|
||||||
std::move(dst_device_buffer), client_, dst_device);
|
std::move(dst_device_buffer), client_, dst_device);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status PyLocalBuffer::CopyToRemoteDevice(
|
||||||
|
absl::string_view serialized_descriptor,
|
||||||
|
std::shared_ptr<Device> dst_device) {
|
||||||
|
return client_->CopyToRemoteDevice(this, serialized_descriptor, dst_device);
|
||||||
|
}
|
||||||
|
|
||||||
Status PyLocalBuffer::BlockHostUntilReady() {
|
Status PyLocalBuffer::BlockHostUntilReady() {
|
||||||
tensorflow::profiler::TraceMe traceme("PyLocalBuffer::BlockHostUntilReady");
|
tensorflow::profiler::TraceMe traceme("PyLocalBuffer::BlockHostUntilReady");
|
||||||
std::shared_ptr<SharedDeviceBuffer> device_buffer = DeviceBuffer();
|
std::shared_ptr<SharedDeviceBuffer> device_buffer = DeviceBuffer();
|
||||||
@ -693,7 +776,7 @@ StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalExecutable::ExecuteHelper(
|
|||||||
|
|
||||||
std::shared_ptr<SharedDeviceBuffer> out_buffer =
|
std::shared_ptr<SharedDeviceBuffer> out_buffer =
|
||||||
SharedDeviceBuffer::FromScopedShapedBuffer(&result_buffer,
|
SharedDeviceBuffer::FromScopedShapedBuffer(&result_buffer,
|
||||||
definition_event);
|
{definition_event});
|
||||||
|
|
||||||
if (device_state->synchronous_deallocation()) {
|
if (device_state->synchronous_deallocation()) {
|
||||||
device_buffers.push_back(out_buffer);
|
device_buffers.push_back(out_buffer);
|
||||||
|
@ -35,6 +35,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/shape.h"
|
#include "tensorflow/compiler/xla/shape.h"
|
||||||
#include "tensorflow/compiler/xla/status.h"
|
#include "tensorflow/compiler/xla/status.h"
|
||||||
#include "tensorflow/compiler/xla/statusor.h"
|
#include "tensorflow/compiler/xla/statusor.h"
|
||||||
|
#include "tensorflow/compiler/xla/util.h"
|
||||||
#include "tensorflow/core/framework/allocator.h"
|
#include "tensorflow/core/framework/allocator.h"
|
||||||
#include "tensorflow/core/lib/core/status.h"
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
|
|
||||||
@ -81,6 +82,19 @@ class Device {
|
|||||||
const std::string platform_name_;
|
const std::string platform_name_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class PyLocalBuffer;
|
||||||
|
// Helper struct for cross host transfers, returned by the callback from a call
|
||||||
|
// to PyLocalBuffer::MakeCrossHostReceiveBuffers.
|
||||||
|
struct PyLocalCrossHostRecvBuffer {
|
||||||
|
// serialized_descriptor should be transmitted to the sender and passed to a
|
||||||
|
// call to src_buffer->CopyToRemoteDevice.
|
||||||
|
std::string serialized_descriptor;
|
||||||
|
// The buffer that will hold the result of the transfer.
|
||||||
|
std::unique_ptr<PyLocalBuffer> buffer;
|
||||||
|
};
|
||||||
|
using PyLocalCrossHostRecvNotifier =
|
||||||
|
std::function<void(StatusOr<std::vector<PyLocalCrossHostRecvBuffer>>&&)>;
|
||||||
|
|
||||||
// Encapsulates the state of Python session with XLA.
|
// Encapsulates the state of Python session with XLA.
|
||||||
class PyLocalClient {
|
class PyLocalClient {
|
||||||
public:
|
public:
|
||||||
@ -134,6 +148,19 @@ class PyLocalClient {
|
|||||||
virtual bool EnqueueD2DTransfersOnSrcStream() const { return true; }
|
virtual bool EnqueueD2DTransfersOnSrcStream() const { return true; }
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
|
friend class PyLocalBuffer;
|
||||||
|
virtual void EnqueueCrossHostReceive(
|
||||||
|
std::vector<std::unique_ptr<PyLocalBuffer>>&& buffers,
|
||||||
|
PyLocalCrossHostRecvNotifier&& notifier) const {
|
||||||
|
notifier(Unimplemented("Cross host receives not implemented."));
|
||||||
|
}
|
||||||
|
|
||||||
|
virtual Status CopyToRemoteDevice(PyLocalBuffer* buffer,
|
||||||
|
absl::string_view serialized_descriptor,
|
||||||
|
std::shared_ptr<Device> device) const {
|
||||||
|
return Unimplemented("Cross host sends not implemented.");
|
||||||
|
}
|
||||||
|
|
||||||
std::string platform_name_;
|
std::string platform_name_;
|
||||||
LocalClient* client_;
|
LocalClient* client_;
|
||||||
|
|
||||||
@ -181,6 +208,19 @@ class PyLocalBuffer {
|
|||||||
const std::vector<PyLocalBuffer*> buffers,
|
const std::vector<PyLocalBuffer*> buffers,
|
||||||
std::shared_ptr<PyLocalClient> client, std::shared_ptr<Device> device);
|
std::shared_ptr<PyLocalClient> client, std::shared_ptr<Device> device);
|
||||||
|
|
||||||
|
// Asynchronously makes a vector of PyLocalBuffers that can be used to receive
|
||||||
|
// cross host transfers using `client` on `device'. `shapes` must be the exact
|
||||||
|
// shapes, with identical layouts, corresponding to the buffers that will be
|
||||||
|
// sent. When resources for the transfer are available, notifier will be
|
||||||
|
// called with a vector of PyLocalCrossHostRecvBuffer structs, one for each
|
||||||
|
// shape in `shapes`. Each struct contains a buffer that will contain the
|
||||||
|
// received value, and an opaque string that should be transmitted to the
|
||||||
|
// sending host and used in a call to CopyToRemoteDevice. None of the recv
|
||||||
|
// buffers will become ready until *all* of the sends have completed.
|
||||||
|
static void MakeCrossHostReceiveBuffers(
|
||||||
|
absl::Span<const Shape> shapes, std::shared_ptr<PyLocalClient> client,
|
||||||
|
std::shared_ptr<Device> device, PyLocalCrossHostRecvNotifier&& notifier);
|
||||||
|
|
||||||
PyLocalBuffer(Shape on_host_shape, Shape on_device_shape,
|
PyLocalBuffer(Shape on_host_shape, Shape on_device_shape,
|
||||||
std::shared_ptr<SharedDeviceBuffer> device_buffer,
|
std::shared_ptr<SharedDeviceBuffer> device_buffer,
|
||||||
std::shared_ptr<PyLocalClient> client,
|
std::shared_ptr<PyLocalClient> client,
|
||||||
@ -227,6 +267,18 @@ class PyLocalBuffer {
|
|||||||
StatusOr<std::unique_ptr<PyLocalBuffer>> CopyToDevice(
|
StatusOr<std::unique_ptr<PyLocalBuffer>> CopyToDevice(
|
||||||
std::shared_ptr<Device> dst_device);
|
std::shared_ptr<Device> dst_device);
|
||||||
|
|
||||||
|
// Copies the buffer to remote device `dst_device`. This call must be preceded
|
||||||
|
// by a call to MakeCrossHostReceiveBuffers on the remote host's
|
||||||
|
// dst_device. MakeCrossHostReceiveBuffers takes an array of shapes to
|
||||||
|
// construct the destination buffers, and a callback supplies an array
|
||||||
|
// containing both the destination buffers, and a serialized descriptor for
|
||||||
|
// each buffer. For each destination buffer there should be a matching call to
|
||||||
|
// src->CopyToRemoteDevice on a remote host for a src buffer of the
|
||||||
|
// corresponding shape. serialized_descriptor is the string returned by the
|
||||||
|
// callback along with the corresponding destination buffer.
|
||||||
|
Status CopyToRemoteDevice(absl::string_view serialized_descriptor,
|
||||||
|
std::shared_ptr<Device> dst_device);
|
||||||
|
|
||||||
// Blocks the host until the buffer's value has been computed and is ready for
|
// Blocks the host until the buffer's value has been computed and is ready for
|
||||||
// immediate use on the device. Useful in particular for timing benchmarks.
|
// immediate use on the device. Useful in particular for timing benchmarks.
|
||||||
Status BlockHostUntilReady();
|
Status BlockHostUntilReady();
|
||||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/compiler/xla/python/shared_device_buffer.h"
|
#include "tensorflow/compiler/xla/python/shared_device_buffer.h"
|
||||||
|
|
||||||
|
#include <iterator>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
|
||||||
#include "tensorflow/stream_executor/device_memory.h"
|
#include "tensorflow/stream_executor/device_memory.h"
|
||||||
@ -60,7 +61,8 @@ static std::shared_ptr<SharedDeviceBuffer> BufferFromScopedShapedBufferIterator(
|
|||||||
int device_ordinal, se::DeviceMemoryAllocator* allocator,
|
int device_ordinal, se::DeviceMemoryAllocator* allocator,
|
||||||
ShapeTree<se::DeviceMemoryBase>::iterator* iterator,
|
ShapeTree<se::DeviceMemoryBase>::iterator* iterator,
|
||||||
const ShapeTree<se::DeviceMemoryBase>::iterator& end,
|
const ShapeTree<se::DeviceMemoryBase>::iterator& end,
|
||||||
const std::shared_ptr<BufferDefinitionEvent>& definition_event) {
|
absl::Span<const std::shared_ptr<BufferDefinitionEvent>>
|
||||||
|
definition_events) {
|
||||||
std::vector<se::OwningDeviceMemory> buffers;
|
std::vector<se::OwningDeviceMemory> buffers;
|
||||||
buffers.reserve(1);
|
buffers.reserve(1);
|
||||||
std::vector<std::shared_ptr<SharedDeviceBuffer>> children;
|
std::vector<std::shared_ptr<SharedDeviceBuffer>> children;
|
||||||
@ -78,7 +80,7 @@ static std::shared_ptr<SharedDeviceBuffer> BufferFromScopedShapedBufferIterator(
|
|||||||
for (int i = 0; i < num_children; ++i) {
|
for (int i = 0; i < num_children; ++i) {
|
||||||
children.push_back(BufferFromScopedShapedBufferIterator(
|
children.push_back(BufferFromScopedShapedBufferIterator(
|
||||||
on_host_shape.tuple_shapes(i), on_device_shape.tuple_shapes(i),
|
on_host_shape.tuple_shapes(i), on_device_shape.tuple_shapes(i),
|
||||||
device_ordinal, allocator, iterator, end, definition_event));
|
device_ordinal, allocator, iterator, end, definition_events));
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// An on-host array may be an on-device tuple. For example, a complex tensor
|
// An on-host array may be an on-device tuple. For example, a complex tensor
|
||||||
@ -88,20 +90,21 @@ static std::shared_ptr<SharedDeviceBuffer> BufferFromScopedShapedBufferIterator(
|
|||||||
[&](const Shape&, const ShapeIndex&) { consume_buffer(); });
|
[&](const Shape&, const ShapeIndex&) { consume_buffer(); });
|
||||||
}
|
}
|
||||||
return std::make_shared<SharedDeviceBuffer>(
|
return std::make_shared<SharedDeviceBuffer>(
|
||||||
absl::Span<se::OwningDeviceMemory>(buffers), children, definition_event);
|
absl::Span<se::OwningDeviceMemory>(buffers), children, definition_events);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* static */ std::shared_ptr<SharedDeviceBuffer>
|
/* static */ std::shared_ptr<SharedDeviceBuffer>
|
||||||
SharedDeviceBuffer::FromScopedShapedBuffer(
|
SharedDeviceBuffer::FromScopedShapedBuffer(
|
||||||
ScopedShapedBuffer* shaped_buffer,
|
ScopedShapedBuffer* shaped_buffer,
|
||||||
const std::shared_ptr<BufferDefinitionEvent>& definition_event) {
|
absl::Span<const std::shared_ptr<BufferDefinitionEvent>>
|
||||||
|
definition_events) {
|
||||||
ShapeTree<se::DeviceMemoryBase>::iterator iterator =
|
ShapeTree<se::DeviceMemoryBase>::iterator iterator =
|
||||||
shaped_buffer->buffers().begin();
|
shaped_buffer->buffers().begin();
|
||||||
std::shared_ptr<SharedDeviceBuffer> output =
|
std::shared_ptr<SharedDeviceBuffer> output =
|
||||||
BufferFromScopedShapedBufferIterator(
|
BufferFromScopedShapedBufferIterator(
|
||||||
shaped_buffer->on_host_shape(), shaped_buffer->on_device_shape(),
|
shaped_buffer->on_host_shape(), shaped_buffer->on_device_shape(),
|
||||||
shaped_buffer->device_ordinal(), shaped_buffer->memory_allocator(),
|
shaped_buffer->device_ordinal(), shaped_buffer->memory_allocator(),
|
||||||
&iterator, shaped_buffer->buffers().end(), definition_event);
|
&iterator, shaped_buffer->buffers().end(), definition_events);
|
||||||
CHECK(iterator == shaped_buffer->buffers().end());
|
CHECK(iterator == shaped_buffer->buffers().end());
|
||||||
return output;
|
return output;
|
||||||
}
|
}
|
||||||
@ -111,7 +114,8 @@ SharedDeviceBuffer::MakeTuple(
|
|||||||
std::vector<std::shared_ptr<SharedDeviceBuffer>> children,
|
std::vector<std::shared_ptr<SharedDeviceBuffer>> children,
|
||||||
const Shape& on_host_shape, TransferManager* transfer_manager,
|
const Shape& on_host_shape, TransferManager* transfer_manager,
|
||||||
se::DeviceMemoryAllocator* allocator, int device_ordinal,
|
se::DeviceMemoryAllocator* allocator, int device_ordinal,
|
||||||
std::shared_ptr<BufferDefinitionEvent> definition_event) {
|
absl::Span<const std::shared_ptr<BufferDefinitionEvent>>
|
||||||
|
definition_events) {
|
||||||
CHECK(on_host_shape.IsTuple() &&
|
CHECK(on_host_shape.IsTuple() &&
|
||||||
on_host_shape.tuple_shapes_size() == children.size());
|
on_host_shape.tuple_shapes_size() == children.size());
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
@ -122,7 +126,7 @@ SharedDeviceBuffer::MakeTuple(
|
|||||||
return std::make_shared<SharedDeviceBuffer>(
|
return std::make_shared<SharedDeviceBuffer>(
|
||||||
allocator, device_ordinal,
|
allocator, device_ordinal,
|
||||||
std::initializer_list<se::DeviceMemoryBase>{device_memory.Release()},
|
std::initializer_list<se::DeviceMemoryBase>{device_memory.Release()},
|
||||||
std::move(children), std::move(definition_event),
|
std::move(children), definition_events,
|
||||||
/*on_delete_callback=*/nullptr);
|
/*on_delete_callback=*/nullptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -130,7 +134,8 @@ SharedDeviceBuffer::MakeTuple(
|
|||||||
SharedDeviceBuffer::MakeArray(
|
SharedDeviceBuffer::MakeArray(
|
||||||
Shape on_device_shape, TransferManager* transfer_manager,
|
Shape on_device_shape, TransferManager* transfer_manager,
|
||||||
se::DeviceMemoryAllocator* allocator, int device_ordinal,
|
se::DeviceMemoryAllocator* allocator, int device_ordinal,
|
||||||
std::shared_ptr<BufferDefinitionEvent> definition_event) {
|
absl::Span<const std::shared_ptr<BufferDefinitionEvent>>
|
||||||
|
definition_events) {
|
||||||
std::vector<se::OwningDeviceMemory> device_buffers;
|
std::vector<se::OwningDeviceMemory> device_buffers;
|
||||||
TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus(
|
TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus(
|
||||||
on_device_shape, [&](const Shape& subshape, const ShapeIndex&) -> Status {
|
on_device_shape, [&](const Shape& subshape, const ShapeIndex&) -> Status {
|
||||||
@ -145,7 +150,7 @@ SharedDeviceBuffer::MakeArray(
|
|||||||
return std::make_shared<SharedDeviceBuffer>(
|
return std::make_shared<SharedDeviceBuffer>(
|
||||||
absl::Span<se::OwningDeviceMemory>(device_buffers),
|
absl::Span<se::OwningDeviceMemory>(device_buffers),
|
||||||
/*children=*/std::vector<std::shared_ptr<SharedDeviceBuffer>>{},
|
/*children=*/std::vector<std::shared_ptr<SharedDeviceBuffer>>{},
|
||||||
std::move(definition_event));
|
definition_events);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Populates a buffer tree from a ShapeTree iterator.
|
// Populates a buffer tree from a ShapeTree iterator.
|
||||||
@ -176,25 +181,36 @@ ShapedBuffer SharedDeviceBuffer::AsShapedBuffer(const Shape& on_host_shape,
|
|||||||
return shaped_buffer;
|
return shaped_buffer;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
using MoveIterator =
|
||||||
|
absl::Span<const std::shared_ptr<BufferDefinitionEvent>>::iterator;
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
SharedDeviceBuffer::SharedDeviceBuffer(
|
SharedDeviceBuffer::SharedDeviceBuffer(
|
||||||
se::DeviceMemoryAllocator* allocator, int device_ordinal,
|
se::DeviceMemoryAllocator* allocator, int device_ordinal,
|
||||||
absl::Span<se::DeviceMemoryBase const> device_memory,
|
absl::Span<se::DeviceMemoryBase const> device_memory,
|
||||||
std::vector<std::shared_ptr<SharedDeviceBuffer>> children,
|
std::vector<std::shared_ptr<SharedDeviceBuffer>> children,
|
||||||
std::shared_ptr<BufferDefinitionEvent> definition_event,
|
absl::Span<const std::shared_ptr<BufferDefinitionEvent>> definition_events,
|
||||||
std::function<void()> on_delete_callback)
|
std::function<void()> on_delete_callback)
|
||||||
: allocator_(allocator),
|
: allocator_(allocator),
|
||||||
device_ordinal_(device_ordinal),
|
device_ordinal_(device_ordinal),
|
||||||
device_memory_(device_memory.begin(), device_memory.end()),
|
device_memory_(device_memory.begin(), device_memory.end()),
|
||||||
children_(std::move(children)),
|
children_(std::move(children)),
|
||||||
definition_event_(std::move(definition_event)),
|
definition_events_(
|
||||||
|
std::move_iterator<MoveIterator>(definition_events.begin()),
|
||||||
|
std::move_iterator<MoveIterator>(definition_events.end())),
|
||||||
on_delete_callback_(std::move(on_delete_callback)) {}
|
on_delete_callback_(std::move(on_delete_callback)) {}
|
||||||
|
|
||||||
SharedDeviceBuffer::SharedDeviceBuffer(
|
SharedDeviceBuffer::SharedDeviceBuffer(
|
||||||
absl::Span<se::OwningDeviceMemory> device_memory,
|
absl::Span<se::OwningDeviceMemory> device_memory,
|
||||||
std::vector<std::shared_ptr<SharedDeviceBuffer>> children,
|
std::vector<std::shared_ptr<SharedDeviceBuffer>> children,
|
||||||
std::shared_ptr<BufferDefinitionEvent> definition_event)
|
absl::Span<const std::shared_ptr<BufferDefinitionEvent>> definition_events)
|
||||||
: children_(std::move(children)),
|
: children_(std::move(children)),
|
||||||
definition_event_(std::move(definition_event)) {
|
definition_events_(
|
||||||
|
std::move_iterator<MoveIterator>(definition_events.begin()),
|
||||||
|
std::move_iterator<MoveIterator>(definition_events.end())) {
|
||||||
CHECK(!device_memory.empty());
|
CHECK(!device_memory.empty());
|
||||||
allocator_ = device_memory.front().allocator();
|
allocator_ = device_memory.front().allocator();
|
||||||
device_ordinal_ = device_memory.front().device_ordinal();
|
device_ordinal_ = device_memory.front().device_ordinal();
|
||||||
@ -222,8 +238,8 @@ SharedDeviceBuffer::~SharedDeviceBuffer() {
|
|||||||
void GetDeviceBufferDefinitionEvents(
|
void GetDeviceBufferDefinitionEvents(
|
||||||
const SharedDeviceBuffer& buffer,
|
const SharedDeviceBuffer& buffer,
|
||||||
absl::flat_hash_set<BufferDefinitionEvent*>* events) {
|
absl::flat_hash_set<BufferDefinitionEvent*>* events) {
|
||||||
if (buffer.definition_event()) {
|
for (const auto& e : buffer.definition_events()) {
|
||||||
events->insert(buffer.definition_event().get());
|
events->insert(e.get());
|
||||||
}
|
}
|
||||||
for (const auto& child : buffer.children()) {
|
for (const auto& child : buffer.children()) {
|
||||||
GetDeviceBufferDefinitionEvents(*child, events);
|
GetDeviceBufferDefinitionEvents(*child, events);
|
||||||
|
@ -93,20 +93,23 @@ class SharedDeviceBuffer {
|
|||||||
// buffers of the shaped_buffer.
|
// buffers of the shaped_buffer.
|
||||||
static std::shared_ptr<SharedDeviceBuffer> FromScopedShapedBuffer(
|
static std::shared_ptr<SharedDeviceBuffer> FromScopedShapedBuffer(
|
||||||
ScopedShapedBuffer* shaped_buffer,
|
ScopedShapedBuffer* shaped_buffer,
|
||||||
const std::shared_ptr<BufferDefinitionEvent>& definition_event);
|
absl::Span<const std::shared_ptr<BufferDefinitionEvent>>
|
||||||
|
definition_events);
|
||||||
|
|
||||||
// Makes a tuple buffer. Does not initialize the tuple table.
|
// Makes a tuple buffer. Does not initialize the tuple table.
|
||||||
static StatusOr<std::shared_ptr<SharedDeviceBuffer>> MakeTuple(
|
static StatusOr<std::shared_ptr<SharedDeviceBuffer>> MakeTuple(
|
||||||
std::vector<std::shared_ptr<SharedDeviceBuffer>> children,
|
std::vector<std::shared_ptr<SharedDeviceBuffer>> children,
|
||||||
const Shape& on_host_shape, TransferManager* transfer_manager,
|
const Shape& on_host_shape, TransferManager* transfer_manager,
|
||||||
se::DeviceMemoryAllocator* allocator, int device_ordinal,
|
se::DeviceMemoryAllocator* allocator, int device_ordinal,
|
||||||
std::shared_ptr<BufferDefinitionEvent> definition_event);
|
absl::Span<const std::shared_ptr<BufferDefinitionEvent>>
|
||||||
|
definition_events);
|
||||||
|
|
||||||
// Makes an uninitialized array buffer.
|
// Makes an uninitialized array buffer.
|
||||||
static StatusOr<std::shared_ptr<SharedDeviceBuffer>> MakeArray(
|
static StatusOr<std::shared_ptr<SharedDeviceBuffer>> MakeArray(
|
||||||
Shape on_device_shape, TransferManager* transfer_manager,
|
Shape on_device_shape, TransferManager* transfer_manager,
|
||||||
se::DeviceMemoryAllocator* allocator, int device_ordinal,
|
se::DeviceMemoryAllocator* allocator, int device_ordinal,
|
||||||
std::shared_ptr<BufferDefinitionEvent> definition_event);
|
absl::Span<const std::shared_ptr<BufferDefinitionEvent>>
|
||||||
|
definition_events);
|
||||||
|
|
||||||
// Builds a ShapedBuffer view onto the buffers of 'tree'. We require but do
|
// Builds a ShapedBuffer view onto the buffers of 'tree'. We require but do
|
||||||
// not verify that TransferManager::HostShapeToDeviceShape(on_host_shape) ==
|
// not verify that TransferManager::HostShapeToDeviceShape(on_host_shape) ==
|
||||||
@ -126,19 +129,22 @@ class SharedDeviceBuffer {
|
|||||||
const absl::InlinedVector<se::DeviceMemoryBase, 1>& device_memory() const {
|
const absl::InlinedVector<se::DeviceMemoryBase, 1>& device_memory() const {
|
||||||
return device_memory_;
|
return device_memory_;
|
||||||
}
|
}
|
||||||
const std::shared_ptr<BufferDefinitionEvent> definition_event() const {
|
absl::Span<const std::shared_ptr<BufferDefinitionEvent>> definition_events()
|
||||||
return definition_event_;
|
const {
|
||||||
|
return definition_events_;
|
||||||
}
|
}
|
||||||
|
|
||||||
SharedDeviceBuffer() = default;
|
SharedDeviceBuffer() = default;
|
||||||
SharedDeviceBuffer(se::DeviceMemoryAllocator* allocator, int device_ordinal,
|
SharedDeviceBuffer(se::DeviceMemoryAllocator* allocator, int device_ordinal,
|
||||||
absl::Span<se::DeviceMemoryBase const> device_memory,
|
absl::Span<se::DeviceMemoryBase const> device_memory,
|
||||||
std::vector<std::shared_ptr<SharedDeviceBuffer>> children,
|
std::vector<std::shared_ptr<SharedDeviceBuffer>> children,
|
||||||
std::shared_ptr<BufferDefinitionEvent> definition_event,
|
absl::Span<const std::shared_ptr<BufferDefinitionEvent>>
|
||||||
|
definition_events,
|
||||||
std::function<void()> on_delete_callback);
|
std::function<void()> on_delete_callback);
|
||||||
SharedDeviceBuffer(absl::Span<se::OwningDeviceMemory> device_memory,
|
SharedDeviceBuffer(absl::Span<se::OwningDeviceMemory> device_memory,
|
||||||
std::vector<std::shared_ptr<SharedDeviceBuffer>> children,
|
std::vector<std::shared_ptr<SharedDeviceBuffer>> children,
|
||||||
std::shared_ptr<BufferDefinitionEvent> definition_event);
|
absl::Span<const std::shared_ptr<BufferDefinitionEvent>>
|
||||||
|
definition_events);
|
||||||
~SharedDeviceBuffer();
|
~SharedDeviceBuffer();
|
||||||
|
|
||||||
private:
|
private:
|
||||||
@ -155,7 +161,8 @@ class SharedDeviceBuffer {
|
|||||||
// ready during multistream execution. May be nullptr, which is used in the
|
// ready during multistream execution. May be nullptr, which is used in the
|
||||||
// single-stream execution case where events are not necessary for buffer
|
// single-stream execution case where events are not necessary for buffer
|
||||||
// event sequencing.
|
// event sequencing.
|
||||||
std::shared_ptr<BufferDefinitionEvent> definition_event_;
|
absl::InlinedVector<std::shared_ptr<BufferDefinitionEvent>, 2>
|
||||||
|
definition_events_;
|
||||||
|
|
||||||
// A callback to call when the SharedDeviceBuffer is about to be destroyed.
|
// A callback to call when the SharedDeviceBuffer is about to be destroyed.
|
||||||
std::function<void()> on_delete_callback_;
|
std::function<void()> on_delete_callback_;
|
||||||
|
@ -28,10 +28,10 @@ TEST(SharedDeviceBufferTest, MakeArray) {
|
|||||||
LocalClient* client = ClientLibrary::LocalClientOrDie();
|
LocalClient* client = ClientLibrary::LocalClientOrDie();
|
||||||
|
|
||||||
Shape shape = ShapeUtil::MakeShape(F32, {3, 101, 4});
|
Shape shape = ShapeUtil::MakeShape(F32, {3, 101, 4});
|
||||||
TF_ASSERT_OK_AND_ASSIGN(
|
TF_ASSERT_OK_AND_ASSIGN(auto buffer,
|
||||||
auto buffer, SharedDeviceBuffer::MakeArray(
|
SharedDeviceBuffer::MakeArray(
|
||||||
shape, client->backend().transfer_manager(),
|
shape, client->backend().transfer_manager(),
|
||||||
client->backend().memory_allocator(), 0, nullptr));
|
client->backend().memory_allocator(), 0, {}));
|
||||||
EXPECT_EQ(buffer->children().size(), 0);
|
EXPECT_EQ(buffer->children().size(), 0);
|
||||||
EXPECT_EQ(buffer->device_ordinal(), 0);
|
EXPECT_EQ(buffer->device_ordinal(), 0);
|
||||||
EXPECT_EQ(buffer->allocator(), client->backend().memory_allocator());
|
EXPECT_EQ(buffer->allocator(), client->backend().memory_allocator());
|
||||||
@ -45,19 +45,19 @@ TEST(SharedDeviceBufferTest, MakeTuple) {
|
|||||||
Shape a_shape = ShapeUtil::MakeShape(F32, {3, 101, 4});
|
Shape a_shape = ShapeUtil::MakeShape(F32, {3, 101, 4});
|
||||||
Shape b_shape = ShapeUtil::MakeShape(S8, {77});
|
Shape b_shape = ShapeUtil::MakeShape(S8, {77});
|
||||||
Shape tuple_shape = ShapeUtil::MakeTupleShape({a_shape, b_shape});
|
Shape tuple_shape = ShapeUtil::MakeTupleShape({a_shape, b_shape});
|
||||||
TF_ASSERT_OK_AND_ASSIGN(
|
TF_ASSERT_OK_AND_ASSIGN(auto a_buffer,
|
||||||
auto a_buffer, SharedDeviceBuffer::MakeArray(
|
SharedDeviceBuffer::MakeArray(
|
||||||
a_shape, client->backend().transfer_manager(),
|
a_shape, client->backend().transfer_manager(),
|
||||||
client->backend().memory_allocator(), 0, nullptr));
|
client->backend().memory_allocator(), 0, {}));
|
||||||
TF_ASSERT_OK_AND_ASSIGN(
|
TF_ASSERT_OK_AND_ASSIGN(auto b_buffer,
|
||||||
auto b_buffer, SharedDeviceBuffer::MakeArray(
|
SharedDeviceBuffer::MakeArray(
|
||||||
b_shape, client->backend().transfer_manager(),
|
b_shape, client->backend().transfer_manager(),
|
||||||
client->backend().memory_allocator(), 0, nullptr));
|
client->backend().memory_allocator(), 0, {}));
|
||||||
TF_ASSERT_OK_AND_ASSIGN(
|
TF_ASSERT_OK_AND_ASSIGN(auto tuple_buffer,
|
||||||
auto tuple_buffer, SharedDeviceBuffer::MakeTuple(
|
SharedDeviceBuffer::MakeTuple(
|
||||||
{a_buffer, b_buffer}, tuple_shape,
|
{a_buffer, b_buffer}, tuple_shape,
|
||||||
client->backend().transfer_manager(),
|
client->backend().transfer_manager(),
|
||||||
client->backend().memory_allocator(), 0, nullptr));
|
client->backend().memory_allocator(), 0, {}));
|
||||||
ASSERT_EQ(tuple_buffer->children().size(), 2);
|
ASSERT_EQ(tuple_buffer->children().size(), 2);
|
||||||
EXPECT_EQ(tuple_buffer->children()[0], a_buffer);
|
EXPECT_EQ(tuple_buffer->children()[0], a_buffer);
|
||||||
EXPECT_EQ(tuple_buffer->children()[1], b_buffer);
|
EXPECT_EQ(tuple_buffer->children()[1], b_buffer);
|
||||||
@ -75,30 +75,28 @@ TEST(SharedDeviceBufferTest, AsShapedBuffer) {
|
|||||||
Shape ab_tuple_shape = ShapeUtil::MakeTupleShape({a_shape, b_shape});
|
Shape ab_tuple_shape = ShapeUtil::MakeTupleShape({a_shape, b_shape});
|
||||||
Shape c_shape = ShapeUtil::MakeShape(S64, {});
|
Shape c_shape = ShapeUtil::MakeShape(S64, {});
|
||||||
Shape abc_tuple_shape = ShapeUtil::MakeTupleShape({c_shape, ab_tuple_shape});
|
Shape abc_tuple_shape = ShapeUtil::MakeTupleShape({c_shape, ab_tuple_shape});
|
||||||
TF_ASSERT_OK_AND_ASSIGN(
|
TF_ASSERT_OK_AND_ASSIGN(auto a_buffer,
|
||||||
auto a_buffer, SharedDeviceBuffer::MakeArray(
|
SharedDeviceBuffer::MakeArray(
|
||||||
a_shape, client->backend().transfer_manager(),
|
a_shape, client->backend().transfer_manager(),
|
||||||
client->backend().memory_allocator(), 0, nullptr));
|
client->backend().memory_allocator(), 0, {}));
|
||||||
TF_ASSERT_OK_AND_ASSIGN(
|
TF_ASSERT_OK_AND_ASSIGN(auto b_buffer,
|
||||||
auto b_buffer, SharedDeviceBuffer::MakeArray(
|
SharedDeviceBuffer::MakeArray(
|
||||||
b_shape, client->backend().transfer_manager(),
|
b_shape, client->backend().transfer_manager(),
|
||||||
client->backend().memory_allocator(), 0, nullptr));
|
client->backend().memory_allocator(), 0, {}));
|
||||||
TF_ASSERT_OK_AND_ASSIGN(
|
TF_ASSERT_OK_AND_ASSIGN(auto ab_tuple_buffer,
|
||||||
auto ab_tuple_buffer,
|
SharedDeviceBuffer::MakeTuple(
|
||||||
SharedDeviceBuffer::MakeTuple({a_buffer, b_buffer}, ab_tuple_shape,
|
{a_buffer, b_buffer}, ab_tuple_shape,
|
||||||
client->backend().transfer_manager(),
|
client->backend().transfer_manager(),
|
||||||
client->backend().memory_allocator(), 0,
|
client->backend().memory_allocator(), 0, {}));
|
||||||
nullptr));
|
TF_ASSERT_OK_AND_ASSIGN(auto c_buffer,
|
||||||
TF_ASSERT_OK_AND_ASSIGN(
|
SharedDeviceBuffer::MakeArray(
|
||||||
auto c_buffer, SharedDeviceBuffer::MakeArray(
|
c_shape, client->backend().transfer_manager(),
|
||||||
c_shape, client->backend().transfer_manager(),
|
client->backend().memory_allocator(), 0, {}));
|
||||||
client->backend().memory_allocator(), 0, nullptr));
|
TF_ASSERT_OK_AND_ASSIGN(auto abc_tuple_buffer,
|
||||||
TF_ASSERT_OK_AND_ASSIGN(
|
SharedDeviceBuffer::MakeTuple(
|
||||||
auto abc_tuple_buffer,
|
{c_buffer, ab_tuple_buffer}, abc_tuple_shape,
|
||||||
SharedDeviceBuffer::MakeTuple(
|
client->backend().transfer_manager(),
|
||||||
{c_buffer, ab_tuple_buffer}, abc_tuple_shape,
|
client->backend().memory_allocator(), 0, {}));
|
||||||
client->backend().transfer_manager(),
|
|
||||||
client->backend().memory_allocator(), 0, nullptr));
|
|
||||||
Shape abc_tuple_device_shape =
|
Shape abc_tuple_device_shape =
|
||||||
client->backend().transfer_manager()->HostShapeToDeviceShape(
|
client->backend().transfer_manager()->HostShapeToDeviceShape(
|
||||||
abc_tuple_shape);
|
abc_tuple_shape);
|
||||||
@ -140,7 +138,7 @@ TEST(SharedDeviceBufferTest, FromScopedShapedBuffer) {
|
|||||||
ScopedShapedBuffer shaped_buffer,
|
ScopedShapedBuffer shaped_buffer,
|
||||||
client->LiteralToShapedBuffer(literal, /*device_ordinal=*/0));
|
client->LiteralToShapedBuffer(literal, /*device_ordinal=*/0));
|
||||||
std::shared_ptr<SharedDeviceBuffer> device_buffer =
|
std::shared_ptr<SharedDeviceBuffer> device_buffer =
|
||||||
SharedDeviceBuffer::FromScopedShapedBuffer(&shaped_buffer, nullptr);
|
SharedDeviceBuffer::FromScopedShapedBuffer(&shaped_buffer, {});
|
||||||
|
|
||||||
ASSERT_EQ(device_buffer->device_memory().size(), 1);
|
ASSERT_EQ(device_buffer->device_memory().size(), 1);
|
||||||
ASSERT_EQ(device_buffer->children().size(), 2);
|
ASSERT_EQ(device_buffer->children().size(), 2);
|
||||||
|
Loading…
Reference in New Issue
Block a user