Adds optional definition event argument to CreateUninitializedBuffer.
PiperOrigin-RevId: 342929825 Change-Id: Idcfd83ef836eaa6c4663117c511c3c33f04cba35
This commit is contained in:
parent
690553e690
commit
9cec6e4e5b
@ -334,10 +334,14 @@ void RecordUsage(PjRtBuffer::ScopedHold device_buffer,
|
|||||||
//
|
//
|
||||||
// It is safe to delete the returned PjRtBuffer without further
|
// It is safe to delete the returned PjRtBuffer without further
|
||||||
// synchronization if an error occurs before the buffer is used.
|
// synchronization if an error occurs before the buffer is used.
|
||||||
|
//
|
||||||
|
// The caller may optionally provide a definition event to be recorded in
|
||||||
|
// the buffer.
|
||||||
StatusOr<std::unique_ptr<PjRtBuffer>> AllocateDestinationBuffer(
|
StatusOr<std::unique_ptr<PjRtBuffer>> AllocateDestinationBuffer(
|
||||||
const Shape& on_host_shape, PjRtDevice* device,
|
const Shape& on_host_shape, PjRtDevice* device,
|
||||||
LocalDeviceState* local_device, se::Stream* copy_stream,
|
LocalDeviceState* local_device, se::Stream* copy_stream,
|
||||||
bool is_uninitialized_create, PjRtClient* client) {
|
bool is_uninitialized_create, PjRtClient* client,
|
||||||
|
std::shared_ptr<BufferSequencingEvent> definition_event = nullptr) {
|
||||||
if (on_host_shape.IsTuple() && on_host_shape.tuple_shapes_size() == 0) {
|
if (on_host_shape.IsTuple() && on_host_shape.tuple_shapes_size() == 0) {
|
||||||
return InvalidArgument("Can't make a buffer from an empty tuple");
|
return InvalidArgument("Can't make a buffer from an empty tuple");
|
||||||
}
|
}
|
||||||
@ -377,10 +381,18 @@ StatusOr<std::unique_ptr<PjRtBuffer>> AllocateDestinationBuffer(
|
|||||||
definition_events.back()->SetSequencingEvent(
|
definition_events.back()->SetSequencingEvent(
|
||||||
std::move(event), local_device->compute_stream());
|
std::move(event), local_device->compute_stream());
|
||||||
}
|
}
|
||||||
|
// if the caller provided a definition event then we record that.
|
||||||
|
if (definition_event) {
|
||||||
|
definition_events.emplace_back(definition_event);
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
// We have at least one definition event, for the copy completing to
|
// We have at least one definition event, for the copy completing to
|
||||||
// the device buffers.
|
// the device buffers.
|
||||||
definition_events.emplace_back(std::make_shared<BufferSequencingEvent>());
|
if (definition_event) {
|
||||||
|
definition_events.emplace_back(definition_event);
|
||||||
|
} else {
|
||||||
|
definition_events.emplace_back(std::make_shared<BufferSequencingEvent>());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
se::Stream* tuple_table_stream = local_device->host_to_device_stream();
|
se::Stream* tuple_table_stream = local_device->host_to_device_stream();
|
||||||
if (on_device_shape.IsTuple()) {
|
if (on_device_shape.IsTuple()) {
|
||||||
@ -698,6 +710,12 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtClient::BufferFromHostBuffer(
|
|||||||
|
|
||||||
StatusOr<std::unique_ptr<PjRtBuffer>> PjRtClient::CreateUninitializedBuffer(
|
StatusOr<std::unique_ptr<PjRtBuffer>> PjRtClient::CreateUninitializedBuffer(
|
||||||
const Shape& shape, PjRtDevice* device) {
|
const Shape& shape, PjRtDevice* device) {
|
||||||
|
return CreateUninitializedBuffer(shape, device, nullptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
StatusOr<std::unique_ptr<PjRtBuffer>> PjRtClient::CreateUninitializedBuffer(
|
||||||
|
const Shape& shape, PjRtDevice* device,
|
||||||
|
std::shared_ptr<BufferSequencingEvent> definition_event) {
|
||||||
tensorflow::profiler::TraceMe traceme(
|
tensorflow::profiler::TraceMe traceme(
|
||||||
"PjRtClient::CreateUninitializedBuffer");
|
"PjRtClient::CreateUninitializedBuffer");
|
||||||
VLOG(2) << "PjRtClient::CreateUninitializedBuffer: shape: "
|
VLOG(2) << "PjRtClient::CreateUninitializedBuffer: shape: "
|
||||||
@ -711,7 +729,8 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtClient::CreateUninitializedBuffer(
|
|||||||
|
|
||||||
return AllocateDestinationBuffer(compact_shape, device, local_device,
|
return AllocateDestinationBuffer(compact_shape, device, local_device,
|
||||||
/*copy_stream=*/nullptr,
|
/*copy_stream=*/nullptr,
|
||||||
/*is_uninitialized_create=*/true, this);
|
/*is_uninitialized_create=*/true, this,
|
||||||
|
definition_event);
|
||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<std::unique_ptr<PjRtBuffer>> PjRtClient::BufferFromHostLiteral(
|
StatusOr<std::unique_ptr<PjRtBuffer>> PjRtClient::BufferFromHostLiteral(
|
||||||
@ -788,14 +807,15 @@ void PjRtClient::MakeCrossHostReceiveBuffers(
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
LocalDeviceState* local_device = local_device_or.ConsumeValueOrDie();
|
LocalDeviceState* local_device = local_device_or.ConsumeValueOrDie();
|
||||||
|
std::shared_ptr<BufferSequencingEvent> definition_event =
|
||||||
|
std::make_shared<BufferSequencingEvent>();
|
||||||
std::vector<std::unique_ptr<PjRtBuffer>> buffers;
|
std::vector<std::unique_ptr<PjRtBuffer>> buffers;
|
||||||
buffers.reserve(shapes.size());
|
buffers.reserve(shapes.size());
|
||||||
for (const auto& shape : shapes) {
|
for (const auto& shape : shapes) {
|
||||||
StatusOr<std::unique_ptr<PjRtBuffer>> buffer_or =
|
StatusOr<std::unique_ptr<PjRtBuffer>> buffer_or = AllocateDestinationBuffer(
|
||||||
AllocateDestinationBuffer(shape, device, local_device,
|
shape, device, local_device,
|
||||||
/*copy_stream=*/nullptr,
|
/*copy_stream=*/nullptr,
|
||||||
/*is_uninitialized_create=*/false, this);
|
/*is_uninitialized_create=*/false, this, definition_event);
|
||||||
if (!buffer_or.ok()) {
|
if (!buffer_or.ok()) {
|
||||||
notifier(buffer_or.status());
|
notifier(buffer_or.status());
|
||||||
return;
|
return;
|
||||||
@ -803,7 +823,8 @@ void PjRtClient::MakeCrossHostReceiveBuffers(
|
|||||||
buffers.push_back(buffer_or.ConsumeValueOrDie());
|
buffers.push_back(buffer_or.ConsumeValueOrDie());
|
||||||
}
|
}
|
||||||
|
|
||||||
EnqueueCrossHostReceive(std::move(buffers), std::move(notifier));
|
EnqueueCrossHostReceive(std::move(buffers), std::move(definition_event),
|
||||||
|
std::move(notifier));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Transfer the given literal to the infeed queue of the given local device.
|
// Transfer the given literal to the infeed queue of the given local device.
|
||||||
|
@ -248,8 +248,17 @@ class PjRtClient {
|
|||||||
virtual StatusOr<std::unique_ptr<PjRtExecutable>> Compile(
|
virtual StatusOr<std::unique_ptr<PjRtExecutable>> Compile(
|
||||||
const XlaComputation& computation, CompileOptions options);
|
const XlaComputation& computation, CompileOptions options);
|
||||||
|
|
||||||
|
// Creates a buffer on the device without initializing or copying any data.
|
||||||
|
// An optional `definition_event` may be speficied that can be used to
|
||||||
|
// ensure the buffer isn't referenced until some external mechanism has
|
||||||
|
// initialized the data.
|
||||||
|
// NOTE: The sequencing mechanism is not guaranteed to be supported by all
|
||||||
|
// future backends and so callers should avoid wherever possible.
|
||||||
virtual StatusOr<std::unique_ptr<PjRtBuffer>> CreateUninitializedBuffer(
|
virtual StatusOr<std::unique_ptr<PjRtBuffer>> CreateUninitializedBuffer(
|
||||||
const Shape& shape, PjRtDevice* device);
|
const Shape& shape, PjRtDevice* device);
|
||||||
|
virtual StatusOr<std::unique_ptr<PjRtBuffer>> CreateUninitializedBuffer(
|
||||||
|
const Shape& shape, PjRtDevice* device,
|
||||||
|
std::shared_ptr<BufferSequencingEvent> definition_event);
|
||||||
|
|
||||||
// Describes the semantics the caller to BufferFromHostBuffer expects from the
|
// Describes the semantics the caller to BufferFromHostBuffer expects from the
|
||||||
// runtime, in a total order from most restrictive to least restrictive.
|
// runtime, in a total order from most restrictive to least restrictive.
|
||||||
@ -316,6 +325,7 @@ class PjRtClient {
|
|||||||
friend class PjRtBuffer;
|
friend class PjRtBuffer;
|
||||||
virtual void EnqueueCrossHostReceive(
|
virtual void EnqueueCrossHostReceive(
|
||||||
std::vector<std::unique_ptr<PjRtBuffer>>&& buffers,
|
std::vector<std::unique_ptr<PjRtBuffer>>&& buffers,
|
||||||
|
std::shared_ptr<BufferSequencingEvent> definition_event,
|
||||||
PjRtCrossHostRecvNotifier&& notifier) const {
|
PjRtCrossHostRecvNotifier&& notifier) const {
|
||||||
notifier(Unimplemented("Cross host receives not implemented."));
|
notifier(Unimplemented("Cross host receives not implemented."));
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user