[XLA:Python] Add Device.platform field.

This can be used to associate a Device with a particular Backend. We
use the platform string instead of having a reference to the Backend
itself since Backend in defined in Python and Device is defined in
C++.

The motivation for this is to allow JAX to derive the correct Backend
to use from a Device.

PiperOrigin-RevId: 281858391
Change-Id: If329f5723e6eefcef1bafddc53b8a58150168cf7
This commit is contained in:
Skye Wanderman-Milne 2019-11-21 16:26:23 -08:00 committed by TensorFlower Gardener
parent 2b4b547e8d
commit 72e362bba7
5 changed files with 39 additions and 7 deletions

View File

@ -157,10 +157,10 @@ static StatusOr<std::unique_ptr<se::MultiDeviceAdapter>> CreateBFCAllocator(
static std::shared_ptr<Device> MakeDevice(const std::string& platform_name, static std::shared_ptr<Device> MakeDevice(const std::string& platform_name,
int id, int local_device_ordinal) { int id, int local_device_ordinal) {
if (platform_name == "cpu") { if (platform_name == "cpu") {
return std::make_shared<CpuDevice>(id, local_device_ordinal); return std::make_shared<CpuDevice>(id, local_device_ordinal, platform_name);
} else { } else {
CHECK_EQ(platform_name, "gpu"); CHECK_EQ(platform_name, "gpu");
return std::make_shared<GpuDevice>(id, local_device_ordinal); return std::make_shared<GpuDevice>(id, local_device_ordinal, platform_name);
} }
} }

View File

@ -43,10 +43,12 @@ class PyLocalExecutable;
class Device { class Device {
public: public:
explicit Device(int id, int local_device_ordinal, int host_id = 0) explicit Device(int id, int local_device_ordinal,
absl::string_view platform_name, int host_id = 0)
: id_(id), : id_(id),
local_device_ordinal_(local_device_ordinal), local_device_ordinal_(local_device_ordinal),
host_id_(host_id) {} host_id_(host_id),
platform_name_(platform_name) {}
virtual ~Device() {} virtual ~Device() {}
// The ID of this device. IDs are unique among devices of this type // The ID of this device. IDs are unique among devices of this type
@ -65,12 +67,15 @@ class Device {
// The ID of this device's host. This is always 0 on single-host platforms. // The ID of this device's host. This is always 0 on single-host platforms.
int host_id() const { return host_id_; } int host_id() const { return host_id_; }
const std::string& platform_name() const { return platform_name_; }
virtual std::string DebugString() const = 0; virtual std::string DebugString() const = 0;
private: private:
const int id_; const int id_;
const int local_device_ordinal_; const int local_device_ordinal_;
const int host_id_; const int host_id_;
const std::string platform_name_;
}; };
class CpuDevice : public Device { class CpuDevice : public Device {

View File

@ -42,7 +42,7 @@ static std::shared_ptr<Device> MakeDevice(const std::string& platform_name,
int id, int local_device_ordinal) { int id, int local_device_ordinal) {
CHECK_EQ(platform_name, "tpu"); CHECK_EQ(platform_name, "tpu");
CHECK_EQ(id, local_device_ordinal); // Every device must be local for now. CHECK_EQ(id, local_device_ordinal); // Every device must be local for now.
return std::make_shared<TpuDevice>(id, local_device_ordinal); return std::make_shared<TpuDevice>(id, local_device_ordinal, "tpu");
} }
StatusOr<std::shared_ptr<PyTpuClient>> PyTpuClient::Get( StatusOr<std::shared_ptr<PyTpuClient>> PyTpuClient::Get(

View File

@ -61,6 +61,12 @@ PYBIND11_MODULE(tpu_client_extension, m) {
std::shared_ptr<Device> device) std::shared_ptr<Device> device)
-> StatusOr<std::unique_ptr<PyTpuBuffer>> { -> StatusOr<std::unique_ptr<PyTpuBuffer>> {
CHECK(device != nullptr); CHECK(device != nullptr);
auto iter = client->id_to_device().find(device->id());
if (iter->second != device) {
return InvalidArgument(
"Cannot copy value to device '%s' with '%s' backend",
device->DebugString(), client->platform_name());
}
GlobalPyRefManager()->CollectGarbage(); GlobalPyRefManager()->CollectGarbage();
TF_ASSIGN_OR_RETURN(PythonBufferTree tree, TF_ASSIGN_OR_RETURN(PythonBufferTree tree,
GetPythonBufferTree(argument)); GetPythonBufferTree(argument));
@ -105,8 +111,15 @@ PYBIND11_MODULE(tpu_client_extension, m) {
.def_static("make_tuple", .def_static("make_tuple",
[](const std::vector<PyTpuBuffer*> buffers, [](const std::vector<PyTpuBuffer*> buffers,
std::shared_ptr<PyTpuClient> client, std::shared_ptr<PyTpuClient> client,
std::shared_ptr<Device> device) { std::shared_ptr<Device> device)
-> StatusOr<std::unique_ptr<PyTpuBuffer>> {
CHECK(device != nullptr); CHECK(device != nullptr);
auto iter = client->id_to_device().find(device->id());
if (iter->second != device) {
return InvalidArgument(
"Cannot make tuple on device '%s' with '%s' backend",
device->DebugString(), client->platform_name());
}
return PyTpuBuffer::MakeTuple( return PyTpuBuffer::MakeTuple(
buffers, client, device->local_device_ordinal()); buffers, client, device->local_device_ordinal());
}) })

View File

@ -318,6 +318,7 @@ PYBIND11_MODULE(xla_extension, m) {
.def_property_readonly("host_id", &Device::host_id, .def_property_readonly("host_id", &Device::host_id,
"Integer ID of this device's host.\n\n" "Integer ID of this device's host.\n\n"
"This is always 0 except on multi-host platforms.") "This is always 0 except on multi-host platforms.")
.def_property_readonly("platform", &Device::platform_name)
.def("__str__", &Device::DebugString); .def("__str__", &Device::DebugString);
py::class_<CpuDevice, Device, std::shared_ptr<CpuDevice>>(m, "CpuDevice") py::class_<CpuDevice, Device, std::shared_ptr<CpuDevice>>(m, "CpuDevice")
@ -391,6 +392,12 @@ PYBIND11_MODULE(xla_extension, m) {
std::shared_ptr<Device> device) std::shared_ptr<Device> device)
-> StatusOr<std::unique_ptr<PyLocalBuffer>> { -> StatusOr<std::unique_ptr<PyLocalBuffer>> {
CHECK(device != nullptr); CHECK(device != nullptr);
auto iter = client->id_to_device().find(device->id());
if (iter->second != device) {
return InvalidArgument(
"Cannot copy value to device '%s' with '%s' backend",
device->DebugString(), client->platform_name());
}
GlobalPyRefManager()->CollectGarbage(); GlobalPyRefManager()->CollectGarbage();
TF_ASSIGN_OR_RETURN(PythonBufferTree tree, TF_ASSIGN_OR_RETURN(PythonBufferTree tree,
GetPythonBufferTree(argument)); GetPythonBufferTree(argument));
@ -436,8 +443,15 @@ PYBIND11_MODULE(xla_extension, m) {
.def_static("make_tuple", .def_static("make_tuple",
[](const std::vector<PyLocalBuffer*> buffers, [](const std::vector<PyLocalBuffer*> buffers,
std::shared_ptr<PyLocalClient> client, std::shared_ptr<PyLocalClient> client,
std::shared_ptr<Device> device) { std::shared_ptr<Device> device)
-> StatusOr<std::unique_ptr<PyLocalBuffer>> {
CHECK(device != nullptr); CHECK(device != nullptr);
auto iter = client->id_to_device().find(device->id());
if (iter->second != device) {
return InvalidArgument(
"Cannot make tuple on device '%s' with '%s' backend",
device->DebugString(), client->platform_name());
}
return PyLocalBuffer::MakeTuple( return PyLocalBuffer::MakeTuple(
buffers, client, device->local_device_ordinal()); buffers, client, device->local_device_ordinal());
}) })