diff --git a/tensorflow/compiler/xla/pjrt/BUILD b/tensorflow/compiler/xla/pjrt/BUILD index 02fc84c8b49..0c3e7461743 100644 --- a/tensorflow/compiler/xla/pjrt/BUILD +++ b/tensorflow/compiler/xla/pjrt/BUILD @@ -119,12 +119,39 @@ cc_library( cc_library( name = "pjrt_client", - srcs = ["pjrt_client.cc"], hdrs = ["pjrt_client.h"], visibility = ["//tensorflow/compiler/xla:friends"], + deps = [ + "//tensorflow/compiler/xla:executable_run_options", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto_cc", + "//tensorflow/compiler/xla/client:executable_build_options", + "//tensorflow/compiler/xla/client:xla_computation", + "//tensorflow/compiler/xla/pjrt/distributed:protocol_proto_cc", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_cost_analysis", + "//tensorflow/core:lib", + "@com_google_absl//absl/base", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:span", + ], +) + +cc_library( + name = "pjrt_stream_executor_client", + srcs = ["pjrt_stream_executor_client.cc"], + hdrs = ["pjrt_stream_executor_client.h"], + visibility = ["//tensorflow/compiler/xla:friends"], deps = [ ":event_pool", ":local_device_state", + ":pjrt_client", ":tracked_device_buffer", "//tensorflow/compiler/xla:cpu_function_runtime", "//tensorflow/compiler/xla:executable_run_options", @@ -181,7 +208,7 @@ cc_library( ], deps = [ ":local_device_state", - ":pjrt_client", + ":pjrt_stream_executor_client", ":tracked_device_buffer", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status", @@ -215,7 +242,7 @@ cc_library( srcs = ["interpreter_device.cc"], hdrs = ["interpreter_device.h"], deps = [ - ":pjrt_client", + ":pjrt_stream_executor_client", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/service:interpreter_plugin", @@ -229,7 +256,7 @@ cc_library( srcs = ["cpu_device.cc"], hdrs = ["cpu_device.h"], deps = [ - ":pjrt_client", + ":pjrt_stream_executor_client", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/service:platform_util", @@ -242,7 +269,7 @@ cc_library( srcs = ["gpu_device.cc"], hdrs = ["gpu_device.h"], deps = [ - ":pjrt_client", + ":pjrt_stream_executor_client", "@com_google_absl//absl/container:flat_hash_map", "//tensorflow/compiler/xla/service/gpu:gpu_executable_run_options", "//tensorflow/compiler/xla:statusor", @@ -279,6 +306,7 @@ tf_cc_test( deps = [ ":gpu_device", ":pjrt_client", + ":pjrt_stream_executor_client", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla/client:executable_build_options", "//tensorflow/compiler/xla/client:xla_builder", diff --git a/tensorflow/compiler/xla/pjrt/cpu_device.cc b/tensorflow/compiler/xla/pjrt/cpu_device.cc index 5241efbd2de..72da2d2b0dd 100644 --- a/tensorflow/compiler/xla/pjrt/cpu_device.cc +++ b/tensorflow/compiler/xla/pjrt/cpu_device.cc @@ -17,7 +17,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/pjrt/pjrt_client.h" +#include "tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h" #include "tensorflow/compiler/xla/service/platform_util.h" namespace xla { diff --git a/tensorflow/compiler/xla/pjrt/cpu_device.h b/tensorflow/compiler/xla/pjrt/cpu_device.h index 0aab55e6493..e0106fdd179 100644 --- a/tensorflow/compiler/xla/pjrt/cpu_device.h +++ b/tensorflow/compiler/xla/pjrt/cpu_device.h @@ -18,7 +18,7 @@ limitations under the License. #include -#include "tensorflow/compiler/xla/pjrt/pjrt_client.h" +#include "tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h" #include "tensorflow/compiler/xla/statusor.h" namespace xla { diff --git a/tensorflow/compiler/xla/pjrt/gpu_device.cc b/tensorflow/compiler/xla/pjrt/gpu_device.cc index 302c7734d73..8c860d56863 100644 --- a/tensorflow/compiler/xla/pjrt/gpu_device.cc +++ b/tensorflow/compiler/xla/pjrt/gpu_device.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/pjrt/gpu_device.h" #include "absl/container/flat_hash_map.h" +#include "tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h" #ifdef NCCL_ENABLED #include "third_party/nccl/nccl.h" diff --git a/tensorflow/compiler/xla/pjrt/gpu_device.h b/tensorflow/compiler/xla/pjrt/gpu_device.h index 142a263d959..3e11c31d51d 100644 --- a/tensorflow/compiler/xla/pjrt/gpu_device.h +++ b/tensorflow/compiler/xla/pjrt/gpu_device.h @@ -19,7 +19,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/pjrt/distributed/client.h" -#include "tensorflow/compiler/xla/pjrt/pjrt_client.h" +#include "tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/core/common_runtime/bfc_allocator.h" diff --git a/tensorflow/compiler/xla/pjrt/interpreter_device.cc b/tensorflow/compiler/xla/pjrt/interpreter_device.cc index 3b3daba5906..818740ca105 100644 --- a/tensorflow/compiler/xla/pjrt/interpreter_device.cc +++ b/tensorflow/compiler/xla/pjrt/interpreter_device.cc @@ -17,7 +17,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/pjrt/pjrt_client.h" +#include "tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h" #include "tensorflow/compiler/xla/service/platform_util.h" namespace xla { diff --git a/tensorflow/compiler/xla/pjrt/interpreter_device.h b/tensorflow/compiler/xla/pjrt/interpreter_device.h index a23ddcb5bb9..4a4477a7b86 100644 --- a/tensorflow/compiler/xla/pjrt/interpreter_device.h +++ b/tensorflow/compiler/xla/pjrt/interpreter_device.h @@ -18,7 +18,7 @@ limitations under the License. #include -#include "tensorflow/compiler/xla/pjrt/pjrt_client.h" +#include "tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h" #include "tensorflow/compiler/xla/statusor.h" namespace xla { diff --git a/tensorflow/compiler/xla/pjrt/pjrt_client.h b/tensorflow/compiler/xla/pjrt/pjrt_client.h index 9a5f8665d5f..9545dbdb031 100644 --- a/tensorflow/compiler/xla/pjrt/pjrt_client.h +++ b/tensorflow/compiler/xla/pjrt/pjrt_client.h @@ -20,30 +20,19 @@ limitations under the License. #include #include -#include "absl/container/flat_hash_map.h" -#include "absl/container/flat_hash_set.h" -#include "absl/container/inlined_vector.h" #include "absl/strings/string_view.h" -#include "absl/synchronization/mutex.h" -#include "absl/synchronization/notification.h" #include "absl/types/optional.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/client/executable_build_options.h" -#include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/layout.h" -#include "tensorflow/compiler/xla/pjrt/local_device_state.h" -#include "tensorflow/compiler/xla/pjrt/tracked_device_buffer.h" -#include "tensorflow/compiler/xla/service/computation_placer.h" -#include "tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h" +#include "tensorflow/compiler/xla/service/hlo_cost_analysis.h" #include "tensorflow/compiler/xla/service/hlo_module.h" -#include "tensorflow/compiler/xla/service/shaped_buffer.h" #include "tensorflow/compiler/xla/shape.h" #include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/framework/allocator.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/casts.h" #include "tensorflow/core/platform/fingerprint.h" @@ -106,78 +95,6 @@ class PjRtDevice { virtual StatusOr TransferFromOutfeed(const Shape& shape) const = 0; }; -class PjRtStreamExecutorDevice : public PjRtDevice { - public: - explicit PjRtStreamExecutorDevice( - int id, std::unique_ptr local_device_state, - std::string device_kind, int host_id = 0) - : id_(id), - device_ordinal_( - local_device_state ? local_device_state->device_ordinal() : -1), - local_device_state_(std::move(local_device_state)), - host_id_(host_id), - device_kind_(std::move(device_kind)) {} - ~PjRtStreamExecutorDevice() override {} - - // Must set client exactly once. - void SetClient(PjRtClient* client) { - CHECK(client_ == nullptr); - client_ = client; - } - - // Task ID. This is always 0 on single-task setup. - int host_id() const override { return host_id_; } - - // Return `platform_id` from client. - PjRtPlatformId platform_id() const; - - // Return `platform_name` from client. - const std::string& platform_name() const; - - PjRtClient* client() const override { return client_; } - - // The ID of this device. IDs are unique among devices of this type - // (e.g. CPUs, GPUs). On multi-host platforms, this will be unique across all - // hosts' devices. This is the ID that should be used in a DeviceAssignment. - int id() const override { return id_; } - - bool IsAddressable() const override { return device_ordinal_ != -1; } - - int local_hardware_id() const override { return device_ordinal_; } - - // If this is a device local to this host, returns a LocalDeviceState object - // that can be used to manipulate the device. Returns nullptr if the device is - // not local to this host. - LocalDeviceState* local_device_state() const { - return local_device_state_.get(); - } - - // If this is a device local to this host, returns a LocalDeviceState object - // that can be used to manipulate the device. Returns an error if the device - // is not local to this host. - StatusOr GetLocalDeviceState() const; - - // A vendor-dependent string that uniquely identifies the kind of device. - const std::string& device_kind() const override { return device_kind_; } - - std::string DebugString() const override; - - // Transfer the given literal to the infeed queue of the given localdevice. - Status TransferToInfeed(const LiteralSlice& literal) const override; - - // Transfer and return a value of the given shape from the outfeed of the - // given device. - StatusOr TransferFromOutfeed(const Shape& shape) const override; - - private: - const int id_; - const int device_ordinal_; // -1 means not local. - const std::unique_ptr local_device_state_; - const int host_id_; - const std::string device_kind_; - PjRtClient* client_ = nullptr; -}; - // Forward declaration. class PjRtBuffer; @@ -333,181 +250,6 @@ class PjRtClient { virtual StatusOr CreateHostToDeviceChannelHandle() = 0; }; -class PjRtStreamExecutorClient : public PjRtClient { - public: - // `allocator` may null, in which case the platform default allocator is used. - explicit PjRtStreamExecutorClient( - std::string platform_name, LocalClient* client, - std::vector> devices, - int host_id, std::unique_ptr allocator, - std::unique_ptr host_memory_allocator, - bool should_stage_host_to_device_transfers, - std::unique_ptr gpu_run_options); - ~PjRtStreamExecutorClient() override = default; - - int host_id() const override { return host_id_; } - - int device_count() const override { return devices_.size(); } - int addressable_device_count() const override { - return local_devices_.size(); - } - absl::Span devices() const override { return devices_; } - absl::Span local_devices() const override { - return local_devices_; - } - - StatusOr LookupDevice(int device_id) const override { - auto it = id_to_device_.find(device_id); - if (it != id_to_device_.end()) { - return it->second; - } - return InvalidArgument("No matching device found for device_id %d", - device_id); - } - - StatusOr LookupAddressableDevice( - int local_hardware_id) const override; - - PjRtPlatformId platform_id() const override { return platform_id_; } - const std::string& platform_name() const override { return platform_name_; } - - // Most platforms expect device-to-device transfers to be enqueued on the - // source d2d stream, but some platforms use the destination d2d stream. This - // function specifies which one the platform expects. - virtual bool EnqueueD2DTransfersOnSrcStream() const { return true; } - - StatusOr GetDefaultDeviceAssignment( - int num_replicas, int num_partitions) const override; - - StatusOr> Compile( - const XlaComputation& computation, CompileOptions options) override; - - // Generates a unique fingerprint for `executable`. - StatusOr> ExecutableFingerprint( - const PjRtExecutable& executable) const override { - return absl::optional(); - } - - // Returns a backend-specific HLO cost analysis visitor. - std::unique_ptr GetHloCostAnalysis() override; - - // Creates a buffer on the device without initializing or copying any data. - // An optional `definition_event` may be speficied that can be used to - // ensure the buffer isn't referenced until some external mechanism has - // initialized the data. - // NOTE: The sequencing mechanism is not guaranteed to be supported by all - // future backends and so callers should avoid wherever possible. - StatusOr> CreateUninitializedBuffer( - const Shape& shape, PjRtDevice* device) override; - StatusOr> CreateUninitializedBuffer( - const Shape& shape, PjRtDevice* device, - std::shared_ptr definition_event); - - StatusOr> BufferFromHostBuffer( - const void* data, const Shape& shape, - HostBufferSemantics host_buffer_semantics, - std::shared_ptr buffer_reference, PjRtDevice* device) override; - - // Note that literal must remain in scope until the transfer has completed, so - // the caller should, for example, wait for BlockHostUntilReady() completes on - // the return value before letting literal go out of scope. - StatusOr> BufferFromHostLiteral( - const LiteralSlice& literal, PjRtDevice* device) override; - - // Asynchronously makes a vector of PjRtBuffers that can be used to receive - // cross host transfers using `client` on `device'. `shapes` must be the exact - // shapes, with identical layouts, corresponding to the buffers that will be - // sent. When resources for the transfer are available, notifier will be - // called with a vector of PjRtCrossHostRecvBuffer structs, one for each - // shape in `shapes`. Each struct contains a buffer that will contain the - // received value, and an opaque string that should be transmitted to the - // sending host and used in a call to CopyToRemoteDevice. None of the recv - // buffers will become ready until *all* of the sends have completed. - void MakeCrossHostReceiveBuffers( - absl::Span shapes, PjRtDevice* device, - PjRtCrossHostRecvNotifier&& notifier) override; - - StatusOr CreateChannelHandle() override { - return client()->CreateChannelHandle(); - } - StatusOr CreateDeviceToHostChannelHandle() override { - return client()->CreateDeviceToHostChannelHandle(); - } - StatusOr CreateHostToDeviceChannelHandle() override { - return client()->CreateHostToDeviceChannelHandle(); - } - - LocalDeviceState& device_state(int device_ordinal) const { - return *tensorflow::down_cast( - local_devices_.at(device_ordinal)) - ->local_device_state(); - } - LocalClient* client() const { return client_; } - se::DeviceMemoryAllocator* allocator() const { return allocator_; } - tensorflow::Allocator* host_memory_allocator() const { - return host_memory_allocator_.get(); - } - bool should_stage_host_to_device_transfers() const { - return should_stage_host_to_device_transfers_; - } - - gpu::GpuExecutableRunOptions* gpu_run_options() const { - return gpu_run_options_.get(); - } - - tensorflow::thread::ThreadPool* h2d_transfer_pool() { - return &h2d_transfer_pool_; - } - - protected: - friend class PjRtStreamExecutorBuffer; - virtual void EnqueueCrossHostReceive( - std::vector>&& buffers, - std::shared_ptr definition_event, - PjRtCrossHostRecvNotifier&& notifier) const { - notifier(Unimplemented("Cross host receives not implemented.")); - } - - virtual Status CopyToRemoteDevice( - PjRtBuffer* buffer, absl::string_view serialized_descriptor) const { - return Unimplemented("Cross host sends not implemented."); - } - - const PjRtPlatformId platform_id_; - const std::string platform_name_; - LocalClient* client_; - - // Allocator to be used for staging memory transfers to devices. - std::unique_ptr host_memory_allocator_; - - // Includes all devices, including non-local devices on multi-host platforms. - std::vector> owned_devices_; - // Pointers to `owned_devices_`. - std::vector devices_; - // Maps Device::id() to the corresponding Device. Includes all devices. - std::map id_to_device_; - // Local devices indexed by local device ordinal. - std::vector local_devices_; - int host_id_; - - se::DeviceMemoryAllocator* allocator_; - std::unique_ptr owned_allocator_; - - // Should we always prefer to stage host-to-device transfers via memory - // allocated on host_memory_allocator_? True only on GPU, where we prefer to - // transfer via pinned memory. - bool should_stage_host_to_device_transfers_; - - std::unique_ptr gpu_run_options_; - - tensorflow::thread::ThreadPool h2d_transfer_pool_; -}; - -// Converts a 2D set of Device objects indexed by [replica][partition] into an -// xla::DeviceAssignment. -StatusOr DevicesToDeviceAssignment( - absl::Span> devices); - // Holds a reference from Python to a tuple of device buffers. A PjRtBuffer // can be either valid or invalid. An invalid buffer is one that has never been // initialized, or a buffer that has been deleted (e.g., by calling Delete, or @@ -625,393 +367,6 @@ class PjRtBuffer { virtual bool IsOnCpu() const = 0; }; -class PjRtStreamExecutorBuffer : public PjRtBuffer { - public: - // Helper class to retain a "hold" on a PjRtBuffer. A ScopedHold may not - // outlive its parent PjRtBuffer. - // - // There are three types of hold, as follows: - // - // 1) Usage hold: a transient hold while an operation using the buffer is - // being enqueued onto a stream. - // A client acquires a usage hold by calling - // PjRtBuffer::GetBufferWithHold(kUsage) or the convenience wrapper - // GetBufferWithUsageHold(). If the enqueue completes successfully the hold - // should be released using a call to ConvertUsageHold. If the ScopedHold is - // deleted without ConvertUsageHold being called, e.g., on error, the hold is - // dropped. It is legal to drop a usage hold instead of calling - // ConvertUsageHold, even if the buffer was successfully enqueued, as long as - // the client ensures that all necessary synchronization has been done. - // - // 2) External hold: a potentially long-lived hold while the buffer is being - // shared by an external framework, e.g., NumPy. - // A client acquires an external hold by calling - // PjRtBuffer::GetBufferWithHold(kExternal) or the convenience wrapper - // GetBufferWithExternalReference and releases it by deleting the ScopedHold. - // The external framework should not modify the underlying buffer unless it is - // confident via its own synchronization that modifications do not race with - // reads from the PjRtBuffer. - // - // 3) Donation hold: a transient hold while an execution that donates the - // buffer is being enqueued onto the compute stream. - // A client acquires a donation hold by calling - // PjRtBuffer::GetBufferWithHold(kDonation). If the enqueue completes - // successfully the hold should be released using a call to ConfirmDonation - // after which the buffer is invalid. If the ScopedHold is deleted without - // ConfirmDonation being called, e.g., on error, the hold is dropped and the - // buffer remains valid. If the buffer is successfully enqueued the client - // *must* call ConfirmDonation. - // - // Donation holds behave like exclusive write locks: when a donation hold - // has been acquired, any attempt to acquire another hold of any type will - // block until the donation hold is dropped or confirmed. Acquiring a donation - // hold will fail with an error if there is any outstanding external hold, and - // will block if there are any outstanding usage holds until those holds are - // dropped or converted. - // - // Calls to PjRtBuffer::Release (and transitively to - // PjRtBuffer::Delete() and ~PjRtBuffer()) will block until all usage - // and donation holds are either deleted or converted/confirmed. - class ScopedHold { - public: - enum Type { kUsage = 0, kExternalReference, kDonation, kMaxValue }; - // Use a State enum instead of encoding the state in an error Status to - // avoid creating Status values in non-error cases. Creating a Status - // entails several allocations and can add O(us) to every use of a hold. - enum State { - kUninitialized = 0, - kValid, - kMoved, - kConverted, - kReleased, - kDonated, - kError - }; - - ~ScopedHold(); - ScopedHold(ScopedHold&& other); - ScopedHold(const ScopedHold&) = delete; - ScopedHold& operator=(const ScopedHold&) = delete; - - Type type() const { return type_; } - - Status status() const { - // Lazily create Status values only when they are requested. - switch (state_) { - case kUninitialized: - return InvalidArgument("Buffer has not been initialized"); - case kValid: - return Status::OK(); - case kMoved: - return InvalidArgument("Buffer has been moved."); - case kConverted: - return InvalidArgument("Buffer has been converted"); - case kReleased: - return InvalidArgument("Buffer has been released"); - case kDonated: - return InvalidArgument("Buffer has been donated"); - case kError: - return buffer_or_.status(); - default: - CHECK(false) << "Unexpected state value " << state_; - } - } - bool ok() const { return state_ == kValid; } - - // Access to the underlying device buffer storage. Requires this->ok(). - const std::shared_ptr& buffer() const { - CHECK_EQ(state_, kValid); - CHECK_NE(buffer_or_.ValueOrDie(), nullptr); - return buffer_or_.ValueOrDie(); - } - TrackedDeviceBuffer* operator->() const { return buffer().get(); } - const TrackedDeviceBuffer& operator*() const { return *buffer(); } - - // Converts the hold into a usage event. Only valid for holds of type - // kUsage. - // - // usage_stream: the stream that the buffer was used on. - // event: an event that has been recorded on usage_stream after - // the buffer was used. - // reference_held: true if and only if the caller has caused a - // reference to this->buffer() to stay live until after - // the host is sure that the usage (transfer or execution) - // has completed. - void ConvertUsageHold(se::Stream* usage_stream, - std::shared_ptr event, - bool reference_held); - - // Confirms that the buffer was successfully donated to an execution. - // Only valid for holds of type kDonation. Causes the buffer to become - // invalid. - void ConfirmDonation(); - - // Adds the held device buffers in order to 'iterator'. Used to add the - // buffers to an ExecutionInput. We require but do not verify that - // 'iterator' when passed in is pointing to a sub-tuple of the - // ExecutionInput whose on_device_shape matches that of the - // TrackedDeviceBuffer. 'end' is used to check that 'iterator' doesn't run - // out of bounds. Donates the device buffers if the hold type is kDonation, - // otherwise retains ownership of the device buffers. - void AddToInput(ShapeTree::iterator* iterator, - const ShapeTree::iterator& end, - ExecutionInput* execution_input, - se::DeviceMemoryAllocator* allocator) const; - - private: - friend class PjRtStreamExecutorBuffer; - friend class PjRtStreamExecutorClient; - - // Helper struct that makes it possible to move a ScopedHold through a - // closure. - using ForClosure = - std::tuple>>; - - ScopedHold(PjRtStreamExecutorBuffer* parent, Type type) - : parent_(parent), type_(type), state_(kUninitialized) {} - explicit ScopedHold(const ForClosure& closure_helper) - : parent_(std::get<0>(closure_helper)), - type_(std::get<1>(closure_helper)), - state_(std::get<2>(closure_helper)), - buffer_or_(std::get<3>(closure_helper)) { - // Check the buffer is not in an error state. - CHECK(buffer_or_.ValueOrDie() != nullptr); - } - - // Sets buffer state. - void SetState(State state) { state_ = state; } - - // Sets buffer_or_. Called by parent_ to initialize the hold. - void Acquire(StatusOr>&& buffer_or); - // Releases the contents of *this, so *this can subsequently be - // deleted without releasing the parent's hold. Should be passed to the - // appropriate constructor of another ScopedHold, e.g., when a hold must be - // passed through a closure that is incompatible with std::move. - ForClosure ToClosure(); - - PjRtStreamExecutorBuffer* const parent_; - const Type type_; - - // There is an invariant that if ok() then - // buffer_or_.ValueOrDie() != nullptr. - State state_; - StatusOr> buffer_or_; - }; - - PjRtStreamExecutorBuffer(Shape on_host_shape, Shape on_device_shape, - std::shared_ptr device_buffer, - PjRtClient* client, PjRtDevice* device); - ~PjRtStreamExecutorBuffer() override; - - PjRtStreamExecutorBuffer(const PjRtStreamExecutorBuffer&) = delete; - PjRtStreamExecutorBuffer(PjRtStreamExecutorBuffer&&) = delete; - PjRtStreamExecutorBuffer& operator=(const PjRtStreamExecutorBuffer&) = delete; - PjRtStreamExecutorBuffer& operator=(PjRtStreamExecutorBuffer&&) = delete; - - const Shape& on_host_shape() const override { return on_host_shape_; } - const Shape& on_device_shape() const override { return on_device_shape_; } - PjRtStreamExecutorDevice* device() const override { return device_; } - PjRtPlatformId platform_id() const { return client_->platform_id(); } - const std::string& platform_name() const { return client_->platform_name(); } - PjRtStreamExecutorClient* client() const override { return client_; } - bool IsEmptyTuple() const { - return on_host_shape_.IsTuple() && on_host_shape_.tuple_shapes_size() == 0; - } - - // Returns the size of the on-device representation of this buffer in bytes. - int64 OnDeviceSizeInBytes() const override; - - // Implement PjRtBuffer::ExternalReferenceHold a wrapped - // ScopedHold::kExternalReference. - class ScopedHoldAsExternalReference - : public PjRtBuffer::ExternalReferenceHold { - public: - explicit ScopedHoldAsExternalReference(ScopedHold hold) - : external_reference_(std::move(hold)) { - CHECK(hold.type() == ScopedHold::kExternalReference); - } - - ~ScopedHoldAsExternalReference() override = default; - - void* OpaqueDeviceMemoryDataPointer() const override { - return external_reference_->device_memory().front().opaque(); - } - - private: - ScopedHold external_reference_; - }; - StatusOr> AcquireExternalReference() - override; - - StatusOr>> ReleaseDeviceMemoryOwnership( - bool wait_for_operations_to_complete) override; - - // Returns the buffer's value as an XLA Literal. If the value has previously - // been prefetched to the host, then returns the prefetched version, otherwise - // copies the buffer to the host. Blocks until the value is ready. If - // `discard_cached_copy` is true then buffer will no longer keep hold of a - // cached copy of the literal (i.e. The reference to the host value will be - // removed.) If a layout is passed than a literal with this layout will be - // returned. - using PjRtBuffer::ToLiteral; - StatusOr> ToLiteral( - bool discard_cached_copy, absl::optional layout) override; - - // Initiates a copy of the buffer to the host. Does not block waiting for - // the transfer to complete. The value can be retrieved by a later call to - // ToLiteral(). If a layout is passed then a cached copy with this layout will - // be created. - using PjRtBuffer::CopyToHostAsync; - Status CopyToHostAsync(absl::optional layout) override; - - // Drops the buffer's reference to its associated device memory, leaving the - // buffer in an invalid state. The memory will be freed lazily when all async - // operations using the buffer have completed, according to the allocation - // semantics of the underlying platform. Delete may briefly block if another - // thread is in the process of enqueuing an operation on this buffer, but it - // will never block for a stream operation to complete. If an external - // framework holds a reference to the TrackedDeviceBuffer via - // GetBufferWithExternalReference, the memory will not be freed until the - // external framework drops the reference. - void Delete() override; - - // True if and only if Delete or Release has previously been called. - bool IsDeleted() override; - - // Returns a view of the PjRtBuffer device memory as a ShapedBuffer. The - // PjRtBuffer retains ownership of the device buffers. - StatusOr AsShapedBuffer() const; - - // Returns a hold on the TrackedDeviceBuffer holding the device - // buffers. See comment on ScopedHold. - ScopedHold GetBufferWithHold(ScopedHold::Type type); - ScopedHold GetBufferWithUsageHold() { - return GetBufferWithHold(ScopedHold::kUsage); - } - ScopedHold GetBufferWithExternalReference() { - return GetBufferWithHold(ScopedHold::kExternalReference); - } - - // Copies the buffer to device `dst_device`, performing a d2d transfer when - // `dst_device` is sharing the same Client, and performing a d2h and h2d copy - // if `dst_device` lives on a different Client. - // Returns an error if the buffer is already on dst_device. - StatusOr> CopyToDevice( - PjRtDevice* dst_device) override; - - // Copies the buffer to the remote device encoded in serialized_descriptor. - // This call must be preceded by a call to MakeCrossHostReceiveBuffers on the - // remote host's destination device. MakeCrossHostReceiveBuffers takes an - // array of shapes to construct the destination buffers, and a callback - // supplies an array containing both the destination buffers, and a serialized - // descriptor for each buffer. For each destination buffer there should be a - // matching call to src->CopyToRemoteDevice on a remote host for a src buffer - // of the corresponding shape. serialized_descriptor is the string returned by - // the callback along with the corresponding destination buffer. - Status CopyToRemoteDevice(absl::string_view serialized_descriptor) override; - - // Blocks the host until the buffer's value has been computed and is ready for - // immediate use on the device. Useful in particular for timing benchmarks. - Status BlockHostUntilReady() override; - - // Whether this buffer is on CPU and thus allows for certain optimizations. - bool IsOnCpu() const override; - - // Similar to Delete, drops the buffer's reference to its associated device - // memory, leaving the buffer in an invalid state, but returns the - // TrackedDeviceBuffer rather than freeing the device memory, so that another - // framework can take ownership of it. The buffer returned from Release may - // be safely dropped at any time even if it still has pending async - // operations. The client should call BlockHostUntilReady before calling - // Release with wait_for_operations_to_complete=false, to ensure that the host - // has synchronized past any outstanding write operations to the buffer. If - // wait_for_operations_to_complete=true the host will block until any - // potentially outstanding asynchronous operations have completed before - // returning, in which case it is safe to read or mutate the returned buffer. - // If the buffer was shared via an external reference it is the client's - // responsibility that accesses via that reference do not interfere with - // accesses via the buffer returned from Release. - StatusOr> Release( - bool wait_for_operations_to_complete); - - private: - friend class PjRtClient; - // The cached value of the buffer on the host, produced either from a call to - // CopyToHost or from a call to ToLiteral. Once a value has been fetched to - // the host, it persists Delete() is called or the PjRtBuffer is destroyed. - struct HostValue { - absl::Notification ready; - // status and value are valid for reading only after `ready` has been - // notified. - Status status; - std::shared_ptr value; - }; - - // Blocks in mu_.Await until there are no more usage holds. - void WaitForOutstandingUsageHolds() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); - - // Blocks in mu_.Await until there is no donation hold. - void WaitForOutstandingDonationHold() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); - - // Adds a hold of 'type' and returns device_buffer_. Returns an error if - // device_buffer_ is null, or if a donation hold was requested when there is - // an outstanding external hold. - StatusOr> GetBufferForHoldLocked( - ScopedHold::Type type) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); - - // Adds a hold of hold->type() and initializes `hold` with device_buffer_. - // Initializes hold with an error if device_buffer_ is null, or if a donation - // hold was requested when there is an outstanding external hold. - void AcquireHoldLocked(ScopedHold* hold) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); - - // Drops a usage hold and calls device_buffer_->AddUsageEvent. Does a sanity - // check that buffer==device_buffer_ or device_buffer_==nullptr. Called after - // device_buffer_ was successfully enqueued on a stream. - void ConvertUsageHold(TrackedDeviceBuffer* buffer, se::Stream* usage_stream, - std::shared_ptr event, - bool reference_held); - - // Drops a donation hold and makes *this invalid for further use. Does a - // sanity check that buffer==device_buffer_. Called after device_buffer_ was - // successfully donated to an execution. - void ConfirmDonation(TrackedDeviceBuffer* device_buffer); - - // Initiates a copy of the buffer to the host. Does not block waiting for - // the transfer to complete. A host value is returned and if - // `discard_cached_copy` is false stored in an internal buffer so that future - // transfers don't have to transfer the data from host again. If a layout is - // passed then a literal of this layout will be returned and possibly cached. - StatusOr> CopyToHostAsyncInternal( - bool discard_cached_copy, absl::optional layout); - - // Drops a hold without taking any other action. Does a sanity check that - // buffer==device_buffer_ or device_buffer_==nullptr. - void DropHold(ScopedHold::Type type, TrackedDeviceBuffer* buffer); - - StatusOr, - std::shared_ptr>> - CopyToDeviceHelper(PjRtDevice* dst_device, LocalDeviceState* dst_local_device, - LocalDeviceState* transfer_local_device, - se::Stream* transfer_stream, - std::shared_ptr src_device_buffer); - - PjRtStreamExecutorClient* const client_; - const Shape on_host_shape_; - const Shape on_device_shape_; - PjRtStreamExecutorDevice* const device_; - - mutable absl::Mutex mu_; - std::shared_ptr device_buffer_ TF_GUARDED_BY(mu_); - absl::flat_hash_map> host_values_ - TF_GUARDED_BY(mu_); - std::shared_ptr host_value_ TF_GUARDED_BY(mu_); - // Count of holds on the buffer. - std::array holds_ TF_GUARDED_BY(mu_); - // Semaphore used to ensure there is only one outstanding donation hold. - Semaphore donation_semaphore_; -}; - class ExecuteContext { public: virtual ~ExecuteContext() = default; @@ -1103,148 +458,6 @@ class PjRtExecutable { virtual void Delete() = 0; }; -// Wraps one or more XLA LocalExecutables (one per partition, as specified by -// the build options). -class PjRtStreamExecutorExecutable : public PjRtExecutable { - public: - PjRtStreamExecutorExecutable( - std::vector> executables, - bool parameter_is_tupled_arguments, - std::shared_ptr device_assignment, - std::vector addressable_device_logical_ids, - std::vector addressable_devices, - PjRtStreamExecutorClient* client); - - ~PjRtStreamExecutorExecutable() override = default; - - PjRtStreamExecutorClient* client() const override { return client_; } - - const std::string& name() const override; - - int num_replicas() const override { - return executables_[0]->build_options().num_replicas(); - } - - int num_partitions() const override { - return executables_[0]->build_options().num_partitions(); - } - - int64 SizeOfGeneratedCodeInBytes() const override { - int64 size = 0; - for (auto& executable : executables_) { - size += executable->executable()->SizeOfGeneratedCodeInBytes(); - } - return size; - } - - const DeviceAssignment& device_assignment() const override { - return *device_assignment_; - } - - absl::Span addressable_device_logical_ids() - const override { - return addressable_device_logical_ids_; - } - - absl::Span addressable_devices() const override { - return addressable_devices_; - } - - // Return an HloModule per partition. - StatusOr>> GetHloModules() - const override; - - StatusOr>>> Execute( - absl::Span> argument_handles, - const ExecuteOptions& options) const override; - - StatusOr>> ExecuteSharded( - absl::Span argument_handles, PjRtDevice* device, - const ExecuteOptions& options) const override; - - StatusOr>> ExecutePortable( - absl::Span argument_handles, PjRtDevice* device, - const ExecuteOptions& options) const override; - - void Delete() override { executables_.clear(); } - - absl::Span> executables() const { - return executables_; - } - - protected: - bool parameter_is_tupled_arguments() const { - return parameter_is_tupled_arguments_; - } - - private: - friend class PjRtStreamExecutorClient; - // Initializes information about which arguments to which executables must be - // donated due to aliases that were specified by the computation. - Status SetUpDonation(bool tuple_inputs); - - virtual bool MustDonateParameter(int executable_idx, int parameter) const; - - virtual StatusOr> - MakeExecutionInputsAndWaitForEvents( - int device_ordinal, const ExecuteOptions& options, - absl::Span argument_handles, - absl::Span device_buffers, - absl::flat_hash_set& events) const; - - StatusOr EnqueueExecution( - absl::Span argument_handles, int replica, - int partition, int executable_idx, const RunId& run_id, - const ExecuteOptions& options, PjRtDevice* device, - std::vector* device_buffers, - std::shared_ptr device_assignment) const; - - virtual std::vector> MakeOutputBuffers( - int device_ordinal, const ExecuteOptions& options, - ScopedShapedBuffer result_buffer, - std::shared_ptr definition_event, - PjRtDevice* device) const; - - StatusOr>> ExecuteHelper( - absl::Span argument_handles, int replica, - int partition, const RunId& run_id, const ExecuteOptions& options, - PjRtDevice* device = nullptr) const; - - // Create shared pointers so we can free them after the execution: with - // asynchronous execution, the process being executed can outlive the - // executable itself. - PjRtStreamExecutorClient* const client_; - // One executable per partition. - std::vector> executables_; - // Per-executable set of parameters that have any aliased buffers and thus - // must be donated when executing the computation. - std::vector> parameters_that_must_be_donated_; - std::shared_ptr device_assignment_; - - // True if the executables were compiled expecting arguments in a single - // tuple. - const bool parameter_is_tupled_arguments_; - - // The replica and partition indices of device_assignment_ to be run by this - // client. On single-host platforms without partitioning, this is all replicas - // (i.e. addressable_device_logical_ids_[i] = (i, 0)), but this may not be the - // case on multi-host platforms. If there are 4 replicas and 2 partitions on a - // single host platform, size of addressable_device_logical_ids_ is 4*2 = 8. - std::vector addressable_device_logical_ids_; - - // addressable_devices_[i] is the Device to which - // addressable_device_logical_ids_[i] is assigned. shared_ptrs instead of - // unique_ptrs to play well with the Python bindings (see xla.cc). - std::vector addressable_devices_; -}; - -// Executables can donate buffers so that buffers can be aliased from inputs -// to outputs. This function returns the list of parameters that must be -// donated when executable is run. tuple_inputs reflects the option that -// executable was compiled with. -StatusOr> GetParametersThatMustBeDonated( - const HloModule& hlo_module, bool tuple_inputs); - } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_PJRT_PJRT_CLIENT_H_ diff --git a/tensorflow/compiler/xla/pjrt/pjrt_client.cc b/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc similarity index 99% rename from tensorflow/compiler/xla/pjrt/pjrt_client.cc rename to tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc index ae801de2105..bbba9df69c0 100644 --- a/tensorflow/compiler/xla/pjrt/pjrt_client.cc +++ b/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc @@ -62,7 +62,7 @@ limitations under the License. // See the comment on LocalDeviceState::AllocationModel for a discussion of the // different allocation semantics on CPU, GPU, and TPU. -#include "tensorflow/compiler/xla/pjrt/pjrt_client.h" +#include "tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h" #include #include diff --git a/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h b/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h new file mode 100644 index 00000000000..d7e0af21d4d --- /dev/null +++ b/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h @@ -0,0 +1,782 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_PJRT_PJRT_STREAM_EXECUTOR_CLIENT_H_ +#define TENSORFLOW_COMPILER_XLA_PJRT_PJRT_STREAM_EXECUTOR_CLIENT_H_ + +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/container/inlined_vector.h" +#include "absl/strings/string_view.h" +#include "absl/synchronization/mutex.h" +#include "absl/synchronization/notification.h" +#include "absl/types/optional.h" +#include "absl/types/span.h" +#include "tensorflow/compiler/xla/client/executable_build_options.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/layout.h" +#include "tensorflow/compiler/xla/pjrt/local_device_state.h" +#include "tensorflow/compiler/xla/pjrt/pjrt_client.h" +#include "tensorflow/compiler/xla/pjrt/tracked_device_buffer.h" +#include "tensorflow/compiler/xla/service/computation_placer.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/shaped_buffer.h" +#include "tensorflow/compiler/xla/shape.h" +#include "tensorflow/compiler/xla/status.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/casts.h" +#include "tensorflow/core/platform/fingerprint.h" +#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { + +class PjRtStreamExecutorDevice : public PjRtDevice { + public: + explicit PjRtStreamExecutorDevice( + int id, std::unique_ptr local_device_state, + std::string device_kind, int host_id = 0) + : id_(id), + device_ordinal_( + local_device_state ? local_device_state->device_ordinal() : -1), + local_device_state_(std::move(local_device_state)), + host_id_(host_id), + device_kind_(std::move(device_kind)) {} + ~PjRtStreamExecutorDevice() override {} + + // Must set client exactly once. + void SetClient(PjRtClient* client) { + CHECK(client_ == nullptr); + client_ = client; + } + + int host_id() const override { return host_id_; } + + // Return `platform_id` from client. + PjRtPlatformId platform_id() const; + + // Return `platform_name` from client. + const std::string& platform_name() const; + + PjRtClient* client() const override { return client_; } + + int id() const override { return id_; } + + bool IsAddressable() const override { return device_ordinal_ != -1; } + + int local_hardware_id() const override { return device_ordinal_; } + + // If this is a device local to this host, returns a LocalDeviceState object + // that can be used to manipulate the device. Returns nullptr if the device is + // not local to this host. + LocalDeviceState* local_device_state() const { + return local_device_state_.get(); + } + + // If this is a device local to this host, returns a LocalDeviceState object + // that can be used to manipulate the device. Returns an error if the device + // is not local to this host. + StatusOr GetLocalDeviceState() const; + + const std::string& device_kind() const override { return device_kind_; } + + std::string DebugString() const override; + + Status TransferToInfeed(const LiteralSlice& literal) const override; + + StatusOr TransferFromOutfeed(const Shape& shape) const override; + + private: + const int id_; + const int device_ordinal_; // -1 means not local. + const std::unique_ptr local_device_state_; + const int host_id_; + const std::string device_kind_; + PjRtClient* client_ = nullptr; +}; + +class PjRtStreamExecutorClient : public PjRtClient { + public: + // `allocator` may null, in which case the platform default allocator is used. + explicit PjRtStreamExecutorClient( + std::string platform_name, LocalClient* client, + std::vector> devices, + int host_id, std::unique_ptr allocator, + std::unique_ptr host_memory_allocator, + bool should_stage_host_to_device_transfers, + std::unique_ptr gpu_run_options); + ~PjRtStreamExecutorClient() override = default; + + int host_id() const override { return host_id_; } + + int device_count() const override { return devices_.size(); } + int addressable_device_count() const override { + return local_devices_.size(); + } + absl::Span devices() const override { return devices_; } + absl::Span local_devices() const override { + return local_devices_; + } + + StatusOr LookupDevice(int device_id) const override { + auto it = id_to_device_.find(device_id); + if (it != id_to_device_.end()) { + return it->second; + } + return InvalidArgument("No matching device found for device_id %d", + device_id); + } + + StatusOr LookupAddressableDevice( + int local_hardware_id) const override; + + PjRtPlatformId platform_id() const override { return platform_id_; } + const std::string& platform_name() const override { return platform_name_; } + + // Most platforms expect device-to-device transfers to be enqueued on the + // source d2d stream, but some platforms use the destination d2d stream. This + // function specifies which one the platform expects. + virtual bool EnqueueD2DTransfersOnSrcStream() const { return true; } + + StatusOr GetDefaultDeviceAssignment( + int num_replicas, int num_partitions) const override; + + StatusOr> Compile( + const XlaComputation& computation, CompileOptions options) override; + + StatusOr> ExecutableFingerprint( + const PjRtExecutable& executable) const override { + return absl::optional(); + } + + std::unique_ptr GetHloCostAnalysis() override; + + // Creates a buffer on the device without initializing or copying any data. + // An optional `definition_event` may be speficied that can be used to + // ensure the buffer isn't referenced until some external mechanism has + // initialized the data. + StatusOr> CreateUninitializedBuffer( + const Shape& shape, PjRtDevice* device) override; + StatusOr> CreateUninitializedBuffer( + const Shape& shape, PjRtDevice* device, + std::shared_ptr definition_event); + + StatusOr> BufferFromHostBuffer( + const void* data, const Shape& shape, + HostBufferSemantics host_buffer_semantics, + std::shared_ptr buffer_reference, PjRtDevice* device) override; + + StatusOr> BufferFromHostLiteral( + const LiteralSlice& literal, PjRtDevice* device) override; + + void MakeCrossHostReceiveBuffers( + absl::Span shapes, PjRtDevice* device, + PjRtCrossHostRecvNotifier&& notifier) override; + + StatusOr CreateChannelHandle() override { + return client()->CreateChannelHandle(); + } + StatusOr CreateDeviceToHostChannelHandle() override { + return client()->CreateDeviceToHostChannelHandle(); + } + StatusOr CreateHostToDeviceChannelHandle() override { + return client()->CreateHostToDeviceChannelHandle(); + } + + LocalDeviceState& device_state(int device_ordinal) const { + return *tensorflow::down_cast( + local_devices_.at(device_ordinal)) + ->local_device_state(); + } + LocalClient* client() const { return client_; } + se::DeviceMemoryAllocator* allocator() const { return allocator_; } + tensorflow::Allocator* host_memory_allocator() const { + return host_memory_allocator_.get(); + } + bool should_stage_host_to_device_transfers() const { + return should_stage_host_to_device_transfers_; + } + + gpu::GpuExecutableRunOptions* gpu_run_options() const { + return gpu_run_options_.get(); + } + + tensorflow::thread::ThreadPool* h2d_transfer_pool() { + return &h2d_transfer_pool_; + } + + protected: + friend class PjRtStreamExecutorBuffer; + virtual void EnqueueCrossHostReceive( + std::vector>&& buffers, + std::shared_ptr definition_event, + PjRtCrossHostRecvNotifier&& notifier) const { + notifier(Unimplemented("Cross host receives not implemented.")); + } + + virtual Status CopyToRemoteDevice( + PjRtBuffer* buffer, absl::string_view serialized_descriptor) const { + return Unimplemented("Cross host sends not implemented."); + } + + const PjRtPlatformId platform_id_; + const std::string platform_name_; + LocalClient* client_; + + // Allocator to be used for staging memory transfers to devices. + std::unique_ptr host_memory_allocator_; + + // Includes all devices, including non-local devices on multi-host platforms. + std::vector> owned_devices_; + // Pointers to `owned_devices_`. + std::vector devices_; + // Maps Device::id() to the corresponding Device. Includes all devices. + std::map id_to_device_; + // Local devices indexed by local device ordinal. + std::vector local_devices_; + int host_id_; + + se::DeviceMemoryAllocator* allocator_; + std::unique_ptr owned_allocator_; + + // Should we always prefer to stage host-to-device transfers via memory + // allocated on host_memory_allocator_? True only on GPU, where we prefer to + // transfer via pinned memory. + bool should_stage_host_to_device_transfers_; + + std::unique_ptr gpu_run_options_; + + tensorflow::thread::ThreadPool h2d_transfer_pool_; +}; + +// Converts a 2D set of Device objects indexed by [replica][partition] into an +// xla::DeviceAssignment. +StatusOr DevicesToDeviceAssignment( + absl::Span> devices); + +class PjRtStreamExecutorBuffer : public PjRtBuffer { + public: + // Helper class to retain a "hold" on a PjRtStreamExecutorBuffer. A ScopedHold + // may not outlive its parent PjRtStreamExecutorBuffer. + // + // There are three types of hold, as follows: + // + // 1) Usage hold: a transient hold while an operation using the buffer is + // being enqueued onto a stream. + // A client acquires a usage hold by calling + // PjRtStreamExecutorBuffer::GetBufferWithHold(kUsage) or the convenience + // wrapper GetBufferWithUsageHold(). If the enqueue completes successfully the + // hold should be released using a call to ConvertUsageHold. If the ScopedHold + // is deleted without ConvertUsageHold being called, e.g., on error, the hold + // is dropped. It is legal to drop a usage hold instead of calling + // ConvertUsageHold, even if the buffer was successfully enqueued, as long as + // the client ensures that all necessary synchronization has been done. + // + // 2) External hold: a potentially long-lived hold while the buffer is being + // shared by an external framework, e.g., NumPy. + // A client acquires an external hold by calling + // PjRtStreamExecutorBuffer::GetBufferWithHold(kExternal) or the convenience + // wrapper GetBufferWithExternalReference and releases it by deleting the + // ScopedHold. The external framework should not modify the underlying buffer + // unless it is confident via its own synchronization that modifications do + // not race with reads from the PjRtStreamExecutorBuffer. + // + // 3) Donation hold: a transient hold while an execution that donates the + // buffer is being enqueued onto the compute stream. + // A client acquires a donation hold by calling + // PjRtStreamExecutorBuffer::GetBufferWithHold(kDonation). If the enqueue + // completes successfully the hold should be released using a call to + // ConfirmDonation after which the buffer is invalid. If the ScopedHold is + // deleted without ConfirmDonation being called, e.g., on error, the hold is + // dropped and the buffer remains valid. If the buffer is successfully + // enqueued the client *must* call ConfirmDonation. + // + // Donation holds behave like exclusive write locks: when a donation hold + // has been acquired, any attempt to acquire another hold of any type will + // block until the donation hold is dropped or confirmed. Acquiring a donation + // hold will fail with an error if there is any outstanding external hold, and + // will block if there are any outstanding usage holds until those holds are + // dropped or converted. + // + // Calls to PjRtStreamExecutorBuffer::Release (and transitively to + // PjRtStreamExecutorBuffer::Delete() and ~PjRtStreamExecutorBuffer()) will + // block until all usage and donation holds are either deleted or + // converted/confirmed. + class ScopedHold { + public: + enum Type { kUsage = 0, kExternalReference, kDonation, kMaxValue }; + // Use a State enum instead of encoding the state in an error Status to + // avoid creating Status values in non-error cases. Creating a Status + // entails several allocations and can add O(us) to every use of a hold. + enum State { + kUninitialized = 0, + kValid, + kMoved, + kConverted, + kReleased, + kDonated, + kError + }; + + ~ScopedHold(); + ScopedHold(ScopedHold&& other); + ScopedHold(const ScopedHold&) = delete; + ScopedHold& operator=(const ScopedHold&) = delete; + + Type type() const { return type_; } + + Status status() const { + // Lazily create Status values only when they are requested. + switch (state_) { + case kUninitialized: + return InvalidArgument("Buffer has not been initialized"); + case kValid: + return Status::OK(); + case kMoved: + return InvalidArgument("Buffer has been moved."); + case kConverted: + return InvalidArgument("Buffer has been converted"); + case kReleased: + return InvalidArgument("Buffer has been released"); + case kDonated: + return InvalidArgument("Buffer has been donated"); + case kError: + return buffer_or_.status(); + default: + CHECK(false) << "Unexpected state value " << state_; + } + } + bool ok() const { return state_ == kValid; } + + // Access to the underlying device buffer storage. Requires this->ok(). + const std::shared_ptr& buffer() const { + CHECK_EQ(state_, kValid); + CHECK_NE(buffer_or_.ValueOrDie(), nullptr); + return buffer_or_.ValueOrDie(); + } + TrackedDeviceBuffer* operator->() const { return buffer().get(); } + const TrackedDeviceBuffer& operator*() const { return *buffer(); } + + // Converts the hold into a usage event. Only valid for holds of type + // kUsage. + // + // usage_stream: the stream that the buffer was used on. + // event: an event that has been recorded on usage_stream after + // the buffer was used. + // reference_held: true if and only if the caller has caused a + // reference to this->buffer() to stay live until after + // the host is sure that the usage (transfer or execution) + // has completed. + void ConvertUsageHold(se::Stream* usage_stream, + std::shared_ptr event, + bool reference_held); + + // Confirms that the buffer was successfully donated to an execution. + // Only valid for holds of type kDonation. Causes the buffer to become + // invalid. + void ConfirmDonation(); + + // Adds the held device buffers in order to 'iterator'. Used to add the + // buffers to an ExecutionInput. We require but do not verify that + // 'iterator' when passed in is pointing to a sub-tuple of the + // ExecutionInput whose on_device_shape matches that of the + // TrackedDeviceBuffer. 'end' is used to check that 'iterator' doesn't run + // out of bounds. Donates the device buffers if the hold type is kDonation, + // otherwise retains ownership of the device buffers. + void AddToInput(ShapeTree::iterator* iterator, + const ShapeTree::iterator& end, + ExecutionInput* execution_input, + se::DeviceMemoryAllocator* allocator) const; + + private: + friend class PjRtStreamExecutorBuffer; + friend class PjRtStreamExecutorClient; + + // Helper struct that makes it possible to move a ScopedHold through a + // closure. + using ForClosure = + std::tuple>>; + + ScopedHold(PjRtStreamExecutorBuffer* parent, Type type) + : parent_(parent), type_(type), state_(kUninitialized) {} + explicit ScopedHold(const ForClosure& closure_helper) + : parent_(std::get<0>(closure_helper)), + type_(std::get<1>(closure_helper)), + state_(std::get<2>(closure_helper)), + buffer_or_(std::get<3>(closure_helper)) { + // Check the buffer is not in an error state. + CHECK(buffer_or_.ValueOrDie() != nullptr); + } + + // Sets buffer state. + void SetState(State state) { state_ = state; } + + // Sets buffer_or_. Called by parent_ to initialize the hold. + void Acquire(StatusOr>&& buffer_or); + // Releases the contents of *this, so *this can subsequently be + // deleted without releasing the parent's hold. Should be passed to the + // appropriate constructor of another ScopedHold, e.g., when a hold must be + // passed through a closure that is incompatible with std::move. + ForClosure ToClosure(); + + PjRtStreamExecutorBuffer* const parent_; + const Type type_; + + // There is an invariant that if ok() then + // buffer_or_.ValueOrDie() != nullptr. + State state_; + StatusOr> buffer_or_; + }; + + PjRtStreamExecutorBuffer(Shape on_host_shape, Shape on_device_shape, + std::shared_ptr device_buffer, + PjRtClient* client, PjRtDevice* device); + ~PjRtStreamExecutorBuffer() override; + + PjRtStreamExecutorBuffer(const PjRtStreamExecutorBuffer&) = delete; + PjRtStreamExecutorBuffer(PjRtStreamExecutorBuffer&&) = delete; + PjRtStreamExecutorBuffer& operator=(const PjRtStreamExecutorBuffer&) = delete; + PjRtStreamExecutorBuffer& operator=(PjRtStreamExecutorBuffer&&) = delete; + + const Shape& on_host_shape() const override { return on_host_shape_; } + const Shape& on_device_shape() const override { return on_device_shape_; } + PjRtStreamExecutorDevice* device() const override { return device_; } + PjRtPlatformId platform_id() const { return client_->platform_id(); } + const std::string& platform_name() const { return client_->platform_name(); } + PjRtStreamExecutorClient* client() const override { return client_; } + bool IsEmptyTuple() const { + return on_host_shape_.IsTuple() && on_host_shape_.tuple_shapes_size() == 0; + } + + int64 OnDeviceSizeInBytes() const override; + + // Implement PjRtBuffer::ExternalReferenceHold a wrapped + // ScopedHold::kExternalReference. + class ScopedHoldAsExternalReference + : public PjRtBuffer::ExternalReferenceHold { + public: + explicit ScopedHoldAsExternalReference(ScopedHold hold) + : external_reference_(std::move(hold)) { + CHECK(hold.type() == ScopedHold::kExternalReference); + } + + ~ScopedHoldAsExternalReference() override = default; + + void* OpaqueDeviceMemoryDataPointer() const override { + return external_reference_->device_memory().front().opaque(); + } + + private: + ScopedHold external_reference_; + }; + StatusOr> AcquireExternalReference() + override; + + StatusOr>> ReleaseDeviceMemoryOwnership( + bool wait_for_operations_to_complete) override; + + using PjRtBuffer::ToLiteral; + StatusOr> ToLiteral( + bool discard_cached_copy, absl::optional layout) override; + + using PjRtBuffer::CopyToHostAsync; + Status CopyToHostAsync(absl::optional layout) override; + + // Drops the buffer's reference to its associated device memory, leaving the + // buffer in an invalid state. The memory will be freed lazily when all async + // operations using the buffer have completed, according to the allocation + // semantics of the underlying platform. Delete may briefly block if another + // thread is in the process of enqueuing an operation on this buffer, but it + // will never block for a stream operation to complete. If an external + // framework holds a reference to the TrackedDeviceBuffer via + // GetBufferWithExternalReference, the memory will not be freed until the + // external framework drops the reference. + void Delete() override; + + bool IsDeleted() override; + + // Returns a view of the PjRtBuffer device memory as a ShapedBuffer. The + // PjRtBuffer retains ownership of the device buffers. + StatusOr AsShapedBuffer() const; + + // Returns a hold on the TrackedDeviceBuffer holding the device + // buffers. See comment on ScopedHold. + ScopedHold GetBufferWithHold(ScopedHold::Type type); + ScopedHold GetBufferWithUsageHold() { + return GetBufferWithHold(ScopedHold::kUsage); + } + ScopedHold GetBufferWithExternalReference() { + return GetBufferWithHold(ScopedHold::kExternalReference); + } + + StatusOr> CopyToDevice( + PjRtDevice* dst_device) override; + + Status CopyToRemoteDevice(absl::string_view serialized_descriptor) override; + + Status BlockHostUntilReady() override; + + bool IsOnCpu() const override; + + // Similar to Delete, drops the buffer's reference to its associated device + // memory, leaving the buffer in an invalid state, but returns the + // TrackedDeviceBuffer rather than freeing the device memory, so that another + // framework can take ownership of it. The buffer returned from Release may + // be safely dropped at any time even if it still has pending async + // operations. The client should call BlockHostUntilReady before calling + // Release with wait_for_operations_to_complete=false, to ensure that the host + // has synchronized past any outstanding write operations to the buffer. If + // wait_for_operations_to_complete=true the host will block until any + // potentially outstanding asynchronous operations have completed before + // returning, in which case it is safe to read or mutate the returned buffer. + // If the buffer was shared via an external reference it is the client's + // responsibility that accesses via that reference do not interfere with + // accesses via the buffer returned from Release. + StatusOr> Release( + bool wait_for_operations_to_complete); + + private: + friend class PjRtClient; + // The cached value of the buffer on the host, produced either from a call to + // CopyToHost or from a call to ToLiteral. Once a value has been fetched to + // the host, it persists Delete() is called or the PjRtBuffer is destroyed. + struct HostValue { + absl::Notification ready; + // status and value are valid for reading only after `ready` has been + // notified. + Status status; + std::shared_ptr value; + }; + + // Blocks in mu_.Await until there are no more usage holds. + void WaitForOutstandingUsageHolds() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); + + // Blocks in mu_.Await until there is no donation hold. + void WaitForOutstandingDonationHold() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); + + // Adds a hold of 'type' and returns device_buffer_. Returns an error if + // device_buffer_ is null, or if a donation hold was requested when there is + // an outstanding external hold. + StatusOr> GetBufferForHoldLocked( + ScopedHold::Type type) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); + + // Adds a hold of hold->type() and initializes `hold` with device_buffer_. + // Initializes hold with an error if device_buffer_ is null, or if a donation + // hold was requested when there is an outstanding external hold. + void AcquireHoldLocked(ScopedHold* hold) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); + + // Drops a usage hold and calls device_buffer_->AddUsageEvent. Does a sanity + // check that buffer==device_buffer_ or device_buffer_==nullptr. Called after + // device_buffer_ was successfully enqueued on a stream. + void ConvertUsageHold(TrackedDeviceBuffer* buffer, se::Stream* usage_stream, + std::shared_ptr event, + bool reference_held); + + // Drops a donation hold and makes *this invalid for further use. Does a + // sanity check that buffer==device_buffer_. Called after device_buffer_ was + // successfully donated to an execution. + void ConfirmDonation(TrackedDeviceBuffer* device_buffer); + + // Initiates a copy of the buffer to the host. Does not block waiting for + // the transfer to complete. A host value is returned and if + // `discard_cached_copy` is false stored in an internal buffer so that future + // transfers don't have to transfer the data from host again. If a layout is + // passed then a literal of this layout will be returned and possibly cached. + StatusOr> CopyToHostAsyncInternal( + bool discard_cached_copy, absl::optional layout); + + // Drops a hold without taking any other action. Does a sanity check that + // buffer==device_buffer_ or device_buffer_==nullptr. + void DropHold(ScopedHold::Type type, TrackedDeviceBuffer* buffer); + + StatusOr, + std::shared_ptr>> + CopyToDeviceHelper(PjRtDevice* dst_device, LocalDeviceState* dst_local_device, + LocalDeviceState* transfer_local_device, + se::Stream* transfer_stream, + std::shared_ptr src_device_buffer); + + PjRtStreamExecutorClient* const client_; + const Shape on_host_shape_; + const Shape on_device_shape_; + PjRtStreamExecutorDevice* const device_; + + mutable absl::Mutex mu_; + std::shared_ptr device_buffer_ TF_GUARDED_BY(mu_); + absl::flat_hash_map> host_values_ + TF_GUARDED_BY(mu_); + std::shared_ptr host_value_ TF_GUARDED_BY(mu_); + // Count of holds on the buffer. + std::array holds_ TF_GUARDED_BY(mu_); + // Semaphore used to ensure there is only one outstanding donation hold. + Semaphore donation_semaphore_; +}; + +// Wraps one or more XLA LocalExecutables (one per partition, as specified by +// the build options). +class PjRtStreamExecutorExecutable : public PjRtExecutable { + public: + PjRtStreamExecutorExecutable( + std::vector> executables, + bool parameter_is_tupled_arguments, + std::shared_ptr device_assignment, + std::vector addressable_device_logical_ids, + std::vector addressable_devices, + PjRtStreamExecutorClient* client); + + ~PjRtStreamExecutorExecutable() override = default; + + PjRtStreamExecutorClient* client() const override { return client_; } + + const std::string& name() const override; + + int num_replicas() const override { + return executables_[0]->build_options().num_replicas(); + } + + int num_partitions() const override { + return executables_[0]->build_options().num_partitions(); + } + + int64 SizeOfGeneratedCodeInBytes() const override { + int64 size = 0; + for (auto& executable : executables_) { + size += executable->executable()->SizeOfGeneratedCodeInBytes(); + } + return size; + } + + const DeviceAssignment& device_assignment() const override { + return *device_assignment_; + } + + absl::Span addressable_device_logical_ids() + const override { + return addressable_device_logical_ids_; + } + + absl::Span addressable_devices() const override { + return addressable_devices_; + } + + // Return an HloModule per partition. + StatusOr>> GetHloModules() + const override; + + StatusOr>>> Execute( + absl::Span> argument_handles, + const ExecuteOptions& options) const override; + + StatusOr>> ExecuteSharded( + absl::Span argument_handles, PjRtDevice* device, + const ExecuteOptions& options) const override; + + StatusOr>> ExecutePortable( + absl::Span argument_handles, PjRtDevice* device, + const ExecuteOptions& options) const override; + + void Delete() override { executables_.clear(); } + + absl::Span> executables() const { + return executables_; + } + + protected: + bool parameter_is_tupled_arguments() const { + return parameter_is_tupled_arguments_; + } + + private: + friend class PjRtStreamExecutorClient; + // Initializes information about which arguments to which executables must be + // donated due to aliases that were specified by the computation. + Status SetUpDonation(bool tuple_inputs); + + virtual bool MustDonateParameter(int executable_idx, int parameter) const; + + virtual StatusOr> + MakeExecutionInputsAndWaitForEvents( + int device_ordinal, const ExecuteOptions& options, + absl::Span argument_handles, + absl::Span device_buffers, + absl::flat_hash_set& events) const; + + StatusOr EnqueueExecution( + absl::Span argument_handles, int replica, + int partition, int executable_idx, const RunId& run_id, + const ExecuteOptions& options, PjRtDevice* device, + std::vector* device_buffers, + std::shared_ptr device_assignment) const; + + virtual std::vector> MakeOutputBuffers( + int device_ordinal, const ExecuteOptions& options, + ScopedShapedBuffer result_buffer, + std::shared_ptr definition_event, + PjRtDevice* device) const; + + StatusOr>> ExecuteHelper( + absl::Span argument_handles, int replica, + int partition, const RunId& run_id, const ExecuteOptions& options, + PjRtDevice* device = nullptr) const; + + // Create shared pointers so we can free them after the execution: with + // asynchronous execution, the process being executed can outlive the + // executable itself. + PjRtStreamExecutorClient* const client_; + // One executable per partition. + std::vector> executables_; + // Per-executable set of parameters that have any aliased buffers and thus + // must be donated when executing the computation. + std::vector> parameters_that_must_be_donated_; + std::shared_ptr device_assignment_; + + // True if the executables were compiled expecting arguments in a single + // tuple. + const bool parameter_is_tupled_arguments_; + + // The replica and partition indices of device_assignment_ to be run by this + // client. On single-host platforms without partitioning, this is all replicas + // (i.e. addressable_device_logical_ids_[i] = (i, 0)), but this may not be the + // case on multi-host platforms. If there are 4 replicas and 2 partitions on a + // single host platform, size of addressable_device_logical_ids_ is 4*2 = 8. + std::vector addressable_device_logical_ids_; + + // addressable_devices_[i] is the Device to which + // addressable_device_logical_ids_[i] is assigned. shared_ptrs instead of + // unique_ptrs to play well with the Python bindings (see xla.cc). + std::vector addressable_devices_; +}; + +// Executables can donate buffers so that buffers can be aliased from inputs +// to outputs. This function returns the list of parameters that must be +// donated when executable is run. tuple_inputs reflects the option that +// executable was compiled with. +StatusOr> GetParametersThatMustBeDonated( + const HloModule& hlo_module, bool tuple_inputs); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_PJRT_PJRT_STREAM_EXECUTOR_CLIENT_H_ diff --git a/tensorflow/compiler/xla/pjrt/tpu_client.cc b/tensorflow/compiler/xla/pjrt/tpu_client.cc index 8222874a229..830f7c66f2a 100644 --- a/tensorflow/compiler/xla/pjrt/tpu_client.cc +++ b/tensorflow/compiler/xla/pjrt/tpu_client.cc @@ -23,7 +23,7 @@ limitations under the License. #include "absl/status/status.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/pjrt/local_device_state.h" -#include "tensorflow/compiler/xla/pjrt/pjrt_client.h" +#include "tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h" #include "tensorflow/compiler/xla/pjrt/tracked_device_buffer.h" #include "tensorflow/compiler/xla/service/shaped_buffer.h" #include "tensorflow/compiler/xla/shape.h" diff --git a/tensorflow/compiler/xla/pjrt/tpu_client.h b/tensorflow/compiler/xla/pjrt/tpu_client.h index f17d82a270e..d9847bb85de 100644 --- a/tensorflow/compiler/xla/pjrt/tpu_client.h +++ b/tensorflow/compiler/xla/pjrt/tpu_client.h @@ -20,7 +20,7 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/pjrt/pjrt_client.h" +#include "tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/stream_executor/tpu/tpu_topology.h" diff --git a/tensorflow/compiler/xla/python/BUILD b/tensorflow/compiler/xla/python/BUILD index e8a61c0e916..bd1d8becf5d 100644 --- a/tensorflow/compiler/xla/python/BUILD +++ b/tensorflow/compiler/xla/python/BUILD @@ -211,6 +211,7 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/pjrt:pjrt_client", + "//tensorflow/compiler/xla/pjrt:pjrt_stream_executor_client", # TODO(zhangqiaorjc): Remove after adding a factory method for PjRtBuffer. "//tensorflow/compiler/xla/pjrt:tracked_device_buffer", "//tensorflow/stream_executor:device_memory", "//tensorflow/stream_executor:platform", diff --git a/tensorflow/compiler/xla/python/dlpack.cc b/tensorflow/compiler/xla/python/dlpack.cc index 47bc1f66569..a67358497a6 100644 --- a/tensorflow/compiler/xla/python/dlpack.cc +++ b/tensorflow/compiler/xla/python/dlpack.cc @@ -25,6 +25,7 @@ limitations under the License. #include "include/dlpack/dlpack.h" // from @dlpack #include "pybind11/pytypes.h" #include "tensorflow/compiler/xla/pjrt/pjrt_client.h" +#include "tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h" #include "tensorflow/compiler/xla/pjrt/tracked_device_buffer.h" #include "tensorflow/compiler/xla/python/python_ref_manager.h" #include "tensorflow/compiler/xla/python/traceback.h" diff --git a/tensorflow/compiler/xla/python/outfeed_receiver.cc b/tensorflow/compiler/xla/python/outfeed_receiver.cc index df4bc3025f1..92aae351085 100644 --- a/tensorflow/compiler/xla/python/outfeed_receiver.cc +++ b/tensorflow/compiler/xla/python/outfeed_receiver.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include #include #include "absl/container/flat_hash_map.h" diff --git a/tensorflow/compiler/xla/python/tpu_driver/client/BUILD b/tensorflow/compiler/xla/python/tpu_driver/client/BUILD index 28a491c0326..3296d29c8ff 100644 --- a/tensorflow/compiler/xla/python/tpu_driver/client/BUILD +++ b/tensorflow/compiler/xla/python/tpu_driver/client/BUILD @@ -26,7 +26,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla/client:executable_build_options", - "//tensorflow/compiler/xla/pjrt:pjrt_client", + "//tensorflow/compiler/xla/pjrt:pjrt_stream_executor_client", "//tensorflow/compiler/xla/pjrt:semaphore", "//tensorflow/compiler/xla/python/tpu_driver", "//tensorflow/compiler/xla/python/tpu_driver:direct_tpu_driver", diff --git a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.h b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.h index 89dca53bbb6..cc4e4471e8e 100644 --- a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.h +++ b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.h @@ -24,7 +24,7 @@ limitations under the License. #include "absl/synchronization/notification.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/client/executable_build_options.h" -#include "tensorflow/compiler/xla/pjrt/pjrt_client.h" +#include "tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h" #include "tensorflow/compiler/xla/python/tpu_driver/tpu_driver.h" #include "tensorflow/compiler/xla/python/tpu_driver/tpu_driver.pb.h" #include "tensorflow/compiler/xla/service/shaped_buffer.h"