Refactor PJRT.
- Replace PjRtPlatformId enum with a Fingerprint(platform_name) - This is an open set - It's fast and avoids string comparison PiperOrigin-RevId: 338776283 Change-Id: I122fcc89205cc9ac235973fdacfdfb2f94322e4f
This commit is contained in:
parent
c711e2b251
commit
df7d1daff4
@ -58,7 +58,7 @@ StatusOr<std::unique_ptr<PjRtClient>> GetCpuClient(bool asynchronous) {
|
||||
}
|
||||
|
||||
return std::make_unique<PjRtClient>(
|
||||
PjRtPlatformId::kCpu, client, std::move(devices), /*host_id=*/0,
|
||||
kCpuName, client, std::move(devices), /*host_id=*/0,
|
||||
/*allocator=*/nullptr, /*host_memory_allocator=*/nullptr,
|
||||
/*should_stage_host_to_device_transfers=*/false,
|
||||
/*gpu_run_options=*/nullptr);
|
||||
|
@ -52,7 +52,7 @@ StatusOr<std::unique_ptr<PjRtClient>> GetInterpreterClient() {
|
||||
devices.push_back(std::move(device));
|
||||
|
||||
return std::make_unique<PjRtClient>(
|
||||
PjRtPlatformId::kInterpreter, client, std::move(devices), /*host_id=*/0,
|
||||
"interpreter", client, std::move(devices), /*host_id=*/0,
|
||||
/*allocator=*/nullptr, /*host_memory_allocator=*/nullptr,
|
||||
/*should_stage_host_to_device_transfers=*/false,
|
||||
/*gpu_run_options=*/nullptr);
|
||||
|
@ -323,7 +323,7 @@ StatusOr<std::unique_ptr<PjRtClient>> GetNvidiaGpuClient(
|
||||
}
|
||||
|
||||
return std::unique_ptr<PjRtClient>(std::make_unique<GpuClient>(
|
||||
PjRtPlatformId::kNvidiaGpu, xla_client, std::move(devices),
|
||||
kGpuName, xla_client, std::move(devices),
|
||||
/*node_id=*/node_id, std::move(allocator),
|
||||
std::move(host_memory_allocator),
|
||||
/*should_stage_host_to_device_transfers=*/true,
|
||||
|
@ -98,6 +98,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/fingerprint.h"
|
||||
#include "tensorflow/core/platform/mem.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
@ -182,14 +183,14 @@ class CpuAllocator : public tensorflow::Allocator {
|
||||
};
|
||||
|
||||
PjRtClient::PjRtClient(
|
||||
PjRtPlatformId platform_id, LocalClient* client,
|
||||
std::string platform_name, LocalClient* client,
|
||||
std::vector<std::unique_ptr<PjRtDevice>> devices, int host_id,
|
||||
std::unique_ptr<se::DeviceMemoryAllocator> allocator,
|
||||
std::unique_ptr<tensorflow::Allocator> host_memory_allocator,
|
||||
bool should_stage_host_to_device_transfers,
|
||||
std::unique_ptr<GpuExecutableRunOptions> gpu_run_options)
|
||||
: platform_id_(platform_id),
|
||||
platform_name_(Name(platform_id)),
|
||||
: platform_id_(tensorflow::Fingerprint64(platform_name)),
|
||||
platform_name_(std::move(platform_name)),
|
||||
client_(client),
|
||||
host_memory_allocator_(std::move(host_memory_allocator)),
|
||||
devices_(std::move(devices)),
|
||||
@ -528,9 +529,7 @@ void PjRtBuffer::ScopedHold::AddToInput(
|
||||
}
|
||||
}
|
||||
|
||||
bool PjRtBuffer::IsOnCpu() const {
|
||||
return client()->platform_id() == PjRtPlatformId::kCpu;
|
||||
}
|
||||
bool PjRtBuffer::IsOnCpu() const { return client()->platform_id() == kCpuId; }
|
||||
|
||||
StatusOr<std::unique_ptr<PjRtBuffer>> PjRtClient::BufferFromHostBuffer(
|
||||
const void* data, const Shape& shape,
|
||||
|
@ -45,39 +45,23 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/framework/allocator.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/platform/fingerprint.h"
|
||||
#include "tensorflow/core/platform/thread_annotations.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
// API notes:
|
||||
// PjRt stands for "Pretty much Just another RunTime".
|
||||
|
||||
namespace xla {
|
||||
|
||||
// TODO(zhangqiaorjc): Add a registration mechanism to add new platforms.
|
||||
enum class PjRtPlatformId : int {
|
||||
kCpu = 0,
|
||||
kNvidiaGpu = 1,
|
||||
kAmdGpu = 2,
|
||||
kTpu = 3,
|
||||
kEdgeTpu = 4,
|
||||
kInterpreter = 5
|
||||
};
|
||||
constexpr const char* Name(PjRtPlatformId platform_id) {
|
||||
switch (platform_id) {
|
||||
case PjRtPlatformId::kCpu:
|
||||
return "cpu";
|
||||
case PjRtPlatformId::kNvidiaGpu:
|
||||
// TODO(zhangqiaorjc): Rename to nvidia_gpu when we add AMD support.
|
||||
return "gpu";
|
||||
case PjRtPlatformId::kAmdGpu:
|
||||
return "amd_gpu";
|
||||
case PjRtPlatformId::kTpu:
|
||||
return "tpu";
|
||||
case PjRtPlatformId::kEdgeTpu:
|
||||
return "edge_tpu";
|
||||
case PjRtPlatformId::kInterpreter:
|
||||
return "interpreter";
|
||||
}
|
||||
}
|
||||
using PjRtPlatformId = uint64;
|
||||
|
||||
constexpr char kCpuName[] = "cpu";
|
||||
constexpr char kGpuName[] = "gpu";
|
||||
constexpr char kTpuName[] = "tpu";
|
||||
static const PjRtPlatformId kCpuId = tensorflow::Fingerprint64(kCpuName);
|
||||
static const PjRtPlatformId kGpuId = tensorflow::Fingerprint64(kGpuName);
|
||||
static const PjRtPlatformId kTpuId = tensorflow::Fingerprint64(kTpuName);
|
||||
|
||||
class PjRtClient;
|
||||
|
||||
@ -196,7 +180,7 @@ class PjRtClient {
|
||||
public:
|
||||
// `allocator` may null, in which case the platform default allocator is used.
|
||||
explicit PjRtClient(
|
||||
PjRtPlatformId platform_id, LocalClient* client,
|
||||
std::string platform_name, LocalClient* client,
|
||||
std::vector<std::unique_ptr<PjRtDevice>> devices, int host_id,
|
||||
std::unique_ptr<se::DeviceMemoryAllocator> allocator,
|
||||
std::unique_ptr<tensorflow::Allocator> host_memory_allocator,
|
||||
|
@ -118,7 +118,7 @@ PjRtTpuClient::PjRtTpuClient(LocalClient* client,
|
||||
std::vector<std::unique_ptr<PjRtDevice>> devices,
|
||||
int host_id,
|
||||
tf_tpu::TpuPlatformInterface* tpu_platform)
|
||||
: PjRtClient(PjRtPlatformId::kTpu, client, std::move(devices), host_id,
|
||||
: PjRtClient(kTpuName, client, std::move(devices), host_id,
|
||||
/*allocator=*/nullptr,
|
||||
/*host_memory_allocator=*/nullptr,
|
||||
/*should_stage_host_to_device_transfers=*/false,
|
||||
@ -145,7 +145,7 @@ StatusOr<absl::optional<std::string>> PjRtTpuClient::ExecutableFingerprint(
|
||||
return InvalidArgument(
|
||||
"Passed executable from different client (platform '%s') to "
|
||||
"PjRtTpuClient::ExecutableFingerprint",
|
||||
Name(executable.client()->platform_id()));
|
||||
executable.client()->platform_name());
|
||||
}
|
||||
if (executable.executables().size() > 1) {
|
||||
LOG(INFO) << "ExecutableFingerprint not fully implemented for MPMD "
|
||||
|
@ -236,14 +236,14 @@ StatusOr<PjRtDevice*> DeviceForDLContext(const PjRtClient& client,
|
||||
const DLContext& context) {
|
||||
switch (context.device_type) {
|
||||
case kDLCPU:
|
||||
if (client.platform_id() != PjRtPlatformId::kCpu) {
|
||||
if (client.platform_id() != kCpuId) {
|
||||
return InvalidArgument(
|
||||
"DLPack CPU device type mismatch with PjRtClient platform %s",
|
||||
client.platform_name());
|
||||
}
|
||||
return client.LookupLocalDevice(context.device_id);
|
||||
case kDLGPU:
|
||||
if (client.platform_id() != PjRtPlatformId::kNvidiaGpu) {
|
||||
if (client.platform_id() != kGpuId) {
|
||||
return InvalidArgument(
|
||||
"DLPack GPU device type mismatch with PjRtClient platform %s",
|
||||
client.platform_name());
|
||||
|
Loading…
x
Reference in New Issue
Block a user