Enable TPU POD for JAX/1VM by creating devices and local_devices topology.

PiperOrigin-RevId: 289761259
Change-Id: Icbdcca91fd37ea0a04ad16df82aede52e3281ed9
This commit is contained in:
Henry Tan 2020-01-14 17:01:13 -08:00 committed by TensorFlower Gardener
parent 3dec91764c
commit 79cea35560
4 changed files with 79 additions and 31 deletions

View File

@ -34,14 +34,34 @@ limitations under the License.
namespace xla {
constexpr char kTpuPlatform[] = "tpu";
TpuDevice::TpuDevice(int id, int host_id, const std::array<int, 3>& coords,
int core_on_chip)
: xla::Device(id, /*local_device_state=*/nullptr, kTpuPlatform, host_id),
coords_(coords),
core_on_chip_(core_on_chip) {}
std::string TpuDevice::DebugString() const {
return absl::StrCat("TPU_", id());
return absl::StrFormat("TPU_%i(host=%i,(%i,%i,%i,%i))", id(), host_id(),
coords_[0], coords_[1], coords_[2], core_on_chip_);
}
static std::shared_ptr<Device> MakeDevice(const std::string& platform_name,
int id) {
CHECK_EQ(platform_name, "tpu");
return std::make_shared<TpuDevice>(id, /*local_device_state=*/nullptr, "tpu");
xla::StatusOr<std::vector<std::shared_ptr<xla::Device>>>
TpuDevice::GetTpuDevices(const tpu_driver::SystemInfo& system_info) {
std::vector<std::shared_ptr<Device>> devices;
for (const auto& chip : system_info.tpu_chip()) {
auto& coord = chip.chip_coord();
std::array<int, 3> coords_array = {coord.x(), coord.y(), coord.z()};
int host_id = chip.host_id();
for (const auto& core : chip.core()) {
auto device = std::make_shared<TpuDevice>(
core.id(), host_id, coords_array, core.core_on_chip_index());
devices.push_back(device);
}
}
return devices;
}
StatusOr<std::shared_ptr<PyTpuClient>> PyTpuClient::Get(
@ -49,7 +69,6 @@ StatusOr<std::shared_ptr<PyTpuClient>> PyTpuClient::Get(
tpu_driver::TpuDriverConfig driver_config;
driver_config.set_worker(worker);
auto client_status = tpu_driver::TpuDriverRegistry::Open(driver_config);
if (!client_status.ok()) {
return client_status.status();
}
@ -58,19 +77,13 @@ StatusOr<std::shared_ptr<PyTpuClient>> PyTpuClient::Get(
tpu_driver::SystemInfo system_info;
client->QuerySystemInfo(&system_info);
int num_cores =
system_info.tpu_chip_size() * system_info.tpu_chip(0).core_size();
std::vector<std::shared_ptr<Device>> devices;
CHECK_GE(num_cores, 1);
LOG(INFO) << "Creating " << num_cores << " TPU device(s).";
devices.reserve(num_cores);
for (int i = 0; i < num_cores; ++i) {
devices.push_back(MakeDevice("tpu", i));
}
TF_ASSIGN_OR_RETURN(std::vector<std::shared_ptr<Device>> devices,
TpuDevice::GetTpuDevices(system_info));
return std::make_shared<PyTpuClient>("tpu", std::move(client),
std::move(devices), /*host_id=*/0);
return std::make_shared<PyTpuClient>(kTpuPlatform, std::move(client),
std::move(devices),
system_info.host_id());
}
PyTpuClient::PyTpuClient(std::string platform_name,
@ -81,18 +94,21 @@ PyTpuClient::PyTpuClient(std::string platform_name,
driver_(std::move(driver)),
devices_(std::move(devices)),
host_id_(host_id) {
local_devices_.resize(devices_.size());
for (const std::shared_ptr<Device>& device : devices_) {
CHECK(id_to_device_.insert({device->id(), device}).second)
<< "Duplicate device id: " << device->id();
if (device->id() != -1) {
int idx = device->id();
CHECK(local_devices_[idx] == nullptr) << idx;
CHECK_LT(idx, local_devices_.size());
local_devices_[idx] = device;
if (device->host_id() == host_id_) {
LOG(INFO) << "Detected local device, host-id: " << host_id_
<< ". core-id: " << device->id();
local_devices_.push_back(device);
} else {
VLOG(2) << "Other devices, id: " << device->id();
}
}
CHECK_GE(local_devices_.size(), 1);
LOG(INFO) << "Creating " << local_devices_.size() << " TPU device(s).";
for (int idx = 0; idx < local_devices_.size(); ++idx) {
CHECK(local_devices_[idx] != nullptr) << idx;
}
@ -217,8 +233,8 @@ StatusOr<std::unique_ptr<PyTpuBuffer>> PyTpuBuffer::MakeTuple(
std::shared_ptr<TpuSharedBuffer> child_device_buffer =
child_buffer->DeviceBuffer();
// Merge all definition events from all children, so that anyone using this
// tuple must wait for all its children to finish receiving transfers.
// This works recursively up a nested tuple tree as well.
// tuple must wait for all its children to finish receiving transfers. This
// works recursively up a nested tuple tree as well.
for (std::shared_ptr<tpu_driver::Event> child_event :
child_device_buffer->wait_for_use) {
child_events.push_back(std::move(child_event));

View File

@ -38,8 +38,21 @@ namespace xla {
class TpuDevice : public Device {
public:
using Device::Device;
TpuDevice(int id, int host_id, const std::array<int, 3>& coords,
int core_on_chip);
const std::array<int, 3>& coords() const { return coords_; }
int core_on_chip() const { return core_on_chip_; }
std::string DebugString() const override;
static xla::StatusOr<std::vector<std::shared_ptr<xla::Device>>> GetTpuDevices(
const tpu_driver::SystemInfo& system_info);
private:
const std::array<int, 3> coords_;
// Index of the core of the same chip.
int core_on_chip_;
};
// Encapsulates the state of Python session with XLA.
@ -50,7 +63,7 @@ class PyTpuClient {
static StatusOr<std::shared_ptr<PyTpuClient>> Get(const std::string& worker);
explicit PyTpuClient(std::string platform_name,
std::unique_ptr<tpu_driver::TpuDriver> client,
std::unique_ptr<tpu_driver::TpuDriver> driver,
std::vector<std::shared_ptr<Device>> devices,
int host_id);
virtual ~PyTpuClient() = default;

View File

@ -206,8 +206,13 @@ PYBIND11_MODULE(tpu_client_extension, m) {
py::call_guard<py::gil_scoped_release>(), py::arg("arguments"));
py::class_<TpuDevice, Device, std::shared_ptr<TpuDevice>>(m, "TpuDevice")
.def_property_readonly("coords", &TpuDevice::coords)
.def_property_readonly("core_on_chip", &TpuDevice::core_on_chip)
.def("__repr__", [](const TpuDevice& device) {
return absl::StrFormat("TpuDevice(id=%i)", device.id());
return absl::StrFormat(
"TpuDevice(id=%i, host_id=%i, coords=(%i,%i,%i), core_on_chip=%i)",
device.id(), device.host_id(), device.coords()[0],
device.coords()[1], device.coords()[2], device.core_on_chip());
});
} // NOLINT(readability/fn_size)

View File

@ -19,15 +19,24 @@ package tpu_driver;
enum MemoryRegion { HBM = 1; }
message ChipCoordinate {
required int32 x = 1;
required int32 y = 2;
required int32 z = 3;
}
message TpuCoreInfo {
required int32 id = 1;
required int64 hbm_bytes_available = 100;
required int64 hbm_bytes_allocatable = 101;
optional int32 core_on_chip_index = 2;
optional int32 core_on_host_index = 3;
optional int64 hbm_bytes_available = 100;
optional int64 hbm_bytes_allocatable = 101;
}
message TpuChipInfo {
repeated TpuCoreInfo core = 1;
optional int32 host_id = 2;
optional ChipCoordinate chip_coord = 3;
}
message CpuInfo {
@ -40,6 +49,11 @@ message CpuInfo {
message SystemInfo {
repeated TpuChipInfo tpu_chip = 1;
required CpuInfo cpu = 2;
repeated TpuCoreInfo local_core = 3;
optional int32 host_id = 4;
optional int32 host_count = 5;
optional int32 chip_count = 6;
optional int32 core_count = 7;
}
message TpuDriverConfig {