Enable TPU POD for JAX/1VM by creating devices and local_devices topology.

PiperOrigin-RevId: 289761259
Change-Id: Icbdcca91fd37ea0a04ad16df82aede52e3281ed9
This commit is contained in:
Henry Tan 2020-01-14 17:01:13 -08:00 committed by TensorFlower Gardener
parent 3dec91764c
commit 79cea35560
4 changed files with 79 additions and 31 deletions

View File

@ -34,14 +34,34 @@ limitations under the License.
namespace xla { namespace xla {
constexpr char kTpuPlatform[] = "tpu";
TpuDevice::TpuDevice(int id, int host_id, const std::array<int, 3>& coords,
int core_on_chip)
: xla::Device(id, /*local_device_state=*/nullptr, kTpuPlatform, host_id),
coords_(coords),
core_on_chip_(core_on_chip) {}
std::string TpuDevice::DebugString() const { std::string TpuDevice::DebugString() const {
return absl::StrCat("TPU_", id()); return absl::StrFormat("TPU_%i(host=%i,(%i,%i,%i,%i))", id(), host_id(),
coords_[0], coords_[1], coords_[2], core_on_chip_);
} }
static std::shared_ptr<Device> MakeDevice(const std::string& platform_name, xla::StatusOr<std::vector<std::shared_ptr<xla::Device>>>
int id) { TpuDevice::GetTpuDevices(const tpu_driver::SystemInfo& system_info) {
CHECK_EQ(platform_name, "tpu"); std::vector<std::shared_ptr<Device>> devices;
return std::make_shared<TpuDevice>(id, /*local_device_state=*/nullptr, "tpu"); for (const auto& chip : system_info.tpu_chip()) {
auto& coord = chip.chip_coord();
std::array<int, 3> coords_array = {coord.x(), coord.y(), coord.z()};
int host_id = chip.host_id();
for (const auto& core : chip.core()) {
auto device = std::make_shared<TpuDevice>(
core.id(), host_id, coords_array, core.core_on_chip_index());
devices.push_back(device);
}
}
return devices;
} }
StatusOr<std::shared_ptr<PyTpuClient>> PyTpuClient::Get( StatusOr<std::shared_ptr<PyTpuClient>> PyTpuClient::Get(
@ -49,7 +69,6 @@ StatusOr<std::shared_ptr<PyTpuClient>> PyTpuClient::Get(
tpu_driver::TpuDriverConfig driver_config; tpu_driver::TpuDriverConfig driver_config;
driver_config.set_worker(worker); driver_config.set_worker(worker);
auto client_status = tpu_driver::TpuDriverRegistry::Open(driver_config); auto client_status = tpu_driver::TpuDriverRegistry::Open(driver_config);
if (!client_status.ok()) { if (!client_status.ok()) {
return client_status.status(); return client_status.status();
} }
@ -58,19 +77,13 @@ StatusOr<std::shared_ptr<PyTpuClient>> PyTpuClient::Get(
tpu_driver::SystemInfo system_info; tpu_driver::SystemInfo system_info;
client->QuerySystemInfo(&system_info); client->QuerySystemInfo(&system_info);
int num_cores =
system_info.tpu_chip_size() * system_info.tpu_chip(0).core_size();
std::vector<std::shared_ptr<Device>> devices; TF_ASSIGN_OR_RETURN(std::vector<std::shared_ptr<Device>> devices,
CHECK_GE(num_cores, 1); TpuDevice::GetTpuDevices(system_info));
LOG(INFO) << "Creating " << num_cores << " TPU device(s).";
devices.reserve(num_cores);
for (int i = 0; i < num_cores; ++i) {
devices.push_back(MakeDevice("tpu", i));
}
return std::make_shared<PyTpuClient>("tpu", std::move(client), return std::make_shared<PyTpuClient>(kTpuPlatform, std::move(client),
std::move(devices), /*host_id=*/0); std::move(devices),
system_info.host_id());
} }
PyTpuClient::PyTpuClient(std::string platform_name, PyTpuClient::PyTpuClient(std::string platform_name,
@ -81,18 +94,21 @@ PyTpuClient::PyTpuClient(std::string platform_name,
driver_(std::move(driver)), driver_(std::move(driver)),
devices_(std::move(devices)), devices_(std::move(devices)),
host_id_(host_id) { host_id_(host_id) {
local_devices_.resize(devices_.size());
for (const std::shared_ptr<Device>& device : devices_) { for (const std::shared_ptr<Device>& device : devices_) {
CHECK(id_to_device_.insert({device->id(), device}).second) CHECK(id_to_device_.insert({device->id(), device}).second)
<< "Duplicate device id: " << device->id(); << "Duplicate device id: " << device->id();
if (device->id() != -1) { if (device->host_id() == host_id_) {
int idx = device->id(); LOG(INFO) << "Detected local device, host-id: " << host_id_
CHECK(local_devices_[idx] == nullptr) << idx; << ". core-id: " << device->id();
CHECK_LT(idx, local_devices_.size()); local_devices_.push_back(device);
local_devices_[idx] = device; } else {
VLOG(2) << "Other devices, id: " << device->id();
} }
} }
CHECK_GE(local_devices_.size(), 1);
LOG(INFO) << "Creating " << local_devices_.size() << " TPU device(s).";
for (int idx = 0; idx < local_devices_.size(); ++idx) { for (int idx = 0; idx < local_devices_.size(); ++idx) {
CHECK(local_devices_[idx] != nullptr) << idx; CHECK(local_devices_[idx] != nullptr) << idx;
} }
@ -217,8 +233,8 @@ StatusOr<std::unique_ptr<PyTpuBuffer>> PyTpuBuffer::MakeTuple(
std::shared_ptr<TpuSharedBuffer> child_device_buffer = std::shared_ptr<TpuSharedBuffer> child_device_buffer =
child_buffer->DeviceBuffer(); child_buffer->DeviceBuffer();
// Merge all definition events from all children, so that anyone using this // Merge all definition events from all children, so that anyone using this
// tuple must wait for all its children to finish receiving transfers. // tuple must wait for all its children to finish receiving transfers. This
// This works recursively up a nested tuple tree as well. // works recursively up a nested tuple tree as well.
for (std::shared_ptr<tpu_driver::Event> child_event : for (std::shared_ptr<tpu_driver::Event> child_event :
child_device_buffer->wait_for_use) { child_device_buffer->wait_for_use) {
child_events.push_back(std::move(child_event)); child_events.push_back(std::move(child_event));

View File

@ -38,8 +38,21 @@ namespace xla {
class TpuDevice : public Device { class TpuDevice : public Device {
public: public:
using Device::Device; TpuDevice(int id, int host_id, const std::array<int, 3>& coords,
int core_on_chip);
const std::array<int, 3>& coords() const { return coords_; }
int core_on_chip() const { return core_on_chip_; }
std::string DebugString() const override; std::string DebugString() const override;
static xla::StatusOr<std::vector<std::shared_ptr<xla::Device>>> GetTpuDevices(
const tpu_driver::SystemInfo& system_info);
private:
const std::array<int, 3> coords_;
// Index of the core of the same chip.
int core_on_chip_;
}; };
// Encapsulates the state of Python session with XLA. // Encapsulates the state of Python session with XLA.
@ -50,7 +63,7 @@ class PyTpuClient {
static StatusOr<std::shared_ptr<PyTpuClient>> Get(const std::string& worker); static StatusOr<std::shared_ptr<PyTpuClient>> Get(const std::string& worker);
explicit PyTpuClient(std::string platform_name, explicit PyTpuClient(std::string platform_name,
std::unique_ptr<tpu_driver::TpuDriver> client, std::unique_ptr<tpu_driver::TpuDriver> driver,
std::vector<std::shared_ptr<Device>> devices, std::vector<std::shared_ptr<Device>> devices,
int host_id); int host_id);
virtual ~PyTpuClient() = default; virtual ~PyTpuClient() = default;

View File

@ -206,8 +206,13 @@ PYBIND11_MODULE(tpu_client_extension, m) {
py::call_guard<py::gil_scoped_release>(), py::arg("arguments")); py::call_guard<py::gil_scoped_release>(), py::arg("arguments"));
py::class_<TpuDevice, Device, std::shared_ptr<TpuDevice>>(m, "TpuDevice") py::class_<TpuDevice, Device, std::shared_ptr<TpuDevice>>(m, "TpuDevice")
.def_property_readonly("coords", &TpuDevice::coords)
.def_property_readonly("core_on_chip", &TpuDevice::core_on_chip)
.def("__repr__", [](const TpuDevice& device) { .def("__repr__", [](const TpuDevice& device) {
return absl::StrFormat("TpuDevice(id=%i)", device.id()); return absl::StrFormat(
"TpuDevice(id=%i, host_id=%i, coords=(%i,%i,%i), core_on_chip=%i)",
device.id(), device.host_id(), device.coords()[0],
device.coords()[1], device.coords()[2], device.core_on_chip());
}); });
} // NOLINT(readability/fn_size) } // NOLINT(readability/fn_size)

View File

@ -19,15 +19,24 @@ package tpu_driver;
enum MemoryRegion { HBM = 1; } enum MemoryRegion { HBM = 1; }
message ChipCoordinate {
required int32 x = 1;
required int32 y = 2;
required int32 z = 3;
}
message TpuCoreInfo { message TpuCoreInfo {
required int32 id = 1; required int32 id = 1;
optional int32 core_on_chip_index = 2;
required int64 hbm_bytes_available = 100; optional int32 core_on_host_index = 3;
required int64 hbm_bytes_allocatable = 101; optional int64 hbm_bytes_available = 100;
optional int64 hbm_bytes_allocatable = 101;
} }
message TpuChipInfo { message TpuChipInfo {
repeated TpuCoreInfo core = 1; repeated TpuCoreInfo core = 1;
optional int32 host_id = 2;
optional ChipCoordinate chip_coord = 3;
} }
message CpuInfo { message CpuInfo {
@ -40,6 +49,11 @@ message CpuInfo {
message SystemInfo { message SystemInfo {
repeated TpuChipInfo tpu_chip = 1; repeated TpuChipInfo tpu_chip = 1;
required CpuInfo cpu = 2; required CpuInfo cpu = 2;
repeated TpuCoreInfo local_core = 3;
optional int32 host_id = 4;
optional int32 host_count = 5;
optional int32 chip_count = 6;
optional int32 core_count = 7;
} }
message TpuDriverConfig { message TpuDriverConfig {