diff --git a/tensorflow/compiler/xla/pjrt/gpu_multistream_test.cc b/tensorflow/compiler/xla/pjrt/gpu_multistream_test.cc index c56b41861b0..f43ec5a9216 100644 --- a/tensorflow/compiler/xla/pjrt/gpu_multistream_test.cc +++ b/tensorflow/compiler/xla/pjrt/gpu_multistream_test.cc @@ -54,9 +54,9 @@ TEST(GpuMultiStream, Basics) { device_assignment(0, 0) = device->id(); compile_options.executable_build_options.set_device_assignment( device_assignment); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<PjRtExecutable> executable, - PjRtExecutable::Compile(computation, client.get(), - std::move(compile_options))); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<PjRtExecutable> executable, + client->Compile(computation, std::move(compile_options))); int64 dummy_size = 1 << 20; std::vector<int32> dummy_inputs(dummy_size); @@ -71,22 +71,22 @@ TEST(GpuMultiStream, Basics) { // must wait. TF_ASSERT_OK_AND_ASSIGN( auto dummy_buffer, - PjRtBuffer::FromHostBuffer( + client->BufferFromHostBuffer( dummy_inputs.data(), dummy_shape, - PjRtBuffer::HostBufferSemantics::kImmutableUntilTransferCompletes, - /*buffer_reference=*/nullptr, client.get(), device)); + PjRtClient::HostBufferSemantics::kImmutableUntilTransferCompletes, + /*buffer_reference=*/nullptr, device)); TF_ASSERT_OK_AND_ASSIGN( auto in_buffer0, - PjRtBuffer::FromHostBuffer( + client->BufferFromHostBuffer( inputs.data(), shape, - PjRtBuffer::HostBufferSemantics::kImmutableUntilTransferCompletes, - /*buffer_reference=*/nullptr, client.get(), device)); + PjRtClient::HostBufferSemantics::kImmutableUntilTransferCompletes, + /*buffer_reference=*/nullptr, device)); TF_ASSERT_OK_AND_ASSIGN( auto in_buffer1, - PjRtBuffer::FromHostBuffer( + client->BufferFromHostBuffer( inputs.data(), shape, - PjRtBuffer::HostBufferSemantics::kImmutableUntilTransferCompletes, - /*buffer_reference=*/nullptr, client.get(), device)); + PjRtClient::HostBufferSemantics::kImmutableUntilTransferCompletes, + /*buffer_reference=*/nullptr, device)); // The execution may be enqueued before the transfers complete, requiring // adequate device-side synchronization. ExecuteOptions options; diff --git a/tensorflow/compiler/xla/pjrt/pjrt_client.cc b/tensorflow/compiler/xla/pjrt/pjrt_client.cc index 02ae37b71db..41afcb01511 100644 --- a/tensorflow/compiler/xla/pjrt/pjrt_client.cc +++ b/tensorflow/compiler/xla/pjrt/pjrt_client.cc @@ -576,24 +576,21 @@ void PjRtBuffer::ScopedHold::AddToInput( } } -/* static */ -StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::FromHostBuffer( +StatusOr<std::unique_ptr<PjRtBuffer>> PjRtClient::BufferFromHostBuffer( const void* data, const Shape& shape, HostBufferSemantics host_buffer_semantics, - std::shared_ptr<void> buffer_reference, PjRtClient* client, - PjRtDevice* device) { - tensorflow::profiler::TraceMe traceme("PjRtBuffer::FromHostBuffer"); - VLOG(2) << "PjRtBuffer::FromHostBuffer: shape: " << shape.ToString() + std::shared_ptr<void> buffer_reference, PjRtDevice* device) { + tensorflow::profiler::TraceMe traceme("PjRtClient::BufferFromHostBuffer"); + VLOG(2) << "PjRtClient::BufferFromHostBuffer: shape: " << shape.ToString() << " device: " << device->DebugString(); if (shape.IsTuple()) { - return InvalidArgument("Use FromHostLiteral to transfer a tuple"); + return InvalidArgument("Use BufferFromHostLiteral to transfer a tuple"); } TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device, device->GetLocalDeviceState()); int64 size = ShapeUtil::ByteSizeOf(shape); - TransferManager* transfer_manager = - client->client()->backend().transfer_manager(); + TransferManager* transfer_manager = client()->backend().transfer_manager(); TF_ASSIGN_OR_RETURN(Shape compact_shape, transfer_manager->ChooseCompactLayoutForShape(shape)); @@ -628,10 +625,11 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::FromHostBuffer( }; buffer = se::DeviceMemoryBase(const_cast<void*>(data), size); } else { - void* staging_buffer = client->host_memory_allocator()->AllocateRaw( + void* staging_buffer = host_memory_allocator()->AllocateRaw( cpu_function_runtime::kMinAlign, size); - on_delete_callback = [staging_buffer, client]() { - client->host_memory_allocator()->DeallocateRaw(staging_buffer); + on_delete_callback = [staging_buffer, host_memory_allocator = + host_memory_allocator()]() { + host_memory_allocator->DeallocateRaw(staging_buffer); }; buffer = se::DeviceMemoryBase(staging_buffer, size); std::memcpy(staging_buffer, data, size); @@ -643,7 +641,7 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::FromHostBuffer( std::initializer_list<se::DeviceMemoryBase>{buffer}, definition_events, std::move(on_delete_callback)); return absl::make_unique<PjRtBuffer>( - shape, shape, std::move(device_buffer), client, device); + shape, shape, std::move(device_buffer), this, device); } } @@ -651,21 +649,22 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::FromHostBuffer( std::unique_ptr<PjRtBuffer> py_buffer, AllocateDestinationBuffer(compact_shape, device, local_device, local_device->host_to_device_stream(), - /*is_uninitialized_create=*/false, client)); + /*is_uninitialized_create=*/false, this)); - ScopedHold device_buffer(py_buffer->GetBufferWithUsageHold()); + PjRtBuffer::ScopedHold device_buffer(py_buffer->GetBufferWithUsageHold()); CHECK(device_buffer.ok()); // If necessary, allocate a host-side buffer for staging host-to-device // transfers. On GPU this is a buffer in pinned memory. std::shared_ptr<void> staging_buffer; if (host_buffer_semantics == HostBufferSemantics::kImmutableOnlyDuringCall || - client->should_stage_host_to_device_transfers()) { - void* ptr = client->host_memory_allocator()->AllocateRaw( + should_stage_host_to_device_transfers()) { + void* ptr = host_memory_allocator()->AllocateRaw( tensorflow::Allocator::kAllocatorAlignment, size); - staging_buffer = std::shared_ptr<void>(ptr, [client](void* ptr) { - client->host_memory_allocator()->DeallocateRaw(ptr); - }); + staging_buffer = std::shared_ptr<void>( + ptr, [host_memory_allocator = host_memory_allocator()](void* ptr) { + host_memory_allocator->DeallocateRaw(ptr); + }); } // Copy the buffer into a staging buffer before returning control to the @@ -684,14 +683,15 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::FromHostBuffer( // usage holds have gone away. // TODO(misard) assess if it would be preferable to introduce a heuristic to // put the transfer into the calling thread for small literals. - auto transfer_h2d = [client, transfer_manager, local_device, data, size, + auto transfer_h2d = [local_client = client(), transfer_manager, local_device, + data, size, movable_device_buffer{device_buffer.ToClosure()}, shape, py_buffer{py_buffer.get()}, compact_shape, on_device_shape{py_buffer->on_device_shape()}, staging_buffer{std::move(staging_buffer)}, buffer_reference{std::move(buffer_reference)}, host_buffer_semantics]() { - ScopedHold device_buffer(movable_device_buffer); + PjRtBuffer::ScopedHold device_buffer(movable_device_buffer); // This function uses TF_CHECK_OK and ValueOrDie() since we have no way // to report failures from a callback. However, the operations here are // unlikely to fail and not recoverable even if we were to fail: DMAs to @@ -699,7 +699,7 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::FromHostBuffer( // allocation. ShapedBuffer buffer = device_buffer->AsShapedBuffer( - compact_shape, on_device_shape, client->client()->platform()); + compact_shape, on_device_shape, local_client->platform()); // If applicable on the backend, stage the transfer via host memory // allocated via the host_memory_allocator. On GPU, this is pinned // memory. @@ -736,41 +736,38 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::FromHostBuffer( // already defers its work onto a stream (= thread on CPU). transfer_h2d(); } else { - client->h2d_transfer_pool()->Schedule(transfer_h2d); + h2d_transfer_pool()->Schedule(transfer_h2d); } return py_buffer; } -/* static */ -StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::CreateUninitialized( - const Shape& shape, PjRtClient* client, PjRtDevice* device) { - tensorflow::profiler::TraceMe traceme("PjRtBuffer::CreateUninitialized"); - VLOG(2) << "PjRtBuffer::CreateUninitialized: shape: " << shape.ToString() - << " device: " << device->DebugString(); +StatusOr<std::unique_ptr<PjRtBuffer>> PjRtClient::CreateUninitializedBuffer( + const Shape& shape, PjRtDevice* device) { + tensorflow::profiler::TraceMe traceme( + "PjRtClient::CreateUninitializedBuffer"); + VLOG(2) << "PjRtClient::CreateUninitializedBuffer: shape: " + << shape.ToString() << " device: " << device->DebugString(); TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device, device->GetLocalDeviceState()); - TransferManager* transfer_manager = - client->client()->backend().transfer_manager(); + TransferManager* transfer_manager = client()->backend().transfer_manager(); TF_ASSIGN_OR_RETURN(Shape compact_shape, transfer_manager->ChooseCompactLayoutForShape(shape)); return AllocateDestinationBuffer(compact_shape, device, local_device, /*copy_stream=*/nullptr, - /*is_uninitialized_create=*/true, client); + /*is_uninitialized_create=*/true, this); } -/* static */ -StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::FromHostLiteral( - const LiteralSlice& literal, PjRtClient* client, PjRtDevice* device) { - tensorflow::profiler::TraceMe traceme("PjRtBuffer::FromHostLiteral"); - VLOG(2) << "PjRtBuffer::FromHostLiteral: shape: " +StatusOr<std::unique_ptr<PjRtBuffer>> PjRtClient::BufferFromHostLiteral( + const LiteralSlice& literal, PjRtDevice* device) { + tensorflow::profiler::TraceMe traceme("PjRtClient::BufferFromHostLiteral"); + VLOG(2) << "PjRtClient::BufferFromHostLiteral: shape: " << literal.shape().ToString() << " device: " << device->DebugString(); TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device, device->GetLocalDeviceState()); - TransferManager* transfer_manager = - client->client()->backend().transfer_manager(); + TransferManager* transfer_manager = client()->backend().transfer_manager(); TF_ASSIGN_OR_RETURN( Shape compact_shape, transfer_manager->ChooseCompactLayoutForShape(literal.shape())); @@ -778,9 +775,9 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::FromHostLiteral( std::unique_ptr<PjRtBuffer> py_buffer, AllocateDestinationBuffer(compact_shape, device, local_device, local_device->host_to_device_stream(), - /*is_uninitialized_create=*/false, client)); + /*is_uninitialized_create=*/false, this)); - ScopedHold device_buffer(py_buffer->GetBufferWithUsageHold()); + PjRtBuffer::ScopedHold device_buffer(py_buffer->GetBufferWithUsageHold()); CHECK(device_buffer.ok()); // The host to device transfer is performed on a thread pool, mostly because @@ -789,11 +786,11 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::FromHostLiteral( // usage holds have gone away. // TODO(misard) assess if it would be preferable to introduce a heuristic to // put the transfer into the calling thread for small literals. - auto transfer_h2d = [client, transfer_manager, local_device, + auto transfer_h2d = [local_client = client(), transfer_manager, local_device, movable_device_buffer{device_buffer.ToClosure()}, literal, py_buffer{py_buffer.get()}, compact_shape, on_device_shape{py_buffer->on_device_shape()}]() { - ScopedHold device_buffer(movable_device_buffer); + PjRtBuffer::ScopedHold device_buffer(movable_device_buffer); // This function uses TF_CHECK_OK and ValueOrDie() since we have no way // to report failures from a callback. However, the operations here are // unlikely to fail and not recoverable even if we were to fail: DMAs to @@ -802,7 +799,7 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::FromHostLiteral( se::Stream* h2d_stream = local_device->host_to_device_stream(); ShapedBuffer buffer = device_buffer->AsShapedBuffer( - compact_shape, on_device_shape, client->client()->platform()); + compact_shape, on_device_shape, local_client->platform()); TF_CHECK_OK(transfer_manager->TransferLiteralToDeviceAsync( h2d_stream, literal, buffer)); @@ -817,12 +814,12 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::FromHostLiteral( .IgnoreError(); // Can return error::Unimplemented QCHECK(h2d_stream->ok()); }; - client->h2d_transfer_pool()->Schedule(transfer_h2d); + h2d_transfer_pool()->Schedule(transfer_h2d); return py_buffer; } -/*static*/ void PjRtBuffer::MakeCrossHostReceiveBuffers( - absl::Span<const Shape> shapes, PjRtClient* client, PjRtDevice* device, +void PjRtClient::MakeCrossHostReceiveBuffers( + absl::Span<const Shape> shapes, PjRtDevice* device, PjRtCrossHostRecvNotifier&& notifier) { if (shapes.empty()) { notifier(InvalidArgument( @@ -843,7 +840,7 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::FromHostLiteral( StatusOr<std::unique_ptr<PjRtBuffer>> buffer_or = AllocateDestinationBuffer(shape, device, local_device, /*copy_stream=*/nullptr, - /*is_uninitialized_create=*/false, client); + /*is_uninitialized_create=*/false, this); if (!buffer_or.ok()) { notifier(buffer_or.status()); return; @@ -851,7 +848,7 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::FromHostLiteral( buffers.push_back(buffer_or.ConsumeValueOrDie()); } - client->EnqueueCrossHostReceive(std::move(buffers), std::move(notifier)); + EnqueueCrossHostReceive(std::move(buffers), std::move(notifier)); } PjRtBuffer::PjRtBuffer(Shape on_host_shape, Shape on_device_shape, @@ -1159,7 +1156,7 @@ PjRtBuffer::CopyToHostAsyncInternal(bool discard_cached_copy, StatusOr<std::shared_ptr<Literal>> PjRtBuffer::ToLiteral( const bool discard_cached_copy, absl::optional<xla::Layout> layout) { - tensorflow::profiler::TraceMe traceme("PjRtBuffer::ToLiteral"); + tensorflow::profiler::TraceMe traceme("PjRtClient::ToLiteral"); TF_ASSIGN_OR_RETURN(std::shared_ptr<HostValue> host_value, CopyToHostAsyncInternal(discard_cached_copy, layout)); if (host_value == nullptr) { @@ -1267,9 +1264,9 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::CopyToDevice( // Copying across PjRtClients involves a copy through the host. if (dst_device->client() != client_) { TF_ASSIGN_OR_RETURN(std::shared_ptr<Literal> literal, ToLiteral()); - return FromHostBuffer(literal->untyped_data(), literal->shape(), - HostBufferSemantics::kZeroCopy, nullptr, - dst_device->client(), dst_device); + return dst_device->client()->BufferFromHostBuffer( + literal->untyped_data(), literal->shape(), + PjRtClient::HostBufferSemantics::kZeroCopy, nullptr, dst_device); } TF_ASSIGN_OR_RETURN(LocalDeviceState * dst_local_device, @@ -2061,14 +2058,13 @@ StatusOr<std::pair<std::vector<Shape>, Shape>> GetShardedProgramShapes( } // namespace -/*static*/ StatusOr<std::unique_ptr<PjRtExecutable>> PjRtExecutable::Compile( - const XlaComputation& computation, PjRtClient* client, - CompileOptions options) { - tensorflow::profiler::TraceMe traceme("LocalExecutable::Compile"); +StatusOr<std::unique_ptr<PjRtExecutable>> PjRtClient::Compile( + const XlaComputation& computation, CompileOptions options) { + tensorflow::profiler::TraceMe traceme("PjRtClient::Compile"); ExecutableBuildOptions& build_options = options.executable_build_options; if (!build_options.device_allocator()) { - build_options.set_device_allocator(client->allocator()); + build_options.set_device_allocator(allocator()); } int num_replicas; @@ -2084,14 +2080,14 @@ StatusOr<std::pair<std::vector<Shape>, Shape>> GetShardedProgramShapes( num_partitions = 1; } else { if (!build_options.has_device_assignment()) { - VLOG(2) << "PjRtExecutable::Compile using default device_assignment."; + VLOG(2) << "PjRtClient::Compile using default device_assignment."; TF_ASSIGN_OR_RETURN( DeviceAssignment device_assignment, - client->GetDefaultDeviceAssignment(build_options.num_replicas(), - build_options.num_partitions())); + GetDefaultDeviceAssignment(build_options.num_replicas(), + build_options.num_partitions())); build_options.set_device_assignment(device_assignment); } - VLOG(2) << "PjRtExecutable::Compile device_assignment:\n" + VLOG(2) << "PjRtClient::Compile device_assignment:\n" << build_options.device_assignment().ToString(); num_replicas = build_options.device_assignment().replica_count(); num_partitions = build_options.device_assignment().computation_count(); @@ -2118,7 +2114,8 @@ StatusOr<std::pair<std::vector<Shape>, Shape>> GetShardedProgramShapes( // Assign a default layout based on `sharded_shape` to any array subshapes in // `dst_shape` that are missing layouts. - auto assign_layouts = [client](const Shape& sharded_shape, Shape* dst_shape) { + auto assign_layouts = [local_client = client()](const Shape& sharded_shape, + Shape* dst_shape) { return ShapeUtil::ForEachMutableSubshapeWithStatus( dst_shape, [&](Shape* subshape, const ShapeIndex& idx) { if (subshape->IsArray() && !subshape->has_layout()) { @@ -2126,8 +2123,7 @@ StatusOr<std::pair<std::vector<Shape>, Shape>> GetShardedProgramShapes( const Shape& sharded_subshape = ShapeUtil::GetSubshape(sharded_shape, idx); LayoutUtil::SetToDefaultLayout(subshape); - TF_ASSIGN_OR_RETURN(Shape layout, client->client() - ->backend() + TF_ASSIGN_OR_RETURN(Shape layout, local_client->backend() .transfer_manager() ->ChooseCompactLayoutForShape( sharded_subshape)); @@ -2162,8 +2158,8 @@ StatusOr<std::pair<std::vector<Shape>, Shape>> GetShardedProgramShapes( for (int replica = 0; replica < num_replicas; ++replica) { for (int partition = 0; partition < num_partitions; ++partition) { int device_id = (*device_assignment)(replica, partition); - PjRtDevice* device = LookupDevice(*client, device_id); - if (device->host_id() != client->host_id()) { + PjRtDevice* device = LookupDevice(*this, device_id); + if (device->host_id() != host_id()) { VLOG(3) << "Non-local device: " << device_id; continue; } @@ -2185,15 +2181,14 @@ StatusOr<std::pair<std::vector<Shape>, Shape>> GetShardedProgramShapes( TF_ASSIGN_OR_RETURN( std::vector<std::unique_ptr<LocalExecutable>> local_executables, - client->client()->Compile(computation, argument_layout_pointers, - build_options)); + client()->Compile(computation, argument_layout_pointers, build_options)); auto executable = absl::make_unique<PjRtExecutable>( std::move(local_executables), options.parameter_is_tupled_arguments, std::move(device_assignment), std::move(local_logical_device_ids), - std::move(local_devices), client); + std::move(local_devices), this); TF_RETURN_IF_ERROR( - executable->SetUpDonation(client, options.parameter_is_tupled_arguments)); + executable->SetUpDonation(this, options.parameter_is_tupled_arguments)); return executable; } diff --git a/tensorflow/compiler/xla/pjrt/pjrt_client.h b/tensorflow/compiler/xla/pjrt/pjrt_client.h index cb4ef9da85b..c10470f7d60 100644 --- a/tensorflow/compiler/xla/pjrt/pjrt_client.h +++ b/tensorflow/compiler/xla/pjrt/pjrt_client.h @@ -120,6 +120,24 @@ struct PjRtCrossHostRecvBuffer { using PjRtCrossHostRecvNotifier = std::function<void(StatusOr<std::vector<PjRtCrossHostRecvBuffer>>&&)>; +struct CompileOptions { + // The layouts of the arguments that the computation should expect. + absl::optional<std::vector<Shape>> argument_layouts; + + // If true, the supplied computation expects its arguments to be wrapped in a + // tuple and passed as a single parameter. + bool parameter_is_tupled_arguments = false; + + // XLA's compilation time options. + ExecutableBuildOptions executable_build_options; + + // If true, the executable can be run on any device. May only be true if + // !executable_build_options.has_device_assignment(), so only applies to + // single-device executables. Beware: on GPUs, sometimes an executable + // compiled for one device doesn't run on another. + bool compile_portable_executable = false; +}; + class PjRtExecutable; // Encapsulates the state of Python session with XLA. @@ -198,6 +216,63 @@ class PjRtClient { // Returns a backend-specific HLO cost analysis visitor. virtual std::unique_ptr<HloCostAnalysis> GetHloCostAnalysis(); + virtual StatusOr<std::unique_ptr<PjRtExecutable>> Compile( + const XlaComputation& computation, CompileOptions options); + + virtual StatusOr<std::unique_ptr<PjRtBuffer>> CreateUninitializedBuffer( + const Shape& shape, PjRtDevice* device); + + // Describes the semantics the caller to BufferFromHostBuffer expects from the + // runtime, in a total order from most restrictive to least restrictive. + enum class HostBufferSemantics { + // The runtime may not hold references to `data` after the call to + // `BufferFromHostBuffer` completes. The caller promises that `data` is + // immutable and will not be freed only for the duration of the + // BufferFromHostBuffer call. `buffer_reference` will be freed by the time + // `BufferFromHostBuffer` returns. + kImmutableOnlyDuringCall, + + // The runtime may hold onto `data` after the call to `BufferFromHostBuffer` + // returns while the runtime completes a transfer to the device. The caller + // promises not to mutate or free `data` until the transfer completes, at + // which point the runtime will release `buffer_reference`. It is also + // correct to wait on the host (directly or indirectly) for the buffer's + // definition event to complete. + kImmutableUntilTransferCompletes, + + // The PjRtBuffer may alias `data` internally and the runtime may use the + // `data` contents as long as the buffer is alive. The caller promises to + // keep `data` alive and not to mutate its contents as long as the buffer is + // alive; to notify the caller that the buffer may be freed, the runtime + // will release its `buffer_reference` when the PjRtBuffer is freed. On + // non-CPU platforms this acts identically to + // kImmutableUntilTransferCompletes. + kZeroCopy, + }; + virtual StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostBuffer( + const void* data, const Shape& shape, + HostBufferSemantics host_buffer_semantics, + std::shared_ptr<void> buffer_reference, PjRtDevice* device); + + // Note that literal must remain in scope until the transfer has completed, so + // the caller should, for example, wait for BlockHostUntilReady() completes on + // the return value before letting literal go out of scope. + virtual StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostLiteral( + const LiteralSlice& literal, PjRtDevice* device); + + // Asynchronously makes a vector of PjRtBuffers that can be used to receive + // cross host transfers using `client` on `device'. `shapes` must be the exact + // shapes, with identical layouts, corresponding to the buffers that will be + // sent. When resources for the transfer are available, notifier will be + // called with a vector of PjRtCrossHostRecvBuffer structs, one for each + // shape in `shapes`. Each struct contains a buffer that will contain the + // received value, and an opaque string that should be transmitted to the + // sending host and used in a call to CopyToRemoteDevice. None of the recv + // buffers will become ready until *all* of the sends have completed. + virtual void MakeCrossHostReceiveBuffers( + absl::Span<const Shape> shapes, PjRtDevice* device, + PjRtCrossHostRecvNotifier&& notifier); + protected: friend class PjRtBuffer; virtual void EnqueueCrossHostReceive( @@ -385,6 +460,7 @@ class PjRtBuffer { private: friend class PjRtBuffer; + friend class PjRtClient; // Helper struct that makes it possible to move a ScopedHold through a // closure. @@ -423,62 +499,6 @@ class PjRtBuffer { StatusOr<std::shared_ptr<TrackedDeviceBuffer>> buffer_or_; }; - // Returns a buffer with uninitialized contents. - static StatusOr<std::unique_ptr<PjRtBuffer>> CreateUninitialized( - const Shape& shape, PjRtClient* client, PjRtDevice* device); - - // Describes the semantics the caller to FromHostBuffer expects from the - // runtime, in a total order from most restrictive to least restrictive. - enum class HostBufferSemantics { - // The runtime may not hold references to `data` after the call to - // `FromHostBuffer` completes. The caller promises that `data` is immutable - // and will not be freed only for the duration of the FromHostBuffer call. - // `buffer_reference` will be freed by the time `FromHostBuffer` returns. - kImmutableOnlyDuringCall, - - // The runtime may hold onto `data` after the call to `FromHostBuffer` - // returns while the runtime completes a transfer to the device. The caller - // promises not to mutate or free `data` until the transfer completes, at - // which point the runtime will release `buffer_reference`. It is also - // correct to wait on the host (directly or indirectly) for the buffer's - // definition event to complete. - kImmutableUntilTransferCompletes, - - // The PjRtBuffer may alias `data` internally and the runtime may use the - // `data` contents as long as the buffer is alive. - // The caller promises to keep `data` alive and not to mutate its contents - // as long as the buffer is alive; to notify the caller that the buffer may - // be freed, the runtime will release its `buffer_reference` when the - // PjRtBuffer is freed. On non-CPU platforms this acts identically to - // kImmutableUntilTransferCompletes. - kZeroCopy, - }; - static StatusOr<std::unique_ptr<PjRtBuffer>> FromHostBuffer( - const void* data, const Shape& shape, - HostBufferSemantics host_buffer_semantics, - std::shared_ptr<void> buffer_reference, PjRtClient* client, - PjRtDevice* device); - - // Note that literal must remain in scope until the transfer has completed, so - // the caller should, for example, wait for BlockHostUntilReady() completes on - // the return value before letting literal go out of scope. - static StatusOr<std::unique_ptr<PjRtBuffer>> FromHostLiteral( - const LiteralSlice& literal, PjRtClient* client, PjRtDevice* device); - - // Asynchronously makes a vector of PjRtBuffers that can be used to receive - // cross host transfers using `client` on `device'. `shapes` must be the exact - // shapes, with identical layouts, corresponding to the buffers that will be - // sent. When resources for the transfer are available, notifier will be - // called with a vector of PjRtCrossHostRecvBuffer structs, one for each - // shape in `shapes`. Each struct contains a buffer that will contain the - // received value, and an opaque string that should be transmitted to the - // sending host and used in a call to CopyToRemoteDevice. None of the recv - // buffers will become ready until *all* of the sends have completed. - static void MakeCrossHostReceiveBuffers(absl::Span<const Shape> shapes, - PjRtClient* client, - PjRtDevice* device, - PjRtCrossHostRecvNotifier&& notifier); - PjRtBuffer(Shape on_host_shape, Shape on_device_shape, std::shared_ptr<TrackedDeviceBuffer> device_buffer, PjRtClient* client, PjRtDevice* device); @@ -661,24 +681,6 @@ class PjRtBuffer { Semaphore donation_semaphore_; }; -struct CompileOptions { - // The layouts of the arguments that the computation should expect. - absl::optional<std::vector<Shape>> argument_layouts; - - // If true, the supplied computation expects its arguments to be wrapped in a - // tuple and passed as a single parameter. - bool parameter_is_tupled_arguments = false; - - // XLA's compilation time options. - ExecutableBuildOptions executable_build_options; - - // If true, the executable can be run on any device. May only be true if - // !executable_build_options.has_device_assignment(), so only applies to - // single-device executables. Beware: on GPUs, sometimes an executable - // compiled for one device doesn't run on another. - bool compile_portable_executable = false; -}; - class ExecuteContext { public: virtual ~ExecuteContext() = default; @@ -710,10 +712,6 @@ struct ExecuteOptions { // buffer will be donated when passed to the execution. class PjRtExecutable { public: - static StatusOr<std::unique_ptr<PjRtExecutable>> Compile( - const XlaComputation& computation, PjRtClient* client, - CompileOptions options); - PjRtExecutable(std::vector<std::unique_ptr<LocalExecutable>> executables, bool parameter_is_tupled_arguments, std::shared_ptr<DeviceAssignment> device_assignment, @@ -783,6 +781,7 @@ class PjRtExecutable { } private: + friend class PjRtClient; // Initializes information about which arguments to which executables must be // donated due to aliases that were specified by the computation. Status SetUpDonation(PjRtClient* client, bool tuple_inputs); diff --git a/tensorflow/compiler/xla/python/jax_jit.cc b/tensorflow/compiler/xla/python/jax_jit.cc index 944b4c20a8a..f4202045a66 100644 --- a/tensorflow/compiler/xla/python/jax_jit.cc +++ b/tensorflow/compiler/xla/python/jax_jit.cc @@ -465,10 +465,10 @@ std::unique_ptr<xla::PjRtBuffer> ConvertToScalarBuffer( xla::PjRtDevice* device) { CppType data = py::cast<Pybind11Type>(scalar); xla::Shape shape = xla::ShapeUtil::MakeShapeWithType<CppType>({}); - return ValueOrThrow(xla::PjRtBuffer::FromHostBuffer( + return ValueOrThrow(client->BufferFromHostBuffer( &data, shape, - xla::PjRtBuffer::HostBufferSemantics::kImmutableOnlyDuringCall, nullptr, - client, device)); + xla::PjRtClient::HostBufferSemantics::kImmutableOnlyDuringCall, nullptr, + device)); } // Convert a scalar to the associated PjRtBuffer or raises an error if it is @@ -502,17 +502,17 @@ StatusOr<std::unique_ptr<xla::PjRtBuffer>> ScalarToBuffer( if (jax_enable_x64) { xla::complex128 data(result.real, result.imag); xla::Shape shape = xla::ShapeUtil::MakeShapeWithType<xla::complex128>({}); - return ValueOrThrow(xla::PjRtBuffer::FromHostBuffer( + return ValueOrThrow(client->BufferFromHostBuffer( &data, shape, - xla::PjRtBuffer::HostBufferSemantics::kImmutableOnlyDuringCall, - nullptr, client, device)); + xla::PjRtClient::HostBufferSemantics::kImmutableOnlyDuringCall, + nullptr, device)); } else { xla::complex64 data(result.real, result.imag); xla::Shape shape = xla::ShapeUtil::MakeShapeWithType<xla::complex64>({}); - return ValueOrThrow(xla::PjRtBuffer::FromHostBuffer( + return ValueOrThrow(client->BufferFromHostBuffer( &data, shape, - xla::PjRtBuffer::HostBufferSemantics::kImmutableOnlyDuringCall, - nullptr, client, device)); + xla::PjRtClient::HostBufferSemantics::kImmutableOnlyDuringCall, + nullptr, device)); } } return InvalidArgument( @@ -678,7 +678,7 @@ Status ConvertArgsToBuffers(bool jax_enable_x64, xla::PyClient& pyclient, ValueOrThrow(pyclient.BufferFromPyval( numpy_array, data_device, /*force_copy=*/false, /*host_buffer_semantics=*/ - xla::PjRtBuffer::HostBufferSemantics::kZeroCopy)); + xla::PjRtClient::HostBufferSemantics::kZeroCopy)); arg_buffers.push_back(buffer->buffer()); ArgSignature sig; diff --git a/tensorflow/compiler/xla/python/outfeed_receiver.cc b/tensorflow/compiler/xla/python/outfeed_receiver.cc index f6067e650c0..2535d62ee7e 100644 --- a/tensorflow/compiler/xla/python/outfeed_receiver.cc +++ b/tensorflow/compiler/xla/python/outfeed_receiver.cc @@ -409,10 +409,9 @@ Status OutfeedReceiverImpl::SendShutdownOutfeedHeader(int device_idx) { compile_options.executable_build_options.set_device_assignment( device_assignment); - TF_ASSIGN_OR_RETURN( - std::unique_ptr<PjRtExecutable> executable, - PjRtExecutable::Compile(computation, devices_[device_idx]->client(), - std::move(compile_options))); + TF_ASSIGN_OR_RETURN(std::unique_ptr<PjRtExecutable> executable, + devices_[device_idx]->client()->Compile( + computation, std::move(compile_options))); ExecuteOptions execute_options; TF_ASSIGN_OR_RETURN(std::vector<std::unique_ptr<PjRtBuffer>> output_buffers, executable->Execute({}, execute_options)); diff --git a/tensorflow/compiler/xla/python/outfeed_receiver_test.cc b/tensorflow/compiler/xla/python/outfeed_receiver_test.cc index 919dafe2e0b..5422a4b3056 100644 --- a/tensorflow/compiler/xla/python/outfeed_receiver_test.cc +++ b/tensorflow/compiler/xla/python/outfeed_receiver_test.cc @@ -40,9 +40,8 @@ Status CompileAndExecute(XlaBuilder* builder, XlaOp root, int device_id, compile_options.executable_build_options.set_device_assignment( device_assignment); - TF_ASSIGN_OR_RETURN( - std::unique_ptr<PjRtExecutable> executable, - PjRtExecutable::Compile(computation, client, std::move(compile_options))); + TF_ASSIGN_OR_RETURN(std::unique_ptr<PjRtExecutable> executable, + client->Compile(computation, std::move(compile_options))); ExecuteOptions execute_options; TF_ASSIGN_OR_RETURN(std::vector<std::unique_ptr<PjRtBuffer>> output_buffers, executable->Execute({}, execute_options)); diff --git a/tensorflow/compiler/xla/python/py_client.cc b/tensorflow/compiler/xla/python/py_client.cc index 07b915c640c..d42bbdca154 100644 --- a/tensorflow/compiler/xla/python/py_client.cc +++ b/tensorflow/compiler/xla/python/py_client.cc @@ -89,7 +89,7 @@ PyClient::GetDefaultDeviceAssignment1D(int num_replicas) { StatusOr<std::unique_ptr<PyBuffer>> PyClient::BufferFromPyval( const pybind11::object& argument, PjRtDevice* device, bool force_copy, - PjRtBuffer::HostBufferSemantics host_buffer_semantics) { + PjRtClient::HostBufferSemantics host_buffer_semantics) { if (device == nullptr) { TF_RET_CHECK(!pjrt_client_->local_devices().empty()); device = pjrt_client_->local_devices().front(); @@ -114,10 +114,9 @@ StatusOr<std::unique_ptr<PyBuffer>> PyClient::BufferFromPyval( std::unique_ptr<PjRtBuffer> buffer; { py::gil_scoped_release gil_release; - TF_ASSIGN_OR_RETURN( - buffer, PjRtBuffer::FromHostBuffer( - c->buf_ptr, c->shape, host_buffer_semantics, - std::move(py_buffer_ref), pjrt_client_.get(), device)); + TF_ASSIGN_OR_RETURN(buffer, pjrt_client_->BufferFromHostBuffer( + c->buf_ptr, c->shape, host_buffer_semantics, + std::move(py_buffer_ref), device)); } auto traceback = Traceback::Get(); return std::make_unique<PyBuffer>(shared_from_this(), std::move(buffer), @@ -131,8 +130,7 @@ StatusOr<std::shared_ptr<PyExecutable>> PyClient::Compile( { py::gil_scoped_release gil_release; TF_ASSIGN_OR_RETURN(executable, - PjRtExecutable::Compile(computation, pjrt_client_.get(), - std::move(options))); + pjrt_client_->Compile(computation, std::move(options))); TF_ASSIGN_OR_RETURN(fingerprint, pjrt_client_->ExecutableFingerprint(*executable)); } diff --git a/tensorflow/compiler/xla/python/py_client.h b/tensorflow/compiler/xla/python/py_client.h index 08249722d6c..224f8278bb1 100644 --- a/tensorflow/compiler/xla/python/py_client.h +++ b/tensorflow/compiler/xla/python/py_client.h @@ -123,7 +123,7 @@ class PyClient : public std::enable_shared_from_this<PyClient> { StatusOr<std::unique_ptr<PyBuffer>> BufferFromPyval( const pybind11::object& argument, PjRtDevice* device, bool force_copy, - PjRtBuffer::HostBufferSemantics host_buffer_semantics); + PjRtClient::HostBufferSemantics host_buffer_semantics); StatusOr<std::shared_ptr<PyExecutable>> Compile( const XlaComputation& computation, CompileOptions options); diff --git a/tensorflow/compiler/xla/python/xla.cc b/tensorflow/compiler/xla/python/xla.cc index b84dfa92e47..b0948fab2b7 100644 --- a/tensorflow/compiler/xla/python/xla.cc +++ b/tensorflow/compiler/xla/python/xla.cc @@ -535,12 +535,12 @@ PYBIND11_MODULE(xla_extension, m) { .value("PLATFORM", GpuAllocatorConfig::Kind::kPlatform) .value("BFC", GpuAllocatorConfig::Kind::kBFC); - py::enum_<PjRtBuffer::HostBufferSemantics>(m, "HostBufferSemantics") + py::enum_<PjRtClient::HostBufferSemantics>(m, "HostBufferSemantics") .value("IMMUTABLE_ONLY_DURING_CALL", - PjRtBuffer::HostBufferSemantics::kImmutableOnlyDuringCall) + PjRtClient::HostBufferSemantics::kImmutableOnlyDuringCall) .value("IMMUTABLE_UNTIL_TRANSFER_COMPLETES", - PjRtBuffer::HostBufferSemantics::kImmutableUntilTransferCompletes) - .value("ZERO_COPY", PjRtBuffer::HostBufferSemantics::kZeroCopy); + PjRtClient::HostBufferSemantics::kImmutableUntilTransferCompletes) + .value("ZERO_COPY", PjRtClient::HostBufferSemantics::kZeroCopy); py::class_<PyClient, std::shared_ptr<PyClient>> py_local_client(m, "Client"); py_local_client.def_property_readonly("platform", &PyClient::platform_name) @@ -562,7 +562,7 @@ PYBIND11_MODULE(xla_extension, m) { .def("buffer_from_pyval", &PyClient::BufferFromPyval, py::arg("argument"), py::arg("device") = nullptr, py::arg("force_copy") = false, py::arg("host_buffer_semantics") = - PjRtBuffer::HostBufferSemantics::kZeroCopy) + PjRtClient::HostBufferSemantics::kZeroCopy) .def("compile", &PyClient::Compile, py::arg("computation"), py::arg("compile_options") = CompileOptions()) .def("heap_profile", &PyClient::HeapProfile);