Refactor PJRT.

- Replace PjRtPlatformId enum with a Fingerprint(platform_name)
  - This is an open set
  - It's fast and avoids string comparison

PiperOrigin-RevId: 338776283
Change-Id: I122fcc89205cc9ac235973fdacfdfb2f94322e4f
This commit is contained in:
Qiao Zhang 2020-10-23 17:47:57 -07:00 committed by TensorFlower Gardener
parent c711e2b251
commit df7d1daff4
7 changed files with 23 additions and 40 deletions

View File

@ -58,7 +58,7 @@ StatusOr<std::unique_ptr<PjRtClient>> GetCpuClient(bool asynchronous) {
}
return std::make_unique<PjRtClient>(
PjRtPlatformId::kCpu, client, std::move(devices), /*host_id=*/0,
kCpuName, client, std::move(devices), /*host_id=*/0,
/*allocator=*/nullptr, /*host_memory_allocator=*/nullptr,
/*should_stage_host_to_device_transfers=*/false,
/*gpu_run_options=*/nullptr);

View File

@ -52,7 +52,7 @@ StatusOr<std::unique_ptr<PjRtClient>> GetInterpreterClient() {
devices.push_back(std::move(device));
return std::make_unique<PjRtClient>(
PjRtPlatformId::kInterpreter, client, std::move(devices), /*host_id=*/0,
"interpreter", client, std::move(devices), /*host_id=*/0,
/*allocator=*/nullptr, /*host_memory_allocator=*/nullptr,
/*should_stage_host_to_device_transfers=*/false,
/*gpu_run_options=*/nullptr);

View File

@ -323,7 +323,7 @@ StatusOr<std::unique_ptr<PjRtClient>> GetNvidiaGpuClient(
}
return std::unique_ptr<PjRtClient>(std::make_unique<GpuClient>(
PjRtPlatformId::kNvidiaGpu, xla_client, std::move(devices),
kGpuName, xla_client, std::move(devices),
/*node_id=*/node_id, std::move(allocator),
std::move(host_memory_allocator),
/*should_stage_host_to_device_transfers=*/true,

View File

@ -98,6 +98,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/fingerprint.h"
#include "tensorflow/core/platform/mem.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/platform/types.h"
@ -182,14 +183,14 @@ class CpuAllocator : public tensorflow::Allocator {
};
PjRtClient::PjRtClient(
PjRtPlatformId platform_id, LocalClient* client,
std::string platform_name, LocalClient* client,
std::vector<std::unique_ptr<PjRtDevice>> devices, int host_id,
std::unique_ptr<se::DeviceMemoryAllocator> allocator,
std::unique_ptr<tensorflow::Allocator> host_memory_allocator,
bool should_stage_host_to_device_transfers,
std::unique_ptr<GpuExecutableRunOptions> gpu_run_options)
: platform_id_(platform_id),
platform_name_(Name(platform_id)),
: platform_id_(tensorflow::Fingerprint64(platform_name)),
platform_name_(std::move(platform_name)),
client_(client),
host_memory_allocator_(std::move(host_memory_allocator)),
devices_(std::move(devices)),
@ -528,9 +529,7 @@ void PjRtBuffer::ScopedHold::AddToInput(
}
}
bool PjRtBuffer::IsOnCpu() const {
return client()->platform_id() == PjRtPlatformId::kCpu;
}
bool PjRtBuffer::IsOnCpu() const { return client()->platform_id() == kCpuId; }
StatusOr<std::unique_ptr<PjRtBuffer>> PjRtClient::BufferFromHostBuffer(
const void* data, const Shape& shape,

View File

@ -45,39 +45,23 @@ limitations under the License.
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/fingerprint.h"
#include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/core/platform/types.h"
// API notes:
// PjRt stands for "Pretty much Just another RunTime".
namespace xla {
// TODO(zhangqiaorjc): Add a registration mechanism to add new platforms.
enum class PjRtPlatformId : int {
kCpu = 0,
kNvidiaGpu = 1,
kAmdGpu = 2,
kTpu = 3,
kEdgeTpu = 4,
kInterpreter = 5
};
constexpr const char* Name(PjRtPlatformId platform_id) {
switch (platform_id) {
case PjRtPlatformId::kCpu:
return "cpu";
case PjRtPlatformId::kNvidiaGpu:
// TODO(zhangqiaorjc): Rename to nvidia_gpu when we add AMD support.
return "gpu";
case PjRtPlatformId::kAmdGpu:
return "amd_gpu";
case PjRtPlatformId::kTpu:
return "tpu";
case PjRtPlatformId::kEdgeTpu:
return "edge_tpu";
case PjRtPlatformId::kInterpreter:
return "interpreter";
}
}
using PjRtPlatformId = uint64;
constexpr char kCpuName[] = "cpu";
constexpr char kGpuName[] = "gpu";
constexpr char kTpuName[] = "tpu";
static const PjRtPlatformId kCpuId = tensorflow::Fingerprint64(kCpuName);
static const PjRtPlatformId kGpuId = tensorflow::Fingerprint64(kGpuName);
static const PjRtPlatformId kTpuId = tensorflow::Fingerprint64(kTpuName);
class PjRtClient;
@ -196,7 +180,7 @@ class PjRtClient {
public:
// `allocator` may null, in which case the platform default allocator is used.
explicit PjRtClient(
PjRtPlatformId platform_id, LocalClient* client,
std::string platform_name, LocalClient* client,
std::vector<std::unique_ptr<PjRtDevice>> devices, int host_id,
std::unique_ptr<se::DeviceMemoryAllocator> allocator,
std::unique_ptr<tensorflow::Allocator> host_memory_allocator,

View File

@ -118,7 +118,7 @@ PjRtTpuClient::PjRtTpuClient(LocalClient* client,
std::vector<std::unique_ptr<PjRtDevice>> devices,
int host_id,
tf_tpu::TpuPlatformInterface* tpu_platform)
: PjRtClient(PjRtPlatformId::kTpu, client, std::move(devices), host_id,
: PjRtClient(kTpuName, client, std::move(devices), host_id,
/*allocator=*/nullptr,
/*host_memory_allocator=*/nullptr,
/*should_stage_host_to_device_transfers=*/false,
@ -145,7 +145,7 @@ StatusOr<absl::optional<std::string>> PjRtTpuClient::ExecutableFingerprint(
return InvalidArgument(
"Passed executable from different client (platform '%s') to "
"PjRtTpuClient::ExecutableFingerprint",
Name(executable.client()->platform_id()));
executable.client()->platform_name());
}
if (executable.executables().size() > 1) {
LOG(INFO) << "ExecutableFingerprint not fully implemented for MPMD "

View File

@ -236,14 +236,14 @@ StatusOr<PjRtDevice*> DeviceForDLContext(const PjRtClient& client,
const DLContext& context) {
switch (context.device_type) {
case kDLCPU:
if (client.platform_id() != PjRtPlatformId::kCpu) {
if (client.platform_id() != kCpuId) {
return InvalidArgument(
"DLPack CPU device type mismatch with PjRtClient platform %s",
client.platform_name());
}
return client.LookupLocalDevice(context.device_id);
case kDLGPU:
if (client.platform_id() != PjRtPlatformId::kNvidiaGpu) {
if (client.platform_id() != kGpuId) {
return InvalidArgument(
"DLPack GPU device type mismatch with PjRtClient platform %s",
client.platform_name());