[XLA:Python] Add Device.platform field.
This can be used to associate a Device with a particular Backend. We use the platform string instead of having a reference to the Backend itself since Backend in defined in Python and Device is defined in C++. The motivation for this is to allow JAX to derive the correct Backend to use from a Device. PiperOrigin-RevId: 281858391 Change-Id: If329f5723e6eefcef1bafddc53b8a58150168cf7
This commit is contained in:
parent
2b4b547e8d
commit
72e362bba7
@ -157,10 +157,10 @@ static StatusOr<std::unique_ptr<se::MultiDeviceAdapter>> CreateBFCAllocator(
|
||||
static std::shared_ptr<Device> MakeDevice(const std::string& platform_name,
|
||||
int id, int local_device_ordinal) {
|
||||
if (platform_name == "cpu") {
|
||||
return std::make_shared<CpuDevice>(id, local_device_ordinal);
|
||||
return std::make_shared<CpuDevice>(id, local_device_ordinal, platform_name);
|
||||
} else {
|
||||
CHECK_EQ(platform_name, "gpu");
|
||||
return std::make_shared<GpuDevice>(id, local_device_ordinal);
|
||||
return std::make_shared<GpuDevice>(id, local_device_ordinal, platform_name);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -43,10 +43,12 @@ class PyLocalExecutable;
|
||||
|
||||
class Device {
|
||||
public:
|
||||
explicit Device(int id, int local_device_ordinal, int host_id = 0)
|
||||
explicit Device(int id, int local_device_ordinal,
|
||||
absl::string_view platform_name, int host_id = 0)
|
||||
: id_(id),
|
||||
local_device_ordinal_(local_device_ordinal),
|
||||
host_id_(host_id) {}
|
||||
host_id_(host_id),
|
||||
platform_name_(platform_name) {}
|
||||
virtual ~Device() {}
|
||||
|
||||
// The ID of this device. IDs are unique among devices of this type
|
||||
@ -65,12 +67,15 @@ class Device {
|
||||
// The ID of this device's host. This is always 0 on single-host platforms.
|
||||
int host_id() const { return host_id_; }
|
||||
|
||||
const std::string& platform_name() const { return platform_name_; }
|
||||
|
||||
virtual std::string DebugString() const = 0;
|
||||
|
||||
private:
|
||||
const int id_;
|
||||
const int local_device_ordinal_;
|
||||
const int host_id_;
|
||||
const std::string platform_name_;
|
||||
};
|
||||
|
||||
class CpuDevice : public Device {
|
||||
|
@ -42,7 +42,7 @@ static std::shared_ptr<Device> MakeDevice(const std::string& platform_name,
|
||||
int id, int local_device_ordinal) {
|
||||
CHECK_EQ(platform_name, "tpu");
|
||||
CHECK_EQ(id, local_device_ordinal); // Every device must be local for now.
|
||||
return std::make_shared<TpuDevice>(id, local_device_ordinal);
|
||||
return std::make_shared<TpuDevice>(id, local_device_ordinal, "tpu");
|
||||
}
|
||||
|
||||
StatusOr<std::shared_ptr<PyTpuClient>> PyTpuClient::Get(
|
||||
|
@ -61,6 +61,12 @@ PYBIND11_MODULE(tpu_client_extension, m) {
|
||||
std::shared_ptr<Device> device)
|
||||
-> StatusOr<std::unique_ptr<PyTpuBuffer>> {
|
||||
CHECK(device != nullptr);
|
||||
auto iter = client->id_to_device().find(device->id());
|
||||
if (iter->second != device) {
|
||||
return InvalidArgument(
|
||||
"Cannot copy value to device '%s' with '%s' backend",
|
||||
device->DebugString(), client->platform_name());
|
||||
}
|
||||
GlobalPyRefManager()->CollectGarbage();
|
||||
TF_ASSIGN_OR_RETURN(PythonBufferTree tree,
|
||||
GetPythonBufferTree(argument));
|
||||
@ -105,8 +111,15 @@ PYBIND11_MODULE(tpu_client_extension, m) {
|
||||
.def_static("make_tuple",
|
||||
[](const std::vector<PyTpuBuffer*> buffers,
|
||||
std::shared_ptr<PyTpuClient> client,
|
||||
std::shared_ptr<Device> device) {
|
||||
std::shared_ptr<Device> device)
|
||||
-> StatusOr<std::unique_ptr<PyTpuBuffer>> {
|
||||
CHECK(device != nullptr);
|
||||
auto iter = client->id_to_device().find(device->id());
|
||||
if (iter->second != device) {
|
||||
return InvalidArgument(
|
||||
"Cannot make tuple on device '%s' with '%s' backend",
|
||||
device->DebugString(), client->platform_name());
|
||||
}
|
||||
return PyTpuBuffer::MakeTuple(
|
||||
buffers, client, device->local_device_ordinal());
|
||||
})
|
||||
|
@ -318,6 +318,7 @@ PYBIND11_MODULE(xla_extension, m) {
|
||||
.def_property_readonly("host_id", &Device::host_id,
|
||||
"Integer ID of this device's host.\n\n"
|
||||
"This is always 0 except on multi-host platforms.")
|
||||
.def_property_readonly("platform", &Device::platform_name)
|
||||
.def("__str__", &Device::DebugString);
|
||||
|
||||
py::class_<CpuDevice, Device, std::shared_ptr<CpuDevice>>(m, "CpuDevice")
|
||||
@ -391,6 +392,12 @@ PYBIND11_MODULE(xla_extension, m) {
|
||||
std::shared_ptr<Device> device)
|
||||
-> StatusOr<std::unique_ptr<PyLocalBuffer>> {
|
||||
CHECK(device != nullptr);
|
||||
auto iter = client->id_to_device().find(device->id());
|
||||
if (iter->second != device) {
|
||||
return InvalidArgument(
|
||||
"Cannot copy value to device '%s' with '%s' backend",
|
||||
device->DebugString(), client->platform_name());
|
||||
}
|
||||
GlobalPyRefManager()->CollectGarbage();
|
||||
TF_ASSIGN_OR_RETURN(PythonBufferTree tree,
|
||||
GetPythonBufferTree(argument));
|
||||
@ -436,8 +443,15 @@ PYBIND11_MODULE(xla_extension, m) {
|
||||
.def_static("make_tuple",
|
||||
[](const std::vector<PyLocalBuffer*> buffers,
|
||||
std::shared_ptr<PyLocalClient> client,
|
||||
std::shared_ptr<Device> device) {
|
||||
std::shared_ptr<Device> device)
|
||||
-> StatusOr<std::unique_ptr<PyLocalBuffer>> {
|
||||
CHECK(device != nullptr);
|
||||
auto iter = client->id_to_device().find(device->id());
|
||||
if (iter->second != device) {
|
||||
return InvalidArgument(
|
||||
"Cannot make tuple on device '%s' with '%s' backend",
|
||||
device->DebugString(), client->platform_name());
|
||||
}
|
||||
return PyLocalBuffer::MakeTuple(
|
||||
buffers, client, device->local_device_ordinal());
|
||||
})
|
||||
|
Loading…
Reference in New Issue
Block a user