diff --git a/tensorflow/compiler/xla/pjrt/pjrt_client.cc b/tensorflow/compiler/xla/pjrt/pjrt_client.cc index 868ba991b71..a5fdd00fb05 100644 --- a/tensorflow/compiler/xla/pjrt/pjrt_client.cc +++ b/tensorflow/compiler/xla/pjrt/pjrt_client.cc @@ -334,10 +334,14 @@ void RecordUsage(PjRtBuffer::ScopedHold device_buffer, // // It is safe to delete the returned PjRtBuffer without further // synchronization if an error occurs before the buffer is used. +// +// The caller may optionally provide a definition event to be recorded in +// the buffer. StatusOr> AllocateDestinationBuffer( const Shape& on_host_shape, PjRtDevice* device, LocalDeviceState* local_device, se::Stream* copy_stream, - bool is_uninitialized_create, PjRtClient* client) { + bool is_uninitialized_create, PjRtClient* client, + std::shared_ptr definition_event = nullptr) { if (on_host_shape.IsTuple() && on_host_shape.tuple_shapes_size() == 0) { return InvalidArgument("Can't make a buffer from an empty tuple"); } @@ -377,10 +381,18 @@ StatusOr> AllocateDestinationBuffer( definition_events.back()->SetSequencingEvent( std::move(event), local_device->compute_stream()); } + // if the caller provided a definition event then we record that. + if (definition_event) { + definition_events.emplace_back(definition_event); + } } else { // We have at least one definition event, for the copy completing to // the device buffers. - definition_events.emplace_back(std::make_shared()); + if (definition_event) { + definition_events.emplace_back(definition_event); + } else { + definition_events.emplace_back(std::make_shared()); + } } se::Stream* tuple_table_stream = local_device->host_to_device_stream(); if (on_device_shape.IsTuple()) { @@ -698,6 +710,12 @@ StatusOr> PjRtClient::BufferFromHostBuffer( StatusOr> PjRtClient::CreateUninitializedBuffer( const Shape& shape, PjRtDevice* device) { + return CreateUninitializedBuffer(shape, device, nullptr); +} + +StatusOr> PjRtClient::CreateUninitializedBuffer( + const Shape& shape, PjRtDevice* device, + std::shared_ptr definition_event) { tensorflow::profiler::TraceMe traceme( "PjRtClient::CreateUninitializedBuffer"); VLOG(2) << "PjRtClient::CreateUninitializedBuffer: shape: " @@ -711,7 +729,8 @@ StatusOr> PjRtClient::CreateUninitializedBuffer( return AllocateDestinationBuffer(compact_shape, device, local_device, /*copy_stream=*/nullptr, - /*is_uninitialized_create=*/true, this); + /*is_uninitialized_create=*/true, this, + definition_event); } StatusOr> PjRtClient::BufferFromHostLiteral( @@ -788,14 +807,15 @@ void PjRtClient::MakeCrossHostReceiveBuffers( return; } LocalDeviceState* local_device = local_device_or.ConsumeValueOrDie(); - + std::shared_ptr definition_event = + std::make_shared(); std::vector> buffers; buffers.reserve(shapes.size()); for (const auto& shape : shapes) { - StatusOr> buffer_or = - AllocateDestinationBuffer(shape, device, local_device, - /*copy_stream=*/nullptr, - /*is_uninitialized_create=*/false, this); + StatusOr> buffer_or = AllocateDestinationBuffer( + shape, device, local_device, + /*copy_stream=*/nullptr, + /*is_uninitialized_create=*/false, this, definition_event); if (!buffer_or.ok()) { notifier(buffer_or.status()); return; @@ -803,7 +823,8 @@ void PjRtClient::MakeCrossHostReceiveBuffers( buffers.push_back(buffer_or.ConsumeValueOrDie()); } - EnqueueCrossHostReceive(std::move(buffers), std::move(notifier)); + EnqueueCrossHostReceive(std::move(buffers), std::move(definition_event), + std::move(notifier)); } // Transfer the given literal to the infeed queue of the given local device. diff --git a/tensorflow/compiler/xla/pjrt/pjrt_client.h b/tensorflow/compiler/xla/pjrt/pjrt_client.h index 38d2610ff93..ec695e05c44 100644 --- a/tensorflow/compiler/xla/pjrt/pjrt_client.h +++ b/tensorflow/compiler/xla/pjrt/pjrt_client.h @@ -248,8 +248,17 @@ class PjRtClient { virtual StatusOr> Compile( const XlaComputation& computation, CompileOptions options); + // Creates a buffer on the device without initializing or copying any data. + // An optional `definition_event` may be speficied that can be used to + // ensure the buffer isn't referenced until some external mechanism has + // initialized the data. + // NOTE: The sequencing mechanism is not guaranteed to be supported by all + // future backends and so callers should avoid wherever possible. virtual StatusOr> CreateUninitializedBuffer( const Shape& shape, PjRtDevice* device); + virtual StatusOr> CreateUninitializedBuffer( + const Shape& shape, PjRtDevice* device, + std::shared_ptr definition_event); // Describes the semantics the caller to BufferFromHostBuffer expects from the // runtime, in a total order from most restrictive to least restrictive. @@ -316,6 +325,7 @@ class PjRtClient { friend class PjRtBuffer; virtual void EnqueueCrossHostReceive( std::vector>&& buffers, + std::shared_ptr definition_event, PjRtCrossHostRecvNotifier&& notifier) const { notifier(Unimplemented("Cross host receives not implemented.")); }