[PJRT] Rename PjRtDevice::host_id and PjRtClient::host_id to ...::task_id instead.

Fixes a TODO in PJRT.

In general, we may have multiple tasks per physical host, so "task_id" is more accurate.

PiperOrigin-RevId: 353984838
Change-Id: I59340c15181374a3db488b238ed81836a94003d0
This commit is contained in:
Peter Hawkins 2021-01-26 17:14:02 -08:00 committed by TensorFlower Gardener
parent 747ca958fa
commit 55f1184704
13 changed files with 64 additions and 57 deletions

View File

@ -58,7 +58,7 @@ StatusOr<std::unique_ptr<PjRtClient>> GetCpuClient(bool asynchronous) {
}
return std::unique_ptr<PjRtClient>(std::make_unique<PjRtStreamExecutorClient>(
kCpuName, client, std::move(devices), /*host_id=*/0,
kCpuName, client, std::move(devices), /*task_id=*/0,
/*allocator=*/nullptr, /*host_memory_allocator=*/nullptr,
/*should_stage_host_to_device_transfers=*/false,
/*gpu_run_options=*/nullptr));

View File

@ -52,7 +52,7 @@ StatusOr<std::unique_ptr<PjRtClient>> GetInterpreterClient() {
devices.push_back(std::move(device));
return std::unique_ptr<PjRtClient>(std::make_unique<PjRtStreamExecutorClient>(
"interpreter", client, std::move(devices), /*host_id=*/0,
"interpreter", client, std::move(devices), /*task_id=*/0,
/*allocator=*/nullptr, /*host_memory_allocator=*/nullptr,
/*should_stage_host_to_device_transfers=*/false,
/*gpu_run_options=*/nullptr));

View File

@ -72,11 +72,11 @@ class PjRtDevice {
// hosts' devices. This is the ID that should be used in a DeviceAssignment.
virtual int id() const = 0;
// The task ID of this device according to TpuTopology. This is not the same
// as PjRtClient::host_id() in a multi-task setting, where each client can see
// devices from all tasks, but only a subset of them are addressable and have
// the same task_id as the client.
virtual int host_id() const = 0;
// The task ID of this device according to TpuTopology. This is not always
// identical to PjRtClient::task_id() in a multi-task setting, where each
// client can see devices from all tasks, but only a subset of them are
// addressable and have the same task_id as the client.
virtual int task_id() const = 0;
// Opaque hardware ID, e.g., the CUDA device number, useful for identifying
// which GPU when interacting with non-JAX code. In general, not guaranteed to
@ -140,9 +140,8 @@ class PjRtClient {
public:
virtual ~PjRtClient() = default;
// TODO(zhangqiaorjc): Rename to task_id.
// Return the task id of this client. In single-task setting, always 0.
virtual int host_id() const = 0;
virtual int task_id() const = 0;
// Return the number of devices in the entire computation. In multi-headed
// client setting, some are addressable by this client, some are not. In a

View File

@ -202,7 +202,7 @@ static int DefaultThreadPoolSize() {
PjRtStreamExecutorClient::PjRtStreamExecutorClient(
std::string platform_name, LocalClient* client,
std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices, int host_id,
std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices, int task_id,
std::unique_ptr<se::DeviceMemoryAllocator> allocator,
std::unique_ptr<tensorflow::Allocator> host_memory_allocator,
bool should_stage_host_to_device_transfers,
@ -212,7 +212,7 @@ PjRtStreamExecutorClient::PjRtStreamExecutorClient(
client_(client),
host_memory_allocator_(std::move(host_memory_allocator)),
owned_devices_(std::move(devices)),
host_id_(host_id),
task_id_(task_id),
owned_allocator_(std::move(allocator)),
should_stage_host_to_device_transfers_(
should_stage_host_to_device_transfers),
@ -1815,7 +1815,7 @@ PjRtStreamExecutorExecutable::ExecuteHelper(
(*device_assignment)(0, 0) = device->id();
}
CHECK_EQ(device->host_id(), client_->host_id());
CHECK_EQ(device->task_id(), client_->task_id());
int device_ordinal = tensorflow::down_cast<PjRtStreamExecutorDevice*>(device)
->local_device_state()
->device_ordinal();
@ -2093,7 +2093,7 @@ StatusOr<std::unique_ptr<PjRtExecutable>> PjRtStreamExecutorClient::Compile(
for (int partition = 0; partition < num_partitions; ++partition) {
int device_id = (*device_assignment)(replica, partition);
TF_ASSIGN_OR_RETURN(PjRtDevice * device, LookupDevice(device_id));
if (device->host_id() != host_id()) {
if (device->task_id() != task_id()) {
VLOG(3) << "Non-local device: " << device_id;
continue;
}

View File

@ -58,12 +58,12 @@ class PjRtStreamExecutorDevice : public PjRtDevice {
public:
explicit PjRtStreamExecutorDevice(
int id, std::unique_ptr<LocalDeviceState> local_device_state,
std::string device_kind, int host_id = 0)
std::string device_kind, int task_id = 0)
: id_(id),
device_ordinal_(
local_device_state ? local_device_state->device_ordinal() : -1),
local_device_state_(std::move(local_device_state)),
host_id_(host_id),
task_id_(task_id),
device_kind_(std::move(device_kind)) {}
~PjRtStreamExecutorDevice() override {}
@ -73,7 +73,7 @@ class PjRtStreamExecutorDevice : public PjRtDevice {
client_ = client;
}
int host_id() const override { return host_id_; }
int task_id() const override { return task_id_; }
// Return `platform_id` from client.
PjRtPlatformId platform_id() const;
@ -113,7 +113,7 @@ class PjRtStreamExecutorDevice : public PjRtDevice {
const int id_;
const int device_ordinal_; // -1 means not local.
const std::unique_ptr<LocalDeviceState> local_device_state_;
const int host_id_;
const int task_id_;
const std::string device_kind_;
PjRtClient* client_ = nullptr;
};
@ -124,13 +124,13 @@ class PjRtStreamExecutorClient : public PjRtClient {
explicit PjRtStreamExecutorClient(
std::string platform_name, LocalClient* client,
std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices,
int host_id, std::unique_ptr<se::DeviceMemoryAllocator> allocator,
int task_id, std::unique_ptr<se::DeviceMemoryAllocator> allocator,
std::unique_ptr<tensorflow::Allocator> host_memory_allocator,
bool should_stage_host_to_device_transfers,
std::unique_ptr<gpu::GpuExecutableRunOptions> gpu_run_options);
~PjRtStreamExecutorClient() override = default;
int host_id() const override { return host_id_; }
int task_id() const override { return task_id_; }
int device_count() const override { return devices_.size(); }
int addressable_device_count() const override {
@ -255,7 +255,7 @@ class PjRtStreamExecutorClient : public PjRtClient {
std::map<int, PjRtDevice*> id_to_device_;
// Local devices indexed by local device ordinal.
std::vector<PjRtDevice*> addressable_devices_;
int host_id_;
int task_id_;
se::DeviceMemoryAllocator* allocator_;
std::unique_ptr<se::DeviceMemoryAllocator> owned_allocator_;

View File

@ -77,7 +77,7 @@ class PjRtTpuClient : public PjRtStreamExecutorClient {
public:
PjRtTpuClient(LocalClient* client,
std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices,
int host_id);
int task_id);
StatusOr<DeviceAssignment> GetDefaultDeviceAssignment(
int num_replicas, int num_partitions) const override;
@ -90,8 +90,8 @@ class PjRtTpuClient : public PjRtStreamExecutorClient {
PjRtTpuClient::PjRtTpuClient(
LocalClient* client,
std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices, int host_id)
: PjRtStreamExecutorClient(kTpuName, client, std::move(devices), host_id,
std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices, int task_id)
: PjRtStreamExecutorClient(kTpuName, client, std::move(devices), task_id,
/*allocator=*/nullptr,
/*host_memory_allocator=*/nullptr,
/*should_stage_host_to_device_transfers=*/false,
@ -155,7 +155,7 @@ StatusOr<std::vector<std::unique_ptr<PjRtStreamExecutorDevice>>> GetTpuDevices(
auto it = core_id_to_device_ordinal.find(core.Id());
int device_ordinal =
(it != core_id_to_device_ordinal.end()) ? it->second : -1;
int host_id = topology.IdForHost(core.host_coordinates());
int task_id = topology.IdForHost(core.host_coordinates());
const tf_tpu::TpuDimensionsExternal coords = core.chip_coordinates();
std::array<int, 3> coords_array = {coords.x, coords.y, coords.z};
std::unique_ptr<LocalDeviceState> local_device_state;
@ -163,7 +163,7 @@ StatusOr<std::vector<std::unique_ptr<PjRtStreamExecutorDevice>>> GetTpuDevices(
local_device_state = std::move(local_device_states[device_ordinal]);
}
auto device = absl::make_unique<PjRtTpuDevice>(
core, std::move(local_device_state), host_id, coords_array,
core, std::move(local_device_state), task_id, coords_array,
std::string(tf_tpu::TpuVersionEnumToString(topology.version())));
devices.push_back(std::move(device));
}
@ -214,10 +214,10 @@ StatusOr<std::shared_ptr<PjRtClient>> GetTpuClient(
TF_ASSIGN_OR_RETURN(auto devices,
GetTpuDevices(client, std::move(local_device_states)));
int host_id = platform->GetTpuHostLocation().Id();
int task_id = platform->GetTpuHostLocation().Id();
return std::shared_ptr<PjRtClient>(
absl::make_unique<PjRtTpuClient>(client, std::move(devices), host_id));
absl::make_unique<PjRtTpuClient>(client, std::move(devices), task_id));
}
} // namespace xla

View File

@ -30,10 +30,10 @@ class PjRtTpuDevice : public PjRtStreamExecutorDevice {
public:
PjRtTpuDevice(const tensorflow::tpu::TpuCoreLocationExternal core,
std::unique_ptr<LocalDeviceState> local_device_state,
int host_id, const std::array<int, 3>& coords,
int task_id, const std::array<int, 3>& coords,
std::string device_kind)
: PjRtStreamExecutorDevice(core.Id(), std::move(local_device_state),
std::move(device_kind), host_id),
std::move(device_kind), task_id),
core_(core),
coords_(coords) {}
@ -42,7 +42,7 @@ class PjRtTpuDevice : public PjRtStreamExecutorDevice {
const tensorflow::tpu::TpuCoreLocationExternal core() const { return core_; }
std::string DebugString() const override {
return absl::StrFormat("TPU_%i(host=%i,(%i,%i,%i,%i))", id(), host_id(),
return absl::StrFormat("TPU_%i(host=%i,(%i,%i,%i,%i))", id(), task_id(),
coords_[0], coords_[1], coords_[2], core_.index());
}

View File

@ -98,7 +98,7 @@ StatusOr<std::unique_ptr<PjRtClient>> GetCpuClientWithNonLocalDevice() {
devices.push_back(absl::make_unique<CpuDevice>(1, nullptr));
return std::unique_ptr<PjRtClient>(std::make_unique<PjRtStreamExecutorClient>(
kCpuName, client, std::move(devices), /*host_id=*/0,
kCpuName, client, std::move(devices), /*task_id=*/0,
/*allocator=*/nullptr, /*host_memory_allocator=*/nullptr,
/*should_stage_host_to_device_transfers=*/false,
/*gpu_run_options=*/nullptr));

View File

@ -101,7 +101,7 @@ class PyClient : public std::enable_shared_from_this<PyClient> {
return pjrt_client_->addressable_device_count();
}
int device_count() const { return pjrt_client_->device_count(); }
int host_id() const { return pjrt_client_->host_id(); }
int task_id() const { return pjrt_client_->task_id(); }
std::vector<ClientAndPtr<PjRtDevice>> Devices();
std::vector<ClientAndPtr<PjRtDevice>> LocalDevices();

View File

@ -35,15 +35,15 @@ limitations under the License.
namespace xla {
TpuDevice::TpuDevice(int id, int host_id, const std::array<int, 3>& coords,
TpuDevice::TpuDevice(int id, int task_id, const std::array<int, 3>& coords,
int core_on_chip)
: id_(id),
host_id_(host_id),
task_id_(task_id),
coords_(coords),
core_on_chip_(core_on_chip) {}
std::string TpuDevice::DebugString() const {
return absl::StrFormat("TPU_%i(host=%i,(%i,%i,%i,%i))", id(), host_id(),
return absl::StrFormat("TPU_%i(host=%i,(%i,%i,%i,%i))", id(), task_id(),
coords_[0], coords_[1], coords_[2], core_on_chip_);
}
@ -53,10 +53,10 @@ TpuDevice::GetTpuDevices(const tpu_driver::SystemInfo& system_info) {
for (const auto& chip : system_info.tpu_chip()) {
auto& coord = chip.chip_coord();
std::array<int, 3> coords_array = {coord.x(), coord.y(), coord.z()};
int host_id = chip.host_id();
int task_id = chip.host_id();
for (const auto& core : chip.core()) {
auto device = std::make_shared<TpuDevice>(
core.id(), host_id, coords_array, core.core_on_chip_index());
core.id(), task_id, coords_array, core.core_on_chip_index());
devices.push_back(device);
}
}
@ -89,17 +89,17 @@ StatusOr<std::shared_ptr<PyTpuClient>> PyTpuClient::Get(
PyTpuClient::PyTpuClient(std::string platform_name,
std::unique_ptr<tpu_driver::TpuDriver> driver,
std::vector<std::shared_ptr<PjRtDevice>> devices,
int host_id)
int task_id)
: platform_name_(std::move(platform_name)),
driver_(std::move(driver)),
devices_(std::move(devices)),
host_id_(host_id) {
task_id_(task_id) {
for (const std::shared_ptr<PjRtDevice>& device : devices_) {
CHECK(id_to_device_.insert({device->id(), device}).second)
<< "Duplicate device id: " << device->id();
if (device->host_id() == host_id_) {
LOG(INFO) << "Detected local device, host id: " << host_id_
if (device->task_id() == task_id_) {
LOG(INFO) << "Detected local device, host id: " << task_id_
<< ". device id: " << device->id();
local_devices_.push_back(device);
} else {
@ -522,7 +522,7 @@ PyTpuExecutable::PyTpuExecutable(
for (int partition = 0; partition < num_partitions; ++partition) {
int device_id = device_assignment_(replica, partition);
std::shared_ptr<PjRtDevice> device = LookupDevice(*client_, device_id);
if (device->host_id() != client_->host_id()) {
if (device->task_id() != client_->task_id()) {
VLOG(3) << "Non-local device: " << device_id;
continue;
}
@ -547,7 +547,7 @@ PyTpuExecutable::ExecuteResult PyTpuExecutable::ExecuteHelper(
int partition, const RunId& run_id) {
const int device_id = device_assignment_(replica, partition);
std::shared_ptr<PjRtDevice> device = LookupDevice(*client_, device_id);
CHECK_EQ(device->host_id(), client_->host_id());
CHECK_EQ(device->task_id(), client_->task_id());
tensorflow::profiler::TraceMe traceme("PyTpuExecutable::Execute");
VLOG(3) << "Replica " << replica << ", partition " << partition
<< " mapped to device id for execution: " << device_id;

View File

@ -42,7 +42,7 @@ constexpr char kTpuPlatform[] = "tpu";
class TpuDevice : public PjRtDevice {
public:
TpuDevice(int id, int host_id, const std::array<int, 3>& coords,
TpuDevice(int id, int task_id, const std::array<int, 3>& coords,
int core_on_chip);
const std::array<int, 3>& coords() const { return coords_; }
@ -59,7 +59,7 @@ class TpuDevice : public PjRtDevice {
int id() const override { return id_; }
int host_id() const override { return host_id_; }
int task_id() const override { return task_id_; }
int local_hardware_id() const override { return -1; }
@ -75,7 +75,7 @@ class TpuDevice : public PjRtDevice {
private:
const int id_;
const int host_id_;
const int task_id_;
const std::array<int, 3> coords_;
const std::string device_kind_ = "Cloud TPU";
// Index of the core of the same chip.
@ -92,7 +92,7 @@ class PyTpuClient {
explicit PyTpuClient(std::string platform_name,
std::unique_ptr<tpu_driver::TpuDriver> driver,
std::vector<std::shared_ptr<PjRtDevice>> devices,
int host_id);
int task_id);
virtual ~PyTpuClient() = default;
PyTpuClient(const PyTpuClient&) = delete;
@ -115,7 +115,7 @@ class PyTpuClient {
const std::map<int, std::shared_ptr<PjRtDevice>>& id_to_device() const {
return id_to_device_;
}
int host_id() const { return host_id_; }
int task_id() const { return task_id_; }
const std::string& platform_name() const { return platform_name_; }
StatusOr<Shape> ChooseCompactLayoutForShape(Shape subshape) {
@ -140,7 +140,7 @@ class PyTpuClient {
std::map<int, std::shared_ptr<PjRtDevice>> id_to_device_;
// Local devices indexed by local device ordinal.
std::vector<std::shared_ptr<PjRtDevice>> local_devices_;
int host_id_;
int task_id_;
// A thread pool for scheduling core executions in parallel.
std::unique_ptr<tensorflow::thread::ThreadPool> pool_;

View File

@ -32,7 +32,8 @@ PYBIND11_MODULE(tpu_client_extension, m) {
.def("local_device_count", &PyTpuClient::local_device_count)
.def("devices", &PyTpuClient::devices)
.def("local_devices", &PyTpuClient::local_devices)
.def("host_id", &PyTpuClient::host_id)
.def("host_id", &PyTpuClient::task_id)
.def("task_id", &PyTpuClient::task_id)
.def("get_default_device_assignment",
[](PyTpuClient* client, int num_replicas, int num_partitions)
-> StatusOr<
@ -213,7 +214,7 @@ PYBIND11_MODULE(tpu_client_extension, m) {
.def("__repr__", [](const TpuDevice& device) {
return absl::StrFormat(
"TpuDevice(id=%i, host_id=%i, coords=(%i,%i,%i), core_on_chip=%i)",
device.id(), device.host_id(), device.coords()[0],
device.id(), device.task_id(), device.coords()[0],
device.coords()[1], device.coords()[2], device.core_on_chip());
});
} // NOLINT(readability/fn_size)

View File

@ -55,6 +55,8 @@ limitations under the License.
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/python/lib/core/bfloat16.h"
// TODO(phawkins): remove host_id properties after JAX is update to avoid them.
namespace xla {
namespace {
@ -124,9 +126,12 @@ PYBIND11_MODULE(xla_extension, m) {
"id", &PjRtDevice::id,
"Integer ID of this device.\n\nUnique across all available devices "
"of this type, including remote devices on multi-host platforms.")
.def_property_readonly("host_id", &PjRtDevice::host_id,
"Integer ID of this device's host.\n\n"
"This is always 0 except on multi-host platforms.")
.def_property_readonly("host_id", &PjRtDevice::task_id,
"Integer ID of this device's task.\n\n"
"This is always 0 except on multi-task platforms.")
.def_property_readonly("task_id", &PjRtDevice::task_id,
"Integer ID of this device's task.\n\n"
"This is always 0 except on multi-task platforms.")
.def_property_readonly("platform",
[](const PjRtDevice& device) {
return device.client()->platform_name();
@ -174,7 +179,8 @@ PYBIND11_MODULE(xla_extension, m) {
py::class_<PjRtTpuDevice, PjRtDevice, ClientAndPtr<PjRtTpuDevice>>(
m, "TpuDevice")
.def_property_readonly("host_id", &PjRtTpuDevice::host_id)
.def_property_readonly("host_id", &PjRtTpuDevice::task_id)
.def_property_readonly("task_id", &PjRtTpuDevice::task_id)
.def_property_readonly(
"coords",
[](const PjRtTpuDevice& device) -> pybind11::tuple {
@ -187,7 +193,7 @@ PYBIND11_MODULE(xla_extension, m) {
.def("__repr__", [](const PjRtTpuDevice& device) {
return absl::StrFormat(
"TpuDevice(id=%i, host=%i, coords=(%s), core_on_chip=%i)",
device.id(), device.host_id(), absl::StrJoin(device.coords(), ","),
device.id(), device.task_id(), absl::StrJoin(device.coords(), ","),
device.core_on_chip());
});
@ -217,7 +223,8 @@ PYBIND11_MODULE(xla_extension, m) {
.def("devices", &PyClient::Devices)
.def("local_devices", &PyClient::LocalDevices)
.def("live_buffers", &PyClient::LiveBuffers)
.def("host_id", &PyClient::host_id)
.def("host_id", &PyClient::task_id)
.def("task_id", &PyClient::task_id)
.def("get_default_device_assignment",
&PyClient::GetDefaultDeviceAssignment)
// TODO(skye): delete after all callers can handle 2D output