Adds optional definition event argument to CreateUninitializedBuffer.
PiperOrigin-RevId: 342929825 Change-Id: Idcfd83ef836eaa6c4663117c511c3c33f04cba35
This commit is contained in:
parent
690553e690
commit
9cec6e4e5b
@ -334,10 +334,14 @@ void RecordUsage(PjRtBuffer::ScopedHold device_buffer,
|
||||
//
|
||||
// It is safe to delete the returned PjRtBuffer without further
|
||||
// synchronization if an error occurs before the buffer is used.
|
||||
//
|
||||
// The caller may optionally provide a definition event to be recorded in
|
||||
// the buffer.
|
||||
StatusOr<std::unique_ptr<PjRtBuffer>> AllocateDestinationBuffer(
|
||||
const Shape& on_host_shape, PjRtDevice* device,
|
||||
LocalDeviceState* local_device, se::Stream* copy_stream,
|
||||
bool is_uninitialized_create, PjRtClient* client) {
|
||||
bool is_uninitialized_create, PjRtClient* client,
|
||||
std::shared_ptr<BufferSequencingEvent> definition_event = nullptr) {
|
||||
if (on_host_shape.IsTuple() && on_host_shape.tuple_shapes_size() == 0) {
|
||||
return InvalidArgument("Can't make a buffer from an empty tuple");
|
||||
}
|
||||
@ -377,10 +381,18 @@ StatusOr<std::unique_ptr<PjRtBuffer>> AllocateDestinationBuffer(
|
||||
definition_events.back()->SetSequencingEvent(
|
||||
std::move(event), local_device->compute_stream());
|
||||
}
|
||||
// if the caller provided a definition event then we record that.
|
||||
if (definition_event) {
|
||||
definition_events.emplace_back(definition_event);
|
||||
}
|
||||
} else {
|
||||
// We have at least one definition event, for the copy completing to
|
||||
// the device buffers.
|
||||
definition_events.emplace_back(std::make_shared<BufferSequencingEvent>());
|
||||
if (definition_event) {
|
||||
definition_events.emplace_back(definition_event);
|
||||
} else {
|
||||
definition_events.emplace_back(std::make_shared<BufferSequencingEvent>());
|
||||
}
|
||||
}
|
||||
se::Stream* tuple_table_stream = local_device->host_to_device_stream();
|
||||
if (on_device_shape.IsTuple()) {
|
||||
@ -698,6 +710,12 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtClient::BufferFromHostBuffer(
|
||||
|
||||
StatusOr<std::unique_ptr<PjRtBuffer>> PjRtClient::CreateUninitializedBuffer(
|
||||
const Shape& shape, PjRtDevice* device) {
|
||||
return CreateUninitializedBuffer(shape, device, nullptr);
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<PjRtBuffer>> PjRtClient::CreateUninitializedBuffer(
|
||||
const Shape& shape, PjRtDevice* device,
|
||||
std::shared_ptr<BufferSequencingEvent> definition_event) {
|
||||
tensorflow::profiler::TraceMe traceme(
|
||||
"PjRtClient::CreateUninitializedBuffer");
|
||||
VLOG(2) << "PjRtClient::CreateUninitializedBuffer: shape: "
|
||||
@ -711,7 +729,8 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtClient::CreateUninitializedBuffer(
|
||||
|
||||
return AllocateDestinationBuffer(compact_shape, device, local_device,
|
||||
/*copy_stream=*/nullptr,
|
||||
/*is_uninitialized_create=*/true, this);
|
||||
/*is_uninitialized_create=*/true, this,
|
||||
definition_event);
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<PjRtBuffer>> PjRtClient::BufferFromHostLiteral(
|
||||
@ -788,14 +807,15 @@ void PjRtClient::MakeCrossHostReceiveBuffers(
|
||||
return;
|
||||
}
|
||||
LocalDeviceState* local_device = local_device_or.ConsumeValueOrDie();
|
||||
|
||||
std::shared_ptr<BufferSequencingEvent> definition_event =
|
||||
std::make_shared<BufferSequencingEvent>();
|
||||
std::vector<std::unique_ptr<PjRtBuffer>> buffers;
|
||||
buffers.reserve(shapes.size());
|
||||
for (const auto& shape : shapes) {
|
||||
StatusOr<std::unique_ptr<PjRtBuffer>> buffer_or =
|
||||
AllocateDestinationBuffer(shape, device, local_device,
|
||||
/*copy_stream=*/nullptr,
|
||||
/*is_uninitialized_create=*/false, this);
|
||||
StatusOr<std::unique_ptr<PjRtBuffer>> buffer_or = AllocateDestinationBuffer(
|
||||
shape, device, local_device,
|
||||
/*copy_stream=*/nullptr,
|
||||
/*is_uninitialized_create=*/false, this, definition_event);
|
||||
if (!buffer_or.ok()) {
|
||||
notifier(buffer_or.status());
|
||||
return;
|
||||
@ -803,7 +823,8 @@ void PjRtClient::MakeCrossHostReceiveBuffers(
|
||||
buffers.push_back(buffer_or.ConsumeValueOrDie());
|
||||
}
|
||||
|
||||
EnqueueCrossHostReceive(std::move(buffers), std::move(notifier));
|
||||
EnqueueCrossHostReceive(std::move(buffers), std::move(definition_event),
|
||||
std::move(notifier));
|
||||
}
|
||||
|
||||
// Transfer the given literal to the infeed queue of the given local device.
|
||||
|
@ -248,8 +248,17 @@ class PjRtClient {
|
||||
virtual StatusOr<std::unique_ptr<PjRtExecutable>> Compile(
|
||||
const XlaComputation& computation, CompileOptions options);
|
||||
|
||||
// Creates a buffer on the device without initializing or copying any data.
|
||||
// An optional `definition_event` may be speficied that can be used to
|
||||
// ensure the buffer isn't referenced until some external mechanism has
|
||||
// initialized the data.
|
||||
// NOTE: The sequencing mechanism is not guaranteed to be supported by all
|
||||
// future backends and so callers should avoid wherever possible.
|
||||
virtual StatusOr<std::unique_ptr<PjRtBuffer>> CreateUninitializedBuffer(
|
||||
const Shape& shape, PjRtDevice* device);
|
||||
virtual StatusOr<std::unique_ptr<PjRtBuffer>> CreateUninitializedBuffer(
|
||||
const Shape& shape, PjRtDevice* device,
|
||||
std::shared_ptr<BufferSequencingEvent> definition_event);
|
||||
|
||||
// Describes the semantics the caller to BufferFromHostBuffer expects from the
|
||||
// runtime, in a total order from most restrictive to least restrictive.
|
||||
@ -316,6 +325,7 @@ class PjRtClient {
|
||||
friend class PjRtBuffer;
|
||||
virtual void EnqueueCrossHostReceive(
|
||||
std::vector<std::unique_ptr<PjRtBuffer>>&& buffers,
|
||||
std::shared_ptr<BufferSequencingEvent> definition_event,
|
||||
PjRtCrossHostRecvNotifier&& notifier) const {
|
||||
notifier(Unimplemented("Cross host receives not implemented."));
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user