Split StreamExecutor impl of pjrt_client into its own files.
Also removed duplicate comments for subclass methods. PiperOrigin-RevId: 346807683 Change-Id: I2ab9b5038d5e5fc991bfee53d0081c5eecf51906
This commit is contained in:
parent
d9f4007ff6
commit
ec86d80f19
@ -119,12 +119,39 @@ cc_library(
|
|||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "pjrt_client",
|
name = "pjrt_client",
|
||||||
srcs = ["pjrt_client.cc"],
|
|
||||||
hdrs = ["pjrt_client.h"],
|
hdrs = ["pjrt_client.h"],
|
||||||
visibility = ["//tensorflow/compiler/xla:friends"],
|
visibility = ["//tensorflow/compiler/xla:friends"],
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/compiler/xla:executable_run_options",
|
||||||
|
"//tensorflow/compiler/xla:literal",
|
||||||
|
"//tensorflow/compiler/xla:shape_util",
|
||||||
|
"//tensorflow/compiler/xla:status",
|
||||||
|
"//tensorflow/compiler/xla:statusor",
|
||||||
|
"//tensorflow/compiler/xla:util",
|
||||||
|
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||||
|
"//tensorflow/compiler/xla/client:executable_build_options",
|
||||||
|
"//tensorflow/compiler/xla/client:xla_computation",
|
||||||
|
"//tensorflow/compiler/xla/pjrt/distributed:protocol_proto_cc",
|
||||||
|
"//tensorflow/compiler/xla/service:hlo",
|
||||||
|
"//tensorflow/compiler/xla/service:hlo_cost_analysis",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
"@com_google_absl//absl/base",
|
||||||
|
"@com_google_absl//absl/memory",
|
||||||
|
"@com_google_absl//absl/strings",
|
||||||
|
"@com_google_absl//absl/types:optional",
|
||||||
|
"@com_google_absl//absl/types:span",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "pjrt_stream_executor_client",
|
||||||
|
srcs = ["pjrt_stream_executor_client.cc"],
|
||||||
|
hdrs = ["pjrt_stream_executor_client.h"],
|
||||||
|
visibility = ["//tensorflow/compiler/xla:friends"],
|
||||||
deps = [
|
deps = [
|
||||||
":event_pool",
|
":event_pool",
|
||||||
":local_device_state",
|
":local_device_state",
|
||||||
|
":pjrt_client",
|
||||||
":tracked_device_buffer",
|
":tracked_device_buffer",
|
||||||
"//tensorflow/compiler/xla:cpu_function_runtime",
|
"//tensorflow/compiler/xla:cpu_function_runtime",
|
||||||
"//tensorflow/compiler/xla:executable_run_options",
|
"//tensorflow/compiler/xla:executable_run_options",
|
||||||
@ -181,7 +208,7 @@ cc_library(
|
|||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
":local_device_state",
|
":local_device_state",
|
||||||
":pjrt_client",
|
":pjrt_stream_executor_client",
|
||||||
":tracked_device_buffer",
|
":tracked_device_buffer",
|
||||||
"//tensorflow/compiler/xla:shape_util",
|
"//tensorflow/compiler/xla:shape_util",
|
||||||
"//tensorflow/compiler/xla:status",
|
"//tensorflow/compiler/xla:status",
|
||||||
@ -215,7 +242,7 @@ cc_library(
|
|||||||
srcs = ["interpreter_device.cc"],
|
srcs = ["interpreter_device.cc"],
|
||||||
hdrs = ["interpreter_device.h"],
|
hdrs = ["interpreter_device.h"],
|
||||||
deps = [
|
deps = [
|
||||||
":pjrt_client",
|
":pjrt_stream_executor_client",
|
||||||
"//tensorflow/compiler/xla:statusor",
|
"//tensorflow/compiler/xla:statusor",
|
||||||
"//tensorflow/compiler/xla/client:client_library",
|
"//tensorflow/compiler/xla/client:client_library",
|
||||||
"//tensorflow/compiler/xla/service:interpreter_plugin",
|
"//tensorflow/compiler/xla/service:interpreter_plugin",
|
||||||
@ -229,7 +256,7 @@ cc_library(
|
|||||||
srcs = ["cpu_device.cc"],
|
srcs = ["cpu_device.cc"],
|
||||||
hdrs = ["cpu_device.h"],
|
hdrs = ["cpu_device.h"],
|
||||||
deps = [
|
deps = [
|
||||||
":pjrt_client",
|
":pjrt_stream_executor_client",
|
||||||
"//tensorflow/compiler/xla:statusor",
|
"//tensorflow/compiler/xla:statusor",
|
||||||
"//tensorflow/compiler/xla/client:client_library",
|
"//tensorflow/compiler/xla/client:client_library",
|
||||||
"//tensorflow/compiler/xla/service:platform_util",
|
"//tensorflow/compiler/xla/service:platform_util",
|
||||||
@ -242,7 +269,7 @@ cc_library(
|
|||||||
srcs = ["gpu_device.cc"],
|
srcs = ["gpu_device.cc"],
|
||||||
hdrs = ["gpu_device.h"],
|
hdrs = ["gpu_device.h"],
|
||||||
deps = [
|
deps = [
|
||||||
":pjrt_client",
|
":pjrt_stream_executor_client",
|
||||||
"@com_google_absl//absl/container:flat_hash_map",
|
"@com_google_absl//absl/container:flat_hash_map",
|
||||||
"//tensorflow/compiler/xla/service/gpu:gpu_executable_run_options",
|
"//tensorflow/compiler/xla/service/gpu:gpu_executable_run_options",
|
||||||
"//tensorflow/compiler/xla:statusor",
|
"//tensorflow/compiler/xla:statusor",
|
||||||
@ -279,6 +306,7 @@ tf_cc_test(
|
|||||||
deps = [
|
deps = [
|
||||||
":gpu_device",
|
":gpu_device",
|
||||||
":pjrt_client",
|
":pjrt_client",
|
||||||
|
":pjrt_stream_executor_client",
|
||||||
"//tensorflow/compiler/xla:test",
|
"//tensorflow/compiler/xla:test",
|
||||||
"//tensorflow/compiler/xla/client:executable_build_options",
|
"//tensorflow/compiler/xla/client:executable_build_options",
|
||||||
"//tensorflow/compiler/xla/client:xla_builder",
|
"//tensorflow/compiler/xla/client:xla_builder",
|
||||||
|
@ -17,7 +17,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "absl/strings/str_cat.h"
|
#include "absl/strings/str_cat.h"
|
||||||
#include "tensorflow/compiler/xla/client/client_library.h"
|
#include "tensorflow/compiler/xla/client/client_library.h"
|
||||||
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
|
#include "tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h"
|
||||||
#include "tensorflow/compiler/xla/service/platform_util.h"
|
#include "tensorflow/compiler/xla/service/platform_util.h"
|
||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
|
@ -18,7 +18,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
|
||||||
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
|
#include "tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h"
|
||||||
#include "tensorflow/compiler/xla/statusor.h"
|
#include "tensorflow/compiler/xla/statusor.h"
|
||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/pjrt/gpu_device.h"
|
#include "tensorflow/compiler/xla/pjrt/gpu_device.h"
|
||||||
|
|
||||||
#include "absl/container/flat_hash_map.h"
|
#include "absl/container/flat_hash_map.h"
|
||||||
|
#include "tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h"
|
||||||
|
|
||||||
#ifdef NCCL_ENABLED
|
#ifdef NCCL_ENABLED
|
||||||
#include "third_party/nccl/nccl.h"
|
#include "third_party/nccl/nccl.h"
|
||||||
|
@ -19,7 +19,7 @@ limitations under the License.
|
|||||||
#include <memory>
|
#include <memory>
|
||||||
|
|
||||||
#include "tensorflow/compiler/xla/pjrt/distributed/client.h"
|
#include "tensorflow/compiler/xla/pjrt/distributed/client.h"
|
||||||
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
|
#include "tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h"
|
||||||
#include "tensorflow/compiler/xla/statusor.h"
|
#include "tensorflow/compiler/xla/statusor.h"
|
||||||
#include "tensorflow/core/common_runtime/bfc_allocator.h"
|
#include "tensorflow/core/common_runtime/bfc_allocator.h"
|
||||||
|
|
||||||
|
@ -17,7 +17,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "absl/strings/str_cat.h"
|
#include "absl/strings/str_cat.h"
|
||||||
#include "tensorflow/compiler/xla/client/client_library.h"
|
#include "tensorflow/compiler/xla/client/client_library.h"
|
||||||
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
|
#include "tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h"
|
||||||
#include "tensorflow/compiler/xla/service/platform_util.h"
|
#include "tensorflow/compiler/xla/service/platform_util.h"
|
||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
|
@ -18,7 +18,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
|
||||||
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
|
#include "tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h"
|
||||||
#include "tensorflow/compiler/xla/statusor.h"
|
#include "tensorflow/compiler/xla/statusor.h"
|
||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
|
@ -20,30 +20,19 @@ limitations under the License.
|
|||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "absl/container/flat_hash_map.h"
|
|
||||||
#include "absl/container/flat_hash_set.h"
|
|
||||||
#include "absl/container/inlined_vector.h"
|
|
||||||
#include "absl/strings/string_view.h"
|
#include "absl/strings/string_view.h"
|
||||||
#include "absl/synchronization/mutex.h"
|
|
||||||
#include "absl/synchronization/notification.h"
|
|
||||||
#include "absl/types/optional.h"
|
#include "absl/types/optional.h"
|
||||||
#include "absl/types/span.h"
|
#include "absl/types/span.h"
|
||||||
#include "tensorflow/compiler/xla/client/executable_build_options.h"
|
#include "tensorflow/compiler/xla/client/executable_build_options.h"
|
||||||
#include "tensorflow/compiler/xla/client/local_client.h"
|
|
||||||
#include "tensorflow/compiler/xla/client/xla_computation.h"
|
#include "tensorflow/compiler/xla/client/xla_computation.h"
|
||||||
#include "tensorflow/compiler/xla/layout.h"
|
#include "tensorflow/compiler/xla/layout.h"
|
||||||
#include "tensorflow/compiler/xla/pjrt/local_device_state.h"
|
#include "tensorflow/compiler/xla/service/hlo_cost_analysis.h"
|
||||||
#include "tensorflow/compiler/xla/pjrt/tracked_device_buffer.h"
|
|
||||||
#include "tensorflow/compiler/xla/service/computation_placer.h"
|
|
||||||
#include "tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h"
|
|
||||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||||
#include "tensorflow/compiler/xla/service/shaped_buffer.h"
|
|
||||||
#include "tensorflow/compiler/xla/shape.h"
|
#include "tensorflow/compiler/xla/shape.h"
|
||||||
#include "tensorflow/compiler/xla/status.h"
|
#include "tensorflow/compiler/xla/status.h"
|
||||||
#include "tensorflow/compiler/xla/statusor.h"
|
#include "tensorflow/compiler/xla/statusor.h"
|
||||||
#include "tensorflow/compiler/xla/util.h"
|
#include "tensorflow/compiler/xla/util.h"
|
||||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||||
#include "tensorflow/core/framework/allocator.h"
|
|
||||||
#include "tensorflow/core/lib/core/status.h"
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
#include "tensorflow/core/platform/casts.h"
|
#include "tensorflow/core/platform/casts.h"
|
||||||
#include "tensorflow/core/platform/fingerprint.h"
|
#include "tensorflow/core/platform/fingerprint.h"
|
||||||
@ -106,78 +95,6 @@ class PjRtDevice {
|
|||||||
virtual StatusOr<Literal> TransferFromOutfeed(const Shape& shape) const = 0;
|
virtual StatusOr<Literal> TransferFromOutfeed(const Shape& shape) const = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
class PjRtStreamExecutorDevice : public PjRtDevice {
|
|
||||||
public:
|
|
||||||
explicit PjRtStreamExecutorDevice(
|
|
||||||
int id, std::unique_ptr<LocalDeviceState> local_device_state,
|
|
||||||
std::string device_kind, int host_id = 0)
|
|
||||||
: id_(id),
|
|
||||||
device_ordinal_(
|
|
||||||
local_device_state ? local_device_state->device_ordinal() : -1),
|
|
||||||
local_device_state_(std::move(local_device_state)),
|
|
||||||
host_id_(host_id),
|
|
||||||
device_kind_(std::move(device_kind)) {}
|
|
||||||
~PjRtStreamExecutorDevice() override {}
|
|
||||||
|
|
||||||
// Must set client exactly once.
|
|
||||||
void SetClient(PjRtClient* client) {
|
|
||||||
CHECK(client_ == nullptr);
|
|
||||||
client_ = client;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Task ID. This is always 0 on single-task setup.
|
|
||||||
int host_id() const override { return host_id_; }
|
|
||||||
|
|
||||||
// Return `platform_id` from client.
|
|
||||||
PjRtPlatformId platform_id() const;
|
|
||||||
|
|
||||||
// Return `platform_name` from client.
|
|
||||||
const std::string& platform_name() const;
|
|
||||||
|
|
||||||
PjRtClient* client() const override { return client_; }
|
|
||||||
|
|
||||||
// The ID of this device. IDs are unique among devices of this type
|
|
||||||
// (e.g. CPUs, GPUs). On multi-host platforms, this will be unique across all
|
|
||||||
// hosts' devices. This is the ID that should be used in a DeviceAssignment.
|
|
||||||
int id() const override { return id_; }
|
|
||||||
|
|
||||||
bool IsAddressable() const override { return device_ordinal_ != -1; }
|
|
||||||
|
|
||||||
int local_hardware_id() const override { return device_ordinal_; }
|
|
||||||
|
|
||||||
// If this is a device local to this host, returns a LocalDeviceState object
|
|
||||||
// that can be used to manipulate the device. Returns nullptr if the device is
|
|
||||||
// not local to this host.
|
|
||||||
LocalDeviceState* local_device_state() const {
|
|
||||||
return local_device_state_.get();
|
|
||||||
}
|
|
||||||
|
|
||||||
// If this is a device local to this host, returns a LocalDeviceState object
|
|
||||||
// that can be used to manipulate the device. Returns an error if the device
|
|
||||||
// is not local to this host.
|
|
||||||
StatusOr<LocalDeviceState*> GetLocalDeviceState() const;
|
|
||||||
|
|
||||||
// A vendor-dependent string that uniquely identifies the kind of device.
|
|
||||||
const std::string& device_kind() const override { return device_kind_; }
|
|
||||||
|
|
||||||
std::string DebugString() const override;
|
|
||||||
|
|
||||||
// Transfer the given literal to the infeed queue of the given localdevice.
|
|
||||||
Status TransferToInfeed(const LiteralSlice& literal) const override;
|
|
||||||
|
|
||||||
// Transfer and return a value of the given shape from the outfeed of the
|
|
||||||
// given device.
|
|
||||||
StatusOr<Literal> TransferFromOutfeed(const Shape& shape) const override;
|
|
||||||
|
|
||||||
private:
|
|
||||||
const int id_;
|
|
||||||
const int device_ordinal_; // -1 means not local.
|
|
||||||
const std::unique_ptr<LocalDeviceState> local_device_state_;
|
|
||||||
const int host_id_;
|
|
||||||
const std::string device_kind_;
|
|
||||||
PjRtClient* client_ = nullptr;
|
|
||||||
};
|
|
||||||
|
|
||||||
// Forward declaration.
|
// Forward declaration.
|
||||||
class PjRtBuffer;
|
class PjRtBuffer;
|
||||||
|
|
||||||
@ -333,181 +250,6 @@ class PjRtClient {
|
|||||||
virtual StatusOr<ChannelHandle> CreateHostToDeviceChannelHandle() = 0;
|
virtual StatusOr<ChannelHandle> CreateHostToDeviceChannelHandle() = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
class PjRtStreamExecutorClient : public PjRtClient {
|
|
||||||
public:
|
|
||||||
// `allocator` may null, in which case the platform default allocator is used.
|
|
||||||
explicit PjRtStreamExecutorClient(
|
|
||||||
std::string platform_name, LocalClient* client,
|
|
||||||
std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices,
|
|
||||||
int host_id, std::unique_ptr<se::DeviceMemoryAllocator> allocator,
|
|
||||||
std::unique_ptr<tensorflow::Allocator> host_memory_allocator,
|
|
||||||
bool should_stage_host_to_device_transfers,
|
|
||||||
std::unique_ptr<gpu::GpuExecutableRunOptions> gpu_run_options);
|
|
||||||
~PjRtStreamExecutorClient() override = default;
|
|
||||||
|
|
||||||
int host_id() const override { return host_id_; }
|
|
||||||
|
|
||||||
int device_count() const override { return devices_.size(); }
|
|
||||||
int addressable_device_count() const override {
|
|
||||||
return local_devices_.size();
|
|
||||||
}
|
|
||||||
absl::Span<PjRtDevice* const> devices() const override { return devices_; }
|
|
||||||
absl::Span<PjRtDevice* const> local_devices() const override {
|
|
||||||
return local_devices_;
|
|
||||||
}
|
|
||||||
|
|
||||||
StatusOr<PjRtDevice*> LookupDevice(int device_id) const override {
|
|
||||||
auto it = id_to_device_.find(device_id);
|
|
||||||
if (it != id_to_device_.end()) {
|
|
||||||
return it->second;
|
|
||||||
}
|
|
||||||
return InvalidArgument("No matching device found for device_id %d",
|
|
||||||
device_id);
|
|
||||||
}
|
|
||||||
|
|
||||||
StatusOr<PjRtDevice*> LookupAddressableDevice(
|
|
||||||
int local_hardware_id) const override;
|
|
||||||
|
|
||||||
PjRtPlatformId platform_id() const override { return platform_id_; }
|
|
||||||
const std::string& platform_name() const override { return platform_name_; }
|
|
||||||
|
|
||||||
// Most platforms expect device-to-device transfers to be enqueued on the
|
|
||||||
// source d2d stream, but some platforms use the destination d2d stream. This
|
|
||||||
// function specifies which one the platform expects.
|
|
||||||
virtual bool EnqueueD2DTransfersOnSrcStream() const { return true; }
|
|
||||||
|
|
||||||
StatusOr<DeviceAssignment> GetDefaultDeviceAssignment(
|
|
||||||
int num_replicas, int num_partitions) const override;
|
|
||||||
|
|
||||||
StatusOr<std::unique_ptr<PjRtExecutable>> Compile(
|
|
||||||
const XlaComputation& computation, CompileOptions options) override;
|
|
||||||
|
|
||||||
// Generates a unique fingerprint for `executable`.
|
|
||||||
StatusOr<absl::optional<std::string>> ExecutableFingerprint(
|
|
||||||
const PjRtExecutable& executable) const override {
|
|
||||||
return absl::optional<std::string>();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Returns a backend-specific HLO cost analysis visitor.
|
|
||||||
std::unique_ptr<HloCostAnalysis> GetHloCostAnalysis() override;
|
|
||||||
|
|
||||||
// Creates a buffer on the device without initializing or copying any data.
|
|
||||||
// An optional `definition_event` may be speficied that can be used to
|
|
||||||
// ensure the buffer isn't referenced until some external mechanism has
|
|
||||||
// initialized the data.
|
|
||||||
// NOTE: The sequencing mechanism is not guaranteed to be supported by all
|
|
||||||
// future backends and so callers should avoid wherever possible.
|
|
||||||
StatusOr<std::unique_ptr<PjRtBuffer>> CreateUninitializedBuffer(
|
|
||||||
const Shape& shape, PjRtDevice* device) override;
|
|
||||||
StatusOr<std::unique_ptr<PjRtBuffer>> CreateUninitializedBuffer(
|
|
||||||
const Shape& shape, PjRtDevice* device,
|
|
||||||
std::shared_ptr<BufferSequencingEvent> definition_event);
|
|
||||||
|
|
||||||
StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostBuffer(
|
|
||||||
const void* data, const Shape& shape,
|
|
||||||
HostBufferSemantics host_buffer_semantics,
|
|
||||||
std::shared_ptr<void> buffer_reference, PjRtDevice* device) override;
|
|
||||||
|
|
||||||
// Note that literal must remain in scope until the transfer has completed, so
|
|
||||||
// the caller should, for example, wait for BlockHostUntilReady() completes on
|
|
||||||
// the return value before letting literal go out of scope.
|
|
||||||
StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostLiteral(
|
|
||||||
const LiteralSlice& literal, PjRtDevice* device) override;
|
|
||||||
|
|
||||||
// Asynchronously makes a vector of PjRtBuffers that can be used to receive
|
|
||||||
// cross host transfers using `client` on `device'. `shapes` must be the exact
|
|
||||||
// shapes, with identical layouts, corresponding to the buffers that will be
|
|
||||||
// sent. When resources for the transfer are available, notifier will be
|
|
||||||
// called with a vector of PjRtCrossHostRecvBuffer structs, one for each
|
|
||||||
// shape in `shapes`. Each struct contains a buffer that will contain the
|
|
||||||
// received value, and an opaque string that should be transmitted to the
|
|
||||||
// sending host and used in a call to CopyToRemoteDevice. None of the recv
|
|
||||||
// buffers will become ready until *all* of the sends have completed.
|
|
||||||
void MakeCrossHostReceiveBuffers(
|
|
||||||
absl::Span<const Shape> shapes, PjRtDevice* device,
|
|
||||||
PjRtCrossHostRecvNotifier&& notifier) override;
|
|
||||||
|
|
||||||
StatusOr<ChannelHandle> CreateChannelHandle() override {
|
|
||||||
return client()->CreateChannelHandle();
|
|
||||||
}
|
|
||||||
StatusOr<ChannelHandle> CreateDeviceToHostChannelHandle() override {
|
|
||||||
return client()->CreateDeviceToHostChannelHandle();
|
|
||||||
}
|
|
||||||
StatusOr<ChannelHandle> CreateHostToDeviceChannelHandle() override {
|
|
||||||
return client()->CreateHostToDeviceChannelHandle();
|
|
||||||
}
|
|
||||||
|
|
||||||
LocalDeviceState& device_state(int device_ordinal) const {
|
|
||||||
return *tensorflow::down_cast<PjRtStreamExecutorDevice*>(
|
|
||||||
local_devices_.at(device_ordinal))
|
|
||||||
->local_device_state();
|
|
||||||
}
|
|
||||||
LocalClient* client() const { return client_; }
|
|
||||||
se::DeviceMemoryAllocator* allocator() const { return allocator_; }
|
|
||||||
tensorflow::Allocator* host_memory_allocator() const {
|
|
||||||
return host_memory_allocator_.get();
|
|
||||||
}
|
|
||||||
bool should_stage_host_to_device_transfers() const {
|
|
||||||
return should_stage_host_to_device_transfers_;
|
|
||||||
}
|
|
||||||
|
|
||||||
gpu::GpuExecutableRunOptions* gpu_run_options() const {
|
|
||||||
return gpu_run_options_.get();
|
|
||||||
}
|
|
||||||
|
|
||||||
tensorflow::thread::ThreadPool* h2d_transfer_pool() {
|
|
||||||
return &h2d_transfer_pool_;
|
|
||||||
}
|
|
||||||
|
|
||||||
protected:
|
|
||||||
friend class PjRtStreamExecutorBuffer;
|
|
||||||
virtual void EnqueueCrossHostReceive(
|
|
||||||
std::vector<std::unique_ptr<PjRtBuffer>>&& buffers,
|
|
||||||
std::shared_ptr<BufferSequencingEvent> definition_event,
|
|
||||||
PjRtCrossHostRecvNotifier&& notifier) const {
|
|
||||||
notifier(Unimplemented("Cross host receives not implemented."));
|
|
||||||
}
|
|
||||||
|
|
||||||
virtual Status CopyToRemoteDevice(
|
|
||||||
PjRtBuffer* buffer, absl::string_view serialized_descriptor) const {
|
|
||||||
return Unimplemented("Cross host sends not implemented.");
|
|
||||||
}
|
|
||||||
|
|
||||||
const PjRtPlatformId platform_id_;
|
|
||||||
const std::string platform_name_;
|
|
||||||
LocalClient* client_;
|
|
||||||
|
|
||||||
// Allocator to be used for staging memory transfers to devices.
|
|
||||||
std::unique_ptr<tensorflow::Allocator> host_memory_allocator_;
|
|
||||||
|
|
||||||
// Includes all devices, including non-local devices on multi-host platforms.
|
|
||||||
std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> owned_devices_;
|
|
||||||
// Pointers to `owned_devices_`.
|
|
||||||
std::vector<PjRtDevice*> devices_;
|
|
||||||
// Maps Device::id() to the corresponding Device. Includes all devices.
|
|
||||||
std::map<int, PjRtDevice*> id_to_device_;
|
|
||||||
// Local devices indexed by local device ordinal.
|
|
||||||
std::vector<PjRtDevice*> local_devices_;
|
|
||||||
int host_id_;
|
|
||||||
|
|
||||||
se::DeviceMemoryAllocator* allocator_;
|
|
||||||
std::unique_ptr<se::DeviceMemoryAllocator> owned_allocator_;
|
|
||||||
|
|
||||||
// Should we always prefer to stage host-to-device transfers via memory
|
|
||||||
// allocated on host_memory_allocator_? True only on GPU, where we prefer to
|
|
||||||
// transfer via pinned memory.
|
|
||||||
bool should_stage_host_to_device_transfers_;
|
|
||||||
|
|
||||||
std::unique_ptr<gpu::GpuExecutableRunOptions> gpu_run_options_;
|
|
||||||
|
|
||||||
tensorflow::thread::ThreadPool h2d_transfer_pool_;
|
|
||||||
};
|
|
||||||
|
|
||||||
// Converts a 2D set of Device objects indexed by [replica][partition] into an
|
|
||||||
// xla::DeviceAssignment.
|
|
||||||
StatusOr<DeviceAssignment> DevicesToDeviceAssignment(
|
|
||||||
absl::Span<const std::vector<PjRtDevice*>> devices);
|
|
||||||
|
|
||||||
// Holds a reference from Python to a tuple of device buffers. A PjRtBuffer
|
// Holds a reference from Python to a tuple of device buffers. A PjRtBuffer
|
||||||
// can be either valid or invalid. An invalid buffer is one that has never been
|
// can be either valid or invalid. An invalid buffer is one that has never been
|
||||||
// initialized, or a buffer that has been deleted (e.g., by calling Delete, or
|
// initialized, or a buffer that has been deleted (e.g., by calling Delete, or
|
||||||
@ -625,393 +367,6 @@ class PjRtBuffer {
|
|||||||
virtual bool IsOnCpu() const = 0;
|
virtual bool IsOnCpu() const = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
class PjRtStreamExecutorBuffer : public PjRtBuffer {
|
|
||||||
public:
|
|
||||||
// Helper class to retain a "hold" on a PjRtBuffer. A ScopedHold may not
|
|
||||||
// outlive its parent PjRtBuffer.
|
|
||||||
//
|
|
||||||
// There are three types of hold, as follows:
|
|
||||||
//
|
|
||||||
// 1) Usage hold: a transient hold while an operation using the buffer is
|
|
||||||
// being enqueued onto a stream.
|
|
||||||
// A client acquires a usage hold by calling
|
|
||||||
// PjRtBuffer::GetBufferWithHold(kUsage) or the convenience wrapper
|
|
||||||
// GetBufferWithUsageHold(). If the enqueue completes successfully the hold
|
|
||||||
// should be released using a call to ConvertUsageHold. If the ScopedHold is
|
|
||||||
// deleted without ConvertUsageHold being called, e.g., on error, the hold is
|
|
||||||
// dropped. It is legal to drop a usage hold instead of calling
|
|
||||||
// ConvertUsageHold, even if the buffer was successfully enqueued, as long as
|
|
||||||
// the client ensures that all necessary synchronization has been done.
|
|
||||||
//
|
|
||||||
// 2) External hold: a potentially long-lived hold while the buffer is being
|
|
||||||
// shared by an external framework, e.g., NumPy.
|
|
||||||
// A client acquires an external hold by calling
|
|
||||||
// PjRtBuffer::GetBufferWithHold(kExternal) or the convenience wrapper
|
|
||||||
// GetBufferWithExternalReference and releases it by deleting the ScopedHold.
|
|
||||||
// The external framework should not modify the underlying buffer unless it is
|
|
||||||
// confident via its own synchronization that modifications do not race with
|
|
||||||
// reads from the PjRtBuffer.
|
|
||||||
//
|
|
||||||
// 3) Donation hold: a transient hold while an execution that donates the
|
|
||||||
// buffer is being enqueued onto the compute stream.
|
|
||||||
// A client acquires a donation hold by calling
|
|
||||||
// PjRtBuffer::GetBufferWithHold(kDonation). If the enqueue completes
|
|
||||||
// successfully the hold should be released using a call to ConfirmDonation
|
|
||||||
// after which the buffer is invalid. If the ScopedHold is deleted without
|
|
||||||
// ConfirmDonation being called, e.g., on error, the hold is dropped and the
|
|
||||||
// buffer remains valid. If the buffer is successfully enqueued the client
|
|
||||||
// *must* call ConfirmDonation.
|
|
||||||
//
|
|
||||||
// Donation holds behave like exclusive write locks: when a donation hold
|
|
||||||
// has been acquired, any attempt to acquire another hold of any type will
|
|
||||||
// block until the donation hold is dropped or confirmed. Acquiring a donation
|
|
||||||
// hold will fail with an error if there is any outstanding external hold, and
|
|
||||||
// will block if there are any outstanding usage holds until those holds are
|
|
||||||
// dropped or converted.
|
|
||||||
//
|
|
||||||
// Calls to PjRtBuffer::Release (and transitively to
|
|
||||||
// PjRtBuffer::Delete() and ~PjRtBuffer()) will block until all usage
|
|
||||||
// and donation holds are either deleted or converted/confirmed.
|
|
||||||
class ScopedHold {
|
|
||||||
public:
|
|
||||||
enum Type { kUsage = 0, kExternalReference, kDonation, kMaxValue };
|
|
||||||
// Use a State enum instead of encoding the state in an error Status to
|
|
||||||
// avoid creating Status values in non-error cases. Creating a Status
|
|
||||||
// entails several allocations and can add O(us) to every use of a hold.
|
|
||||||
enum State {
|
|
||||||
kUninitialized = 0,
|
|
||||||
kValid,
|
|
||||||
kMoved,
|
|
||||||
kConverted,
|
|
||||||
kReleased,
|
|
||||||
kDonated,
|
|
||||||
kError
|
|
||||||
};
|
|
||||||
|
|
||||||
~ScopedHold();
|
|
||||||
ScopedHold(ScopedHold&& other);
|
|
||||||
ScopedHold(const ScopedHold&) = delete;
|
|
||||||
ScopedHold& operator=(const ScopedHold&) = delete;
|
|
||||||
|
|
||||||
Type type() const { return type_; }
|
|
||||||
|
|
||||||
Status status() const {
|
|
||||||
// Lazily create Status values only when they are requested.
|
|
||||||
switch (state_) {
|
|
||||||
case kUninitialized:
|
|
||||||
return InvalidArgument("Buffer has not been initialized");
|
|
||||||
case kValid:
|
|
||||||
return Status::OK();
|
|
||||||
case kMoved:
|
|
||||||
return InvalidArgument("Buffer has been moved.");
|
|
||||||
case kConverted:
|
|
||||||
return InvalidArgument("Buffer has been converted");
|
|
||||||
case kReleased:
|
|
||||||
return InvalidArgument("Buffer has been released");
|
|
||||||
case kDonated:
|
|
||||||
return InvalidArgument("Buffer has been donated");
|
|
||||||
case kError:
|
|
||||||
return buffer_or_.status();
|
|
||||||
default:
|
|
||||||
CHECK(false) << "Unexpected state value " << state_;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
bool ok() const { return state_ == kValid; }
|
|
||||||
|
|
||||||
// Access to the underlying device buffer storage. Requires this->ok().
|
|
||||||
const std::shared_ptr<TrackedDeviceBuffer>& buffer() const {
|
|
||||||
CHECK_EQ(state_, kValid);
|
|
||||||
CHECK_NE(buffer_or_.ValueOrDie(), nullptr);
|
|
||||||
return buffer_or_.ValueOrDie();
|
|
||||||
}
|
|
||||||
TrackedDeviceBuffer* operator->() const { return buffer().get(); }
|
|
||||||
const TrackedDeviceBuffer& operator*() const { return *buffer(); }
|
|
||||||
|
|
||||||
// Converts the hold into a usage event. Only valid for holds of type
|
|
||||||
// kUsage.
|
|
||||||
//
|
|
||||||
// usage_stream: the stream that the buffer was used on.
|
|
||||||
// event: an event that has been recorded on usage_stream after
|
|
||||||
// the buffer was used.
|
|
||||||
// reference_held: true if and only if the caller has caused a
|
|
||||||
// reference to this->buffer() to stay live until after
|
|
||||||
// the host is sure that the usage (transfer or execution)
|
|
||||||
// has completed.
|
|
||||||
void ConvertUsageHold(se::Stream* usage_stream,
|
|
||||||
std::shared_ptr<BufferSequencingEvent> event,
|
|
||||||
bool reference_held);
|
|
||||||
|
|
||||||
// Confirms that the buffer was successfully donated to an execution.
|
|
||||||
// Only valid for holds of type kDonation. Causes the buffer to become
|
|
||||||
// invalid.
|
|
||||||
void ConfirmDonation();
|
|
||||||
|
|
||||||
// Adds the held device buffers in order to 'iterator'. Used to add the
|
|
||||||
// buffers to an ExecutionInput. We require but do not verify that
|
|
||||||
// 'iterator' when passed in is pointing to a sub-tuple of the
|
|
||||||
// ExecutionInput whose on_device_shape matches that of the
|
|
||||||
// TrackedDeviceBuffer. 'end' is used to check that 'iterator' doesn't run
|
|
||||||
// out of bounds. Donates the device buffers if the hold type is kDonation,
|
|
||||||
// otherwise retains ownership of the device buffers.
|
|
||||||
void AddToInput(ShapeTree<MaybeOwningDeviceMemory>::iterator* iterator,
|
|
||||||
const ShapeTree<MaybeOwningDeviceMemory>::iterator& end,
|
|
||||||
ExecutionInput* execution_input,
|
|
||||||
se::DeviceMemoryAllocator* allocator) const;
|
|
||||||
|
|
||||||
private:
|
|
||||||
friend class PjRtStreamExecutorBuffer;
|
|
||||||
friend class PjRtStreamExecutorClient;
|
|
||||||
|
|
||||||
// Helper struct that makes it possible to move a ScopedHold through a
|
|
||||||
// closure.
|
|
||||||
using ForClosure =
|
|
||||||
std::tuple<PjRtStreamExecutorBuffer*, Type, State,
|
|
||||||
StatusOr<std::shared_ptr<TrackedDeviceBuffer>>>;
|
|
||||||
|
|
||||||
ScopedHold(PjRtStreamExecutorBuffer* parent, Type type)
|
|
||||||
: parent_(parent), type_(type), state_(kUninitialized) {}
|
|
||||||
explicit ScopedHold(const ForClosure& closure_helper)
|
|
||||||
: parent_(std::get<0>(closure_helper)),
|
|
||||||
type_(std::get<1>(closure_helper)),
|
|
||||||
state_(std::get<2>(closure_helper)),
|
|
||||||
buffer_or_(std::get<3>(closure_helper)) {
|
|
||||||
// Check the buffer is not in an error state.
|
|
||||||
CHECK(buffer_or_.ValueOrDie() != nullptr);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Sets buffer state.
|
|
||||||
void SetState(State state) { state_ = state; }
|
|
||||||
|
|
||||||
// Sets buffer_or_. Called by parent_ to initialize the hold.
|
|
||||||
void Acquire(StatusOr<std::shared_ptr<TrackedDeviceBuffer>>&& buffer_or);
|
|
||||||
// Releases the contents of *this, so *this can subsequently be
|
|
||||||
// deleted without releasing the parent's hold. Should be passed to the
|
|
||||||
// appropriate constructor of another ScopedHold, e.g., when a hold must be
|
|
||||||
// passed through a closure that is incompatible with std::move.
|
|
||||||
ForClosure ToClosure();
|
|
||||||
|
|
||||||
PjRtStreamExecutorBuffer* const parent_;
|
|
||||||
const Type type_;
|
|
||||||
|
|
||||||
// There is an invariant that if ok() then
|
|
||||||
// buffer_or_.ValueOrDie() != nullptr.
|
|
||||||
State state_;
|
|
||||||
StatusOr<std::shared_ptr<TrackedDeviceBuffer>> buffer_or_;
|
|
||||||
};
|
|
||||||
|
|
||||||
PjRtStreamExecutorBuffer(Shape on_host_shape, Shape on_device_shape,
|
|
||||||
std::shared_ptr<TrackedDeviceBuffer> device_buffer,
|
|
||||||
PjRtClient* client, PjRtDevice* device);
|
|
||||||
~PjRtStreamExecutorBuffer() override;
|
|
||||||
|
|
||||||
PjRtStreamExecutorBuffer(const PjRtStreamExecutorBuffer&) = delete;
|
|
||||||
PjRtStreamExecutorBuffer(PjRtStreamExecutorBuffer&&) = delete;
|
|
||||||
PjRtStreamExecutorBuffer& operator=(const PjRtStreamExecutorBuffer&) = delete;
|
|
||||||
PjRtStreamExecutorBuffer& operator=(PjRtStreamExecutorBuffer&&) = delete;
|
|
||||||
|
|
||||||
const Shape& on_host_shape() const override { return on_host_shape_; }
|
|
||||||
const Shape& on_device_shape() const override { return on_device_shape_; }
|
|
||||||
PjRtStreamExecutorDevice* device() const override { return device_; }
|
|
||||||
PjRtPlatformId platform_id() const { return client_->platform_id(); }
|
|
||||||
const std::string& platform_name() const { return client_->platform_name(); }
|
|
||||||
PjRtStreamExecutorClient* client() const override { return client_; }
|
|
||||||
bool IsEmptyTuple() const {
|
|
||||||
return on_host_shape_.IsTuple() && on_host_shape_.tuple_shapes_size() == 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Returns the size of the on-device representation of this buffer in bytes.
|
|
||||||
int64 OnDeviceSizeInBytes() const override;
|
|
||||||
|
|
||||||
// Implement PjRtBuffer::ExternalReferenceHold a wrapped
|
|
||||||
// ScopedHold::kExternalReference.
|
|
||||||
class ScopedHoldAsExternalReference
|
|
||||||
: public PjRtBuffer::ExternalReferenceHold {
|
|
||||||
public:
|
|
||||||
explicit ScopedHoldAsExternalReference(ScopedHold hold)
|
|
||||||
: external_reference_(std::move(hold)) {
|
|
||||||
CHECK(hold.type() == ScopedHold::kExternalReference);
|
|
||||||
}
|
|
||||||
|
|
||||||
~ScopedHoldAsExternalReference() override = default;
|
|
||||||
|
|
||||||
void* OpaqueDeviceMemoryDataPointer() const override {
|
|
||||||
return external_reference_->device_memory().front().opaque();
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
ScopedHold external_reference_;
|
|
||||||
};
|
|
||||||
StatusOr<std::unique_ptr<ExternalReferenceHold>> AcquireExternalReference()
|
|
||||||
override;
|
|
||||||
|
|
||||||
StatusOr<absl::optional<std::shared_ptr<void>>> ReleaseDeviceMemoryOwnership(
|
|
||||||
bool wait_for_operations_to_complete) override;
|
|
||||||
|
|
||||||
// Returns the buffer's value as an XLA Literal. If the value has previously
|
|
||||||
// been prefetched to the host, then returns the prefetched version, otherwise
|
|
||||||
// copies the buffer to the host. Blocks until the value is ready. If
|
|
||||||
// `discard_cached_copy` is true then buffer will no longer keep hold of a
|
|
||||||
// cached copy of the literal (i.e. The reference to the host value will be
|
|
||||||
// removed.) If a layout is passed than a literal with this layout will be
|
|
||||||
// returned.
|
|
||||||
using PjRtBuffer::ToLiteral;
|
|
||||||
StatusOr<std::shared_ptr<Literal>> ToLiteral(
|
|
||||||
bool discard_cached_copy, absl::optional<xla::Layout> layout) override;
|
|
||||||
|
|
||||||
// Initiates a copy of the buffer to the host. Does not block waiting for
|
|
||||||
// the transfer to complete. The value can be retrieved by a later call to
|
|
||||||
// ToLiteral(). If a layout is passed then a cached copy with this layout will
|
|
||||||
// be created.
|
|
||||||
using PjRtBuffer::CopyToHostAsync;
|
|
||||||
Status CopyToHostAsync(absl::optional<xla::Layout> layout) override;
|
|
||||||
|
|
||||||
// Drops the buffer's reference to its associated device memory, leaving the
|
|
||||||
// buffer in an invalid state. The memory will be freed lazily when all async
|
|
||||||
// operations using the buffer have completed, according to the allocation
|
|
||||||
// semantics of the underlying platform. Delete may briefly block if another
|
|
||||||
// thread is in the process of enqueuing an operation on this buffer, but it
|
|
||||||
// will never block for a stream operation to complete. If an external
|
|
||||||
// framework holds a reference to the TrackedDeviceBuffer via
|
|
||||||
// GetBufferWithExternalReference, the memory will not be freed until the
|
|
||||||
// external framework drops the reference.
|
|
||||||
void Delete() override;
|
|
||||||
|
|
||||||
// True if and only if Delete or Release has previously been called.
|
|
||||||
bool IsDeleted() override;
|
|
||||||
|
|
||||||
// Returns a view of the PjRtBuffer device memory as a ShapedBuffer. The
|
|
||||||
// PjRtBuffer retains ownership of the device buffers.
|
|
||||||
StatusOr<ShapedBuffer> AsShapedBuffer() const;
|
|
||||||
|
|
||||||
// Returns a hold on the TrackedDeviceBuffer holding the device
|
|
||||||
// buffers. See comment on ScopedHold.
|
|
||||||
ScopedHold GetBufferWithHold(ScopedHold::Type type);
|
|
||||||
ScopedHold GetBufferWithUsageHold() {
|
|
||||||
return GetBufferWithHold(ScopedHold::kUsage);
|
|
||||||
}
|
|
||||||
ScopedHold GetBufferWithExternalReference() {
|
|
||||||
return GetBufferWithHold(ScopedHold::kExternalReference);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Copies the buffer to device `dst_device`, performing a d2d transfer when
|
|
||||||
// `dst_device` is sharing the same Client, and performing a d2h and h2d copy
|
|
||||||
// if `dst_device` lives on a different Client.
|
|
||||||
// Returns an error if the buffer is already on dst_device.
|
|
||||||
StatusOr<std::unique_ptr<PjRtBuffer>> CopyToDevice(
|
|
||||||
PjRtDevice* dst_device) override;
|
|
||||||
|
|
||||||
// Copies the buffer to the remote device encoded in serialized_descriptor.
|
|
||||||
// This call must be preceded by a call to MakeCrossHostReceiveBuffers on the
|
|
||||||
// remote host's destination device. MakeCrossHostReceiveBuffers takes an
|
|
||||||
// array of shapes to construct the destination buffers, and a callback
|
|
||||||
// supplies an array containing both the destination buffers, and a serialized
|
|
||||||
// descriptor for each buffer. For each destination buffer there should be a
|
|
||||||
// matching call to src->CopyToRemoteDevice on a remote host for a src buffer
|
|
||||||
// of the corresponding shape. serialized_descriptor is the string returned by
|
|
||||||
// the callback along with the corresponding destination buffer.
|
|
||||||
Status CopyToRemoteDevice(absl::string_view serialized_descriptor) override;
|
|
||||||
|
|
||||||
// Blocks the host until the buffer's value has been computed and is ready for
|
|
||||||
// immediate use on the device. Useful in particular for timing benchmarks.
|
|
||||||
Status BlockHostUntilReady() override;
|
|
||||||
|
|
||||||
// Whether this buffer is on CPU and thus allows for certain optimizations.
|
|
||||||
bool IsOnCpu() const override;
|
|
||||||
|
|
||||||
// Similar to Delete, drops the buffer's reference to its associated device
|
|
||||||
// memory, leaving the buffer in an invalid state, but returns the
|
|
||||||
// TrackedDeviceBuffer rather than freeing the device memory, so that another
|
|
||||||
// framework can take ownership of it. The buffer returned from Release may
|
|
||||||
// be safely dropped at any time even if it still has pending async
|
|
||||||
// operations. The client should call BlockHostUntilReady before calling
|
|
||||||
// Release with wait_for_operations_to_complete=false, to ensure that the host
|
|
||||||
// has synchronized past any outstanding write operations to the buffer. If
|
|
||||||
// wait_for_operations_to_complete=true the host will block until any
|
|
||||||
// potentially outstanding asynchronous operations have completed before
|
|
||||||
// returning, in which case it is safe to read or mutate the returned buffer.
|
|
||||||
// If the buffer was shared via an external reference it is the client's
|
|
||||||
// responsibility that accesses via that reference do not interfere with
|
|
||||||
// accesses via the buffer returned from Release.
|
|
||||||
StatusOr<std::shared_ptr<TrackedDeviceBuffer>> Release(
|
|
||||||
bool wait_for_operations_to_complete);
|
|
||||||
|
|
||||||
private:
|
|
||||||
friend class PjRtClient;
|
|
||||||
// The cached value of the buffer on the host, produced either from a call to
|
|
||||||
// CopyToHost or from a call to ToLiteral. Once a value has been fetched to
|
|
||||||
// the host, it persists Delete() is called or the PjRtBuffer is destroyed.
|
|
||||||
struct HostValue {
|
|
||||||
absl::Notification ready;
|
|
||||||
// status and value are valid for reading only after `ready` has been
|
|
||||||
// notified.
|
|
||||||
Status status;
|
|
||||||
std::shared_ptr<Literal> value;
|
|
||||||
};
|
|
||||||
|
|
||||||
// Blocks in mu_.Await until there are no more usage holds.
|
|
||||||
void WaitForOutstandingUsageHolds() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
|
||||||
|
|
||||||
// Blocks in mu_.Await until there is no donation hold.
|
|
||||||
void WaitForOutstandingDonationHold() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
|
||||||
|
|
||||||
// Adds a hold of 'type' and returns device_buffer_. Returns an error if
|
|
||||||
// device_buffer_ is null, or if a donation hold was requested when there is
|
|
||||||
// an outstanding external hold.
|
|
||||||
StatusOr<std::shared_ptr<TrackedDeviceBuffer>> GetBufferForHoldLocked(
|
|
||||||
ScopedHold::Type type) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
|
||||||
|
|
||||||
// Adds a hold of hold->type() and initializes `hold` with device_buffer_.
|
|
||||||
// Initializes hold with an error if device_buffer_ is null, or if a donation
|
|
||||||
// hold was requested when there is an outstanding external hold.
|
|
||||||
void AcquireHoldLocked(ScopedHold* hold) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
|
||||||
|
|
||||||
// Drops a usage hold and calls device_buffer_->AddUsageEvent. Does a sanity
|
|
||||||
// check that buffer==device_buffer_ or device_buffer_==nullptr. Called after
|
|
||||||
// device_buffer_ was successfully enqueued on a stream.
|
|
||||||
void ConvertUsageHold(TrackedDeviceBuffer* buffer, se::Stream* usage_stream,
|
|
||||||
std::shared_ptr<BufferSequencingEvent> event,
|
|
||||||
bool reference_held);
|
|
||||||
|
|
||||||
// Drops a donation hold and makes *this invalid for further use. Does a
|
|
||||||
// sanity check that buffer==device_buffer_. Called after device_buffer_ was
|
|
||||||
// successfully donated to an execution.
|
|
||||||
void ConfirmDonation(TrackedDeviceBuffer* device_buffer);
|
|
||||||
|
|
||||||
// Initiates a copy of the buffer to the host. Does not block waiting for
|
|
||||||
// the transfer to complete. A host value is returned and if
|
|
||||||
// `discard_cached_copy` is false stored in an internal buffer so that future
|
|
||||||
// transfers don't have to transfer the data from host again. If a layout is
|
|
||||||
// passed then a literal of this layout will be returned and possibly cached.
|
|
||||||
StatusOr<std::shared_ptr<HostValue>> CopyToHostAsyncInternal(
|
|
||||||
bool discard_cached_copy, absl::optional<xla::Layout> layout);
|
|
||||||
|
|
||||||
// Drops a hold without taking any other action. Does a sanity check that
|
|
||||||
// buffer==device_buffer_ or device_buffer_==nullptr.
|
|
||||||
void DropHold(ScopedHold::Type type, TrackedDeviceBuffer* buffer);
|
|
||||||
|
|
||||||
StatusOr<std::pair<std::unique_ptr<PjRtBuffer>,
|
|
||||||
std::shared_ptr<BufferSequencingEvent>>>
|
|
||||||
CopyToDeviceHelper(PjRtDevice* dst_device, LocalDeviceState* dst_local_device,
|
|
||||||
LocalDeviceState* transfer_local_device,
|
|
||||||
se::Stream* transfer_stream,
|
|
||||||
std::shared_ptr<TrackedDeviceBuffer> src_device_buffer);
|
|
||||||
|
|
||||||
PjRtStreamExecutorClient* const client_;
|
|
||||||
const Shape on_host_shape_;
|
|
||||||
const Shape on_device_shape_;
|
|
||||||
PjRtStreamExecutorDevice* const device_;
|
|
||||||
|
|
||||||
mutable absl::Mutex mu_;
|
|
||||||
std::shared_ptr<TrackedDeviceBuffer> device_buffer_ TF_GUARDED_BY(mu_);
|
|
||||||
absl::flat_hash_map<xla::Layout, std::shared_ptr<HostValue>> host_values_
|
|
||||||
TF_GUARDED_BY(mu_);
|
|
||||||
std::shared_ptr<HostValue> host_value_ TF_GUARDED_BY(mu_);
|
|
||||||
// Count of holds on the buffer.
|
|
||||||
std::array<int, ScopedHold::Type::kMaxValue> holds_ TF_GUARDED_BY(mu_);
|
|
||||||
// Semaphore used to ensure there is only one outstanding donation hold.
|
|
||||||
Semaphore donation_semaphore_;
|
|
||||||
};
|
|
||||||
|
|
||||||
class ExecuteContext {
|
class ExecuteContext {
|
||||||
public:
|
public:
|
||||||
virtual ~ExecuteContext() = default;
|
virtual ~ExecuteContext() = default;
|
||||||
@ -1103,148 +458,6 @@ class PjRtExecutable {
|
|||||||
virtual void Delete() = 0;
|
virtual void Delete() = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Wraps one or more XLA LocalExecutables (one per partition, as specified by
|
|
||||||
// the build options).
|
|
||||||
class PjRtStreamExecutorExecutable : public PjRtExecutable {
|
|
||||||
public:
|
|
||||||
PjRtStreamExecutorExecutable(
|
|
||||||
std::vector<std::unique_ptr<LocalExecutable>> executables,
|
|
||||||
bool parameter_is_tupled_arguments,
|
|
||||||
std::shared_ptr<DeviceAssignment> device_assignment,
|
|
||||||
std::vector<LogicalDeviceIds> addressable_device_logical_ids,
|
|
||||||
std::vector<PjRtDevice*> addressable_devices,
|
|
||||||
PjRtStreamExecutorClient* client);
|
|
||||||
|
|
||||||
~PjRtStreamExecutorExecutable() override = default;
|
|
||||||
|
|
||||||
PjRtStreamExecutorClient* client() const override { return client_; }
|
|
||||||
|
|
||||||
const std::string& name() const override;
|
|
||||||
|
|
||||||
int num_replicas() const override {
|
|
||||||
return executables_[0]->build_options().num_replicas();
|
|
||||||
}
|
|
||||||
|
|
||||||
int num_partitions() const override {
|
|
||||||
return executables_[0]->build_options().num_partitions();
|
|
||||||
}
|
|
||||||
|
|
||||||
int64 SizeOfGeneratedCodeInBytes() const override {
|
|
||||||
int64 size = 0;
|
|
||||||
for (auto& executable : executables_) {
|
|
||||||
size += executable->executable()->SizeOfGeneratedCodeInBytes();
|
|
||||||
}
|
|
||||||
return size;
|
|
||||||
}
|
|
||||||
|
|
||||||
const DeviceAssignment& device_assignment() const override {
|
|
||||||
return *device_assignment_;
|
|
||||||
}
|
|
||||||
|
|
||||||
absl::Span<const LogicalDeviceIds> addressable_device_logical_ids()
|
|
||||||
const override {
|
|
||||||
return addressable_device_logical_ids_;
|
|
||||||
}
|
|
||||||
|
|
||||||
absl::Span<PjRtDevice* const> addressable_devices() const override {
|
|
||||||
return addressable_devices_;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Return an HloModule per partition.
|
|
||||||
StatusOr<std::vector<std::shared_ptr<HloModule>>> GetHloModules()
|
|
||||||
const override;
|
|
||||||
|
|
||||||
StatusOr<std::vector<std::vector<std::unique_ptr<PjRtBuffer>>>> Execute(
|
|
||||||
absl::Span<const std::vector<PjRtBuffer*>> argument_handles,
|
|
||||||
const ExecuteOptions& options) const override;
|
|
||||||
|
|
||||||
StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>> ExecuteSharded(
|
|
||||||
absl::Span<PjRtBuffer* const> argument_handles, PjRtDevice* device,
|
|
||||||
const ExecuteOptions& options) const override;
|
|
||||||
|
|
||||||
StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>> ExecutePortable(
|
|
||||||
absl::Span<PjRtBuffer* const> argument_handles, PjRtDevice* device,
|
|
||||||
const ExecuteOptions& options) const override;
|
|
||||||
|
|
||||||
void Delete() override { executables_.clear(); }
|
|
||||||
|
|
||||||
absl::Span<const std::shared_ptr<LocalExecutable>> executables() const {
|
|
||||||
return executables_;
|
|
||||||
}
|
|
||||||
|
|
||||||
protected:
|
|
||||||
bool parameter_is_tupled_arguments() const {
|
|
||||||
return parameter_is_tupled_arguments_;
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
friend class PjRtStreamExecutorClient;
|
|
||||||
// Initializes information about which arguments to which executables must be
|
|
||||||
// donated due to aliases that were specified by the computation.
|
|
||||||
Status SetUpDonation(bool tuple_inputs);
|
|
||||||
|
|
||||||
virtual bool MustDonateParameter(int executable_idx, int parameter) const;
|
|
||||||
|
|
||||||
virtual StatusOr<std::vector<ExecutionInput>>
|
|
||||||
MakeExecutionInputsAndWaitForEvents(
|
|
||||||
int device_ordinal, const ExecuteOptions& options,
|
|
||||||
absl::Span<PjRtBuffer* const> argument_handles,
|
|
||||||
absl::Span<const PjRtStreamExecutorBuffer::ScopedHold> device_buffers,
|
|
||||||
absl::flat_hash_set<BufferSequencingEvent*>& events) const;
|
|
||||||
|
|
||||||
StatusOr<ScopedShapedBuffer> EnqueueExecution(
|
|
||||||
absl::Span<PjRtBuffer* const> argument_handles, int replica,
|
|
||||||
int partition, int executable_idx, const RunId& run_id,
|
|
||||||
const ExecuteOptions& options, PjRtDevice* device,
|
|
||||||
std::vector<PjRtStreamExecutorBuffer::ScopedHold>* device_buffers,
|
|
||||||
std::shared_ptr<DeviceAssignment> device_assignment) const;
|
|
||||||
|
|
||||||
virtual std::vector<std::unique_ptr<PjRtBuffer>> MakeOutputBuffers(
|
|
||||||
int device_ordinal, const ExecuteOptions& options,
|
|
||||||
ScopedShapedBuffer result_buffer,
|
|
||||||
std::shared_ptr<BufferSequencingEvent> definition_event,
|
|
||||||
PjRtDevice* device) const;
|
|
||||||
|
|
||||||
StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>> ExecuteHelper(
|
|
||||||
absl::Span<PjRtBuffer* const> argument_handles, int replica,
|
|
||||||
int partition, const RunId& run_id, const ExecuteOptions& options,
|
|
||||||
PjRtDevice* device = nullptr) const;
|
|
||||||
|
|
||||||
// Create shared pointers so we can free them after the execution: with
|
|
||||||
// asynchronous execution, the process being executed can outlive the
|
|
||||||
// executable itself.
|
|
||||||
PjRtStreamExecutorClient* const client_;
|
|
||||||
// One executable per partition.
|
|
||||||
std::vector<std::shared_ptr<LocalExecutable>> executables_;
|
|
||||||
// Per-executable set of parameters that have any aliased buffers and thus
|
|
||||||
// must be donated when executing the computation.
|
|
||||||
std::vector<absl::flat_hash_set<int>> parameters_that_must_be_donated_;
|
|
||||||
std::shared_ptr<DeviceAssignment> device_assignment_;
|
|
||||||
|
|
||||||
// True if the executables were compiled expecting arguments in a single
|
|
||||||
// tuple.
|
|
||||||
const bool parameter_is_tupled_arguments_;
|
|
||||||
|
|
||||||
// The replica and partition indices of device_assignment_ to be run by this
|
|
||||||
// client. On single-host platforms without partitioning, this is all replicas
|
|
||||||
// (i.e. addressable_device_logical_ids_[i] = (i, 0)), but this may not be the
|
|
||||||
// case on multi-host platforms. If there are 4 replicas and 2 partitions on a
|
|
||||||
// single host platform, size of addressable_device_logical_ids_ is 4*2 = 8.
|
|
||||||
std::vector<LogicalDeviceIds> addressable_device_logical_ids_;
|
|
||||||
|
|
||||||
// addressable_devices_[i] is the Device to which
|
|
||||||
// addressable_device_logical_ids_[i] is assigned. shared_ptrs instead of
|
|
||||||
// unique_ptrs to play well with the Python bindings (see xla.cc).
|
|
||||||
std::vector<PjRtDevice*> addressable_devices_;
|
|
||||||
};
|
|
||||||
|
|
||||||
// Executables can donate buffers so that buffers can be aliased from inputs
|
|
||||||
// to outputs. This function returns the list of parameters that must be
|
|
||||||
// donated when executable is run. tuple_inputs reflects the option that
|
|
||||||
// executable was compiled with.
|
|
||||||
StatusOr<absl::flat_hash_set<int>> GetParametersThatMustBeDonated(
|
|
||||||
const HloModule& hlo_module, bool tuple_inputs);
|
|
||||||
|
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
|
||||||
#endif // TENSORFLOW_COMPILER_XLA_PJRT_PJRT_CLIENT_H_
|
#endif // TENSORFLOW_COMPILER_XLA_PJRT_PJRT_CLIENT_H_
|
||||||
|
@ -62,7 +62,7 @@ limitations under the License.
|
|||||||
// See the comment on LocalDeviceState::AllocationModel for a discussion of the
|
// See the comment on LocalDeviceState::AllocationModel for a discussion of the
|
||||||
// different allocation semantics on CPU, GPU, and TPU.
|
// different allocation semantics on CPU, GPU, and TPU.
|
||||||
|
|
||||||
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
|
#include "tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h"
|
||||||
|
|
||||||
#include <cstddef>
|
#include <cstddef>
|
||||||
#include <memory>
|
#include <memory>
|
782
tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h
Normal file
782
tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h
Normal file
@ -0,0 +1,782 @@
|
|||||||
|
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_COMPILER_XLA_PJRT_PJRT_STREAM_EXECUTOR_CLIENT_H_
|
||||||
|
#define TENSORFLOW_COMPILER_XLA_PJRT_PJRT_STREAM_EXECUTOR_CLIENT_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "absl/container/flat_hash_map.h"
|
||||||
|
#include "absl/container/flat_hash_set.h"
|
||||||
|
#include "absl/container/inlined_vector.h"
|
||||||
|
#include "absl/strings/string_view.h"
|
||||||
|
#include "absl/synchronization/mutex.h"
|
||||||
|
#include "absl/synchronization/notification.h"
|
||||||
|
#include "absl/types/optional.h"
|
||||||
|
#include "absl/types/span.h"
|
||||||
|
#include "tensorflow/compiler/xla/client/executable_build_options.h"
|
||||||
|
#include "tensorflow/compiler/xla/client/local_client.h"
|
||||||
|
#include "tensorflow/compiler/xla/client/xla_computation.h"
|
||||||
|
#include "tensorflow/compiler/xla/layout.h"
|
||||||
|
#include "tensorflow/compiler/xla/pjrt/local_device_state.h"
|
||||||
|
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
|
||||||
|
#include "tensorflow/compiler/xla/pjrt/tracked_device_buffer.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/computation_placer.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/shaped_buffer.h"
|
||||||
|
#include "tensorflow/compiler/xla/shape.h"
|
||||||
|
#include "tensorflow/compiler/xla/status.h"
|
||||||
|
#include "tensorflow/compiler/xla/statusor.h"
|
||||||
|
#include "tensorflow/compiler/xla/util.h"
|
||||||
|
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||||
|
#include "tensorflow/core/framework/allocator.h"
|
||||||
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
|
#include "tensorflow/core/platform/casts.h"
|
||||||
|
#include "tensorflow/core/platform/fingerprint.h"
|
||||||
|
#include "tensorflow/core/platform/thread_annotations.h"
|
||||||
|
#include "tensorflow/core/platform/types.h"
|
||||||
|
|
||||||
|
namespace xla {
|
||||||
|
|
||||||
|
class PjRtStreamExecutorDevice : public PjRtDevice {
|
||||||
|
public:
|
||||||
|
explicit PjRtStreamExecutorDevice(
|
||||||
|
int id, std::unique_ptr<LocalDeviceState> local_device_state,
|
||||||
|
std::string device_kind, int host_id = 0)
|
||||||
|
: id_(id),
|
||||||
|
device_ordinal_(
|
||||||
|
local_device_state ? local_device_state->device_ordinal() : -1),
|
||||||
|
local_device_state_(std::move(local_device_state)),
|
||||||
|
host_id_(host_id),
|
||||||
|
device_kind_(std::move(device_kind)) {}
|
||||||
|
~PjRtStreamExecutorDevice() override {}
|
||||||
|
|
||||||
|
// Must set client exactly once.
|
||||||
|
void SetClient(PjRtClient* client) {
|
||||||
|
CHECK(client_ == nullptr);
|
||||||
|
client_ = client;
|
||||||
|
}
|
||||||
|
|
||||||
|
int host_id() const override { return host_id_; }
|
||||||
|
|
||||||
|
// Return `platform_id` from client.
|
||||||
|
PjRtPlatformId platform_id() const;
|
||||||
|
|
||||||
|
// Return `platform_name` from client.
|
||||||
|
const std::string& platform_name() const;
|
||||||
|
|
||||||
|
PjRtClient* client() const override { return client_; }
|
||||||
|
|
||||||
|
int id() const override { return id_; }
|
||||||
|
|
||||||
|
bool IsAddressable() const override { return device_ordinal_ != -1; }
|
||||||
|
|
||||||
|
int local_hardware_id() const override { return device_ordinal_; }
|
||||||
|
|
||||||
|
// If this is a device local to this host, returns a LocalDeviceState object
|
||||||
|
// that can be used to manipulate the device. Returns nullptr if the device is
|
||||||
|
// not local to this host.
|
||||||
|
LocalDeviceState* local_device_state() const {
|
||||||
|
return local_device_state_.get();
|
||||||
|
}
|
||||||
|
|
||||||
|
// If this is a device local to this host, returns a LocalDeviceState object
|
||||||
|
// that can be used to manipulate the device. Returns an error if the device
|
||||||
|
// is not local to this host.
|
||||||
|
StatusOr<LocalDeviceState*> GetLocalDeviceState() const;
|
||||||
|
|
||||||
|
const std::string& device_kind() const override { return device_kind_; }
|
||||||
|
|
||||||
|
std::string DebugString() const override;
|
||||||
|
|
||||||
|
Status TransferToInfeed(const LiteralSlice& literal) const override;
|
||||||
|
|
||||||
|
StatusOr<Literal> TransferFromOutfeed(const Shape& shape) const override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
const int id_;
|
||||||
|
const int device_ordinal_; // -1 means not local.
|
||||||
|
const std::unique_ptr<LocalDeviceState> local_device_state_;
|
||||||
|
const int host_id_;
|
||||||
|
const std::string device_kind_;
|
||||||
|
PjRtClient* client_ = nullptr;
|
||||||
|
};
|
||||||
|
|
||||||
|
class PjRtStreamExecutorClient : public PjRtClient {
|
||||||
|
public:
|
||||||
|
// `allocator` may null, in which case the platform default allocator is used.
|
||||||
|
explicit PjRtStreamExecutorClient(
|
||||||
|
std::string platform_name, LocalClient* client,
|
||||||
|
std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices,
|
||||||
|
int host_id, std::unique_ptr<se::DeviceMemoryAllocator> allocator,
|
||||||
|
std::unique_ptr<tensorflow::Allocator> host_memory_allocator,
|
||||||
|
bool should_stage_host_to_device_transfers,
|
||||||
|
std::unique_ptr<gpu::GpuExecutableRunOptions> gpu_run_options);
|
||||||
|
~PjRtStreamExecutorClient() override = default;
|
||||||
|
|
||||||
|
int host_id() const override { return host_id_; }
|
||||||
|
|
||||||
|
int device_count() const override { return devices_.size(); }
|
||||||
|
int addressable_device_count() const override {
|
||||||
|
return local_devices_.size();
|
||||||
|
}
|
||||||
|
absl::Span<PjRtDevice* const> devices() const override { return devices_; }
|
||||||
|
absl::Span<PjRtDevice* const> local_devices() const override {
|
||||||
|
return local_devices_;
|
||||||
|
}
|
||||||
|
|
||||||
|
StatusOr<PjRtDevice*> LookupDevice(int device_id) const override {
|
||||||
|
auto it = id_to_device_.find(device_id);
|
||||||
|
if (it != id_to_device_.end()) {
|
||||||
|
return it->second;
|
||||||
|
}
|
||||||
|
return InvalidArgument("No matching device found for device_id %d",
|
||||||
|
device_id);
|
||||||
|
}
|
||||||
|
|
||||||
|
StatusOr<PjRtDevice*> LookupAddressableDevice(
|
||||||
|
int local_hardware_id) const override;
|
||||||
|
|
||||||
|
PjRtPlatformId platform_id() const override { return platform_id_; }
|
||||||
|
const std::string& platform_name() const override { return platform_name_; }
|
||||||
|
|
||||||
|
// Most platforms expect device-to-device transfers to be enqueued on the
|
||||||
|
// source d2d stream, but some platforms use the destination d2d stream. This
|
||||||
|
// function specifies which one the platform expects.
|
||||||
|
virtual bool EnqueueD2DTransfersOnSrcStream() const { return true; }
|
||||||
|
|
||||||
|
StatusOr<DeviceAssignment> GetDefaultDeviceAssignment(
|
||||||
|
int num_replicas, int num_partitions) const override;
|
||||||
|
|
||||||
|
StatusOr<std::unique_ptr<PjRtExecutable>> Compile(
|
||||||
|
const XlaComputation& computation, CompileOptions options) override;
|
||||||
|
|
||||||
|
StatusOr<absl::optional<std::string>> ExecutableFingerprint(
|
||||||
|
const PjRtExecutable& executable) const override {
|
||||||
|
return absl::optional<std::string>();
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<HloCostAnalysis> GetHloCostAnalysis() override;
|
||||||
|
|
||||||
|
// Creates a buffer on the device without initializing or copying any data.
|
||||||
|
// An optional `definition_event` may be speficied that can be used to
|
||||||
|
// ensure the buffer isn't referenced until some external mechanism has
|
||||||
|
// initialized the data.
|
||||||
|
StatusOr<std::unique_ptr<PjRtBuffer>> CreateUninitializedBuffer(
|
||||||
|
const Shape& shape, PjRtDevice* device) override;
|
||||||
|
StatusOr<std::unique_ptr<PjRtBuffer>> CreateUninitializedBuffer(
|
||||||
|
const Shape& shape, PjRtDevice* device,
|
||||||
|
std::shared_ptr<BufferSequencingEvent> definition_event);
|
||||||
|
|
||||||
|
StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostBuffer(
|
||||||
|
const void* data, const Shape& shape,
|
||||||
|
HostBufferSemantics host_buffer_semantics,
|
||||||
|
std::shared_ptr<void> buffer_reference, PjRtDevice* device) override;
|
||||||
|
|
||||||
|
StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostLiteral(
|
||||||
|
const LiteralSlice& literal, PjRtDevice* device) override;
|
||||||
|
|
||||||
|
void MakeCrossHostReceiveBuffers(
|
||||||
|
absl::Span<const Shape> shapes, PjRtDevice* device,
|
||||||
|
PjRtCrossHostRecvNotifier&& notifier) override;
|
||||||
|
|
||||||
|
StatusOr<ChannelHandle> CreateChannelHandle() override {
|
||||||
|
return client()->CreateChannelHandle();
|
||||||
|
}
|
||||||
|
StatusOr<ChannelHandle> CreateDeviceToHostChannelHandle() override {
|
||||||
|
return client()->CreateDeviceToHostChannelHandle();
|
||||||
|
}
|
||||||
|
StatusOr<ChannelHandle> CreateHostToDeviceChannelHandle() override {
|
||||||
|
return client()->CreateHostToDeviceChannelHandle();
|
||||||
|
}
|
||||||
|
|
||||||
|
LocalDeviceState& device_state(int device_ordinal) const {
|
||||||
|
return *tensorflow::down_cast<PjRtStreamExecutorDevice*>(
|
||||||
|
local_devices_.at(device_ordinal))
|
||||||
|
->local_device_state();
|
||||||
|
}
|
||||||
|
LocalClient* client() const { return client_; }
|
||||||
|
se::DeviceMemoryAllocator* allocator() const { return allocator_; }
|
||||||
|
tensorflow::Allocator* host_memory_allocator() const {
|
||||||
|
return host_memory_allocator_.get();
|
||||||
|
}
|
||||||
|
bool should_stage_host_to_device_transfers() const {
|
||||||
|
return should_stage_host_to_device_transfers_;
|
||||||
|
}
|
||||||
|
|
||||||
|
gpu::GpuExecutableRunOptions* gpu_run_options() const {
|
||||||
|
return gpu_run_options_.get();
|
||||||
|
}
|
||||||
|
|
||||||
|
tensorflow::thread::ThreadPool* h2d_transfer_pool() {
|
||||||
|
return &h2d_transfer_pool_;
|
||||||
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
friend class PjRtStreamExecutorBuffer;
|
||||||
|
virtual void EnqueueCrossHostReceive(
|
||||||
|
std::vector<std::unique_ptr<PjRtBuffer>>&& buffers,
|
||||||
|
std::shared_ptr<BufferSequencingEvent> definition_event,
|
||||||
|
PjRtCrossHostRecvNotifier&& notifier) const {
|
||||||
|
notifier(Unimplemented("Cross host receives not implemented."));
|
||||||
|
}
|
||||||
|
|
||||||
|
virtual Status CopyToRemoteDevice(
|
||||||
|
PjRtBuffer* buffer, absl::string_view serialized_descriptor) const {
|
||||||
|
return Unimplemented("Cross host sends not implemented.");
|
||||||
|
}
|
||||||
|
|
||||||
|
const PjRtPlatformId platform_id_;
|
||||||
|
const std::string platform_name_;
|
||||||
|
LocalClient* client_;
|
||||||
|
|
||||||
|
// Allocator to be used for staging memory transfers to devices.
|
||||||
|
std::unique_ptr<tensorflow::Allocator> host_memory_allocator_;
|
||||||
|
|
||||||
|
// Includes all devices, including non-local devices on multi-host platforms.
|
||||||
|
std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> owned_devices_;
|
||||||
|
// Pointers to `owned_devices_`.
|
||||||
|
std::vector<PjRtDevice*> devices_;
|
||||||
|
// Maps Device::id() to the corresponding Device. Includes all devices.
|
||||||
|
std::map<int, PjRtDevice*> id_to_device_;
|
||||||
|
// Local devices indexed by local device ordinal.
|
||||||
|
std::vector<PjRtDevice*> local_devices_;
|
||||||
|
int host_id_;
|
||||||
|
|
||||||
|
se::DeviceMemoryAllocator* allocator_;
|
||||||
|
std::unique_ptr<se::DeviceMemoryAllocator> owned_allocator_;
|
||||||
|
|
||||||
|
// Should we always prefer to stage host-to-device transfers via memory
|
||||||
|
// allocated on host_memory_allocator_? True only on GPU, where we prefer to
|
||||||
|
// transfer via pinned memory.
|
||||||
|
bool should_stage_host_to_device_transfers_;
|
||||||
|
|
||||||
|
std::unique_ptr<gpu::GpuExecutableRunOptions> gpu_run_options_;
|
||||||
|
|
||||||
|
tensorflow::thread::ThreadPool h2d_transfer_pool_;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Converts a 2D set of Device objects indexed by [replica][partition] into an
|
||||||
|
// xla::DeviceAssignment.
|
||||||
|
StatusOr<DeviceAssignment> DevicesToDeviceAssignment(
|
||||||
|
absl::Span<const std::vector<PjRtDevice*>> devices);
|
||||||
|
|
||||||
|
class PjRtStreamExecutorBuffer : public PjRtBuffer {
|
||||||
|
public:
|
||||||
|
// Helper class to retain a "hold" on a PjRtStreamExecutorBuffer. A ScopedHold
|
||||||
|
// may not outlive its parent PjRtStreamExecutorBuffer.
|
||||||
|
//
|
||||||
|
// There are three types of hold, as follows:
|
||||||
|
//
|
||||||
|
// 1) Usage hold: a transient hold while an operation using the buffer is
|
||||||
|
// being enqueued onto a stream.
|
||||||
|
// A client acquires a usage hold by calling
|
||||||
|
// PjRtStreamExecutorBuffer::GetBufferWithHold(kUsage) or the convenience
|
||||||
|
// wrapper GetBufferWithUsageHold(). If the enqueue completes successfully the
|
||||||
|
// hold should be released using a call to ConvertUsageHold. If the ScopedHold
|
||||||
|
// is deleted without ConvertUsageHold being called, e.g., on error, the hold
|
||||||
|
// is dropped. It is legal to drop a usage hold instead of calling
|
||||||
|
// ConvertUsageHold, even if the buffer was successfully enqueued, as long as
|
||||||
|
// the client ensures that all necessary synchronization has been done.
|
||||||
|
//
|
||||||
|
// 2) External hold: a potentially long-lived hold while the buffer is being
|
||||||
|
// shared by an external framework, e.g., NumPy.
|
||||||
|
// A client acquires an external hold by calling
|
||||||
|
// PjRtStreamExecutorBuffer::GetBufferWithHold(kExternal) or the convenience
|
||||||
|
// wrapper GetBufferWithExternalReference and releases it by deleting the
|
||||||
|
// ScopedHold. The external framework should not modify the underlying buffer
|
||||||
|
// unless it is confident via its own synchronization that modifications do
|
||||||
|
// not race with reads from the PjRtStreamExecutorBuffer.
|
||||||
|
//
|
||||||
|
// 3) Donation hold: a transient hold while an execution that donates the
|
||||||
|
// buffer is being enqueued onto the compute stream.
|
||||||
|
// A client acquires a donation hold by calling
|
||||||
|
// PjRtStreamExecutorBuffer::GetBufferWithHold(kDonation). If the enqueue
|
||||||
|
// completes successfully the hold should be released using a call to
|
||||||
|
// ConfirmDonation after which the buffer is invalid. If the ScopedHold is
|
||||||
|
// deleted without ConfirmDonation being called, e.g., on error, the hold is
|
||||||
|
// dropped and the buffer remains valid. If the buffer is successfully
|
||||||
|
// enqueued the client *must* call ConfirmDonation.
|
||||||
|
//
|
||||||
|
// Donation holds behave like exclusive write locks: when a donation hold
|
||||||
|
// has been acquired, any attempt to acquire another hold of any type will
|
||||||
|
// block until the donation hold is dropped or confirmed. Acquiring a donation
|
||||||
|
// hold will fail with an error if there is any outstanding external hold, and
|
||||||
|
// will block if there are any outstanding usage holds until those holds are
|
||||||
|
// dropped or converted.
|
||||||
|
//
|
||||||
|
// Calls to PjRtStreamExecutorBuffer::Release (and transitively to
|
||||||
|
// PjRtStreamExecutorBuffer::Delete() and ~PjRtStreamExecutorBuffer()) will
|
||||||
|
// block until all usage and donation holds are either deleted or
|
||||||
|
// converted/confirmed.
|
||||||
|
class ScopedHold {
|
||||||
|
public:
|
||||||
|
enum Type { kUsage = 0, kExternalReference, kDonation, kMaxValue };
|
||||||
|
// Use a State enum instead of encoding the state in an error Status to
|
||||||
|
// avoid creating Status values in non-error cases. Creating a Status
|
||||||
|
// entails several allocations and can add O(us) to every use of a hold.
|
||||||
|
enum State {
|
||||||
|
kUninitialized = 0,
|
||||||
|
kValid,
|
||||||
|
kMoved,
|
||||||
|
kConverted,
|
||||||
|
kReleased,
|
||||||
|
kDonated,
|
||||||
|
kError
|
||||||
|
};
|
||||||
|
|
||||||
|
~ScopedHold();
|
||||||
|
ScopedHold(ScopedHold&& other);
|
||||||
|
ScopedHold(const ScopedHold&) = delete;
|
||||||
|
ScopedHold& operator=(const ScopedHold&) = delete;
|
||||||
|
|
||||||
|
Type type() const { return type_; }
|
||||||
|
|
||||||
|
Status status() const {
|
||||||
|
// Lazily create Status values only when they are requested.
|
||||||
|
switch (state_) {
|
||||||
|
case kUninitialized:
|
||||||
|
return InvalidArgument("Buffer has not been initialized");
|
||||||
|
case kValid:
|
||||||
|
return Status::OK();
|
||||||
|
case kMoved:
|
||||||
|
return InvalidArgument("Buffer has been moved.");
|
||||||
|
case kConverted:
|
||||||
|
return InvalidArgument("Buffer has been converted");
|
||||||
|
case kReleased:
|
||||||
|
return InvalidArgument("Buffer has been released");
|
||||||
|
case kDonated:
|
||||||
|
return InvalidArgument("Buffer has been donated");
|
||||||
|
case kError:
|
||||||
|
return buffer_or_.status();
|
||||||
|
default:
|
||||||
|
CHECK(false) << "Unexpected state value " << state_;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
bool ok() const { return state_ == kValid; }
|
||||||
|
|
||||||
|
// Access to the underlying device buffer storage. Requires this->ok().
|
||||||
|
const std::shared_ptr<TrackedDeviceBuffer>& buffer() const {
|
||||||
|
CHECK_EQ(state_, kValid);
|
||||||
|
CHECK_NE(buffer_or_.ValueOrDie(), nullptr);
|
||||||
|
return buffer_or_.ValueOrDie();
|
||||||
|
}
|
||||||
|
TrackedDeviceBuffer* operator->() const { return buffer().get(); }
|
||||||
|
const TrackedDeviceBuffer& operator*() const { return *buffer(); }
|
||||||
|
|
||||||
|
// Converts the hold into a usage event. Only valid for holds of type
|
||||||
|
// kUsage.
|
||||||
|
//
|
||||||
|
// usage_stream: the stream that the buffer was used on.
|
||||||
|
// event: an event that has been recorded on usage_stream after
|
||||||
|
// the buffer was used.
|
||||||
|
// reference_held: true if and only if the caller has caused a
|
||||||
|
// reference to this->buffer() to stay live until after
|
||||||
|
// the host is sure that the usage (transfer or execution)
|
||||||
|
// has completed.
|
||||||
|
void ConvertUsageHold(se::Stream* usage_stream,
|
||||||
|
std::shared_ptr<BufferSequencingEvent> event,
|
||||||
|
bool reference_held);
|
||||||
|
|
||||||
|
// Confirms that the buffer was successfully donated to an execution.
|
||||||
|
// Only valid for holds of type kDonation. Causes the buffer to become
|
||||||
|
// invalid.
|
||||||
|
void ConfirmDonation();
|
||||||
|
|
||||||
|
// Adds the held device buffers in order to 'iterator'. Used to add the
|
||||||
|
// buffers to an ExecutionInput. We require but do not verify that
|
||||||
|
// 'iterator' when passed in is pointing to a sub-tuple of the
|
||||||
|
// ExecutionInput whose on_device_shape matches that of the
|
||||||
|
// TrackedDeviceBuffer. 'end' is used to check that 'iterator' doesn't run
|
||||||
|
// out of bounds. Donates the device buffers if the hold type is kDonation,
|
||||||
|
// otherwise retains ownership of the device buffers.
|
||||||
|
void AddToInput(ShapeTree<MaybeOwningDeviceMemory>::iterator* iterator,
|
||||||
|
const ShapeTree<MaybeOwningDeviceMemory>::iterator& end,
|
||||||
|
ExecutionInput* execution_input,
|
||||||
|
se::DeviceMemoryAllocator* allocator) const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
friend class PjRtStreamExecutorBuffer;
|
||||||
|
friend class PjRtStreamExecutorClient;
|
||||||
|
|
||||||
|
// Helper struct that makes it possible to move a ScopedHold through a
|
||||||
|
// closure.
|
||||||
|
using ForClosure =
|
||||||
|
std::tuple<PjRtStreamExecutorBuffer*, Type, State,
|
||||||
|
StatusOr<std::shared_ptr<TrackedDeviceBuffer>>>;
|
||||||
|
|
||||||
|
ScopedHold(PjRtStreamExecutorBuffer* parent, Type type)
|
||||||
|
: parent_(parent), type_(type), state_(kUninitialized) {}
|
||||||
|
explicit ScopedHold(const ForClosure& closure_helper)
|
||||||
|
: parent_(std::get<0>(closure_helper)),
|
||||||
|
type_(std::get<1>(closure_helper)),
|
||||||
|
state_(std::get<2>(closure_helper)),
|
||||||
|
buffer_or_(std::get<3>(closure_helper)) {
|
||||||
|
// Check the buffer is not in an error state.
|
||||||
|
CHECK(buffer_or_.ValueOrDie() != nullptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sets buffer state.
|
||||||
|
void SetState(State state) { state_ = state; }
|
||||||
|
|
||||||
|
// Sets buffer_or_. Called by parent_ to initialize the hold.
|
||||||
|
void Acquire(StatusOr<std::shared_ptr<TrackedDeviceBuffer>>&& buffer_or);
|
||||||
|
// Releases the contents of *this, so *this can subsequently be
|
||||||
|
// deleted without releasing the parent's hold. Should be passed to the
|
||||||
|
// appropriate constructor of another ScopedHold, e.g., when a hold must be
|
||||||
|
// passed through a closure that is incompatible with std::move.
|
||||||
|
ForClosure ToClosure();
|
||||||
|
|
||||||
|
PjRtStreamExecutorBuffer* const parent_;
|
||||||
|
const Type type_;
|
||||||
|
|
||||||
|
// There is an invariant that if ok() then
|
||||||
|
// buffer_or_.ValueOrDie() != nullptr.
|
||||||
|
State state_;
|
||||||
|
StatusOr<std::shared_ptr<TrackedDeviceBuffer>> buffer_or_;
|
||||||
|
};
|
||||||
|
|
||||||
|
PjRtStreamExecutorBuffer(Shape on_host_shape, Shape on_device_shape,
|
||||||
|
std::shared_ptr<TrackedDeviceBuffer> device_buffer,
|
||||||
|
PjRtClient* client, PjRtDevice* device);
|
||||||
|
~PjRtStreamExecutorBuffer() override;
|
||||||
|
|
||||||
|
PjRtStreamExecutorBuffer(const PjRtStreamExecutorBuffer&) = delete;
|
||||||
|
PjRtStreamExecutorBuffer(PjRtStreamExecutorBuffer&&) = delete;
|
||||||
|
PjRtStreamExecutorBuffer& operator=(const PjRtStreamExecutorBuffer&) = delete;
|
||||||
|
PjRtStreamExecutorBuffer& operator=(PjRtStreamExecutorBuffer&&) = delete;
|
||||||
|
|
||||||
|
const Shape& on_host_shape() const override { return on_host_shape_; }
|
||||||
|
const Shape& on_device_shape() const override { return on_device_shape_; }
|
||||||
|
PjRtStreamExecutorDevice* device() const override { return device_; }
|
||||||
|
PjRtPlatformId platform_id() const { return client_->platform_id(); }
|
||||||
|
const std::string& platform_name() const { return client_->platform_name(); }
|
||||||
|
PjRtStreamExecutorClient* client() const override { return client_; }
|
||||||
|
bool IsEmptyTuple() const {
|
||||||
|
return on_host_shape_.IsTuple() && on_host_shape_.tuple_shapes_size() == 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
int64 OnDeviceSizeInBytes() const override;
|
||||||
|
|
||||||
|
// Implement PjRtBuffer::ExternalReferenceHold a wrapped
|
||||||
|
// ScopedHold::kExternalReference.
|
||||||
|
class ScopedHoldAsExternalReference
|
||||||
|
: public PjRtBuffer::ExternalReferenceHold {
|
||||||
|
public:
|
||||||
|
explicit ScopedHoldAsExternalReference(ScopedHold hold)
|
||||||
|
: external_reference_(std::move(hold)) {
|
||||||
|
CHECK(hold.type() == ScopedHold::kExternalReference);
|
||||||
|
}
|
||||||
|
|
||||||
|
~ScopedHoldAsExternalReference() override = default;
|
||||||
|
|
||||||
|
void* OpaqueDeviceMemoryDataPointer() const override {
|
||||||
|
return external_reference_->device_memory().front().opaque();
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
ScopedHold external_reference_;
|
||||||
|
};
|
||||||
|
StatusOr<std::unique_ptr<ExternalReferenceHold>> AcquireExternalReference()
|
||||||
|
override;
|
||||||
|
|
||||||
|
StatusOr<absl::optional<std::shared_ptr<void>>> ReleaseDeviceMemoryOwnership(
|
||||||
|
bool wait_for_operations_to_complete) override;
|
||||||
|
|
||||||
|
using PjRtBuffer::ToLiteral;
|
||||||
|
StatusOr<std::shared_ptr<Literal>> ToLiteral(
|
||||||
|
bool discard_cached_copy, absl::optional<xla::Layout> layout) override;
|
||||||
|
|
||||||
|
using PjRtBuffer::CopyToHostAsync;
|
||||||
|
Status CopyToHostAsync(absl::optional<xla::Layout> layout) override;
|
||||||
|
|
||||||
|
// Drops the buffer's reference to its associated device memory, leaving the
|
||||||
|
// buffer in an invalid state. The memory will be freed lazily when all async
|
||||||
|
// operations using the buffer have completed, according to the allocation
|
||||||
|
// semantics of the underlying platform. Delete may briefly block if another
|
||||||
|
// thread is in the process of enqueuing an operation on this buffer, but it
|
||||||
|
// will never block for a stream operation to complete. If an external
|
||||||
|
// framework holds a reference to the TrackedDeviceBuffer via
|
||||||
|
// GetBufferWithExternalReference, the memory will not be freed until the
|
||||||
|
// external framework drops the reference.
|
||||||
|
void Delete() override;
|
||||||
|
|
||||||
|
bool IsDeleted() override;
|
||||||
|
|
||||||
|
// Returns a view of the PjRtBuffer device memory as a ShapedBuffer. The
|
||||||
|
// PjRtBuffer retains ownership of the device buffers.
|
||||||
|
StatusOr<ShapedBuffer> AsShapedBuffer() const;
|
||||||
|
|
||||||
|
// Returns a hold on the TrackedDeviceBuffer holding the device
|
||||||
|
// buffers. See comment on ScopedHold.
|
||||||
|
ScopedHold GetBufferWithHold(ScopedHold::Type type);
|
||||||
|
ScopedHold GetBufferWithUsageHold() {
|
||||||
|
return GetBufferWithHold(ScopedHold::kUsage);
|
||||||
|
}
|
||||||
|
ScopedHold GetBufferWithExternalReference() {
|
||||||
|
return GetBufferWithHold(ScopedHold::kExternalReference);
|
||||||
|
}
|
||||||
|
|
||||||
|
StatusOr<std::unique_ptr<PjRtBuffer>> CopyToDevice(
|
||||||
|
PjRtDevice* dst_device) override;
|
||||||
|
|
||||||
|
Status CopyToRemoteDevice(absl::string_view serialized_descriptor) override;
|
||||||
|
|
||||||
|
Status BlockHostUntilReady() override;
|
||||||
|
|
||||||
|
bool IsOnCpu() const override;
|
||||||
|
|
||||||
|
// Similar to Delete, drops the buffer's reference to its associated device
|
||||||
|
// memory, leaving the buffer in an invalid state, but returns the
|
||||||
|
// TrackedDeviceBuffer rather than freeing the device memory, so that another
|
||||||
|
// framework can take ownership of it. The buffer returned from Release may
|
||||||
|
// be safely dropped at any time even if it still has pending async
|
||||||
|
// operations. The client should call BlockHostUntilReady before calling
|
||||||
|
// Release with wait_for_operations_to_complete=false, to ensure that the host
|
||||||
|
// has synchronized past any outstanding write operations to the buffer. If
|
||||||
|
// wait_for_operations_to_complete=true the host will block until any
|
||||||
|
// potentially outstanding asynchronous operations have completed before
|
||||||
|
// returning, in which case it is safe to read or mutate the returned buffer.
|
||||||
|
// If the buffer was shared via an external reference it is the client's
|
||||||
|
// responsibility that accesses via that reference do not interfere with
|
||||||
|
// accesses via the buffer returned from Release.
|
||||||
|
StatusOr<std::shared_ptr<TrackedDeviceBuffer>> Release(
|
||||||
|
bool wait_for_operations_to_complete);
|
||||||
|
|
||||||
|
private:
|
||||||
|
friend class PjRtClient;
|
||||||
|
// The cached value of the buffer on the host, produced either from a call to
|
||||||
|
// CopyToHost or from a call to ToLiteral. Once a value has been fetched to
|
||||||
|
// the host, it persists Delete() is called or the PjRtBuffer is destroyed.
|
||||||
|
struct HostValue {
|
||||||
|
absl::Notification ready;
|
||||||
|
// status and value are valid for reading only after `ready` has been
|
||||||
|
// notified.
|
||||||
|
Status status;
|
||||||
|
std::shared_ptr<Literal> value;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Blocks in mu_.Await until there are no more usage holds.
|
||||||
|
void WaitForOutstandingUsageHolds() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
||||||
|
|
||||||
|
// Blocks in mu_.Await until there is no donation hold.
|
||||||
|
void WaitForOutstandingDonationHold() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
||||||
|
|
||||||
|
// Adds a hold of 'type' and returns device_buffer_. Returns an error if
|
||||||
|
// device_buffer_ is null, or if a donation hold was requested when there is
|
||||||
|
// an outstanding external hold.
|
||||||
|
StatusOr<std::shared_ptr<TrackedDeviceBuffer>> GetBufferForHoldLocked(
|
||||||
|
ScopedHold::Type type) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
||||||
|
|
||||||
|
// Adds a hold of hold->type() and initializes `hold` with device_buffer_.
|
||||||
|
// Initializes hold with an error if device_buffer_ is null, or if a donation
|
||||||
|
// hold was requested when there is an outstanding external hold.
|
||||||
|
void AcquireHoldLocked(ScopedHold* hold) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
||||||
|
|
||||||
|
// Drops a usage hold and calls device_buffer_->AddUsageEvent. Does a sanity
|
||||||
|
// check that buffer==device_buffer_ or device_buffer_==nullptr. Called after
|
||||||
|
// device_buffer_ was successfully enqueued on a stream.
|
||||||
|
void ConvertUsageHold(TrackedDeviceBuffer* buffer, se::Stream* usage_stream,
|
||||||
|
std::shared_ptr<BufferSequencingEvent> event,
|
||||||
|
bool reference_held);
|
||||||
|
|
||||||
|
// Drops a donation hold and makes *this invalid for further use. Does a
|
||||||
|
// sanity check that buffer==device_buffer_. Called after device_buffer_ was
|
||||||
|
// successfully donated to an execution.
|
||||||
|
void ConfirmDonation(TrackedDeviceBuffer* device_buffer);
|
||||||
|
|
||||||
|
// Initiates a copy of the buffer to the host. Does not block waiting for
|
||||||
|
// the transfer to complete. A host value is returned and if
|
||||||
|
// `discard_cached_copy` is false stored in an internal buffer so that future
|
||||||
|
// transfers don't have to transfer the data from host again. If a layout is
|
||||||
|
// passed then a literal of this layout will be returned and possibly cached.
|
||||||
|
StatusOr<std::shared_ptr<HostValue>> CopyToHostAsyncInternal(
|
||||||
|
bool discard_cached_copy, absl::optional<xla::Layout> layout);
|
||||||
|
|
||||||
|
// Drops a hold without taking any other action. Does a sanity check that
|
||||||
|
// buffer==device_buffer_ or device_buffer_==nullptr.
|
||||||
|
void DropHold(ScopedHold::Type type, TrackedDeviceBuffer* buffer);
|
||||||
|
|
||||||
|
StatusOr<std::pair<std::unique_ptr<PjRtBuffer>,
|
||||||
|
std::shared_ptr<BufferSequencingEvent>>>
|
||||||
|
CopyToDeviceHelper(PjRtDevice* dst_device, LocalDeviceState* dst_local_device,
|
||||||
|
LocalDeviceState* transfer_local_device,
|
||||||
|
se::Stream* transfer_stream,
|
||||||
|
std::shared_ptr<TrackedDeviceBuffer> src_device_buffer);
|
||||||
|
|
||||||
|
PjRtStreamExecutorClient* const client_;
|
||||||
|
const Shape on_host_shape_;
|
||||||
|
const Shape on_device_shape_;
|
||||||
|
PjRtStreamExecutorDevice* const device_;
|
||||||
|
|
||||||
|
mutable absl::Mutex mu_;
|
||||||
|
std::shared_ptr<TrackedDeviceBuffer> device_buffer_ TF_GUARDED_BY(mu_);
|
||||||
|
absl::flat_hash_map<xla::Layout, std::shared_ptr<HostValue>> host_values_
|
||||||
|
TF_GUARDED_BY(mu_);
|
||||||
|
std::shared_ptr<HostValue> host_value_ TF_GUARDED_BY(mu_);
|
||||||
|
// Count of holds on the buffer.
|
||||||
|
std::array<int, ScopedHold::Type::kMaxValue> holds_ TF_GUARDED_BY(mu_);
|
||||||
|
// Semaphore used to ensure there is only one outstanding donation hold.
|
||||||
|
Semaphore donation_semaphore_;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Wraps one or more XLA LocalExecutables (one per partition, as specified by
|
||||||
|
// the build options).
|
||||||
|
class PjRtStreamExecutorExecutable : public PjRtExecutable {
|
||||||
|
public:
|
||||||
|
PjRtStreamExecutorExecutable(
|
||||||
|
std::vector<std::unique_ptr<LocalExecutable>> executables,
|
||||||
|
bool parameter_is_tupled_arguments,
|
||||||
|
std::shared_ptr<DeviceAssignment> device_assignment,
|
||||||
|
std::vector<LogicalDeviceIds> addressable_device_logical_ids,
|
||||||
|
std::vector<PjRtDevice*> addressable_devices,
|
||||||
|
PjRtStreamExecutorClient* client);
|
||||||
|
|
||||||
|
~PjRtStreamExecutorExecutable() override = default;
|
||||||
|
|
||||||
|
PjRtStreamExecutorClient* client() const override { return client_; }
|
||||||
|
|
||||||
|
const std::string& name() const override;
|
||||||
|
|
||||||
|
int num_replicas() const override {
|
||||||
|
return executables_[0]->build_options().num_replicas();
|
||||||
|
}
|
||||||
|
|
||||||
|
int num_partitions() const override {
|
||||||
|
return executables_[0]->build_options().num_partitions();
|
||||||
|
}
|
||||||
|
|
||||||
|
int64 SizeOfGeneratedCodeInBytes() const override {
|
||||||
|
int64 size = 0;
|
||||||
|
for (auto& executable : executables_) {
|
||||||
|
size += executable->executable()->SizeOfGeneratedCodeInBytes();
|
||||||
|
}
|
||||||
|
return size;
|
||||||
|
}
|
||||||
|
|
||||||
|
const DeviceAssignment& device_assignment() const override {
|
||||||
|
return *device_assignment_;
|
||||||
|
}
|
||||||
|
|
||||||
|
absl::Span<const LogicalDeviceIds> addressable_device_logical_ids()
|
||||||
|
const override {
|
||||||
|
return addressable_device_logical_ids_;
|
||||||
|
}
|
||||||
|
|
||||||
|
absl::Span<PjRtDevice* const> addressable_devices() const override {
|
||||||
|
return addressable_devices_;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return an HloModule per partition.
|
||||||
|
StatusOr<std::vector<std::shared_ptr<HloModule>>> GetHloModules()
|
||||||
|
const override;
|
||||||
|
|
||||||
|
StatusOr<std::vector<std::vector<std::unique_ptr<PjRtBuffer>>>> Execute(
|
||||||
|
absl::Span<const std::vector<PjRtBuffer*>> argument_handles,
|
||||||
|
const ExecuteOptions& options) const override;
|
||||||
|
|
||||||
|
StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>> ExecuteSharded(
|
||||||
|
absl::Span<PjRtBuffer* const> argument_handles, PjRtDevice* device,
|
||||||
|
const ExecuteOptions& options) const override;
|
||||||
|
|
||||||
|
StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>> ExecutePortable(
|
||||||
|
absl::Span<PjRtBuffer* const> argument_handles, PjRtDevice* device,
|
||||||
|
const ExecuteOptions& options) const override;
|
||||||
|
|
||||||
|
void Delete() override { executables_.clear(); }
|
||||||
|
|
||||||
|
absl::Span<const std::shared_ptr<LocalExecutable>> executables() const {
|
||||||
|
return executables_;
|
||||||
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
bool parameter_is_tupled_arguments() const {
|
||||||
|
return parameter_is_tupled_arguments_;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
friend class PjRtStreamExecutorClient;
|
||||||
|
// Initializes information about which arguments to which executables must be
|
||||||
|
// donated due to aliases that were specified by the computation.
|
||||||
|
Status SetUpDonation(bool tuple_inputs);
|
||||||
|
|
||||||
|
virtual bool MustDonateParameter(int executable_idx, int parameter) const;
|
||||||
|
|
||||||
|
virtual StatusOr<std::vector<ExecutionInput>>
|
||||||
|
MakeExecutionInputsAndWaitForEvents(
|
||||||
|
int device_ordinal, const ExecuteOptions& options,
|
||||||
|
absl::Span<PjRtBuffer* const> argument_handles,
|
||||||
|
absl::Span<const PjRtStreamExecutorBuffer::ScopedHold> device_buffers,
|
||||||
|
absl::flat_hash_set<BufferSequencingEvent*>& events) const;
|
||||||
|
|
||||||
|
StatusOr<ScopedShapedBuffer> EnqueueExecution(
|
||||||
|
absl::Span<PjRtBuffer* const> argument_handles, int replica,
|
||||||
|
int partition, int executable_idx, const RunId& run_id,
|
||||||
|
const ExecuteOptions& options, PjRtDevice* device,
|
||||||
|
std::vector<PjRtStreamExecutorBuffer::ScopedHold>* device_buffers,
|
||||||
|
std::shared_ptr<DeviceAssignment> device_assignment) const;
|
||||||
|
|
||||||
|
virtual std::vector<std::unique_ptr<PjRtBuffer>> MakeOutputBuffers(
|
||||||
|
int device_ordinal, const ExecuteOptions& options,
|
||||||
|
ScopedShapedBuffer result_buffer,
|
||||||
|
std::shared_ptr<BufferSequencingEvent> definition_event,
|
||||||
|
PjRtDevice* device) const;
|
||||||
|
|
||||||
|
StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>> ExecuteHelper(
|
||||||
|
absl::Span<PjRtBuffer* const> argument_handles, int replica,
|
||||||
|
int partition, const RunId& run_id, const ExecuteOptions& options,
|
||||||
|
PjRtDevice* device = nullptr) const;
|
||||||
|
|
||||||
|
// Create shared pointers so we can free them after the execution: with
|
||||||
|
// asynchronous execution, the process being executed can outlive the
|
||||||
|
// executable itself.
|
||||||
|
PjRtStreamExecutorClient* const client_;
|
||||||
|
// One executable per partition.
|
||||||
|
std::vector<std::shared_ptr<LocalExecutable>> executables_;
|
||||||
|
// Per-executable set of parameters that have any aliased buffers and thus
|
||||||
|
// must be donated when executing the computation.
|
||||||
|
std::vector<absl::flat_hash_set<int>> parameters_that_must_be_donated_;
|
||||||
|
std::shared_ptr<DeviceAssignment> device_assignment_;
|
||||||
|
|
||||||
|
// True if the executables were compiled expecting arguments in a single
|
||||||
|
// tuple.
|
||||||
|
const bool parameter_is_tupled_arguments_;
|
||||||
|
|
||||||
|
// The replica and partition indices of device_assignment_ to be run by this
|
||||||
|
// client. On single-host platforms without partitioning, this is all replicas
|
||||||
|
// (i.e. addressable_device_logical_ids_[i] = (i, 0)), but this may not be the
|
||||||
|
// case on multi-host platforms. If there are 4 replicas and 2 partitions on a
|
||||||
|
// single host platform, size of addressable_device_logical_ids_ is 4*2 = 8.
|
||||||
|
std::vector<LogicalDeviceIds> addressable_device_logical_ids_;
|
||||||
|
|
||||||
|
// addressable_devices_[i] is the Device to which
|
||||||
|
// addressable_device_logical_ids_[i] is assigned. shared_ptrs instead of
|
||||||
|
// unique_ptrs to play well with the Python bindings (see xla.cc).
|
||||||
|
std::vector<PjRtDevice*> addressable_devices_;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Executables can donate buffers so that buffers can be aliased from inputs
|
||||||
|
// to outputs. This function returns the list of parameters that must be
|
||||||
|
// donated when executable is run. tuple_inputs reflects the option that
|
||||||
|
// executable was compiled with.
|
||||||
|
StatusOr<absl::flat_hash_set<int>> GetParametersThatMustBeDonated(
|
||||||
|
const HloModule& hlo_module, bool tuple_inputs);
|
||||||
|
|
||||||
|
} // namespace xla
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_COMPILER_XLA_PJRT_PJRT_STREAM_EXECUTOR_CLIENT_H_
|
@ -23,7 +23,7 @@ limitations under the License.
|
|||||||
#include "absl/status/status.h"
|
#include "absl/status/status.h"
|
||||||
#include "tensorflow/compiler/xla/client/client_library.h"
|
#include "tensorflow/compiler/xla/client/client_library.h"
|
||||||
#include "tensorflow/compiler/xla/pjrt/local_device_state.h"
|
#include "tensorflow/compiler/xla/pjrt/local_device_state.h"
|
||||||
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
|
#include "tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h"
|
||||||
#include "tensorflow/compiler/xla/pjrt/tracked_device_buffer.h"
|
#include "tensorflow/compiler/xla/pjrt/tracked_device_buffer.h"
|
||||||
#include "tensorflow/compiler/xla/service/shaped_buffer.h"
|
#include "tensorflow/compiler/xla/service/shaped_buffer.h"
|
||||||
#include "tensorflow/compiler/xla/shape.h"
|
#include "tensorflow/compiler/xla/shape.h"
|
||||||
|
@ -20,7 +20,7 @@ limitations under the License.
|
|||||||
#include <memory>
|
#include <memory>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
|
#include "tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h"
|
||||||
#include "tensorflow/compiler/xla/statusor.h"
|
#include "tensorflow/compiler/xla/statusor.h"
|
||||||
#include "tensorflow/stream_executor/tpu/tpu_topology.h"
|
#include "tensorflow/stream_executor/tpu/tpu_topology.h"
|
||||||
|
|
||||||
|
@ -211,6 +211,7 @@ cc_library(
|
|||||||
"//tensorflow/compiler/xla:types",
|
"//tensorflow/compiler/xla:types",
|
||||||
"//tensorflow/compiler/xla:util",
|
"//tensorflow/compiler/xla:util",
|
||||||
"//tensorflow/compiler/xla/pjrt:pjrt_client",
|
"//tensorflow/compiler/xla/pjrt:pjrt_client",
|
||||||
|
"//tensorflow/compiler/xla/pjrt:pjrt_stream_executor_client", # TODO(zhangqiaorjc): Remove after adding a factory method for PjRtBuffer.
|
||||||
"//tensorflow/compiler/xla/pjrt:tracked_device_buffer",
|
"//tensorflow/compiler/xla/pjrt:tracked_device_buffer",
|
||||||
"//tensorflow/stream_executor:device_memory",
|
"//tensorflow/stream_executor:device_memory",
|
||||||
"//tensorflow/stream_executor:platform",
|
"//tensorflow/stream_executor:platform",
|
||||||
|
@ -25,6 +25,7 @@ limitations under the License.
|
|||||||
#include "include/dlpack/dlpack.h" // from @dlpack
|
#include "include/dlpack/dlpack.h" // from @dlpack
|
||||||
#include "pybind11/pytypes.h"
|
#include "pybind11/pytypes.h"
|
||||||
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
|
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
|
||||||
|
#include "tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h"
|
||||||
#include "tensorflow/compiler/xla/pjrt/tracked_device_buffer.h"
|
#include "tensorflow/compiler/xla/pjrt/tracked_device_buffer.h"
|
||||||
#include "tensorflow/compiler/xla/python/python_ref_manager.h"
|
#include "tensorflow/compiler/xla/python/python_ref_manager.h"
|
||||||
#include "tensorflow/compiler/xla/python/traceback.h"
|
#include "tensorflow/compiler/xla/python/traceback.h"
|
||||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
|||||||
#include <sys/types.h>
|
#include <sys/types.h>
|
||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
#include <queue>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
|
|
||||||
#include "absl/container/flat_hash_map.h"
|
#include "absl/container/flat_hash_map.h"
|
||||||
|
@ -26,7 +26,7 @@ cc_library(
|
|||||||
"//tensorflow/compiler/xla:util",
|
"//tensorflow/compiler/xla:util",
|
||||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||||
"//tensorflow/compiler/xla/client:executable_build_options",
|
"//tensorflow/compiler/xla/client:executable_build_options",
|
||||||
"//tensorflow/compiler/xla/pjrt:pjrt_client",
|
"//tensorflow/compiler/xla/pjrt:pjrt_stream_executor_client",
|
||||||
"//tensorflow/compiler/xla/pjrt:semaphore",
|
"//tensorflow/compiler/xla/pjrt:semaphore",
|
||||||
"//tensorflow/compiler/xla/python/tpu_driver",
|
"//tensorflow/compiler/xla/python/tpu_driver",
|
||||||
"//tensorflow/compiler/xla/python/tpu_driver:direct_tpu_driver",
|
"//tensorflow/compiler/xla/python/tpu_driver:direct_tpu_driver",
|
||||||
|
@ -24,7 +24,7 @@ limitations under the License.
|
|||||||
#include "absl/synchronization/notification.h"
|
#include "absl/synchronization/notification.h"
|
||||||
#include "absl/types/span.h"
|
#include "absl/types/span.h"
|
||||||
#include "tensorflow/compiler/xla/client/executable_build_options.h"
|
#include "tensorflow/compiler/xla/client/executable_build_options.h"
|
||||||
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
|
#include "tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h"
|
||||||
#include "tensorflow/compiler/xla/python/tpu_driver/tpu_driver.h"
|
#include "tensorflow/compiler/xla/python/tpu_driver/tpu_driver.h"
|
||||||
#include "tensorflow/compiler/xla/python/tpu_driver/tpu_driver.pb.h"
|
#include "tensorflow/compiler/xla/python/tpu_driver/tpu_driver.pb.h"
|
||||||
#include "tensorflow/compiler/xla/service/shaped_buffer.h"
|
#include "tensorflow/compiler/xla/service/shaped_buffer.h"
|
||||||
|
Loading…
Reference in New Issue
Block a user