Qiao Zhang 2d0d477194 Mark Execute methods non-const allow more flexibility, e.g., lazily load programs on first execution.
PiperOrigin-RevId: 347870420
Change-Id: Ie4c21280c14dd6201d24d42b9f1fd967a9617307
2020-12-16 12:13:08 -08:00

464 lines
21 KiB
C++

/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_PJRT_PJRT_CLIENT_H_
#define TENSORFLOW_COMPILER_XLA_PJRT_PJRT_CLIENT_H_
#include <memory>
#include <string>
#include <vector>
#include "absl/strings/string_view.h"
#include "absl/types/optional.h"
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/client/executable_build_options.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/layout.h"
#include "tensorflow/compiler/xla/service/hlo_cost_analysis.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/shape.h"
#include "tensorflow/compiler/xla/status.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/casts.h"
#include "tensorflow/core/platform/fingerprint.h"
#include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/core/platform/types.h"
// API notes:
// PjRt stands for "Pretty much Just another RunTime".
namespace xla {
using PjRtPlatformId = uint64;
constexpr char kCpuName[] = "cpu";
constexpr char kGpuName[] = "gpu";
constexpr char kTpuName[] = "tpu";
static const PjRtPlatformId kCpuId = tensorflow::Fingerprint64(kCpuName);
static const PjRtPlatformId kGpuId = tensorflow::Fingerprint64(kGpuName);
static const PjRtPlatformId kTpuId = tensorflow::Fingerprint64(kTpuName);
class PjRtClient;
class PjRtDevice {
public:
virtual ~PjRtDevice() {}
// Return the client that owns this device.
virtual PjRtClient* client() const = 0;
// Whether client can issue command to this device.
virtual bool IsAddressable() const = 0;
// The ID of this device. IDs are unique among devices of this type
// (e.g. CPUs, GPUs). On multi-host platforms, this will be unique across all
// hosts' devices. This is the ID that should be used in a DeviceAssignment.
virtual int id() const = 0;
// The task ID of this device according to TpuTopology. This is not the same
// as PjRtClient::host_id() in a multi-task setting, where each client can see
// devices from all tasks, but only a subset of them are addressable and have
// the same task_id as the client.
virtual int host_id() const = 0;
// Opaque hardware ID, e.g., the CUDA device number, useful for identifying
// which GPU when interacting with non-JAX code. In general, not guaranteed to
// be dense, and -1 if undefined.
virtual int local_hardware_id() const = 0;
// A vendor-dependent string that uniquely identifies the kind of device,
// e.g., "Tesla V100-SXM2-16GB". May be used to determine whether two GPUs are
// compatible compilation.
virtual const std::string& device_kind() const = 0;
virtual std::string DebugString() const = 0;
// Transfer the given literal to the infeed queue.
virtual Status TransferToInfeed(const LiteralSlice& literal) const = 0;
// Transfer and return a value of the given shape from the outfeed queue.
virtual StatusOr<Literal> TransferFromOutfeed(const Shape& shape) const = 0;
};
// Forward declaration.
class PjRtBuffer;
// Helper struct for cross host transfers, returned by the callback from a call
// to PjRtBuffer::MakeCrossHostReceiveBuffers.
struct PjRtCrossHostRecvBuffer {
// serialized_descriptor should be transmitted to the sender and passed to a
// call to src_buffer->CopyToRemoteDevice.
std::string serialized_descriptor;
// The buffer that will hold the result of the transfer.
std::unique_ptr<PjRtBuffer> buffer;
};
using PjRtCrossHostRecvNotifier =
std::function<void(StatusOr<std::vector<PjRtCrossHostRecvBuffer>>&&)>;
struct CompileOptions {
// The layouts of the arguments that the computation should expect.
absl::optional<std::vector<Shape>> argument_layouts;
// If true, the supplied computation expects its arguments to be wrapped in a
// tuple and passed as a single parameter.
bool parameter_is_tupled_arguments = false;
// XLA's compilation time options.
ExecutableBuildOptions executable_build_options;
// If true, the executable can be run on any device. May only be true if
// !executable_build_options.has_device_assignment(), so only applies to
// single-device executables. Beware: on GPUs, sometimes an executable
// compiled for one device doesn't run on another.
bool compile_portable_executable = false;
};
class PjRtExecutable;
// Encapsulates the state of Python session with XLA.
//
// It is the responsibility of the client of this API to keep the PjRtClient
// alive as long as any of the other runtime objects are alive.
class PjRtClient {
public:
virtual ~PjRtClient() = default;
// TODO(zhangqiaorjc): Rename to task_id.
// Return the task id of this client. In single-task setting, always 0.
virtual int host_id() const = 0;
// Return the number of devices in the entire computation. In multi-headed
// client setting, some are addressable by this client, some are not. In a
// single-client setting, this is equal to the number of addressable devices.
virtual int device_count() const = 0;
// Return number of addressable devices. Addressable devices are those that
// the client can issue commands to.
virtual int addressable_device_count() const = 0;
// Return all devices in the entire computation, including addressable and
// non-addressable devices.
virtual absl::Span<PjRtDevice* const> devices() const = 0;
// TODO(zhangqiaorjc): Rename to addressable_devices.
// Return only addressable devices.
virtual absl::Span<PjRtDevice* const> local_devices() const = 0;
// Lookup any PjRtDevice for a given PjRtDevice::id().
virtual StatusOr<PjRtDevice*> LookupDevice(int device_id) const = 0;
// Return an addressable PjRtDevice for a given
// PjRtDevice::local_hardware_id().
virtual StatusOr<PjRtDevice*> LookupAddressableDevice(
int local_hardware_id) const = 0;
// Return an ID that identifies the platform (CPU/GPU/TPU).
virtual PjRtPlatformId platform_id() const = 0;
// Returns a string that identifies the platform (CPU/GPU/TPU).
virtual const std::string& platform_name() const = 0;
// Return a device-specific default device assignment, e.g., GPU and TPU may
// be different.
virtual StatusOr<DeviceAssignment> GetDefaultDeviceAssignment(
int num_replicas, int num_partitions) const = 0;
// Returns a backend-specific HLO cost analysis visitor.
virtual std::unique_ptr<HloCostAnalysis> GetHloCostAnalysis() = 0;
// Compile `computation` with given `options`.
virtual StatusOr<std::unique_ptr<PjRtExecutable>> Compile(
const XlaComputation& computation, CompileOptions options) = 0;
// Generates a unique fingerprint for `executable`, may be absl::nullopt.
virtual StatusOr<absl::optional<std::string>> ExecutableFingerprint(
const PjRtExecutable& executable) const = 0;
// Creates a buffer on the device without initializing or copying any data.
virtual StatusOr<std::unique_ptr<PjRtBuffer>> CreateUninitializedBuffer(
const Shape& shape, PjRtDevice* device) = 0;
// Describes the semantics the caller to BufferFromHostBuffer expects from the
// runtime, in a total order from most restrictive to least restrictive.
enum class HostBufferSemantics {
// The runtime may not hold references to `data` after the call to
// `BufferFromHostBuffer` completes. The caller promises that `data` is
// immutable and will not be freed only for the duration of the
// BufferFromHostBuffer call. `buffer_reference` will be freed by the time
// `BufferFromHostBuffer` returns.
kImmutableOnlyDuringCall,
// The runtime may hold onto `data` after the call to `BufferFromHostBuffer`
// returns while the runtime completes a transfer to the device. The caller
// promises not to mutate or free `data` until the transfer completes, at
// which point the runtime will release `buffer_reference`. It is also
// correct to wait on the host (directly or indirectly) for the buffer's
// definition event to complete.
kImmutableUntilTransferCompletes,
// The PjRtBuffer may alias `data` internally and the runtime may use the
// `data` contents as long as the buffer is alive. The caller promises to
// keep `data` alive and not to mutate its contents as long as the buffer is
// alive; to notify the caller that the buffer may be freed, the runtime
// will release its `buffer_reference` when the PjRtBuffer is freed. On
// non-CPU platforms this acts identically to
// kImmutableUntilTransferCompletes.
kZeroCopy,
};
virtual StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostBuffer(
const void* data, const Shape& shape,
HostBufferSemantics host_buffer_semantics,
std::shared_ptr<void> buffer_reference, PjRtDevice* device) = 0;
// Note that literal must remain in scope until the transfer has completed, so
// the caller should, for example, wait for BlockHostUntilReady() completes on
// the return value before letting literal go out of scope.
virtual StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostLiteral(
const LiteralSlice& literal, PjRtDevice* device) = 0;
// Asynchronously makes a vector of PjRtBuffers that can be used to receive
// cross host transfers using `client` on `device'. `shapes` must be the exact
// shapes, with identical layouts, corresponding to the buffers that will be
// sent. When resources for the transfer are available, notifier will be
// called with a vector of PjRtCrossHostRecvBuffer structs, one for each
// shape in `shapes`. Each struct contains a buffer that will contain the
// received value, and an opaque string that should be transmitted to the
// sending host and used in a call to CopyToRemoteDevice. None of the recv
// buffers will become ready until *all* of the sends have completed.
virtual void MakeCrossHostReceiveBuffers(
absl::Span<const Shape> shapes, PjRtDevice* device,
PjRtCrossHostRecvNotifier&& notifier) = 0;
// Create ChannelHandles for XLA send/recv.
virtual StatusOr<ChannelHandle> CreateChannelHandle() = 0;
virtual StatusOr<ChannelHandle> CreateDeviceToHostChannelHandle() = 0;
virtual StatusOr<ChannelHandle> CreateHostToDeviceChannelHandle() = 0;
};
// Holds a reference from Python to a tuple of device buffers. A PjRtBuffer
// can be either valid or invalid. An invalid buffer is one that has never been
// initialized, or a buffer that has been deleted (e.g., by calling Delete, or
// by donating it to a computation that aliases an input parameter to an
// output). We allow PjRtBuffer objects to outlive the underlying device
// buffers so we can decouple buffer lifetimes from the corresponding Python
// references if needed. Thread-safe.
class PjRtBuffer {
public:
virtual ~PjRtBuffer() = default;
virtual const Shape& on_host_shape() const = 0;
virtual const Shape& on_device_shape() const = 0;
virtual PjRtDevice* device() const = 0;
virtual PjRtClient* client() const = 0;
// Returns the size of the on-device representation of this buffer in bytes.
virtual int64 OnDeviceSizeInBytes() const = 0;
// ExternalReferenceHold is a potentially long-lived hold while the buffer is
// being shared by an external framework, e.g., NumPy. A client acquires an
// external hold by calling PjRtBuffer::AcquireExternalReference() and
// releases it by deleting the ExternalReferenceHold. The external framework
// should not modify the underlying buffer unless it is confident via its own
// synchronization that modifications do not race with reads from the
// PjRtBuffer.
struct ExternalReferenceHold {
virtual ~ExternalReferenceHold() = default;
// Return opaque device memory pointer to root buffer.
virtual void* OpaqueDeviceMemoryDataPointer() const = 0;
};
virtual StatusOr<std::unique_ptr<ExternalReferenceHold>>
AcquireExternalReference() = 0;
// Returns the buffer's value as an XLA Literal. If the value has previously
// been prefetched to the host, then returns the prefetched version, otherwise
// copies the buffer to the host. Blocks until the value is ready. If
// `discard_cached_copy` is true then buffer will no longer keep hold of a
// cached copy of the literal (i.e. The reference to the host value will be
// removed.) If a layout is passed than a literal with this layout will be
// returned.
StatusOr<std::shared_ptr<Literal>> ToLiteral() {
return ToLiteral(/*discard_cached_copy=*/false, /*layout=*/{});
}
StatusOr<std::shared_ptr<Literal>> ToLiteral(bool discard_cached_copy) {
return ToLiteral(discard_cached_copy, /*layout=*/{});
}
virtual StatusOr<std::shared_ptr<Literal>> ToLiteral(
bool discard_cached_copy, absl::optional<xla::Layout> layout) = 0;
// Initiates a copy of the buffer to the host. Does not block waiting for
// the transfer to complete. The value can be retrieved by a later call to
// ToLiteral(). If a layout is passed then a cached copy with this layout will
// be created.
Status CopyToHostAsync() { return CopyToHostAsync(/*layout=*/{}); }
virtual Status CopyToHostAsync(absl::optional<xla::Layout> layout) = 0;
// Drops the buffer's reference to its associated device memory, leaving the
// buffer in an invalid state. The memory will be freed lazily when all async
// operations using the buffer have completed, according to the allocation
// semantics of the underlying platform. Delete may briefly block if another
// thread is in the process of enqueuing an operation on this buffer, but it
// will never block for a stream operation to complete. If an external
// framework holds a reference to the TrackedDeviceBuffer via
// GetBufferWithExternalReference, the memory will not be freed until the
// external framework drops the reference.
virtual void Delete() = 0;
// Similar to Delete, drops the buffer's reference to its associated device
// memory, leaving the buffer in an invalid state, but transfers the device
// memory ownership out via absl::optional<std::shared_ptr<void>> rather than
// freeing the device memory, so that another framework can take ownership of
// it. A return value of absl::nullopt indicates that PjRtBuffer has been
// deleted. The buffer returned from Release may be safely dropped at any time
// even if it still has pending async operations. The client should call
// BlockHostUntilReady before calling ReleaseDeviceMemoryOwnership with
// wait_for_operations_to_complete=false, to ensure that the host has
// synchronized past any outstanding write operations to the buffer. If
// wait_for_operations_to_complete=true the host will block until any
// potentially outstanding asynchronous operations have completed before
// returning, in which case it is safe to read or mutate the returned buffer.
// If the buffer was shared via an external reference it is the client's
// responsibility that accesses via that reference do not interfere with
// accesses via the buffer returned from ReleaseDeviceMemoryOwnership.
virtual StatusOr<absl::optional<std::shared_ptr<void>>>
ReleaseDeviceMemoryOwnership(bool wait_for_operations_to_complete) = 0;
// True if and only if Delete or Release has previously been called.
virtual bool IsDeleted() = 0;
// Copies the buffer to device `dst_device`, performing a d2d transfer when
// `dst_device` is sharing the same Client, and performing a d2h and h2d copy
// if `dst_device` lives on a different Client.
// Returns an error if the buffer is already on dst_device.
virtual StatusOr<std::unique_ptr<PjRtBuffer>> CopyToDevice(
PjRtDevice* dst_device) = 0;
// Copies the buffer to the remote device encoded in serialized_descriptor.
// This call must be preceded by a call to MakeCrossHostReceiveBuffers on the
// remote host's destination device. MakeCrossHostReceiveBuffers takes an
// array of shapes to construct the destination buffers, and a callback
// supplies an array containing both the destination buffers, and a serialized
// descriptor for each buffer. For each destination buffer there should be a
// matching call to src->CopyToRemoteDevice on a remote host for a src buffer
// of the corresponding shape. serialized_descriptor is the string returned by
// the callback along with the corresponding destination buffer.
virtual Status CopyToRemoteDevice(
absl::string_view serialized_descriptor) = 0;
// Blocks the host until the buffer's value has been computed and is ready for
// immediate use on the device. Useful in particular for timing benchmarks.
virtual Status BlockHostUntilReady() = 0;
// Whether this buffer is on CPU and thus allows for certain optimizations.
virtual bool IsOnCpu() const = 0;
};
class ExecuteContext {
public:
virtual ~ExecuteContext() = default;
};
struct ExecuteOptions {
// If true, the client must pass a single PjRtBuffer which contains all of
// the arguments as a single XLA tuple, otherwise each argument must be
// passed in its own PjRtBuffer. May only be true if the executable was
// compiled with parameter_is_tupled_arguments==true.
bool arguments_are_tupled = false;
// If true, the computation must return a tuple, which will be destructured
// into its elements.
bool untuple_result = false;
// If non-zero, identifies this execution as part of a potentially
// multi-device launch. This can be used to detect scheduling errors, e.g. if
// multi-host programs are launched in different orders on different hosts,
// the launch IDs may be used by the runtime to detect the mismatch.
int32 launch_id = 0;
// If non-null, an opaque context passed to an execution that may be used to
// supply additional arguments to a derived class of PjRtExecutable.
const ExecuteContext* context = nullptr;
};
// Represents a compiled computation that can be executed given handles to
// device-allocated literals. If any input/output alias has been specified in
// the computation, the parameter containing the input buffer will be donated
// when passed to the execution.
class PjRtExecutable {
public:
virtual ~PjRtExecutable() = default;
virtual PjRtClient* client() const = 0;
// Unique name for this executable, e.g., HloModule name.
virtual const std::string& name() const = 0;
virtual int num_replicas() const = 0;
virtual int num_partitions() const = 0;
virtual int64 SizeOfGeneratedCodeInBytes() const = 0;
virtual const DeviceAssignment& device_assignment() const = 0;
// The replica and partition indices of device_assignment to be run by this
// client. On single-host platforms without partitioning, this is all replicas
// (i.e. addressable_device_logical_ids_[i] = (i, 0)), but this may not be the
// case on multi-host platforms. If there are 4 replicas and 2 partitions on a
// single host platform, size of addressable_device_logical_ids_ is 4*2 = 8.
struct LogicalDeviceIds {
int replica;
int partition;
};
virtual absl::Span<const LogicalDeviceIds> addressable_device_logical_ids()
const = 0;
// An addressable_device is one which the client can issue commands to.
// addressable_devices()[i] is the Device to which
// addressable_device_logical_ids()[i] is assigned.
virtual absl::Span<PjRtDevice* const> addressable_devices() const = 0;
// Return an HloModule (optimized) per partition.
virtual StatusOr<std::vector<std::shared_ptr<HloModule>>> GetHloModules()
const = 0;
// Executes on devices addressable by the client. Requires executable has a
// device_assignment and all devices in the device_assignment are addressable
// by the client.
virtual StatusOr<std::vector<std::vector<std::unique_ptr<PjRtBuffer>>>>
Execute(absl::Span<const std::vector<PjRtBuffer*>> argument_handles,
const ExecuteOptions& options) = 0;
// Execute the assigned replica/partition on a given `device`. Requires
// executable has a device_assignment, `device` is present in the
// device_assignment and addressable by the client.
virtual StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>> ExecuteSharded(
absl::Span<PjRtBuffer* const> argument_handles, PjRtDevice* device,
const ExecuteOptions& options) = 0;
// Execute on a given `device`. Requires `device` to be addressable by client.
// Requires executable has exactly 1 replica and 1 partition and no
// device_assignment (thus portable).
virtual StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>> ExecutePortable(
absl::Span<PjRtBuffer* const> argument_handles, PjRtDevice* device,
const ExecuteOptions& options) = 0;
// Asynchronously free resources after the last execution completes.
virtual void Delete() = 0;
};
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_PJRT_PJRT_CLIENT_H_