[XLA:Python] Add Device.platform field.

This can be used to associate a Device with a particular Backend. We
use the platform string instead of having a reference to the Backend
itself since Backend in defined in Python and Device is defined in
C++.

The motivation for this is to allow JAX to derive the correct Backend
to use from a Device.

PiperOrigin-RevId: 281858391
Change-Id: If329f5723e6eefcef1bafddc53b8a58150168cf7
This commit is contained in:
Skye Wanderman-Milne 2019-11-21 16:26:23 -08:00 committed by TensorFlower Gardener
parent 2b4b547e8d
commit 72e362bba7
5 changed files with 39 additions and 7 deletions

View File

@ -157,10 +157,10 @@ static StatusOr<std::unique_ptr<se::MultiDeviceAdapter>> CreateBFCAllocator(
static std::shared_ptr<Device> MakeDevice(const std::string& platform_name,
int id, int local_device_ordinal) {
if (platform_name == "cpu") {
return std::make_shared<CpuDevice>(id, local_device_ordinal);
return std::make_shared<CpuDevice>(id, local_device_ordinal, platform_name);
} else {
CHECK_EQ(platform_name, "gpu");
return std::make_shared<GpuDevice>(id, local_device_ordinal);
return std::make_shared<GpuDevice>(id, local_device_ordinal, platform_name);
}
}

View File

@ -43,10 +43,12 @@ class PyLocalExecutable;
class Device {
public:
explicit Device(int id, int local_device_ordinal, int host_id = 0)
explicit Device(int id, int local_device_ordinal,
absl::string_view platform_name, int host_id = 0)
: id_(id),
local_device_ordinal_(local_device_ordinal),
host_id_(host_id) {}
host_id_(host_id),
platform_name_(platform_name) {}
virtual ~Device() {}
// The ID of this device. IDs are unique among devices of this type
@ -65,12 +67,15 @@ class Device {
// The ID of this device's host. This is always 0 on single-host platforms.
int host_id() const { return host_id_; }
const std::string& platform_name() const { return platform_name_; }
virtual std::string DebugString() const = 0;
private:
const int id_;
const int local_device_ordinal_;
const int host_id_;
const std::string platform_name_;
};
class CpuDevice : public Device {

View File

@ -42,7 +42,7 @@ static std::shared_ptr<Device> MakeDevice(const std::string& platform_name,
int id, int local_device_ordinal) {
CHECK_EQ(platform_name, "tpu");
CHECK_EQ(id, local_device_ordinal); // Every device must be local for now.
return std::make_shared<TpuDevice>(id, local_device_ordinal);
return std::make_shared<TpuDevice>(id, local_device_ordinal, "tpu");
}
StatusOr<std::shared_ptr<PyTpuClient>> PyTpuClient::Get(

View File

@ -61,6 +61,12 @@ PYBIND11_MODULE(tpu_client_extension, m) {
std::shared_ptr<Device> device)
-> StatusOr<std::unique_ptr<PyTpuBuffer>> {
CHECK(device != nullptr);
auto iter = client->id_to_device().find(device->id());
if (iter->second != device) {
return InvalidArgument(
"Cannot copy value to device '%s' with '%s' backend",
device->DebugString(), client->platform_name());
}
GlobalPyRefManager()->CollectGarbage();
TF_ASSIGN_OR_RETURN(PythonBufferTree tree,
GetPythonBufferTree(argument));
@ -105,8 +111,15 @@ PYBIND11_MODULE(tpu_client_extension, m) {
.def_static("make_tuple",
[](const std::vector<PyTpuBuffer*> buffers,
std::shared_ptr<PyTpuClient> client,
std::shared_ptr<Device> device) {
std::shared_ptr<Device> device)
-> StatusOr<std::unique_ptr<PyTpuBuffer>> {
CHECK(device != nullptr);
auto iter = client->id_to_device().find(device->id());
if (iter->second != device) {
return InvalidArgument(
"Cannot make tuple on device '%s' with '%s' backend",
device->DebugString(), client->platform_name());
}
return PyTpuBuffer::MakeTuple(
buffers, client, device->local_device_ordinal());
})

View File

@ -318,6 +318,7 @@ PYBIND11_MODULE(xla_extension, m) {
.def_property_readonly("host_id", &Device::host_id,
"Integer ID of this device's host.\n\n"
"This is always 0 except on multi-host platforms.")
.def_property_readonly("platform", &Device::platform_name)
.def("__str__", &Device::DebugString);
py::class_<CpuDevice, Device, std::shared_ptr<CpuDevice>>(m, "CpuDevice")
@ -391,6 +392,12 @@ PYBIND11_MODULE(xla_extension, m) {
std::shared_ptr<Device> device)
-> StatusOr<std::unique_ptr<PyLocalBuffer>> {
CHECK(device != nullptr);
auto iter = client->id_to_device().find(device->id());
if (iter->second != device) {
return InvalidArgument(
"Cannot copy value to device '%s' with '%s' backend",
device->DebugString(), client->platform_name());
}
GlobalPyRefManager()->CollectGarbage();
TF_ASSIGN_OR_RETURN(PythonBufferTree tree,
GetPythonBufferTree(argument));
@ -436,8 +443,15 @@ PYBIND11_MODULE(xla_extension, m) {
.def_static("make_tuple",
[](const std::vector<PyLocalBuffer*> buffers,
std::shared_ptr<PyLocalClient> client,
std::shared_ptr<Device> device) {
std::shared_ptr<Device> device)
-> StatusOr<std::unique_ptr<PyLocalBuffer>> {
CHECK(device != nullptr);
auto iter = client->id_to_device().find(device->id());
if (iter->second != device) {
return InvalidArgument(
"Cannot make tuple on device '%s' with '%s' backend",
device->DebugString(), client->platform_name());
}
return PyLocalBuffer::MakeTuple(
buffers, client, device->local_device_ordinal());
})