Split StreamExecutor impl of pjrt_client into its own files.

Also removed duplicate comments for subclass methods.

PiperOrigin-RevId: 346807683
Change-Id: I2ab9b5038d5e5fc991bfee53d0081c5eecf51906
This commit is contained in:
Qiao Zhang 2020-12-10 10:01:36 -08:00 committed by TensorFlower Gardener
parent d9f4007ff6
commit ec86d80f19
17 changed files with 830 additions and 803 deletions

View File

@ -119,12 +119,39 @@ cc_library(
cc_library(
name = "pjrt_client",
srcs = ["pjrt_client.cc"],
hdrs = ["pjrt_client.h"],
visibility = ["//tensorflow/compiler/xla:friends"],
deps = [
"//tensorflow/compiler/xla:executable_run_options",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto_cc",
"//tensorflow/compiler/xla/client:executable_build_options",
"//tensorflow/compiler/xla/client:xla_computation",
"//tensorflow/compiler/xla/pjrt/distributed:protocol_proto_cc",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_cost_analysis",
"//tensorflow/core:lib",
"@com_google_absl//absl/base",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
"@com_google_absl//absl/types:span",
],
)
cc_library(
name = "pjrt_stream_executor_client",
srcs = ["pjrt_stream_executor_client.cc"],
hdrs = ["pjrt_stream_executor_client.h"],
visibility = ["//tensorflow/compiler/xla:friends"],
deps = [
":event_pool",
":local_device_state",
":pjrt_client",
":tracked_device_buffer",
"//tensorflow/compiler/xla:cpu_function_runtime",
"//tensorflow/compiler/xla:executable_run_options",
@ -181,7 +208,7 @@ cc_library(
],
deps = [
":local_device_state",
":pjrt_client",
":pjrt_stream_executor_client",
":tracked_device_buffer",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status",
@ -215,7 +242,7 @@ cc_library(
srcs = ["interpreter_device.cc"],
hdrs = ["interpreter_device.h"],
deps = [
":pjrt_client",
":pjrt_stream_executor_client",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla/client:client_library",
"//tensorflow/compiler/xla/service:interpreter_plugin",
@ -229,7 +256,7 @@ cc_library(
srcs = ["cpu_device.cc"],
hdrs = ["cpu_device.h"],
deps = [
":pjrt_client",
":pjrt_stream_executor_client",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla/client:client_library",
"//tensorflow/compiler/xla/service:platform_util",
@ -242,7 +269,7 @@ cc_library(
srcs = ["gpu_device.cc"],
hdrs = ["gpu_device.h"],
deps = [
":pjrt_client",
":pjrt_stream_executor_client",
"@com_google_absl//absl/container:flat_hash_map",
"//tensorflow/compiler/xla/service/gpu:gpu_executable_run_options",
"//tensorflow/compiler/xla:statusor",
@ -279,6 +306,7 @@ tf_cc_test(
deps = [
":gpu_device",
":pjrt_client",
":pjrt_stream_executor_client",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla/client:executable_build_options",
"//tensorflow/compiler/xla/client:xla_builder",

View File

@ -17,7 +17,7 @@ limitations under the License.
#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
#include "tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h"
#include "tensorflow/compiler/xla/service/platform_util.h"
namespace xla {

View File

@ -18,7 +18,7 @@ limitations under the License.
#include <memory>
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
#include "tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h"
#include "tensorflow/compiler/xla/statusor.h"
namespace xla {

View File

@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/pjrt/gpu_device.h"
#include "absl/container/flat_hash_map.h"
#include "tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h"
#ifdef NCCL_ENABLED
#include "third_party/nccl/nccl.h"

View File

@ -19,7 +19,7 @@ limitations under the License.
#include <memory>
#include "tensorflow/compiler/xla/pjrt/distributed/client.h"
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
#include "tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/core/common_runtime/bfc_allocator.h"

View File

@ -17,7 +17,7 @@ limitations under the License.
#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
#include "tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h"
#include "tensorflow/compiler/xla/service/platform_util.h"
namespace xla {

View File

@ -18,7 +18,7 @@ limitations under the License.
#include <memory>
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
#include "tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h"
#include "tensorflow/compiler/xla/statusor.h"
namespace xla {

View File

@ -20,30 +20,19 @@ limitations under the License.
#include <string>
#include <vector>
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/container/inlined_vector.h"
#include "absl/strings/string_view.h"
#include "absl/synchronization/mutex.h"
#include "absl/synchronization/notification.h"
#include "absl/types/optional.h"
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/client/executable_build_options.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/layout.h"
#include "tensorflow/compiler/xla/pjrt/local_device_state.h"
#include "tensorflow/compiler/xla/pjrt/tracked_device_buffer.h"
#include "tensorflow/compiler/xla/service/computation_placer.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h"
#include "tensorflow/compiler/xla/service/hlo_cost_analysis.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/shaped_buffer.h"
#include "tensorflow/compiler/xla/shape.h"
#include "tensorflow/compiler/xla/status.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/casts.h"
#include "tensorflow/core/platform/fingerprint.h"
@ -106,78 +95,6 @@ class PjRtDevice {
virtual StatusOr<Literal> TransferFromOutfeed(const Shape& shape) const = 0;
};
class PjRtStreamExecutorDevice : public PjRtDevice {
public:
explicit PjRtStreamExecutorDevice(
int id, std::unique_ptr<LocalDeviceState> local_device_state,
std::string device_kind, int host_id = 0)
: id_(id),
device_ordinal_(
local_device_state ? local_device_state->device_ordinal() : -1),
local_device_state_(std::move(local_device_state)),
host_id_(host_id),
device_kind_(std::move(device_kind)) {}
~PjRtStreamExecutorDevice() override {}
// Must set client exactly once.
void SetClient(PjRtClient* client) {
CHECK(client_ == nullptr);
client_ = client;
}
// Task ID. This is always 0 on single-task setup.
int host_id() const override { return host_id_; }
// Return `platform_id` from client.
PjRtPlatformId platform_id() const;
// Return `platform_name` from client.
const std::string& platform_name() const;
PjRtClient* client() const override { return client_; }
// The ID of this device. IDs are unique among devices of this type
// (e.g. CPUs, GPUs). On multi-host platforms, this will be unique across all
// hosts' devices. This is the ID that should be used in a DeviceAssignment.
int id() const override { return id_; }
bool IsAddressable() const override { return device_ordinal_ != -1; }
int local_hardware_id() const override { return device_ordinal_; }
// If this is a device local to this host, returns a LocalDeviceState object
// that can be used to manipulate the device. Returns nullptr if the device is
// not local to this host.
LocalDeviceState* local_device_state() const {
return local_device_state_.get();
}
// If this is a device local to this host, returns a LocalDeviceState object
// that can be used to manipulate the device. Returns an error if the device
// is not local to this host.
StatusOr<LocalDeviceState*> GetLocalDeviceState() const;
// A vendor-dependent string that uniquely identifies the kind of device.
const std::string& device_kind() const override { return device_kind_; }
std::string DebugString() const override;
// Transfer the given literal to the infeed queue of the given localdevice.
Status TransferToInfeed(const LiteralSlice& literal) const override;
// Transfer and return a value of the given shape from the outfeed of the
// given device.
StatusOr<Literal> TransferFromOutfeed(const Shape& shape) const override;
private:
const int id_;
const int device_ordinal_; // -1 means not local.
const std::unique_ptr<LocalDeviceState> local_device_state_;
const int host_id_;
const std::string device_kind_;
PjRtClient* client_ = nullptr;
};
// Forward declaration.
class PjRtBuffer;
@ -333,181 +250,6 @@ class PjRtClient {
virtual StatusOr<ChannelHandle> CreateHostToDeviceChannelHandle() = 0;
};
class PjRtStreamExecutorClient : public PjRtClient {
public:
// `allocator` may null, in which case the platform default allocator is used.
explicit PjRtStreamExecutorClient(
std::string platform_name, LocalClient* client,
std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices,
int host_id, std::unique_ptr<se::DeviceMemoryAllocator> allocator,
std::unique_ptr<tensorflow::Allocator> host_memory_allocator,
bool should_stage_host_to_device_transfers,
std::unique_ptr<gpu::GpuExecutableRunOptions> gpu_run_options);
~PjRtStreamExecutorClient() override = default;
int host_id() const override { return host_id_; }
int device_count() const override { return devices_.size(); }
int addressable_device_count() const override {
return local_devices_.size();
}
absl::Span<PjRtDevice* const> devices() const override { return devices_; }
absl::Span<PjRtDevice* const> local_devices() const override {
return local_devices_;
}
StatusOr<PjRtDevice*> LookupDevice(int device_id) const override {
auto it = id_to_device_.find(device_id);
if (it != id_to_device_.end()) {
return it->second;
}
return InvalidArgument("No matching device found for device_id %d",
device_id);
}
StatusOr<PjRtDevice*> LookupAddressableDevice(
int local_hardware_id) const override;
PjRtPlatformId platform_id() const override { return platform_id_; }
const std::string& platform_name() const override { return platform_name_; }
// Most platforms expect device-to-device transfers to be enqueued on the
// source d2d stream, but some platforms use the destination d2d stream. This
// function specifies which one the platform expects.
virtual bool EnqueueD2DTransfersOnSrcStream() const { return true; }
StatusOr<DeviceAssignment> GetDefaultDeviceAssignment(
int num_replicas, int num_partitions) const override;
StatusOr<std::unique_ptr<PjRtExecutable>> Compile(
const XlaComputation& computation, CompileOptions options) override;
// Generates a unique fingerprint for `executable`.
StatusOr<absl::optional<std::string>> ExecutableFingerprint(
const PjRtExecutable& executable) const override {
return absl::optional<std::string>();
}
// Returns a backend-specific HLO cost analysis visitor.
std::unique_ptr<HloCostAnalysis> GetHloCostAnalysis() override;
// Creates a buffer on the device without initializing or copying any data.
// An optional `definition_event` may be speficied that can be used to
// ensure the buffer isn't referenced until some external mechanism has
// initialized the data.
// NOTE: The sequencing mechanism is not guaranteed to be supported by all
// future backends and so callers should avoid wherever possible.
StatusOr<std::unique_ptr<PjRtBuffer>> CreateUninitializedBuffer(
const Shape& shape, PjRtDevice* device) override;
StatusOr<std::unique_ptr<PjRtBuffer>> CreateUninitializedBuffer(
const Shape& shape, PjRtDevice* device,
std::shared_ptr<BufferSequencingEvent> definition_event);
StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostBuffer(
const void* data, const Shape& shape,
HostBufferSemantics host_buffer_semantics,
std::shared_ptr<void> buffer_reference, PjRtDevice* device) override;
// Note that literal must remain in scope until the transfer has completed, so
// the caller should, for example, wait for BlockHostUntilReady() completes on
// the return value before letting literal go out of scope.
StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostLiteral(
const LiteralSlice& literal, PjRtDevice* device) override;
// Asynchronously makes a vector of PjRtBuffers that can be used to receive
// cross host transfers using `client` on `device'. `shapes` must be the exact
// shapes, with identical layouts, corresponding to the buffers that will be
// sent. When resources for the transfer are available, notifier will be
// called with a vector of PjRtCrossHostRecvBuffer structs, one for each
// shape in `shapes`. Each struct contains a buffer that will contain the
// received value, and an opaque string that should be transmitted to the
// sending host and used in a call to CopyToRemoteDevice. None of the recv
// buffers will become ready until *all* of the sends have completed.
void MakeCrossHostReceiveBuffers(
absl::Span<const Shape> shapes, PjRtDevice* device,
PjRtCrossHostRecvNotifier&& notifier) override;
StatusOr<ChannelHandle> CreateChannelHandle() override {
return client()->CreateChannelHandle();
}
StatusOr<ChannelHandle> CreateDeviceToHostChannelHandle() override {
return client()->CreateDeviceToHostChannelHandle();
}
StatusOr<ChannelHandle> CreateHostToDeviceChannelHandle() override {
return client()->CreateHostToDeviceChannelHandle();
}
LocalDeviceState& device_state(int device_ordinal) const {
return *tensorflow::down_cast<PjRtStreamExecutorDevice*>(
local_devices_.at(device_ordinal))
->local_device_state();
}
LocalClient* client() const { return client_; }
se::DeviceMemoryAllocator* allocator() const { return allocator_; }
tensorflow::Allocator* host_memory_allocator() const {
return host_memory_allocator_.get();
}
bool should_stage_host_to_device_transfers() const {
return should_stage_host_to_device_transfers_;
}
gpu::GpuExecutableRunOptions* gpu_run_options() const {
return gpu_run_options_.get();
}
tensorflow::thread::ThreadPool* h2d_transfer_pool() {
return &h2d_transfer_pool_;
}
protected:
friend class PjRtStreamExecutorBuffer;
virtual void EnqueueCrossHostReceive(
std::vector<std::unique_ptr<PjRtBuffer>>&& buffers,
std::shared_ptr<BufferSequencingEvent> definition_event,
PjRtCrossHostRecvNotifier&& notifier) const {
notifier(Unimplemented("Cross host receives not implemented."));
}
virtual Status CopyToRemoteDevice(
PjRtBuffer* buffer, absl::string_view serialized_descriptor) const {
return Unimplemented("Cross host sends not implemented.");
}
const PjRtPlatformId platform_id_;
const std::string platform_name_;
LocalClient* client_;
// Allocator to be used for staging memory transfers to devices.
std::unique_ptr<tensorflow::Allocator> host_memory_allocator_;
// Includes all devices, including non-local devices on multi-host platforms.
std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> owned_devices_;
// Pointers to `owned_devices_`.
std::vector<PjRtDevice*> devices_;
// Maps Device::id() to the corresponding Device. Includes all devices.
std::map<int, PjRtDevice*> id_to_device_;
// Local devices indexed by local device ordinal.
std::vector<PjRtDevice*> local_devices_;
int host_id_;
se::DeviceMemoryAllocator* allocator_;
std::unique_ptr<se::DeviceMemoryAllocator> owned_allocator_;
// Should we always prefer to stage host-to-device transfers via memory
// allocated on host_memory_allocator_? True only on GPU, where we prefer to
// transfer via pinned memory.
bool should_stage_host_to_device_transfers_;
std::unique_ptr<gpu::GpuExecutableRunOptions> gpu_run_options_;
tensorflow::thread::ThreadPool h2d_transfer_pool_;
};
// Converts a 2D set of Device objects indexed by [replica][partition] into an
// xla::DeviceAssignment.
StatusOr<DeviceAssignment> DevicesToDeviceAssignment(
absl::Span<const std::vector<PjRtDevice*>> devices);
// Holds a reference from Python to a tuple of device buffers. A PjRtBuffer
// can be either valid or invalid. An invalid buffer is one that has never been
// initialized, or a buffer that has been deleted (e.g., by calling Delete, or
@ -625,393 +367,6 @@ class PjRtBuffer {
virtual bool IsOnCpu() const = 0;
};
class PjRtStreamExecutorBuffer : public PjRtBuffer {
public:
// Helper class to retain a "hold" on a PjRtBuffer. A ScopedHold may not
// outlive its parent PjRtBuffer.
//
// There are three types of hold, as follows:
//
// 1) Usage hold: a transient hold while an operation using the buffer is
// being enqueued onto a stream.
// A client acquires a usage hold by calling
// PjRtBuffer::GetBufferWithHold(kUsage) or the convenience wrapper
// GetBufferWithUsageHold(). If the enqueue completes successfully the hold
// should be released using a call to ConvertUsageHold. If the ScopedHold is
// deleted without ConvertUsageHold being called, e.g., on error, the hold is
// dropped. It is legal to drop a usage hold instead of calling
// ConvertUsageHold, even if the buffer was successfully enqueued, as long as
// the client ensures that all necessary synchronization has been done.
//
// 2) External hold: a potentially long-lived hold while the buffer is being
// shared by an external framework, e.g., NumPy.
// A client acquires an external hold by calling
// PjRtBuffer::GetBufferWithHold(kExternal) or the convenience wrapper
// GetBufferWithExternalReference and releases it by deleting the ScopedHold.
// The external framework should not modify the underlying buffer unless it is
// confident via its own synchronization that modifications do not race with
// reads from the PjRtBuffer.
//
// 3) Donation hold: a transient hold while an execution that donates the
// buffer is being enqueued onto the compute stream.
// A client acquires a donation hold by calling
// PjRtBuffer::GetBufferWithHold(kDonation). If the enqueue completes
// successfully the hold should be released using a call to ConfirmDonation
// after which the buffer is invalid. If the ScopedHold is deleted without
// ConfirmDonation being called, e.g., on error, the hold is dropped and the
// buffer remains valid. If the buffer is successfully enqueued the client
// *must* call ConfirmDonation.
//
// Donation holds behave like exclusive write locks: when a donation hold
// has been acquired, any attempt to acquire another hold of any type will
// block until the donation hold is dropped or confirmed. Acquiring a donation
// hold will fail with an error if there is any outstanding external hold, and
// will block if there are any outstanding usage holds until those holds are
// dropped or converted.
//
// Calls to PjRtBuffer::Release (and transitively to
// PjRtBuffer::Delete() and ~PjRtBuffer()) will block until all usage
// and donation holds are either deleted or converted/confirmed.
class ScopedHold {
public:
enum Type { kUsage = 0, kExternalReference, kDonation, kMaxValue };
// Use a State enum instead of encoding the state in an error Status to
// avoid creating Status values in non-error cases. Creating a Status
// entails several allocations and can add O(us) to every use of a hold.
enum State {
kUninitialized = 0,
kValid,
kMoved,
kConverted,
kReleased,
kDonated,
kError
};
~ScopedHold();
ScopedHold(ScopedHold&& other);
ScopedHold(const ScopedHold&) = delete;
ScopedHold& operator=(const ScopedHold&) = delete;
Type type() const { return type_; }
Status status() const {
// Lazily create Status values only when they are requested.
switch (state_) {
case kUninitialized:
return InvalidArgument("Buffer has not been initialized");
case kValid:
return Status::OK();
case kMoved:
return InvalidArgument("Buffer has been moved.");
case kConverted:
return InvalidArgument("Buffer has been converted");
case kReleased:
return InvalidArgument("Buffer has been released");
case kDonated:
return InvalidArgument("Buffer has been donated");
case kError:
return buffer_or_.status();
default:
CHECK(false) << "Unexpected state value " << state_;
}
}
bool ok() const { return state_ == kValid; }
// Access to the underlying device buffer storage. Requires this->ok().
const std::shared_ptr<TrackedDeviceBuffer>& buffer() const {
CHECK_EQ(state_, kValid);
CHECK_NE(buffer_or_.ValueOrDie(), nullptr);
return buffer_or_.ValueOrDie();
}
TrackedDeviceBuffer* operator->() const { return buffer().get(); }
const TrackedDeviceBuffer& operator*() const { return *buffer(); }
// Converts the hold into a usage event. Only valid for holds of type
// kUsage.
//
// usage_stream: the stream that the buffer was used on.
// event: an event that has been recorded on usage_stream after
// the buffer was used.
// reference_held: true if and only if the caller has caused a
// reference to this->buffer() to stay live until after
// the host is sure that the usage (transfer or execution)
// has completed.
void ConvertUsageHold(se::Stream* usage_stream,
std::shared_ptr<BufferSequencingEvent> event,
bool reference_held);
// Confirms that the buffer was successfully donated to an execution.
// Only valid for holds of type kDonation. Causes the buffer to become
// invalid.
void ConfirmDonation();
// Adds the held device buffers in order to 'iterator'. Used to add the
// buffers to an ExecutionInput. We require but do not verify that
// 'iterator' when passed in is pointing to a sub-tuple of the
// ExecutionInput whose on_device_shape matches that of the
// TrackedDeviceBuffer. 'end' is used to check that 'iterator' doesn't run
// out of bounds. Donates the device buffers if the hold type is kDonation,
// otherwise retains ownership of the device buffers.
void AddToInput(ShapeTree<MaybeOwningDeviceMemory>::iterator* iterator,
const ShapeTree<MaybeOwningDeviceMemory>::iterator& end,
ExecutionInput* execution_input,
se::DeviceMemoryAllocator* allocator) const;
private:
friend class PjRtStreamExecutorBuffer;
friend class PjRtStreamExecutorClient;
// Helper struct that makes it possible to move a ScopedHold through a
// closure.
using ForClosure =
std::tuple<PjRtStreamExecutorBuffer*, Type, State,
StatusOr<std::shared_ptr<TrackedDeviceBuffer>>>;
ScopedHold(PjRtStreamExecutorBuffer* parent, Type type)
: parent_(parent), type_(type), state_(kUninitialized) {}
explicit ScopedHold(const ForClosure& closure_helper)
: parent_(std::get<0>(closure_helper)),
type_(std::get<1>(closure_helper)),
state_(std::get<2>(closure_helper)),
buffer_or_(std::get<3>(closure_helper)) {
// Check the buffer is not in an error state.
CHECK(buffer_or_.ValueOrDie() != nullptr);
}
// Sets buffer state.
void SetState(State state) { state_ = state; }
// Sets buffer_or_. Called by parent_ to initialize the hold.
void Acquire(StatusOr<std::shared_ptr<TrackedDeviceBuffer>>&& buffer_or);
// Releases the contents of *this, so *this can subsequently be
// deleted without releasing the parent's hold. Should be passed to the
// appropriate constructor of another ScopedHold, e.g., when a hold must be
// passed through a closure that is incompatible with std::move.
ForClosure ToClosure();
PjRtStreamExecutorBuffer* const parent_;
const Type type_;
// There is an invariant that if ok() then
// buffer_or_.ValueOrDie() != nullptr.
State state_;
StatusOr<std::shared_ptr<TrackedDeviceBuffer>> buffer_or_;
};
PjRtStreamExecutorBuffer(Shape on_host_shape, Shape on_device_shape,
std::shared_ptr<TrackedDeviceBuffer> device_buffer,
PjRtClient* client, PjRtDevice* device);
~PjRtStreamExecutorBuffer() override;
PjRtStreamExecutorBuffer(const PjRtStreamExecutorBuffer&) = delete;
PjRtStreamExecutorBuffer(PjRtStreamExecutorBuffer&&) = delete;
PjRtStreamExecutorBuffer& operator=(const PjRtStreamExecutorBuffer&) = delete;
PjRtStreamExecutorBuffer& operator=(PjRtStreamExecutorBuffer&&) = delete;
const Shape& on_host_shape() const override { return on_host_shape_; }
const Shape& on_device_shape() const override { return on_device_shape_; }
PjRtStreamExecutorDevice* device() const override { return device_; }
PjRtPlatformId platform_id() const { return client_->platform_id(); }
const std::string& platform_name() const { return client_->platform_name(); }
PjRtStreamExecutorClient* client() const override { return client_; }
bool IsEmptyTuple() const {
return on_host_shape_.IsTuple() && on_host_shape_.tuple_shapes_size() == 0;
}
// Returns the size of the on-device representation of this buffer in bytes.
int64 OnDeviceSizeInBytes() const override;
// Implement PjRtBuffer::ExternalReferenceHold a wrapped
// ScopedHold::kExternalReference.
class ScopedHoldAsExternalReference
: public PjRtBuffer::ExternalReferenceHold {
public:
explicit ScopedHoldAsExternalReference(ScopedHold hold)
: external_reference_(std::move(hold)) {
CHECK(hold.type() == ScopedHold::kExternalReference);
}
~ScopedHoldAsExternalReference() override = default;
void* OpaqueDeviceMemoryDataPointer() const override {
return external_reference_->device_memory().front().opaque();
}
private:
ScopedHold external_reference_;
};
StatusOr<std::unique_ptr<ExternalReferenceHold>> AcquireExternalReference()
override;
StatusOr<absl::optional<std::shared_ptr<void>>> ReleaseDeviceMemoryOwnership(
bool wait_for_operations_to_complete) override;
// Returns the buffer's value as an XLA Literal. If the value has previously
// been prefetched to the host, then returns the prefetched version, otherwise
// copies the buffer to the host. Blocks until the value is ready. If
// `discard_cached_copy` is true then buffer will no longer keep hold of a
// cached copy of the literal (i.e. The reference to the host value will be
// removed.) If a layout is passed than a literal with this layout will be
// returned.
using PjRtBuffer::ToLiteral;
StatusOr<std::shared_ptr<Literal>> ToLiteral(
bool discard_cached_copy, absl::optional<xla::Layout> layout) override;
// Initiates a copy of the buffer to the host. Does not block waiting for
// the transfer to complete. The value can be retrieved by a later call to
// ToLiteral(). If a layout is passed then a cached copy with this layout will
// be created.
using PjRtBuffer::CopyToHostAsync;
Status CopyToHostAsync(absl::optional<xla::Layout> layout) override;
// Drops the buffer's reference to its associated device memory, leaving the
// buffer in an invalid state. The memory will be freed lazily when all async
// operations using the buffer have completed, according to the allocation
// semantics of the underlying platform. Delete may briefly block if another
// thread is in the process of enqueuing an operation on this buffer, but it
// will never block for a stream operation to complete. If an external
// framework holds a reference to the TrackedDeviceBuffer via
// GetBufferWithExternalReference, the memory will not be freed until the
// external framework drops the reference.
void Delete() override;
// True if and only if Delete or Release has previously been called.
bool IsDeleted() override;
// Returns a view of the PjRtBuffer device memory as a ShapedBuffer. The
// PjRtBuffer retains ownership of the device buffers.
StatusOr<ShapedBuffer> AsShapedBuffer() const;
// Returns a hold on the TrackedDeviceBuffer holding the device
// buffers. See comment on ScopedHold.
ScopedHold GetBufferWithHold(ScopedHold::Type type);
ScopedHold GetBufferWithUsageHold() {
return GetBufferWithHold(ScopedHold::kUsage);
}
ScopedHold GetBufferWithExternalReference() {
return GetBufferWithHold(ScopedHold::kExternalReference);
}
// Copies the buffer to device `dst_device`, performing a d2d transfer when
// `dst_device` is sharing the same Client, and performing a d2h and h2d copy
// if `dst_device` lives on a different Client.
// Returns an error if the buffer is already on dst_device.
StatusOr<std::unique_ptr<PjRtBuffer>> CopyToDevice(
PjRtDevice* dst_device) override;
// Copies the buffer to the remote device encoded in serialized_descriptor.
// This call must be preceded by a call to MakeCrossHostReceiveBuffers on the
// remote host's destination device. MakeCrossHostReceiveBuffers takes an
// array of shapes to construct the destination buffers, and a callback
// supplies an array containing both the destination buffers, and a serialized
// descriptor for each buffer. For each destination buffer there should be a
// matching call to src->CopyToRemoteDevice on a remote host for a src buffer
// of the corresponding shape. serialized_descriptor is the string returned by
// the callback along with the corresponding destination buffer.
Status CopyToRemoteDevice(absl::string_view serialized_descriptor) override;
// Blocks the host until the buffer's value has been computed and is ready for
// immediate use on the device. Useful in particular for timing benchmarks.
Status BlockHostUntilReady() override;
// Whether this buffer is on CPU and thus allows for certain optimizations.
bool IsOnCpu() const override;
// Similar to Delete, drops the buffer's reference to its associated device
// memory, leaving the buffer in an invalid state, but returns the
// TrackedDeviceBuffer rather than freeing the device memory, so that another
// framework can take ownership of it. The buffer returned from Release may
// be safely dropped at any time even if it still has pending async
// operations. The client should call BlockHostUntilReady before calling
// Release with wait_for_operations_to_complete=false, to ensure that the host
// has synchronized past any outstanding write operations to the buffer. If
// wait_for_operations_to_complete=true the host will block until any
// potentially outstanding asynchronous operations have completed before
// returning, in which case it is safe to read or mutate the returned buffer.
// If the buffer was shared via an external reference it is the client's
// responsibility that accesses via that reference do not interfere with
// accesses via the buffer returned from Release.
StatusOr<std::shared_ptr<TrackedDeviceBuffer>> Release(
bool wait_for_operations_to_complete);
private:
friend class PjRtClient;
// The cached value of the buffer on the host, produced either from a call to
// CopyToHost or from a call to ToLiteral. Once a value has been fetched to
// the host, it persists Delete() is called or the PjRtBuffer is destroyed.
struct HostValue {
absl::Notification ready;
// status and value are valid for reading only after `ready` has been
// notified.
Status status;
std::shared_ptr<Literal> value;
};
// Blocks in mu_.Await until there are no more usage holds.
void WaitForOutstandingUsageHolds() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
// Blocks in mu_.Await until there is no donation hold.
void WaitForOutstandingDonationHold() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
// Adds a hold of 'type' and returns device_buffer_. Returns an error if
// device_buffer_ is null, or if a donation hold was requested when there is
// an outstanding external hold.
StatusOr<std::shared_ptr<TrackedDeviceBuffer>> GetBufferForHoldLocked(
ScopedHold::Type type) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
// Adds a hold of hold->type() and initializes `hold` with device_buffer_.
// Initializes hold with an error if device_buffer_ is null, or if a donation
// hold was requested when there is an outstanding external hold.
void AcquireHoldLocked(ScopedHold* hold) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
// Drops a usage hold and calls device_buffer_->AddUsageEvent. Does a sanity
// check that buffer==device_buffer_ or device_buffer_==nullptr. Called after
// device_buffer_ was successfully enqueued on a stream.
void ConvertUsageHold(TrackedDeviceBuffer* buffer, se::Stream* usage_stream,
std::shared_ptr<BufferSequencingEvent> event,
bool reference_held);
// Drops a donation hold and makes *this invalid for further use. Does a
// sanity check that buffer==device_buffer_. Called after device_buffer_ was
// successfully donated to an execution.
void ConfirmDonation(TrackedDeviceBuffer* device_buffer);
// Initiates a copy of the buffer to the host. Does not block waiting for
// the transfer to complete. A host value is returned and if
// `discard_cached_copy` is false stored in an internal buffer so that future
// transfers don't have to transfer the data from host again. If a layout is
// passed then a literal of this layout will be returned and possibly cached.
StatusOr<std::shared_ptr<HostValue>> CopyToHostAsyncInternal(
bool discard_cached_copy, absl::optional<xla::Layout> layout);
// Drops a hold without taking any other action. Does a sanity check that
// buffer==device_buffer_ or device_buffer_==nullptr.
void DropHold(ScopedHold::Type type, TrackedDeviceBuffer* buffer);
StatusOr<std::pair<std::unique_ptr<PjRtBuffer>,
std::shared_ptr<BufferSequencingEvent>>>
CopyToDeviceHelper(PjRtDevice* dst_device, LocalDeviceState* dst_local_device,
LocalDeviceState* transfer_local_device,
se::Stream* transfer_stream,
std::shared_ptr<TrackedDeviceBuffer> src_device_buffer);
PjRtStreamExecutorClient* const client_;
const Shape on_host_shape_;
const Shape on_device_shape_;
PjRtStreamExecutorDevice* const device_;
mutable absl::Mutex mu_;
std::shared_ptr<TrackedDeviceBuffer> device_buffer_ TF_GUARDED_BY(mu_);
absl::flat_hash_map<xla::Layout, std::shared_ptr<HostValue>> host_values_
TF_GUARDED_BY(mu_);
std::shared_ptr<HostValue> host_value_ TF_GUARDED_BY(mu_);
// Count of holds on the buffer.
std::array<int, ScopedHold::Type::kMaxValue> holds_ TF_GUARDED_BY(mu_);
// Semaphore used to ensure there is only one outstanding donation hold.
Semaphore donation_semaphore_;
};
class ExecuteContext {
public:
virtual ~ExecuteContext() = default;
@ -1103,148 +458,6 @@ class PjRtExecutable {
virtual void Delete() = 0;
};
// Wraps one or more XLA LocalExecutables (one per partition, as specified by
// the build options).
class PjRtStreamExecutorExecutable : public PjRtExecutable {
public:
PjRtStreamExecutorExecutable(
std::vector<std::unique_ptr<LocalExecutable>> executables,
bool parameter_is_tupled_arguments,
std::shared_ptr<DeviceAssignment> device_assignment,
std::vector<LogicalDeviceIds> addressable_device_logical_ids,
std::vector<PjRtDevice*> addressable_devices,
PjRtStreamExecutorClient* client);
~PjRtStreamExecutorExecutable() override = default;
PjRtStreamExecutorClient* client() const override { return client_; }
const std::string& name() const override;
int num_replicas() const override {
return executables_[0]->build_options().num_replicas();
}
int num_partitions() const override {
return executables_[0]->build_options().num_partitions();
}
int64 SizeOfGeneratedCodeInBytes() const override {
int64 size = 0;
for (auto& executable : executables_) {
size += executable->executable()->SizeOfGeneratedCodeInBytes();
}
return size;
}
const DeviceAssignment& device_assignment() const override {
return *device_assignment_;
}
absl::Span<const LogicalDeviceIds> addressable_device_logical_ids()
const override {
return addressable_device_logical_ids_;
}
absl::Span<PjRtDevice* const> addressable_devices() const override {
return addressable_devices_;
}
// Return an HloModule per partition.
StatusOr<std::vector<std::shared_ptr<HloModule>>> GetHloModules()
const override;
StatusOr<std::vector<std::vector<std::unique_ptr<PjRtBuffer>>>> Execute(
absl::Span<const std::vector<PjRtBuffer*>> argument_handles,
const ExecuteOptions& options) const override;
StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>> ExecuteSharded(
absl::Span<PjRtBuffer* const> argument_handles, PjRtDevice* device,
const ExecuteOptions& options) const override;
StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>> ExecutePortable(
absl::Span<PjRtBuffer* const> argument_handles, PjRtDevice* device,
const ExecuteOptions& options) const override;
void Delete() override { executables_.clear(); }
absl::Span<const std::shared_ptr<LocalExecutable>> executables() const {
return executables_;
}
protected:
bool parameter_is_tupled_arguments() const {
return parameter_is_tupled_arguments_;
}
private:
friend class PjRtStreamExecutorClient;
// Initializes information about which arguments to which executables must be
// donated due to aliases that were specified by the computation.
Status SetUpDonation(bool tuple_inputs);
virtual bool MustDonateParameter(int executable_idx, int parameter) const;
virtual StatusOr<std::vector<ExecutionInput>>
MakeExecutionInputsAndWaitForEvents(
int device_ordinal, const ExecuteOptions& options,
absl::Span<PjRtBuffer* const> argument_handles,
absl::Span<const PjRtStreamExecutorBuffer::ScopedHold> device_buffers,
absl::flat_hash_set<BufferSequencingEvent*>& events) const;
StatusOr<ScopedShapedBuffer> EnqueueExecution(
absl::Span<PjRtBuffer* const> argument_handles, int replica,
int partition, int executable_idx, const RunId& run_id,
const ExecuteOptions& options, PjRtDevice* device,
std::vector<PjRtStreamExecutorBuffer::ScopedHold>* device_buffers,
std::shared_ptr<DeviceAssignment> device_assignment) const;
virtual std::vector<std::unique_ptr<PjRtBuffer>> MakeOutputBuffers(
int device_ordinal, const ExecuteOptions& options,
ScopedShapedBuffer result_buffer,
std::shared_ptr<BufferSequencingEvent> definition_event,
PjRtDevice* device) const;
StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>> ExecuteHelper(
absl::Span<PjRtBuffer* const> argument_handles, int replica,
int partition, const RunId& run_id, const ExecuteOptions& options,
PjRtDevice* device = nullptr) const;
// Create shared pointers so we can free them after the execution: with
// asynchronous execution, the process being executed can outlive the
// executable itself.
PjRtStreamExecutorClient* const client_;
// One executable per partition.
std::vector<std::shared_ptr<LocalExecutable>> executables_;
// Per-executable set of parameters that have any aliased buffers and thus
// must be donated when executing the computation.
std::vector<absl::flat_hash_set<int>> parameters_that_must_be_donated_;
std::shared_ptr<DeviceAssignment> device_assignment_;
// True if the executables were compiled expecting arguments in a single
// tuple.
const bool parameter_is_tupled_arguments_;
// The replica and partition indices of device_assignment_ to be run by this
// client. On single-host platforms without partitioning, this is all replicas
// (i.e. addressable_device_logical_ids_[i] = (i, 0)), but this may not be the
// case on multi-host platforms. If there are 4 replicas and 2 partitions on a
// single host platform, size of addressable_device_logical_ids_ is 4*2 = 8.
std::vector<LogicalDeviceIds> addressable_device_logical_ids_;
// addressable_devices_[i] is the Device to which
// addressable_device_logical_ids_[i] is assigned. shared_ptrs instead of
// unique_ptrs to play well with the Python bindings (see xla.cc).
std::vector<PjRtDevice*> addressable_devices_;
};
// Executables can donate buffers so that buffers can be aliased from inputs
// to outputs. This function returns the list of parameters that must be
// donated when executable is run. tuple_inputs reflects the option that
// executable was compiled with.
StatusOr<absl::flat_hash_set<int>> GetParametersThatMustBeDonated(
const HloModule& hlo_module, bool tuple_inputs);
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_PJRT_PJRT_CLIENT_H_

View File

@ -62,7 +62,7 @@ limitations under the License.
// See the comment on LocalDeviceState::AllocationModel for a discussion of the
// different allocation semantics on CPU, GPU, and TPU.
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
#include "tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h"
#include <cstddef>
#include <memory>

View File

@ -0,0 +1,782 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_PJRT_PJRT_STREAM_EXECUTOR_CLIENT_H_
#define TENSORFLOW_COMPILER_XLA_PJRT_PJRT_STREAM_EXECUTOR_CLIENT_H_
#include <memory>
#include <string>
#include <vector>
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/container/inlined_vector.h"
#include "absl/strings/string_view.h"
#include "absl/synchronization/mutex.h"
#include "absl/synchronization/notification.h"
#include "absl/types/optional.h"
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/client/executable_build_options.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/layout.h"
#include "tensorflow/compiler/xla/pjrt/local_device_state.h"
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
#include "tensorflow/compiler/xla/pjrt/tracked_device_buffer.h"
#include "tensorflow/compiler/xla/service/computation_placer.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/shaped_buffer.h"
#include "tensorflow/compiler/xla/shape.h"
#include "tensorflow/compiler/xla/status.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/casts.h"
#include "tensorflow/core/platform/fingerprint.h"
#include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/core/platform/types.h"
namespace xla {
class PjRtStreamExecutorDevice : public PjRtDevice {
public:
explicit PjRtStreamExecutorDevice(
int id, std::unique_ptr<LocalDeviceState> local_device_state,
std::string device_kind, int host_id = 0)
: id_(id),
device_ordinal_(
local_device_state ? local_device_state->device_ordinal() : -1),
local_device_state_(std::move(local_device_state)),
host_id_(host_id),
device_kind_(std::move(device_kind)) {}
~PjRtStreamExecutorDevice() override {}
// Must set client exactly once.
void SetClient(PjRtClient* client) {
CHECK(client_ == nullptr);
client_ = client;
}
int host_id() const override { return host_id_; }
// Return `platform_id` from client.
PjRtPlatformId platform_id() const;
// Return `platform_name` from client.
const std::string& platform_name() const;
PjRtClient* client() const override { return client_; }
int id() const override { return id_; }
bool IsAddressable() const override { return device_ordinal_ != -1; }
int local_hardware_id() const override { return device_ordinal_; }
// If this is a device local to this host, returns a LocalDeviceState object
// that can be used to manipulate the device. Returns nullptr if the device is
// not local to this host.
LocalDeviceState* local_device_state() const {
return local_device_state_.get();
}
// If this is a device local to this host, returns a LocalDeviceState object
// that can be used to manipulate the device. Returns an error if the device
// is not local to this host.
StatusOr<LocalDeviceState*> GetLocalDeviceState() const;
const std::string& device_kind() const override { return device_kind_; }
std::string DebugString() const override;
Status TransferToInfeed(const LiteralSlice& literal) const override;
StatusOr<Literal> TransferFromOutfeed(const Shape& shape) const override;
private:
const int id_;
const int device_ordinal_; // -1 means not local.
const std::unique_ptr<LocalDeviceState> local_device_state_;
const int host_id_;
const std::string device_kind_;
PjRtClient* client_ = nullptr;
};
class PjRtStreamExecutorClient : public PjRtClient {
public:
// `allocator` may null, in which case the platform default allocator is used.
explicit PjRtStreamExecutorClient(
std::string platform_name, LocalClient* client,
std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices,
int host_id, std::unique_ptr<se::DeviceMemoryAllocator> allocator,
std::unique_ptr<tensorflow::Allocator> host_memory_allocator,
bool should_stage_host_to_device_transfers,
std::unique_ptr<gpu::GpuExecutableRunOptions> gpu_run_options);
~PjRtStreamExecutorClient() override = default;
int host_id() const override { return host_id_; }
int device_count() const override { return devices_.size(); }
int addressable_device_count() const override {
return local_devices_.size();
}
absl::Span<PjRtDevice* const> devices() const override { return devices_; }
absl::Span<PjRtDevice* const> local_devices() const override {
return local_devices_;
}
StatusOr<PjRtDevice*> LookupDevice(int device_id) const override {
auto it = id_to_device_.find(device_id);
if (it != id_to_device_.end()) {
return it->second;
}
return InvalidArgument("No matching device found for device_id %d",
device_id);
}
StatusOr<PjRtDevice*> LookupAddressableDevice(
int local_hardware_id) const override;
PjRtPlatformId platform_id() const override { return platform_id_; }
const std::string& platform_name() const override { return platform_name_; }
// Most platforms expect device-to-device transfers to be enqueued on the
// source d2d stream, but some platforms use the destination d2d stream. This
// function specifies which one the platform expects.
virtual bool EnqueueD2DTransfersOnSrcStream() const { return true; }
StatusOr<DeviceAssignment> GetDefaultDeviceAssignment(
int num_replicas, int num_partitions) const override;
StatusOr<std::unique_ptr<PjRtExecutable>> Compile(
const XlaComputation& computation, CompileOptions options) override;
StatusOr<absl::optional<std::string>> ExecutableFingerprint(
const PjRtExecutable& executable) const override {
return absl::optional<std::string>();
}
std::unique_ptr<HloCostAnalysis> GetHloCostAnalysis() override;
// Creates a buffer on the device without initializing or copying any data.
// An optional `definition_event` may be speficied that can be used to
// ensure the buffer isn't referenced until some external mechanism has
// initialized the data.
StatusOr<std::unique_ptr<PjRtBuffer>> CreateUninitializedBuffer(
const Shape& shape, PjRtDevice* device) override;
StatusOr<std::unique_ptr<PjRtBuffer>> CreateUninitializedBuffer(
const Shape& shape, PjRtDevice* device,
std::shared_ptr<BufferSequencingEvent> definition_event);
StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostBuffer(
const void* data, const Shape& shape,
HostBufferSemantics host_buffer_semantics,
std::shared_ptr<void> buffer_reference, PjRtDevice* device) override;
StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostLiteral(
const LiteralSlice& literal, PjRtDevice* device) override;
void MakeCrossHostReceiveBuffers(
absl::Span<const Shape> shapes, PjRtDevice* device,
PjRtCrossHostRecvNotifier&& notifier) override;
StatusOr<ChannelHandle> CreateChannelHandle() override {
return client()->CreateChannelHandle();
}
StatusOr<ChannelHandle> CreateDeviceToHostChannelHandle() override {
return client()->CreateDeviceToHostChannelHandle();
}
StatusOr<ChannelHandle> CreateHostToDeviceChannelHandle() override {
return client()->CreateHostToDeviceChannelHandle();
}
LocalDeviceState& device_state(int device_ordinal) const {
return *tensorflow::down_cast<PjRtStreamExecutorDevice*>(
local_devices_.at(device_ordinal))
->local_device_state();
}
LocalClient* client() const { return client_; }
se::DeviceMemoryAllocator* allocator() const { return allocator_; }
tensorflow::Allocator* host_memory_allocator() const {
return host_memory_allocator_.get();
}
bool should_stage_host_to_device_transfers() const {
return should_stage_host_to_device_transfers_;
}
gpu::GpuExecutableRunOptions* gpu_run_options() const {
return gpu_run_options_.get();
}
tensorflow::thread::ThreadPool* h2d_transfer_pool() {
return &h2d_transfer_pool_;
}
protected:
friend class PjRtStreamExecutorBuffer;
virtual void EnqueueCrossHostReceive(
std::vector<std::unique_ptr<PjRtBuffer>>&& buffers,
std::shared_ptr<BufferSequencingEvent> definition_event,
PjRtCrossHostRecvNotifier&& notifier) const {
notifier(Unimplemented("Cross host receives not implemented."));
}
virtual Status CopyToRemoteDevice(
PjRtBuffer* buffer, absl::string_view serialized_descriptor) const {
return Unimplemented("Cross host sends not implemented.");
}
const PjRtPlatformId platform_id_;
const std::string platform_name_;
LocalClient* client_;
// Allocator to be used for staging memory transfers to devices.
std::unique_ptr<tensorflow::Allocator> host_memory_allocator_;
// Includes all devices, including non-local devices on multi-host platforms.
std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> owned_devices_;
// Pointers to `owned_devices_`.
std::vector<PjRtDevice*> devices_;
// Maps Device::id() to the corresponding Device. Includes all devices.
std::map<int, PjRtDevice*> id_to_device_;
// Local devices indexed by local device ordinal.
std::vector<PjRtDevice*> local_devices_;
int host_id_;
se::DeviceMemoryAllocator* allocator_;
std::unique_ptr<se::DeviceMemoryAllocator> owned_allocator_;
// Should we always prefer to stage host-to-device transfers via memory
// allocated on host_memory_allocator_? True only on GPU, where we prefer to
// transfer via pinned memory.
bool should_stage_host_to_device_transfers_;
std::unique_ptr<gpu::GpuExecutableRunOptions> gpu_run_options_;
tensorflow::thread::ThreadPool h2d_transfer_pool_;
};
// Converts a 2D set of Device objects indexed by [replica][partition] into an
// xla::DeviceAssignment.
StatusOr<DeviceAssignment> DevicesToDeviceAssignment(
absl::Span<const std::vector<PjRtDevice*>> devices);
class PjRtStreamExecutorBuffer : public PjRtBuffer {
public:
// Helper class to retain a "hold" on a PjRtStreamExecutorBuffer. A ScopedHold
// may not outlive its parent PjRtStreamExecutorBuffer.
//
// There are three types of hold, as follows:
//
// 1) Usage hold: a transient hold while an operation using the buffer is
// being enqueued onto a stream.
// A client acquires a usage hold by calling
// PjRtStreamExecutorBuffer::GetBufferWithHold(kUsage) or the convenience
// wrapper GetBufferWithUsageHold(). If the enqueue completes successfully the
// hold should be released using a call to ConvertUsageHold. If the ScopedHold
// is deleted without ConvertUsageHold being called, e.g., on error, the hold
// is dropped. It is legal to drop a usage hold instead of calling
// ConvertUsageHold, even if the buffer was successfully enqueued, as long as
// the client ensures that all necessary synchronization has been done.
//
// 2) External hold: a potentially long-lived hold while the buffer is being
// shared by an external framework, e.g., NumPy.
// A client acquires an external hold by calling
// PjRtStreamExecutorBuffer::GetBufferWithHold(kExternal) or the convenience
// wrapper GetBufferWithExternalReference and releases it by deleting the
// ScopedHold. The external framework should not modify the underlying buffer
// unless it is confident via its own synchronization that modifications do
// not race with reads from the PjRtStreamExecutorBuffer.
//
// 3) Donation hold: a transient hold while an execution that donates the
// buffer is being enqueued onto the compute stream.
// A client acquires a donation hold by calling
// PjRtStreamExecutorBuffer::GetBufferWithHold(kDonation). If the enqueue
// completes successfully the hold should be released using a call to
// ConfirmDonation after which the buffer is invalid. If the ScopedHold is
// deleted without ConfirmDonation being called, e.g., on error, the hold is
// dropped and the buffer remains valid. If the buffer is successfully
// enqueued the client *must* call ConfirmDonation.
//
// Donation holds behave like exclusive write locks: when a donation hold
// has been acquired, any attempt to acquire another hold of any type will
// block until the donation hold is dropped or confirmed. Acquiring a donation
// hold will fail with an error if there is any outstanding external hold, and
// will block if there are any outstanding usage holds until those holds are
// dropped or converted.
//
// Calls to PjRtStreamExecutorBuffer::Release (and transitively to
// PjRtStreamExecutorBuffer::Delete() and ~PjRtStreamExecutorBuffer()) will
// block until all usage and donation holds are either deleted or
// converted/confirmed.
class ScopedHold {
public:
enum Type { kUsage = 0, kExternalReference, kDonation, kMaxValue };
// Use a State enum instead of encoding the state in an error Status to
// avoid creating Status values in non-error cases. Creating a Status
// entails several allocations and can add O(us) to every use of a hold.
enum State {
kUninitialized = 0,
kValid,
kMoved,
kConverted,
kReleased,
kDonated,
kError
};
~ScopedHold();
ScopedHold(ScopedHold&& other);
ScopedHold(const ScopedHold&) = delete;
ScopedHold& operator=(const ScopedHold&) = delete;
Type type() const { return type_; }
Status status() const {
// Lazily create Status values only when they are requested.
switch (state_) {
case kUninitialized:
return InvalidArgument("Buffer has not been initialized");
case kValid:
return Status::OK();
case kMoved:
return InvalidArgument("Buffer has been moved.");
case kConverted:
return InvalidArgument("Buffer has been converted");
case kReleased:
return InvalidArgument("Buffer has been released");
case kDonated:
return InvalidArgument("Buffer has been donated");
case kError:
return buffer_or_.status();
default:
CHECK(false) << "Unexpected state value " << state_;
}
}
bool ok() const { return state_ == kValid; }
// Access to the underlying device buffer storage. Requires this->ok().
const std::shared_ptr<TrackedDeviceBuffer>& buffer() const {
CHECK_EQ(state_, kValid);
CHECK_NE(buffer_or_.ValueOrDie(), nullptr);
return buffer_or_.ValueOrDie();
}
TrackedDeviceBuffer* operator->() const { return buffer().get(); }
const TrackedDeviceBuffer& operator*() const { return *buffer(); }
// Converts the hold into a usage event. Only valid for holds of type
// kUsage.
//
// usage_stream: the stream that the buffer was used on.
// event: an event that has been recorded on usage_stream after
// the buffer was used.
// reference_held: true if and only if the caller has caused a
// reference to this->buffer() to stay live until after
// the host is sure that the usage (transfer or execution)
// has completed.
void ConvertUsageHold(se::Stream* usage_stream,
std::shared_ptr<BufferSequencingEvent> event,
bool reference_held);
// Confirms that the buffer was successfully donated to an execution.
// Only valid for holds of type kDonation. Causes the buffer to become
// invalid.
void ConfirmDonation();
// Adds the held device buffers in order to 'iterator'. Used to add the
// buffers to an ExecutionInput. We require but do not verify that
// 'iterator' when passed in is pointing to a sub-tuple of the
// ExecutionInput whose on_device_shape matches that of the
// TrackedDeviceBuffer. 'end' is used to check that 'iterator' doesn't run
// out of bounds. Donates the device buffers if the hold type is kDonation,
// otherwise retains ownership of the device buffers.
void AddToInput(ShapeTree<MaybeOwningDeviceMemory>::iterator* iterator,
const ShapeTree<MaybeOwningDeviceMemory>::iterator& end,
ExecutionInput* execution_input,
se::DeviceMemoryAllocator* allocator) const;
private:
friend class PjRtStreamExecutorBuffer;
friend class PjRtStreamExecutorClient;
// Helper struct that makes it possible to move a ScopedHold through a
// closure.
using ForClosure =
std::tuple<PjRtStreamExecutorBuffer*, Type, State,
StatusOr<std::shared_ptr<TrackedDeviceBuffer>>>;
ScopedHold(PjRtStreamExecutorBuffer* parent, Type type)
: parent_(parent), type_(type), state_(kUninitialized) {}
explicit ScopedHold(const ForClosure& closure_helper)
: parent_(std::get<0>(closure_helper)),
type_(std::get<1>(closure_helper)),
state_(std::get<2>(closure_helper)),
buffer_or_(std::get<3>(closure_helper)) {
// Check the buffer is not in an error state.
CHECK(buffer_or_.ValueOrDie() != nullptr);
}
// Sets buffer state.
void SetState(State state) { state_ = state; }
// Sets buffer_or_. Called by parent_ to initialize the hold.
void Acquire(StatusOr<std::shared_ptr<TrackedDeviceBuffer>>&& buffer_or);
// Releases the contents of *this, so *this can subsequently be
// deleted without releasing the parent's hold. Should be passed to the
// appropriate constructor of another ScopedHold, e.g., when a hold must be
// passed through a closure that is incompatible with std::move.
ForClosure ToClosure();
PjRtStreamExecutorBuffer* const parent_;
const Type type_;
// There is an invariant that if ok() then
// buffer_or_.ValueOrDie() != nullptr.
State state_;
StatusOr<std::shared_ptr<TrackedDeviceBuffer>> buffer_or_;
};
PjRtStreamExecutorBuffer(Shape on_host_shape, Shape on_device_shape,
std::shared_ptr<TrackedDeviceBuffer> device_buffer,
PjRtClient* client, PjRtDevice* device);
~PjRtStreamExecutorBuffer() override;
PjRtStreamExecutorBuffer(const PjRtStreamExecutorBuffer&) = delete;
PjRtStreamExecutorBuffer(PjRtStreamExecutorBuffer&&) = delete;
PjRtStreamExecutorBuffer& operator=(const PjRtStreamExecutorBuffer&) = delete;
PjRtStreamExecutorBuffer& operator=(PjRtStreamExecutorBuffer&&) = delete;
const Shape& on_host_shape() const override { return on_host_shape_; }
const Shape& on_device_shape() const override { return on_device_shape_; }
PjRtStreamExecutorDevice* device() const override { return device_; }
PjRtPlatformId platform_id() const { return client_->platform_id(); }
const std::string& platform_name() const { return client_->platform_name(); }
PjRtStreamExecutorClient* client() const override { return client_; }
bool IsEmptyTuple() const {
return on_host_shape_.IsTuple() && on_host_shape_.tuple_shapes_size() == 0;
}
int64 OnDeviceSizeInBytes() const override;
// Implement PjRtBuffer::ExternalReferenceHold a wrapped
// ScopedHold::kExternalReference.
class ScopedHoldAsExternalReference
: public PjRtBuffer::ExternalReferenceHold {
public:
explicit ScopedHoldAsExternalReference(ScopedHold hold)
: external_reference_(std::move(hold)) {
CHECK(hold.type() == ScopedHold::kExternalReference);
}
~ScopedHoldAsExternalReference() override = default;
void* OpaqueDeviceMemoryDataPointer() const override {
return external_reference_->device_memory().front().opaque();
}
private:
ScopedHold external_reference_;
};
StatusOr<std::unique_ptr<ExternalReferenceHold>> AcquireExternalReference()
override;
StatusOr<absl::optional<std::shared_ptr<void>>> ReleaseDeviceMemoryOwnership(
bool wait_for_operations_to_complete) override;
using PjRtBuffer::ToLiteral;
StatusOr<std::shared_ptr<Literal>> ToLiteral(
bool discard_cached_copy, absl::optional<xla::Layout> layout) override;
using PjRtBuffer::CopyToHostAsync;
Status CopyToHostAsync(absl::optional<xla::Layout> layout) override;
// Drops the buffer's reference to its associated device memory, leaving the
// buffer in an invalid state. The memory will be freed lazily when all async
// operations using the buffer have completed, according to the allocation
// semantics of the underlying platform. Delete may briefly block if another
// thread is in the process of enqueuing an operation on this buffer, but it
// will never block for a stream operation to complete. If an external
// framework holds a reference to the TrackedDeviceBuffer via
// GetBufferWithExternalReference, the memory will not be freed until the
// external framework drops the reference.
void Delete() override;
bool IsDeleted() override;
// Returns a view of the PjRtBuffer device memory as a ShapedBuffer. The
// PjRtBuffer retains ownership of the device buffers.
StatusOr<ShapedBuffer> AsShapedBuffer() const;
// Returns a hold on the TrackedDeviceBuffer holding the device
// buffers. See comment on ScopedHold.
ScopedHold GetBufferWithHold(ScopedHold::Type type);
ScopedHold GetBufferWithUsageHold() {
return GetBufferWithHold(ScopedHold::kUsage);
}
ScopedHold GetBufferWithExternalReference() {
return GetBufferWithHold(ScopedHold::kExternalReference);
}
StatusOr<std::unique_ptr<PjRtBuffer>> CopyToDevice(
PjRtDevice* dst_device) override;
Status CopyToRemoteDevice(absl::string_view serialized_descriptor) override;
Status BlockHostUntilReady() override;
bool IsOnCpu() const override;
// Similar to Delete, drops the buffer's reference to its associated device
// memory, leaving the buffer in an invalid state, but returns the
// TrackedDeviceBuffer rather than freeing the device memory, so that another
// framework can take ownership of it. The buffer returned from Release may
// be safely dropped at any time even if it still has pending async
// operations. The client should call BlockHostUntilReady before calling
// Release with wait_for_operations_to_complete=false, to ensure that the host
// has synchronized past any outstanding write operations to the buffer. If
// wait_for_operations_to_complete=true the host will block until any
// potentially outstanding asynchronous operations have completed before
// returning, in which case it is safe to read or mutate the returned buffer.
// If the buffer was shared via an external reference it is the client's
// responsibility that accesses via that reference do not interfere with
// accesses via the buffer returned from Release.
StatusOr<std::shared_ptr<TrackedDeviceBuffer>> Release(
bool wait_for_operations_to_complete);
private:
friend class PjRtClient;
// The cached value of the buffer on the host, produced either from a call to
// CopyToHost or from a call to ToLiteral. Once a value has been fetched to
// the host, it persists Delete() is called or the PjRtBuffer is destroyed.
struct HostValue {
absl::Notification ready;
// status and value are valid for reading only after `ready` has been
// notified.
Status status;
std::shared_ptr<Literal> value;
};
// Blocks in mu_.Await until there are no more usage holds.
void WaitForOutstandingUsageHolds() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
// Blocks in mu_.Await until there is no donation hold.
void WaitForOutstandingDonationHold() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
// Adds a hold of 'type' and returns device_buffer_. Returns an error if
// device_buffer_ is null, or if a donation hold was requested when there is
// an outstanding external hold.
StatusOr<std::shared_ptr<TrackedDeviceBuffer>> GetBufferForHoldLocked(
ScopedHold::Type type) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
// Adds a hold of hold->type() and initializes `hold` with device_buffer_.
// Initializes hold with an error if device_buffer_ is null, or if a donation
// hold was requested when there is an outstanding external hold.
void AcquireHoldLocked(ScopedHold* hold) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
// Drops a usage hold and calls device_buffer_->AddUsageEvent. Does a sanity
// check that buffer==device_buffer_ or device_buffer_==nullptr. Called after
// device_buffer_ was successfully enqueued on a stream.
void ConvertUsageHold(TrackedDeviceBuffer* buffer, se::Stream* usage_stream,
std::shared_ptr<BufferSequencingEvent> event,
bool reference_held);
// Drops a donation hold and makes *this invalid for further use. Does a
// sanity check that buffer==device_buffer_. Called after device_buffer_ was
// successfully donated to an execution.
void ConfirmDonation(TrackedDeviceBuffer* device_buffer);
// Initiates a copy of the buffer to the host. Does not block waiting for
// the transfer to complete. A host value is returned and if
// `discard_cached_copy` is false stored in an internal buffer so that future
// transfers don't have to transfer the data from host again. If a layout is
// passed then a literal of this layout will be returned and possibly cached.
StatusOr<std::shared_ptr<HostValue>> CopyToHostAsyncInternal(
bool discard_cached_copy, absl::optional<xla::Layout> layout);
// Drops a hold without taking any other action. Does a sanity check that
// buffer==device_buffer_ or device_buffer_==nullptr.
void DropHold(ScopedHold::Type type, TrackedDeviceBuffer* buffer);
StatusOr<std::pair<std::unique_ptr<PjRtBuffer>,
std::shared_ptr<BufferSequencingEvent>>>
CopyToDeviceHelper(PjRtDevice* dst_device, LocalDeviceState* dst_local_device,
LocalDeviceState* transfer_local_device,
se::Stream* transfer_stream,
std::shared_ptr<TrackedDeviceBuffer> src_device_buffer);
PjRtStreamExecutorClient* const client_;
const Shape on_host_shape_;
const Shape on_device_shape_;
PjRtStreamExecutorDevice* const device_;
mutable absl::Mutex mu_;
std::shared_ptr<TrackedDeviceBuffer> device_buffer_ TF_GUARDED_BY(mu_);
absl::flat_hash_map<xla::Layout, std::shared_ptr<HostValue>> host_values_
TF_GUARDED_BY(mu_);
std::shared_ptr<HostValue> host_value_ TF_GUARDED_BY(mu_);
// Count of holds on the buffer.
std::array<int, ScopedHold::Type::kMaxValue> holds_ TF_GUARDED_BY(mu_);
// Semaphore used to ensure there is only one outstanding donation hold.
Semaphore donation_semaphore_;
};
// Wraps one or more XLA LocalExecutables (one per partition, as specified by
// the build options).
class PjRtStreamExecutorExecutable : public PjRtExecutable {
public:
PjRtStreamExecutorExecutable(
std::vector<std::unique_ptr<LocalExecutable>> executables,
bool parameter_is_tupled_arguments,
std::shared_ptr<DeviceAssignment> device_assignment,
std::vector<LogicalDeviceIds> addressable_device_logical_ids,
std::vector<PjRtDevice*> addressable_devices,
PjRtStreamExecutorClient* client);
~PjRtStreamExecutorExecutable() override = default;
PjRtStreamExecutorClient* client() const override { return client_; }
const std::string& name() const override;
int num_replicas() const override {
return executables_[0]->build_options().num_replicas();
}
int num_partitions() const override {
return executables_[0]->build_options().num_partitions();
}
int64 SizeOfGeneratedCodeInBytes() const override {
int64 size = 0;
for (auto& executable : executables_) {
size += executable->executable()->SizeOfGeneratedCodeInBytes();
}
return size;
}
const DeviceAssignment& device_assignment() const override {
return *device_assignment_;
}
absl::Span<const LogicalDeviceIds> addressable_device_logical_ids()
const override {
return addressable_device_logical_ids_;
}
absl::Span<PjRtDevice* const> addressable_devices() const override {
return addressable_devices_;
}
// Return an HloModule per partition.
StatusOr<std::vector<std::shared_ptr<HloModule>>> GetHloModules()
const override;
StatusOr<std::vector<std::vector<std::unique_ptr<PjRtBuffer>>>> Execute(
absl::Span<const std::vector<PjRtBuffer*>> argument_handles,
const ExecuteOptions& options) const override;
StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>> ExecuteSharded(
absl::Span<PjRtBuffer* const> argument_handles, PjRtDevice* device,
const ExecuteOptions& options) const override;
StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>> ExecutePortable(
absl::Span<PjRtBuffer* const> argument_handles, PjRtDevice* device,
const ExecuteOptions& options) const override;
void Delete() override { executables_.clear(); }
absl::Span<const std::shared_ptr<LocalExecutable>> executables() const {
return executables_;
}
protected:
bool parameter_is_tupled_arguments() const {
return parameter_is_tupled_arguments_;
}
private:
friend class PjRtStreamExecutorClient;
// Initializes information about which arguments to which executables must be
// donated due to aliases that were specified by the computation.
Status SetUpDonation(bool tuple_inputs);
virtual bool MustDonateParameter(int executable_idx, int parameter) const;
virtual StatusOr<std::vector<ExecutionInput>>
MakeExecutionInputsAndWaitForEvents(
int device_ordinal, const ExecuteOptions& options,
absl::Span<PjRtBuffer* const> argument_handles,
absl::Span<const PjRtStreamExecutorBuffer::ScopedHold> device_buffers,
absl::flat_hash_set<BufferSequencingEvent*>& events) const;
StatusOr<ScopedShapedBuffer> EnqueueExecution(
absl::Span<PjRtBuffer* const> argument_handles, int replica,
int partition, int executable_idx, const RunId& run_id,
const ExecuteOptions& options, PjRtDevice* device,
std::vector<PjRtStreamExecutorBuffer::ScopedHold>* device_buffers,
std::shared_ptr<DeviceAssignment> device_assignment) const;
virtual std::vector<std::unique_ptr<PjRtBuffer>> MakeOutputBuffers(
int device_ordinal, const ExecuteOptions& options,
ScopedShapedBuffer result_buffer,
std::shared_ptr<BufferSequencingEvent> definition_event,
PjRtDevice* device) const;
StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>> ExecuteHelper(
absl::Span<PjRtBuffer* const> argument_handles, int replica,
int partition, const RunId& run_id, const ExecuteOptions& options,
PjRtDevice* device = nullptr) const;
// Create shared pointers so we can free them after the execution: with
// asynchronous execution, the process being executed can outlive the
// executable itself.
PjRtStreamExecutorClient* const client_;
// One executable per partition.
std::vector<std::shared_ptr<LocalExecutable>> executables_;
// Per-executable set of parameters that have any aliased buffers and thus
// must be donated when executing the computation.
std::vector<absl::flat_hash_set<int>> parameters_that_must_be_donated_;
std::shared_ptr<DeviceAssignment> device_assignment_;
// True if the executables were compiled expecting arguments in a single
// tuple.
const bool parameter_is_tupled_arguments_;
// The replica and partition indices of device_assignment_ to be run by this
// client. On single-host platforms without partitioning, this is all replicas
// (i.e. addressable_device_logical_ids_[i] = (i, 0)), but this may not be the
// case on multi-host platforms. If there are 4 replicas and 2 partitions on a
// single host platform, size of addressable_device_logical_ids_ is 4*2 = 8.
std::vector<LogicalDeviceIds> addressable_device_logical_ids_;
// addressable_devices_[i] is the Device to which
// addressable_device_logical_ids_[i] is assigned. shared_ptrs instead of
// unique_ptrs to play well with the Python bindings (see xla.cc).
std::vector<PjRtDevice*> addressable_devices_;
};
// Executables can donate buffers so that buffers can be aliased from inputs
// to outputs. This function returns the list of parameters that must be
// donated when executable is run. tuple_inputs reflects the option that
// executable was compiled with.
StatusOr<absl::flat_hash_set<int>> GetParametersThatMustBeDonated(
const HloModule& hlo_module, bool tuple_inputs);
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_PJRT_PJRT_STREAM_EXECUTOR_CLIENT_H_

View File

@ -23,7 +23,7 @@ limitations under the License.
#include "absl/status/status.h"
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/pjrt/local_device_state.h"
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
#include "tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h"
#include "tensorflow/compiler/xla/pjrt/tracked_device_buffer.h"
#include "tensorflow/compiler/xla/service/shaped_buffer.h"
#include "tensorflow/compiler/xla/shape.h"

View File

@ -20,7 +20,7 @@ limitations under the License.
#include <memory>
#include <vector>
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
#include "tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/stream_executor/tpu/tpu_topology.h"

View File

@ -211,6 +211,7 @@ cc_library(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla/pjrt:pjrt_client",
"//tensorflow/compiler/xla/pjrt:pjrt_stream_executor_client", # TODO(zhangqiaorjc): Remove after adding a factory method for PjRtBuffer.
"//tensorflow/compiler/xla/pjrt:tracked_device_buffer",
"//tensorflow/stream_executor:device_memory",
"//tensorflow/stream_executor:platform",

View File

@ -25,6 +25,7 @@ limitations under the License.
#include "include/dlpack/dlpack.h" // from @dlpack
#include "pybind11/pytypes.h"
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
#include "tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h"
#include "tensorflow/compiler/xla/pjrt/tracked_device_buffer.h"
#include "tensorflow/compiler/xla/python/python_ref_manager.h"
#include "tensorflow/compiler/xla/python/traceback.h"

View File

@ -18,6 +18,7 @@ limitations under the License.
#include <sys/types.h>
#include <memory>
#include <queue>
#include <sstream>
#include "absl/container/flat_hash_map.h"

View File

@ -26,7 +26,7 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto_cc",
"//tensorflow/compiler/xla/client:executable_build_options",
"//tensorflow/compiler/xla/pjrt:pjrt_client",
"//tensorflow/compiler/xla/pjrt:pjrt_stream_executor_client",
"//tensorflow/compiler/xla/pjrt:semaphore",
"//tensorflow/compiler/xla/python/tpu_driver",
"//tensorflow/compiler/xla/python/tpu_driver:direct_tpu_driver",

View File

@ -24,7 +24,7 @@ limitations under the License.
#include "absl/synchronization/notification.h"
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/client/executable_build_options.h"
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
#include "tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h"
#include "tensorflow/compiler/xla/python/tpu_driver/tpu_driver.h"
#include "tensorflow/compiler/xla/python/tpu_driver/tpu_driver.pb.h"
#include "tensorflow/compiler/xla/service/shaped_buffer.h"