diff --git a/tensorflow/compiler/xla/pjrt/cpu_device.cc b/tensorflow/compiler/xla/pjrt/cpu_device.cc index c571ef2a4df..37c4ab3b7c5 100644 --- a/tensorflow/compiler/xla/pjrt/cpu_device.cc +++ b/tensorflow/compiler/xla/pjrt/cpu_device.cc @@ -17,6 +17,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/client/client_library.h" +#include "tensorflow/compiler/xla/pjrt/pjrt_client.h" #include "tensorflow/compiler/xla/service/platform_util.h" namespace xla { @@ -25,7 +26,7 @@ static const char kCpuPlatformName[] = "cpu"; CpuDevice::CpuDevice(int id, std::unique_ptr local_device_state) - : PjRtDevice(id, std::move(local_device_state), kCpuPlatformName, + : PjRtDevice(id, std::move(local_device_state), /*device_kind=*/kCpuPlatformName) {} StatusOr> GetCpuClient(bool asynchronous) { @@ -57,7 +58,7 @@ StatusOr> GetCpuClient(bool asynchronous) { } return std::make_unique( - kCpuPlatformName, client, std::move(devices), /*host_id=*/0, + PjRtPlatformId::kCpu, client, std::move(devices), /*host_id=*/0, /*allocator=*/nullptr, /*host_memory_allocator=*/nullptr, /*should_stage_host_to_device_transfers=*/false, /*gpu_run_options=*/nullptr); diff --git a/tensorflow/compiler/xla/pjrt/interpreter_device.cc b/tensorflow/compiler/xla/pjrt/interpreter_device.cc index 376d8687892..53a4bed8bb5 100644 --- a/tensorflow/compiler/xla/pjrt/interpreter_device.cc +++ b/tensorflow/compiler/xla/pjrt/interpreter_device.cc @@ -17,6 +17,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/client/client_library.h" +#include "tensorflow/compiler/xla/pjrt/pjrt_client.h" #include "tensorflow/compiler/xla/service/platform_util.h" namespace xla { @@ -25,7 +26,7 @@ static const char kInterpreterPlatformName[] = "interpreter"; InterpreterDevice::InterpreterDevice( int id, std::unique_ptr local_device_state) - : PjRtDevice(id, std::move(local_device_state), kInterpreterPlatformName, + : PjRtDevice(id, std::move(local_device_state), /*device_kind=*/kInterpreterPlatformName) {} StatusOr> GetInterpreterClient() { @@ -51,7 +52,7 @@ StatusOr> GetInterpreterClient() { devices.push_back(std::move(device)); return std::make_unique( - kInterpreterPlatformName, client, std::move(devices), /*host_id=*/0, + PjRtPlatformId::kInterpreter, client, std::move(devices), /*host_id=*/0, /*allocator=*/nullptr, /*host_memory_allocator=*/nullptr, /*should_stage_host_to_device_transfers=*/false, /*gpu_run_options=*/nullptr); diff --git a/tensorflow/compiler/xla/pjrt/nvidia_gpu_device.cc b/tensorflow/compiler/xla/pjrt/nvidia_gpu_device.cc index df92921c39d..5003c8a7cde 100644 --- a/tensorflow/compiler/xla/pjrt/nvidia_gpu_device.cc +++ b/tensorflow/compiler/xla/pjrt/nvidia_gpu_device.cc @@ -30,8 +30,6 @@ limitations under the License. namespace xla { namespace { -static const char kGpuPlatformName[] = "gpu"; - // A custom PjRtClient that overrides the device assignment method. class GpuClient : public xla::PjRtClient { public: @@ -298,8 +296,8 @@ Status BuildDistributedDevices( GpuDevice::GpuDevice(int id, std::unique_ptr local_device_state, std::string device_kind, int node_id) - : PjRtDevice(id, std::move(local_device_state), kGpuPlatformName, - std::move(device_kind), node_id) {} + : PjRtDevice(id, std::move(local_device_state), std::move(device_kind), + node_id) {} StatusOr> GetNvidiaGpuClient( bool asynchronous, const GpuAllocatorConfig& allocator_config, @@ -325,7 +323,7 @@ StatusOr> GetNvidiaGpuClient( } return std::unique_ptr(std::make_unique( - "gpu", xla_client, std::move(devices), + PjRtPlatformId::kNvidiaGpu, xla_client, std::move(devices), /*node_id=*/node_id, std::move(allocator), std::move(host_memory_allocator), /*should_stage_host_to_device_transfers=*/true, diff --git a/tensorflow/compiler/xla/pjrt/pjrt_client.cc b/tensorflow/compiler/xla/pjrt/pjrt_client.cc index 8752b6260f6..18de057395b 100644 --- a/tensorflow/compiler/xla/pjrt/pjrt_client.cc +++ b/tensorflow/compiler/xla/pjrt/pjrt_client.cc @@ -113,6 +113,13 @@ limitations under the License. namespace xla { +PjRtPlatformId PjRtDevice::platform_id() const { + return client_->platform_id(); +} +const std::string& PjRtDevice::platform_name() const { + return client_->platform_name(); +} + StatusOr PjRtDevice::GetLocalDeviceState() const { if (local_device_state_) { return local_device_state_.get(); @@ -145,8 +152,8 @@ StatusOr DevicesToDeviceAssignment( devices[replica].size(), replica, devices[0].size()); } for (int partition = 0; partition < devices[replica].size(); ++partition) { - if (devices[0][0]->platform_name() != - devices[replica][partition]->platform_name()) { + if (devices[0][0]->platform_id() != + devices[replica][partition]->platform_id()) { return InvalidArgument( "Device assignment passed to Compile() must have devices of a " "single kind, got %s for replica 0 partition 0 and %s for replica " @@ -175,13 +182,14 @@ class CpuAllocator : public tensorflow::Allocator { }; PjRtClient::PjRtClient( - std::string platform_name, LocalClient* client, + PjRtPlatformId platform_id, LocalClient* client, std::vector> devices, int host_id, std::unique_ptr allocator, std::unique_ptr host_memory_allocator, bool should_stage_host_to_device_transfers, std::unique_ptr gpu_run_options) - : platform_name_(std::move(platform_name)), + : platform_id_(platform_id), + platform_name_(Name(platform_id)), client_(client), host_memory_allocator_(std::move(host_memory_allocator)), devices_(std::move(devices)), @@ -206,15 +214,15 @@ PjRtClient::PjRtClient( CHECK(id_to_device_.insert({device->id(), device.get()}).second) << "Duplicate device id: " << device->id(); - if (device->local_device_state()) { - int idx = device->local_device_state()->device_ordinal(); + if (device->IsLocalDevice()) { + int idx = device->local_device_id(); if (idx >= local_devices_.size()) { local_devices_.resize(idx + 1); } CHECK(local_devices_[idx] == nullptr) << idx; local_devices_[idx] = device.get(); } - device->client_ = this; + device->SetClient(this); } for (int idx = 0; idx < local_devices_.size(); ++idx) { CHECK(local_devices_[idx] != nullptr) << idx; @@ -576,6 +584,10 @@ void PjRtBuffer::ScopedHold::AddToInput( } } +bool PjRtBuffer::IsOnCpu() const { + return client()->platform_id() == PjRtPlatformId::kCpu; +} + StatusOr> PjRtClient::BufferFromHostBuffer( const void* data, const Shape& shape, HostBufferSemantics host_buffer_semantics, @@ -865,6 +877,16 @@ StatusOr PjRtDevice::TransferFromOutfeed(const Shape& shape) const { shape, local_device->device_ordinal()); } +StatusOr PjRtClient::LookupLocalDevice(int local_device_id) const { + for (auto* device : local_devices_) { + if (local_device_id == device->local_device_id()) { + return device; + } + } + return InvalidArgument("No matching device found for local_device_id %d", + local_device_id); +} + PjRtBuffer::PjRtBuffer(Shape on_host_shape, Shape on_device_shape, std::shared_ptr device_buffer, PjRtClient* client, PjRtDevice* device) @@ -1985,6 +2007,19 @@ PjRtExecutable::ExecuteOnLocalDevices( return wrapped_results; } +StatusOr>> +PjRtExecutable::GetHloModules() { + std::vector> modules; + modules.reserve(executables().size()); + for (const auto& local_exec : executables()) { + if (!local_exec->executable()->has_module()) { + return InvalidArgument("Executable does not have HLO modules."); + } + modules.push_back(local_exec->executable()->shared_module()); + } + return std::move(modules); +} + namespace { StatusOr GetShardedShape(const Shape& shape, diff --git a/tensorflow/compiler/xla/pjrt/pjrt_client.h b/tensorflow/compiler/xla/pjrt/pjrt_client.h index 3331bf890cc..4ec129eb49d 100644 --- a/tensorflow/compiler/xla/pjrt/pjrt_client.h +++ b/tensorflow/compiler/xla/pjrt/pjrt_client.h @@ -36,11 +36,13 @@ limitations under the License. #include "tensorflow/compiler/xla/pjrt/tracked_device_buffer.h" #include "tensorflow/compiler/xla/service/computation_placer.h" #include "tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/shaped_buffer.h" #include "tensorflow/compiler/xla/shape.h" #include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/allocator.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/thread_annotations.h" @@ -50,26 +52,63 @@ limitations under the License. namespace xla { +// TODO(zhangqiaorjc): Add a registration mechanism to add new platforms. +enum class PjRtPlatformId : int { + kCpu = 0, + kNvidiaGpu = 1, + kAmdGpu = 2, + kTpu = 3, + kEdgeTpu = 4, + kInterpreter = 5 +}; +constexpr const char* Name(PjRtPlatformId platform_id) { + switch (platform_id) { + case PjRtPlatformId::kCpu: + return "cpu"; + case PjRtPlatformId::kNvidiaGpu: + // TODO(zhangqiaorjc): Rename to nvidia_gpu when we add AMD support. + return "gpu"; + case PjRtPlatformId::kAmdGpu: + return "amd_gpu"; + case PjRtPlatformId::kTpu: + return "tpu"; + case PjRtPlatformId::kEdgeTpu: + return "edge_tpu"; + case PjRtPlatformId::kInterpreter: + return "interpreter"; + } +} + class PjRtClient; class PjRtDevice { public: explicit PjRtDevice(int id, std::unique_ptr local_device_state, - std::string platform_name, std::string device_kind, - int host_id = 0) + std::string device_kind, int host_id = 0) : id_(id), + local_device_id_( + local_device_state ? local_device_state->device_ordinal() : -1), local_device_state_(std::move(local_device_state)), host_id_(host_id), - platform_name_(std::move(platform_name)), device_kind_(std::move(device_kind)) {} virtual ~PjRtDevice() {} + // Must set client exactly once. + void SetClient(PjRtClient* client) { + CHECK(client_ == nullptr); + client_ = client; + } + // The ID of this device. IDs are unique among devices of this type // (e.g. CPUs, GPUs). On multi-host platforms, this will be unique across all // hosts' devices. This is the ID that should be used in a DeviceAssignment. int id() const { return id_; } + bool IsLocalDevice() const { return local_device_id_ != -1; } + + int local_device_id() const { return local_device_id_; } + // If this is a device local to this host, returns a LocalDeviceState object // that can be used to manipulate the device. Returns nullptr if the device is // not local to this host. @@ -85,7 +124,11 @@ class PjRtDevice { // The ID of this device's host. This is always 0 on single-host platforms. int host_id() const { return host_id_; } - const std::string& platform_name() const { return platform_name_; } + // Return `platform_id` from client. + PjRtPlatformId platform_id() const; + + // Return `platform_name` from client. + const std::string& platform_name() const; // A vendor-dependent string that uniquely identifies the kind of device. const std::string& device_kind() const { return device_kind_; } @@ -102,12 +145,10 @@ class PjRtDevice { virtual StatusOr TransferFromOutfeed(const Shape& shape) const; private: - friend class PjRtClient; - const int id_; + const int local_device_id_; // -1 means not local. const std::unique_ptr local_device_state_; const int host_id_; - const std::string platform_name_; const std::string device_kind_; PjRtClient* client_ = nullptr; }; @@ -155,7 +196,7 @@ class PjRtClient { public: // `allocator` may null, in which case the platform default allocator is used. explicit PjRtClient( - std::string platform_name, LocalClient* client, + PjRtPlatformId platform_id, LocalClient* client, std::vector> devices, int host_id, std::unique_ptr allocator, std::unique_ptr host_memory_allocator, @@ -178,12 +219,16 @@ class PjRtClient { return id_to_device_; } int host_id() const { return host_id_; } + PjRtPlatformId platform_id() const { return platform_id_; } const std::string& platform_name() const { return platform_name_; } LocalDeviceState& device_state(int device_ordinal) const { return *local_devices_.at(device_ordinal)->local_device_state(); } + // Return a local PjRtDevice for a given `local_device_id`. + virtual StatusOr LookupLocalDevice(int local_device_id) const; + LocalClient* client() const { return client_; } se::DeviceMemoryAllocator* allocator() const { return allocator_; } tensorflow::Allocator* host_memory_allocator() const { @@ -280,6 +325,16 @@ class PjRtClient { absl::Span shapes, PjRtDevice* device, PjRtCrossHostRecvNotifier&& notifier); + virtual StatusOr CreateChannelHandle() { + return client()->CreateChannelHandle(); + } + virtual StatusOr CreateDeviceToHostChannelHandle() { + return client()->CreateDeviceToHostChannelHandle(); + } + virtual StatusOr CreateHostToDeviceChannelHandle() { + return client()->CreateHostToDeviceChannelHandle(); + } + protected: friend class PjRtBuffer; virtual void EnqueueCrossHostReceive( @@ -293,7 +348,8 @@ class PjRtClient { return Unimplemented("Cross host sends not implemented."); } - std::string platform_name_; + const PjRtPlatformId platform_id_; + const std::string platform_name_; LocalClient* client_; // Allocator to be used for staging memory transfers to devices. @@ -509,7 +565,7 @@ class PjRtBuffer { PjRtBuffer(Shape on_host_shape, Shape on_device_shape, std::shared_ptr device_buffer, PjRtClient* client, PjRtDevice* device); - ~PjRtBuffer(); + virtual ~PjRtBuffer(); PjRtBuffer(const PjRtBuffer&) = delete; PjRtBuffer(PjRtBuffer&&) = delete; @@ -519,6 +575,7 @@ class PjRtBuffer { const Shape& on_host_shape() const { return on_host_shape_; } const Shape& on_device_shape() const { return on_device_shape_; } PjRtDevice* device() const { return device_; } + PjRtPlatformId platform_id() const { return client_->platform_id(); } const std::string& platform_name() const { return client_->platform_name(); } PjRtClient* client() const { return client_; } bool IsEmptyTuple() const { @@ -611,6 +668,9 @@ class PjRtBuffer { // immediate use on the device. Useful in particular for timing benchmarks. Status BlockHostUntilReady(); + // Whether this buffer is on CPU and thus allows for certain optimizations. + bool IsOnCpu() const; + private: friend class PjRtClient; // The cached value of the buffer on the host, produced either from a call to @@ -782,6 +842,9 @@ class PjRtExecutable { const string& name() const; + // Return an HloModule per partition. + StatusOr>> GetHloModules(); + protected: bool parameter_is_tupled_arguments() const { return parameter_is_tupled_arguments_; diff --git a/tensorflow/compiler/xla/pjrt/tpu_client.cc b/tensorflow/compiler/xla/pjrt/tpu_client.cc index b2af6e79980..5a28d82335e 100644 --- a/tensorflow/compiler/xla/pjrt/tpu_client.cc +++ b/tensorflow/compiler/xla/pjrt/tpu_client.cc @@ -118,7 +118,7 @@ PjRtTpuClient::PjRtTpuClient(LocalClient* client, std::vector> devices, int host_id, tf_tpu::TpuPlatformInterface* tpu_platform) - : PjRtClient("tpu", client, std::move(devices), host_id, + : PjRtClient(PjRtPlatformId::kTpu, client, std::move(devices), host_id, /*allocator=*/nullptr, /*host_memory_allocator=*/nullptr, /*should_stage_host_to_device_transfers=*/false, @@ -145,7 +145,7 @@ StatusOr> PjRtTpuClient::ExecutableFingerprint( return InvalidArgument( "Passed executable from different client (platform '%s') to " "PjRtTpuClient::ExecutableFingerprint", - executable.client()->platform_name()); + Name(executable.client()->platform_id())); } if (executable.executables().size() > 1) { LOG(INFO) << "ExecutableFingerprint not fully implemented for MPMD " diff --git a/tensorflow/compiler/xla/pjrt/tpu_client.h b/tensorflow/compiler/xla/pjrt/tpu_client.h index 1a458c1480b..cdc68bc9606 100644 --- a/tensorflow/compiler/xla/pjrt/tpu_client.h +++ b/tensorflow/compiler/xla/pjrt/tpu_client.h @@ -33,7 +33,7 @@ class PjRtTpuDevice : public PjRtDevice { int host_id, const std::array& coords, std::string device_kind) : PjRtDevice(core.Id(), std::move(local_device_state), - /*platform_name=*/"tpu", std::move(device_kind), host_id), + std::move(device_kind), host_id), core_(core), coords_(coords) {} diff --git a/tensorflow/compiler/xla/python/dlpack.cc b/tensorflow/compiler/xla/python/dlpack.cc index d0013c50dc6..8f4045a0e7c 100644 --- a/tensorflow/compiler/xla/python/dlpack.cc +++ b/tensorflow/compiler/xla/python/dlpack.cc @@ -228,35 +228,31 @@ StatusOr DLDeviceTypeForDevice(const PjRtDevice& device) { StatusOr DLContextForDevice(const PjRtDevice& device) { DLContext context; TF_ASSIGN_OR_RETURN(context.device_type, DLDeviceTypeForDevice(device)); - context.device_id = device.local_device_state()->device_ordinal(); + context.device_id = device.local_device_id(); return context; } StatusOr DeviceForDLContext(const PjRtClient& client, const DLContext& context) { - se::Platform::Id platform_id; switch (context.device_type) { case kDLCPU: - platform_id = se::host::kHostPlatformId; - break; + if (client.platform_id() != PjRtPlatformId::kCpu) { + return InvalidArgument( + "DLPack CPU device type mismatch with PjRtClient platform %s", + client.platform_name()); + } + return client.LookupLocalDevice(context.device_id); case kDLGPU: - platform_id = se::cuda::kCudaPlatformId; - break; + if (client.platform_id() != PjRtPlatformId::kNvidiaGpu) { + return InvalidArgument( + "DLPack GPU device type mismatch with PjRtClient platform %s", + client.platform_name()); + } + return client.LookupLocalDevice(context.device_id); default: return InvalidArgument("Unknown/unsupported DLPack device type %d", context.device_type); } - auto it = absl::c_find_if(client.local_devices(), [&](PjRtDevice* device) { - return device->local_device_state()->executor()->platform()->id() == - platform_id && - device->local_device_state()->device_ordinal() == context.device_id; - }); - if (it == client.local_devices().end()) { - return InvalidArgument( - "No matching device found for DLPack device_type %d device_id %d", - context.device_type, context.device_id); - } - return *it; } } // namespace @@ -301,8 +297,7 @@ StatusOr BufferToDLPackManagedTensor(py::handle py_buffer, pack->tensor.manager_ctx = pack.get(); pack->tensor.deleter = DLPackTensorDeleter; TF_ASSIGN_OR_RETURN(dt.ctx, DLContextForDevice(*buffer->buffer()->device())); - dt.ctx.device_id = - buffer->buffer()->device()->local_device_state()->device_ordinal(); + dt.ctx.device_id = buffer->buffer()->device()->local_device_id(); dt.ndim = buffer->buffer()->on_host_shape().dimensions_size(); TF_ASSIGN_OR_RETURN(dt.dtype, PrimitiveTypeToDLDataType( diff --git a/tensorflow/compiler/xla/python/py_buffer.cc b/tensorflow/compiler/xla/python/py_buffer.cc index b32fe047530..cac14142b75 100644 --- a/tensorflow/compiler/xla/python/py_buffer.cc +++ b/tensorflow/compiler/xla/python/py_buffer.cc @@ -144,7 +144,7 @@ int PjRtBufferGetBuffer(PyObject* exporter, Py_buffer* view, int flags) { // Additionally we call BlockHostUntilReady() below, which may block. py::gil_scoped_release gil_release; - if (buffer.device()->platform_name() != "cpu") { + if (!buffer.IsOnCpu()) { return InvalidArgument( "Python buffer protocol is only defined for CPU buffers."); } diff --git a/tensorflow/compiler/xla/python/py_client.h b/tensorflow/compiler/xla/python/py_client.h index 224f8278bb1..37f5333ea1c 100644 --- a/tensorflow/compiler/xla/python/py_client.h +++ b/tensorflow/compiler/xla/python/py_client.h @@ -112,13 +112,13 @@ class PyClient : public std::enable_shared_from_this { int num_replicas); StatusOr CreateChannelHandle() { - return pjrt_client_->client()->CreateChannelHandle(); + return pjrt_client_->CreateChannelHandle(); } StatusOr CreateDeviceToHostChannelHandle() { - return pjrt_client_->client()->CreateDeviceToHostChannelHandle(); + return pjrt_client_->CreateDeviceToHostChannelHandle(); } StatusOr CreateHostToDeviceChannelHandle() { - return pjrt_client_->client()->CreateHostToDeviceChannelHandle(); + return pjrt_client_->CreateHostToDeviceChannelHandle(); } StatusOr> BufferFromPyval( diff --git a/tensorflow/compiler/xla/python/py_executable.cc b/tensorflow/compiler/xla/python/py_executable.cc index 53891b96846..9d1b89a1cbc 100644 --- a/tensorflow/compiler/xla/python/py_executable.cc +++ b/tensorflow/compiler/xla/python/py_executable.cc @@ -135,15 +135,7 @@ PyExecutable::ExecuteOnLocalDevices( StatusOr>> PyExecutable::HloModules() const { - std::vector> modules; - modules.reserve(executable_->executables().size()); - for (const auto& local_exec : executable_->executables()) { - if (!local_exec->executable()->has_module()) { - return InvalidArgument("Executable does not have HLO modules."); - } - modules.push_back(local_exec->executable()->shared_module()); - } - return std::move(modules); + return executable_->GetHloModules(); } } // namespace xla diff --git a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.cc b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.cc index 0602d096aaa..6cd55d0e631 100644 --- a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.cc +++ b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.cc @@ -37,7 +37,7 @@ namespace xla { TpuDevice::TpuDevice(int id, int host_id, const std::array& coords, int core_on_chip) - : xla::PjRtDevice(id, /*local_device_state=*/nullptr, kTpuPlatform, + : xla::PjRtDevice(id, /*local_device_state=*/nullptr, /*device_kind=*/"Cloud TPU", host_id), coords_(coords), core_on_chip_(core_on_chip) {} diff --git a/tensorflow/compiler/xla/python/xla.cc b/tensorflow/compiler/xla/python/xla.cc index 2101191be86..fb8c0ba0ba4 100644 --- a/tensorflow/compiler/xla/python/xla.cc +++ b/tensorflow/compiler/xla/python/xla.cc @@ -641,9 +641,7 @@ PYBIND11_MODULE(xla_extension, m) { [](py::object buffer_obj) -> StatusOr { GlobalPyRefManager()->CollectGarbage(); PyBuffer* buffer = buffer_obj.cast(); - LocalDeviceState* state = - buffer->buffer()->device()->local_device_state(); - if (state->executor()->platform_kind() == se::PlatformKind::kHost && + if (buffer->buffer()->IsOnCpu() && buffer->buffer()->on_device_shape().IsArray() && buffer->buffer()->on_device_shape().element_type() != BF16) { py::object out = py::reinterpret_steal(