Adds optional definition event argument to CreateUninitializedBuffer.

PiperOrigin-RevId: 342929825
Change-Id: Idcfd83ef836eaa6c4663117c511c3c33f04cba35
This commit is contained in:
A. Unique TensorFlower 2020-11-17 13:09:49 -08:00 committed by TensorFlower Gardener
parent 690553e690
commit 9cec6e4e5b
2 changed files with 40 additions and 9 deletions

View File

@ -334,10 +334,14 @@ void RecordUsage(PjRtBuffer::ScopedHold device_buffer,
//
// It is safe to delete the returned PjRtBuffer without further
// synchronization if an error occurs before the buffer is used.
//
// The caller may optionally provide a definition event to be recorded in
// the buffer.
StatusOr<std::unique_ptr<PjRtBuffer>> AllocateDestinationBuffer(
const Shape& on_host_shape, PjRtDevice* device,
LocalDeviceState* local_device, se::Stream* copy_stream,
bool is_uninitialized_create, PjRtClient* client) {
bool is_uninitialized_create, PjRtClient* client,
std::shared_ptr<BufferSequencingEvent> definition_event = nullptr) {
if (on_host_shape.IsTuple() && on_host_shape.tuple_shapes_size() == 0) {
return InvalidArgument("Can't make a buffer from an empty tuple");
}
@ -377,10 +381,18 @@ StatusOr<std::unique_ptr<PjRtBuffer>> AllocateDestinationBuffer(
definition_events.back()->SetSequencingEvent(
std::move(event), local_device->compute_stream());
}
// if the caller provided a definition event then we record that.
if (definition_event) {
definition_events.emplace_back(definition_event);
}
} else {
// We have at least one definition event, for the copy completing to
// the device buffers.
definition_events.emplace_back(std::make_shared<BufferSequencingEvent>());
if (definition_event) {
definition_events.emplace_back(definition_event);
} else {
definition_events.emplace_back(std::make_shared<BufferSequencingEvent>());
}
}
se::Stream* tuple_table_stream = local_device->host_to_device_stream();
if (on_device_shape.IsTuple()) {
@ -698,6 +710,12 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtClient::BufferFromHostBuffer(
StatusOr<std::unique_ptr<PjRtBuffer>> PjRtClient::CreateUninitializedBuffer(
const Shape& shape, PjRtDevice* device) {
return CreateUninitializedBuffer(shape, device, nullptr);
}
StatusOr<std::unique_ptr<PjRtBuffer>> PjRtClient::CreateUninitializedBuffer(
const Shape& shape, PjRtDevice* device,
std::shared_ptr<BufferSequencingEvent> definition_event) {
tensorflow::profiler::TraceMe traceme(
"PjRtClient::CreateUninitializedBuffer");
VLOG(2) << "PjRtClient::CreateUninitializedBuffer: shape: "
@ -711,7 +729,8 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtClient::CreateUninitializedBuffer(
return AllocateDestinationBuffer(compact_shape, device, local_device,
/*copy_stream=*/nullptr,
/*is_uninitialized_create=*/true, this);
/*is_uninitialized_create=*/true, this,
definition_event);
}
StatusOr<std::unique_ptr<PjRtBuffer>> PjRtClient::BufferFromHostLiteral(
@ -788,14 +807,15 @@ void PjRtClient::MakeCrossHostReceiveBuffers(
return;
}
LocalDeviceState* local_device = local_device_or.ConsumeValueOrDie();
std::shared_ptr<BufferSequencingEvent> definition_event =
std::make_shared<BufferSequencingEvent>();
std::vector<std::unique_ptr<PjRtBuffer>> buffers;
buffers.reserve(shapes.size());
for (const auto& shape : shapes) {
StatusOr<std::unique_ptr<PjRtBuffer>> buffer_or =
AllocateDestinationBuffer(shape, device, local_device,
/*copy_stream=*/nullptr,
/*is_uninitialized_create=*/false, this);
StatusOr<std::unique_ptr<PjRtBuffer>> buffer_or = AllocateDestinationBuffer(
shape, device, local_device,
/*copy_stream=*/nullptr,
/*is_uninitialized_create=*/false, this, definition_event);
if (!buffer_or.ok()) {
notifier(buffer_or.status());
return;
@ -803,7 +823,8 @@ void PjRtClient::MakeCrossHostReceiveBuffers(
buffers.push_back(buffer_or.ConsumeValueOrDie());
}
EnqueueCrossHostReceive(std::move(buffers), std::move(notifier));
EnqueueCrossHostReceive(std::move(buffers), std::move(definition_event),
std::move(notifier));
}
// Transfer the given literal to the infeed queue of the given local device.

View File

@ -248,8 +248,17 @@ class PjRtClient {
virtual StatusOr<std::unique_ptr<PjRtExecutable>> Compile(
const XlaComputation& computation, CompileOptions options);
// Creates a buffer on the device without initializing or copying any data.
// An optional `definition_event` may be speficied that can be used to
// ensure the buffer isn't referenced until some external mechanism has
// initialized the data.
// NOTE: The sequencing mechanism is not guaranteed to be supported by all
// future backends and so callers should avoid wherever possible.
virtual StatusOr<std::unique_ptr<PjRtBuffer>> CreateUninitializedBuffer(
const Shape& shape, PjRtDevice* device);
virtual StatusOr<std::unique_ptr<PjRtBuffer>> CreateUninitializedBuffer(
const Shape& shape, PjRtDevice* device,
std::shared_ptr<BufferSequencingEvent> definition_event);
// Describes the semantics the caller to BufferFromHostBuffer expects from the
// runtime, in a total order from most restrictive to least restrictive.
@ -316,6 +325,7 @@ class PjRtClient {
friend class PjRtBuffer;
virtual void EnqueueCrossHostReceive(
std::vector<std::unique_ptr<PjRtBuffer>>&& buffers,
std::shared_ptr<BufferSequencingEvent> definition_event,
PjRtCrossHostRecvNotifier&& notifier) const {
notifier(Unimplemented("Cross host receives not implemented."));
}