diff --git a/tensorflow/compiler/xla/pjrt/cpu_device.cc b/tensorflow/compiler/xla/pjrt/cpu_device.cc index 37c4ab3b7c5..9b0f060f392 100644 --- a/tensorflow/compiler/xla/pjrt/cpu_device.cc +++ b/tensorflow/compiler/xla/pjrt/cpu_device.cc @@ -58,7 +58,7 @@ StatusOr> GetCpuClient(bool asynchronous) { } return std::make_unique( - PjRtPlatformId::kCpu, client, std::move(devices), /*host_id=*/0, + kCpuName, client, std::move(devices), /*host_id=*/0, /*allocator=*/nullptr, /*host_memory_allocator=*/nullptr, /*should_stage_host_to_device_transfers=*/false, /*gpu_run_options=*/nullptr); diff --git a/tensorflow/compiler/xla/pjrt/interpreter_device.cc b/tensorflow/compiler/xla/pjrt/interpreter_device.cc index 53a4bed8bb5..2819cabf258 100644 --- a/tensorflow/compiler/xla/pjrt/interpreter_device.cc +++ b/tensorflow/compiler/xla/pjrt/interpreter_device.cc @@ -52,7 +52,7 @@ StatusOr> GetInterpreterClient() { devices.push_back(std::move(device)); return std::make_unique( - PjRtPlatformId::kInterpreter, client, std::move(devices), /*host_id=*/0, + "interpreter", client, std::move(devices), /*host_id=*/0, /*allocator=*/nullptr, /*host_memory_allocator=*/nullptr, /*should_stage_host_to_device_transfers=*/false, /*gpu_run_options=*/nullptr); diff --git a/tensorflow/compiler/xla/pjrt/nvidia_gpu_device.cc b/tensorflow/compiler/xla/pjrt/nvidia_gpu_device.cc index 5003c8a7cde..fde6016e5f9 100644 --- a/tensorflow/compiler/xla/pjrt/nvidia_gpu_device.cc +++ b/tensorflow/compiler/xla/pjrt/nvidia_gpu_device.cc @@ -323,7 +323,7 @@ StatusOr> GetNvidiaGpuClient( } return std::unique_ptr(std::make_unique( - PjRtPlatformId::kNvidiaGpu, xla_client, std::move(devices), + kGpuName, xla_client, std::move(devices), /*node_id=*/node_id, std::move(allocator), std::move(host_memory_allocator), /*should_stage_host_to_device_transfers=*/true, diff --git a/tensorflow/compiler/xla/pjrt/pjrt_client.cc b/tensorflow/compiler/xla/pjrt/pjrt_client.cc index 83ed61cfe63..b2e02a0450f 100644 --- a/tensorflow/compiler/xla/pjrt/pjrt_client.cc +++ b/tensorflow/compiler/xla/pjrt/pjrt_client.cc @@ -98,6 +98,7 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/fingerprint.h" #include "tensorflow/core/platform/mem.h" #include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/types.h" @@ -182,14 +183,14 @@ class CpuAllocator : public tensorflow::Allocator { }; PjRtClient::PjRtClient( - PjRtPlatformId platform_id, LocalClient* client, + std::string platform_name, LocalClient* client, std::vector> devices, int host_id, std::unique_ptr allocator, std::unique_ptr host_memory_allocator, bool should_stage_host_to_device_transfers, std::unique_ptr gpu_run_options) - : platform_id_(platform_id), - platform_name_(Name(platform_id)), + : platform_id_(tensorflow::Fingerprint64(platform_name)), + platform_name_(std::move(platform_name)), client_(client), host_memory_allocator_(std::move(host_memory_allocator)), devices_(std::move(devices)), @@ -528,9 +529,7 @@ void PjRtBuffer::ScopedHold::AddToInput( } } -bool PjRtBuffer::IsOnCpu() const { - return client()->platform_id() == PjRtPlatformId::kCpu; -} +bool PjRtBuffer::IsOnCpu() const { return client()->platform_id() == kCpuId; } StatusOr> PjRtClient::BufferFromHostBuffer( const void* data, const Shape& shape, diff --git a/tensorflow/compiler/xla/pjrt/pjrt_client.h b/tensorflow/compiler/xla/pjrt/pjrt_client.h index 86805182525..38d2610ff93 100644 --- a/tensorflow/compiler/xla/pjrt/pjrt_client.h +++ b/tensorflow/compiler/xla/pjrt/pjrt_client.h @@ -45,39 +45,23 @@ limitations under the License. #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/allocator.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/fingerprint.h" #include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/platform/types.h" // API notes: // PjRt stands for "Pretty much Just another RunTime". namespace xla { -// TODO(zhangqiaorjc): Add a registration mechanism to add new platforms. -enum class PjRtPlatformId : int { - kCpu = 0, - kNvidiaGpu = 1, - kAmdGpu = 2, - kTpu = 3, - kEdgeTpu = 4, - kInterpreter = 5 -}; -constexpr const char* Name(PjRtPlatformId platform_id) { - switch (platform_id) { - case PjRtPlatformId::kCpu: - return "cpu"; - case PjRtPlatformId::kNvidiaGpu: - // TODO(zhangqiaorjc): Rename to nvidia_gpu when we add AMD support. - return "gpu"; - case PjRtPlatformId::kAmdGpu: - return "amd_gpu"; - case PjRtPlatformId::kTpu: - return "tpu"; - case PjRtPlatformId::kEdgeTpu: - return "edge_tpu"; - case PjRtPlatformId::kInterpreter: - return "interpreter"; - } -} +using PjRtPlatformId = uint64; + +constexpr char kCpuName[] = "cpu"; +constexpr char kGpuName[] = "gpu"; +constexpr char kTpuName[] = "tpu"; +static const PjRtPlatformId kCpuId = tensorflow::Fingerprint64(kCpuName); +static const PjRtPlatformId kGpuId = tensorflow::Fingerprint64(kGpuName); +static const PjRtPlatformId kTpuId = tensorflow::Fingerprint64(kTpuName); class PjRtClient; @@ -196,7 +180,7 @@ class PjRtClient { public: // `allocator` may null, in which case the platform default allocator is used. explicit PjRtClient( - PjRtPlatformId platform_id, LocalClient* client, + std::string platform_name, LocalClient* client, std::vector> devices, int host_id, std::unique_ptr allocator, std::unique_ptr host_memory_allocator, diff --git a/tensorflow/compiler/xla/pjrt/tpu_client.cc b/tensorflow/compiler/xla/pjrt/tpu_client.cc index 5a28d82335e..b0c8b7cb62f 100644 --- a/tensorflow/compiler/xla/pjrt/tpu_client.cc +++ b/tensorflow/compiler/xla/pjrt/tpu_client.cc @@ -118,7 +118,7 @@ PjRtTpuClient::PjRtTpuClient(LocalClient* client, std::vector> devices, int host_id, tf_tpu::TpuPlatformInterface* tpu_platform) - : PjRtClient(PjRtPlatformId::kTpu, client, std::move(devices), host_id, + : PjRtClient(kTpuName, client, std::move(devices), host_id, /*allocator=*/nullptr, /*host_memory_allocator=*/nullptr, /*should_stage_host_to_device_transfers=*/false, @@ -145,7 +145,7 @@ StatusOr> PjRtTpuClient::ExecutableFingerprint( return InvalidArgument( "Passed executable from different client (platform '%s') to " "PjRtTpuClient::ExecutableFingerprint", - Name(executable.client()->platform_id())); + executable.client()->platform_name()); } if (executable.executables().size() > 1) { LOG(INFO) << "ExecutableFingerprint not fully implemented for MPMD " diff --git a/tensorflow/compiler/xla/python/dlpack.cc b/tensorflow/compiler/xla/python/dlpack.cc index 8f4045a0e7c..85252256657 100644 --- a/tensorflow/compiler/xla/python/dlpack.cc +++ b/tensorflow/compiler/xla/python/dlpack.cc @@ -236,14 +236,14 @@ StatusOr DeviceForDLContext(const PjRtClient& client, const DLContext& context) { switch (context.device_type) { case kDLCPU: - if (client.platform_id() != PjRtPlatformId::kCpu) { + if (client.platform_id() != kCpuId) { return InvalidArgument( "DLPack CPU device type mismatch with PjRtClient platform %s", client.platform_name()); } return client.LookupLocalDevice(context.device_id); case kDLGPU: - if (client.platform_id() != PjRtPlatformId::kNvidiaGpu) { + if (client.platform_id() != kGpuId) { return InvalidArgument( "DLPack GPU device type mismatch with PjRtClient platform %s", client.platform_name());