[PJRT] Rename PjRtDevice::host_id and PjRtClient::host_id to ...::task_id instead.
Fixes a TODO in PJRT. In general, we may have multiple tasks per physical host, so "task_id" is more accurate. PiperOrigin-RevId: 353984838 Change-Id: I59340c15181374a3db488b238ed81836a94003d0
This commit is contained in:
parent
747ca958fa
commit
55f1184704
@ -58,7 +58,7 @@ StatusOr<std::unique_ptr<PjRtClient>> GetCpuClient(bool asynchronous) {
|
||||
}
|
||||
|
||||
return std::unique_ptr<PjRtClient>(std::make_unique<PjRtStreamExecutorClient>(
|
||||
kCpuName, client, std::move(devices), /*host_id=*/0,
|
||||
kCpuName, client, std::move(devices), /*task_id=*/0,
|
||||
/*allocator=*/nullptr, /*host_memory_allocator=*/nullptr,
|
||||
/*should_stage_host_to_device_transfers=*/false,
|
||||
/*gpu_run_options=*/nullptr));
|
||||
|
@ -52,7 +52,7 @@ StatusOr<std::unique_ptr<PjRtClient>> GetInterpreterClient() {
|
||||
devices.push_back(std::move(device));
|
||||
|
||||
return std::unique_ptr<PjRtClient>(std::make_unique<PjRtStreamExecutorClient>(
|
||||
"interpreter", client, std::move(devices), /*host_id=*/0,
|
||||
"interpreter", client, std::move(devices), /*task_id=*/0,
|
||||
/*allocator=*/nullptr, /*host_memory_allocator=*/nullptr,
|
||||
/*should_stage_host_to_device_transfers=*/false,
|
||||
/*gpu_run_options=*/nullptr));
|
||||
|
@ -72,11 +72,11 @@ class PjRtDevice {
|
||||
// hosts' devices. This is the ID that should be used in a DeviceAssignment.
|
||||
virtual int id() const = 0;
|
||||
|
||||
// The task ID of this device according to TpuTopology. This is not the same
|
||||
// as PjRtClient::host_id() in a multi-task setting, where each client can see
|
||||
// devices from all tasks, but only a subset of them are addressable and have
|
||||
// the same task_id as the client.
|
||||
virtual int host_id() const = 0;
|
||||
// The task ID of this device according to TpuTopology. This is not always
|
||||
// identical to PjRtClient::task_id() in a multi-task setting, where each
|
||||
// client can see devices from all tasks, but only a subset of them are
|
||||
// addressable and have the same task_id as the client.
|
||||
virtual int task_id() const = 0;
|
||||
|
||||
// Opaque hardware ID, e.g., the CUDA device number, useful for identifying
|
||||
// which GPU when interacting with non-JAX code. In general, not guaranteed to
|
||||
@ -140,9 +140,8 @@ class PjRtClient {
|
||||
public:
|
||||
virtual ~PjRtClient() = default;
|
||||
|
||||
// TODO(zhangqiaorjc): Rename to task_id.
|
||||
// Return the task id of this client. In single-task setting, always 0.
|
||||
virtual int host_id() const = 0;
|
||||
virtual int task_id() const = 0;
|
||||
|
||||
// Return the number of devices in the entire computation. In multi-headed
|
||||
// client setting, some are addressable by this client, some are not. In a
|
||||
|
@ -202,7 +202,7 @@ static int DefaultThreadPoolSize() {
|
||||
|
||||
PjRtStreamExecutorClient::PjRtStreamExecutorClient(
|
||||
std::string platform_name, LocalClient* client,
|
||||
std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices, int host_id,
|
||||
std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices, int task_id,
|
||||
std::unique_ptr<se::DeviceMemoryAllocator> allocator,
|
||||
std::unique_ptr<tensorflow::Allocator> host_memory_allocator,
|
||||
bool should_stage_host_to_device_transfers,
|
||||
@ -212,7 +212,7 @@ PjRtStreamExecutorClient::PjRtStreamExecutorClient(
|
||||
client_(client),
|
||||
host_memory_allocator_(std::move(host_memory_allocator)),
|
||||
owned_devices_(std::move(devices)),
|
||||
host_id_(host_id),
|
||||
task_id_(task_id),
|
||||
owned_allocator_(std::move(allocator)),
|
||||
should_stage_host_to_device_transfers_(
|
||||
should_stage_host_to_device_transfers),
|
||||
@ -1815,7 +1815,7 @@ PjRtStreamExecutorExecutable::ExecuteHelper(
|
||||
(*device_assignment)(0, 0) = device->id();
|
||||
}
|
||||
|
||||
CHECK_EQ(device->host_id(), client_->host_id());
|
||||
CHECK_EQ(device->task_id(), client_->task_id());
|
||||
int device_ordinal = tensorflow::down_cast<PjRtStreamExecutorDevice*>(device)
|
||||
->local_device_state()
|
||||
->device_ordinal();
|
||||
@ -2093,7 +2093,7 @@ StatusOr<std::unique_ptr<PjRtExecutable>> PjRtStreamExecutorClient::Compile(
|
||||
for (int partition = 0; partition < num_partitions; ++partition) {
|
||||
int device_id = (*device_assignment)(replica, partition);
|
||||
TF_ASSIGN_OR_RETURN(PjRtDevice * device, LookupDevice(device_id));
|
||||
if (device->host_id() != host_id()) {
|
||||
if (device->task_id() != task_id()) {
|
||||
VLOG(3) << "Non-local device: " << device_id;
|
||||
continue;
|
||||
}
|
||||
|
@ -58,12 +58,12 @@ class PjRtStreamExecutorDevice : public PjRtDevice {
|
||||
public:
|
||||
explicit PjRtStreamExecutorDevice(
|
||||
int id, std::unique_ptr<LocalDeviceState> local_device_state,
|
||||
std::string device_kind, int host_id = 0)
|
||||
std::string device_kind, int task_id = 0)
|
||||
: id_(id),
|
||||
device_ordinal_(
|
||||
local_device_state ? local_device_state->device_ordinal() : -1),
|
||||
local_device_state_(std::move(local_device_state)),
|
||||
host_id_(host_id),
|
||||
task_id_(task_id),
|
||||
device_kind_(std::move(device_kind)) {}
|
||||
~PjRtStreamExecutorDevice() override {}
|
||||
|
||||
@ -73,7 +73,7 @@ class PjRtStreamExecutorDevice : public PjRtDevice {
|
||||
client_ = client;
|
||||
}
|
||||
|
||||
int host_id() const override { return host_id_; }
|
||||
int task_id() const override { return task_id_; }
|
||||
|
||||
// Return `platform_id` from client.
|
||||
PjRtPlatformId platform_id() const;
|
||||
@ -113,7 +113,7 @@ class PjRtStreamExecutorDevice : public PjRtDevice {
|
||||
const int id_;
|
||||
const int device_ordinal_; // -1 means not local.
|
||||
const std::unique_ptr<LocalDeviceState> local_device_state_;
|
||||
const int host_id_;
|
||||
const int task_id_;
|
||||
const std::string device_kind_;
|
||||
PjRtClient* client_ = nullptr;
|
||||
};
|
||||
@ -124,13 +124,13 @@ class PjRtStreamExecutorClient : public PjRtClient {
|
||||
explicit PjRtStreamExecutorClient(
|
||||
std::string platform_name, LocalClient* client,
|
||||
std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices,
|
||||
int host_id, std::unique_ptr<se::DeviceMemoryAllocator> allocator,
|
||||
int task_id, std::unique_ptr<se::DeviceMemoryAllocator> allocator,
|
||||
std::unique_ptr<tensorflow::Allocator> host_memory_allocator,
|
||||
bool should_stage_host_to_device_transfers,
|
||||
std::unique_ptr<gpu::GpuExecutableRunOptions> gpu_run_options);
|
||||
~PjRtStreamExecutorClient() override = default;
|
||||
|
||||
int host_id() const override { return host_id_; }
|
||||
int task_id() const override { return task_id_; }
|
||||
|
||||
int device_count() const override { return devices_.size(); }
|
||||
int addressable_device_count() const override {
|
||||
@ -255,7 +255,7 @@ class PjRtStreamExecutorClient : public PjRtClient {
|
||||
std::map<int, PjRtDevice*> id_to_device_;
|
||||
// Local devices indexed by local device ordinal.
|
||||
std::vector<PjRtDevice*> addressable_devices_;
|
||||
int host_id_;
|
||||
int task_id_;
|
||||
|
||||
se::DeviceMemoryAllocator* allocator_;
|
||||
std::unique_ptr<se::DeviceMemoryAllocator> owned_allocator_;
|
||||
|
@ -77,7 +77,7 @@ class PjRtTpuClient : public PjRtStreamExecutorClient {
|
||||
public:
|
||||
PjRtTpuClient(LocalClient* client,
|
||||
std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices,
|
||||
int host_id);
|
||||
int task_id);
|
||||
|
||||
StatusOr<DeviceAssignment> GetDefaultDeviceAssignment(
|
||||
int num_replicas, int num_partitions) const override;
|
||||
@ -90,8 +90,8 @@ class PjRtTpuClient : public PjRtStreamExecutorClient {
|
||||
|
||||
PjRtTpuClient::PjRtTpuClient(
|
||||
LocalClient* client,
|
||||
std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices, int host_id)
|
||||
: PjRtStreamExecutorClient(kTpuName, client, std::move(devices), host_id,
|
||||
std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices, int task_id)
|
||||
: PjRtStreamExecutorClient(kTpuName, client, std::move(devices), task_id,
|
||||
/*allocator=*/nullptr,
|
||||
/*host_memory_allocator=*/nullptr,
|
||||
/*should_stage_host_to_device_transfers=*/false,
|
||||
@ -155,7 +155,7 @@ StatusOr<std::vector<std::unique_ptr<PjRtStreamExecutorDevice>>> GetTpuDevices(
|
||||
auto it = core_id_to_device_ordinal.find(core.Id());
|
||||
int device_ordinal =
|
||||
(it != core_id_to_device_ordinal.end()) ? it->second : -1;
|
||||
int host_id = topology.IdForHost(core.host_coordinates());
|
||||
int task_id = topology.IdForHost(core.host_coordinates());
|
||||
const tf_tpu::TpuDimensionsExternal coords = core.chip_coordinates();
|
||||
std::array<int, 3> coords_array = {coords.x, coords.y, coords.z};
|
||||
std::unique_ptr<LocalDeviceState> local_device_state;
|
||||
@ -163,7 +163,7 @@ StatusOr<std::vector<std::unique_ptr<PjRtStreamExecutorDevice>>> GetTpuDevices(
|
||||
local_device_state = std::move(local_device_states[device_ordinal]);
|
||||
}
|
||||
auto device = absl::make_unique<PjRtTpuDevice>(
|
||||
core, std::move(local_device_state), host_id, coords_array,
|
||||
core, std::move(local_device_state), task_id, coords_array,
|
||||
std::string(tf_tpu::TpuVersionEnumToString(topology.version())));
|
||||
devices.push_back(std::move(device));
|
||||
}
|
||||
@ -214,10 +214,10 @@ StatusOr<std::shared_ptr<PjRtClient>> GetTpuClient(
|
||||
|
||||
TF_ASSIGN_OR_RETURN(auto devices,
|
||||
GetTpuDevices(client, std::move(local_device_states)));
|
||||
int host_id = platform->GetTpuHostLocation().Id();
|
||||
int task_id = platform->GetTpuHostLocation().Id();
|
||||
|
||||
return std::shared_ptr<PjRtClient>(
|
||||
absl::make_unique<PjRtTpuClient>(client, std::move(devices), host_id));
|
||||
absl::make_unique<PjRtTpuClient>(client, std::move(devices), task_id));
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
|
@ -30,10 +30,10 @@ class PjRtTpuDevice : public PjRtStreamExecutorDevice {
|
||||
public:
|
||||
PjRtTpuDevice(const tensorflow::tpu::TpuCoreLocationExternal core,
|
||||
std::unique_ptr<LocalDeviceState> local_device_state,
|
||||
int host_id, const std::array<int, 3>& coords,
|
||||
int task_id, const std::array<int, 3>& coords,
|
||||
std::string device_kind)
|
||||
: PjRtStreamExecutorDevice(core.Id(), std::move(local_device_state),
|
||||
std::move(device_kind), host_id),
|
||||
std::move(device_kind), task_id),
|
||||
core_(core),
|
||||
coords_(coords) {}
|
||||
|
||||
@ -42,7 +42,7 @@ class PjRtTpuDevice : public PjRtStreamExecutorDevice {
|
||||
const tensorflow::tpu::TpuCoreLocationExternal core() const { return core_; }
|
||||
|
||||
std::string DebugString() const override {
|
||||
return absl::StrFormat("TPU_%i(host=%i,(%i,%i,%i,%i))", id(), host_id(),
|
||||
return absl::StrFormat("TPU_%i(host=%i,(%i,%i,%i,%i))", id(), task_id(),
|
||||
coords_[0], coords_[1], coords_[2], core_.index());
|
||||
}
|
||||
|
||||
|
@ -98,7 +98,7 @@ StatusOr<std::unique_ptr<PjRtClient>> GetCpuClientWithNonLocalDevice() {
|
||||
devices.push_back(absl::make_unique<CpuDevice>(1, nullptr));
|
||||
|
||||
return std::unique_ptr<PjRtClient>(std::make_unique<PjRtStreamExecutorClient>(
|
||||
kCpuName, client, std::move(devices), /*host_id=*/0,
|
||||
kCpuName, client, std::move(devices), /*task_id=*/0,
|
||||
/*allocator=*/nullptr, /*host_memory_allocator=*/nullptr,
|
||||
/*should_stage_host_to_device_transfers=*/false,
|
||||
/*gpu_run_options=*/nullptr));
|
||||
|
@ -101,7 +101,7 @@ class PyClient : public std::enable_shared_from_this<PyClient> {
|
||||
return pjrt_client_->addressable_device_count();
|
||||
}
|
||||
int device_count() const { return pjrt_client_->device_count(); }
|
||||
int host_id() const { return pjrt_client_->host_id(); }
|
||||
int task_id() const { return pjrt_client_->task_id(); }
|
||||
|
||||
std::vector<ClientAndPtr<PjRtDevice>> Devices();
|
||||
std::vector<ClientAndPtr<PjRtDevice>> LocalDevices();
|
||||
|
@ -35,15 +35,15 @@ limitations under the License.
|
||||
|
||||
namespace xla {
|
||||
|
||||
TpuDevice::TpuDevice(int id, int host_id, const std::array<int, 3>& coords,
|
||||
TpuDevice::TpuDevice(int id, int task_id, const std::array<int, 3>& coords,
|
||||
int core_on_chip)
|
||||
: id_(id),
|
||||
host_id_(host_id),
|
||||
task_id_(task_id),
|
||||
coords_(coords),
|
||||
core_on_chip_(core_on_chip) {}
|
||||
|
||||
std::string TpuDevice::DebugString() const {
|
||||
return absl::StrFormat("TPU_%i(host=%i,(%i,%i,%i,%i))", id(), host_id(),
|
||||
return absl::StrFormat("TPU_%i(host=%i,(%i,%i,%i,%i))", id(), task_id(),
|
||||
coords_[0], coords_[1], coords_[2], core_on_chip_);
|
||||
}
|
||||
|
||||
@ -53,10 +53,10 @@ TpuDevice::GetTpuDevices(const tpu_driver::SystemInfo& system_info) {
|
||||
for (const auto& chip : system_info.tpu_chip()) {
|
||||
auto& coord = chip.chip_coord();
|
||||
std::array<int, 3> coords_array = {coord.x(), coord.y(), coord.z()};
|
||||
int host_id = chip.host_id();
|
||||
int task_id = chip.host_id();
|
||||
for (const auto& core : chip.core()) {
|
||||
auto device = std::make_shared<TpuDevice>(
|
||||
core.id(), host_id, coords_array, core.core_on_chip_index());
|
||||
core.id(), task_id, coords_array, core.core_on_chip_index());
|
||||
devices.push_back(device);
|
||||
}
|
||||
}
|
||||
@ -89,17 +89,17 @@ StatusOr<std::shared_ptr<PyTpuClient>> PyTpuClient::Get(
|
||||
PyTpuClient::PyTpuClient(std::string platform_name,
|
||||
std::unique_ptr<tpu_driver::TpuDriver> driver,
|
||||
std::vector<std::shared_ptr<PjRtDevice>> devices,
|
||||
int host_id)
|
||||
int task_id)
|
||||
: platform_name_(std::move(platform_name)),
|
||||
driver_(std::move(driver)),
|
||||
devices_(std::move(devices)),
|
||||
host_id_(host_id) {
|
||||
task_id_(task_id) {
|
||||
for (const std::shared_ptr<PjRtDevice>& device : devices_) {
|
||||
CHECK(id_to_device_.insert({device->id(), device}).second)
|
||||
<< "Duplicate device id: " << device->id();
|
||||
|
||||
if (device->host_id() == host_id_) {
|
||||
LOG(INFO) << "Detected local device, host id: " << host_id_
|
||||
if (device->task_id() == task_id_) {
|
||||
LOG(INFO) << "Detected local device, host id: " << task_id_
|
||||
<< ". device id: " << device->id();
|
||||
local_devices_.push_back(device);
|
||||
} else {
|
||||
@ -522,7 +522,7 @@ PyTpuExecutable::PyTpuExecutable(
|
||||
for (int partition = 0; partition < num_partitions; ++partition) {
|
||||
int device_id = device_assignment_(replica, partition);
|
||||
std::shared_ptr<PjRtDevice> device = LookupDevice(*client_, device_id);
|
||||
if (device->host_id() != client_->host_id()) {
|
||||
if (device->task_id() != client_->task_id()) {
|
||||
VLOG(3) << "Non-local device: " << device_id;
|
||||
continue;
|
||||
}
|
||||
@ -547,7 +547,7 @@ PyTpuExecutable::ExecuteResult PyTpuExecutable::ExecuteHelper(
|
||||
int partition, const RunId& run_id) {
|
||||
const int device_id = device_assignment_(replica, partition);
|
||||
std::shared_ptr<PjRtDevice> device = LookupDevice(*client_, device_id);
|
||||
CHECK_EQ(device->host_id(), client_->host_id());
|
||||
CHECK_EQ(device->task_id(), client_->task_id());
|
||||
tensorflow::profiler::TraceMe traceme("PyTpuExecutable::Execute");
|
||||
VLOG(3) << "Replica " << replica << ", partition " << partition
|
||||
<< " mapped to device id for execution: " << device_id;
|
||||
|
@ -42,7 +42,7 @@ constexpr char kTpuPlatform[] = "tpu";
|
||||
|
||||
class TpuDevice : public PjRtDevice {
|
||||
public:
|
||||
TpuDevice(int id, int host_id, const std::array<int, 3>& coords,
|
||||
TpuDevice(int id, int task_id, const std::array<int, 3>& coords,
|
||||
int core_on_chip);
|
||||
|
||||
const std::array<int, 3>& coords() const { return coords_; }
|
||||
@ -59,7 +59,7 @@ class TpuDevice : public PjRtDevice {
|
||||
|
||||
int id() const override { return id_; }
|
||||
|
||||
int host_id() const override { return host_id_; }
|
||||
int task_id() const override { return task_id_; }
|
||||
|
||||
int local_hardware_id() const override { return -1; }
|
||||
|
||||
@ -75,7 +75,7 @@ class TpuDevice : public PjRtDevice {
|
||||
|
||||
private:
|
||||
const int id_;
|
||||
const int host_id_;
|
||||
const int task_id_;
|
||||
const std::array<int, 3> coords_;
|
||||
const std::string device_kind_ = "Cloud TPU";
|
||||
// Index of the core of the same chip.
|
||||
@ -92,7 +92,7 @@ class PyTpuClient {
|
||||
explicit PyTpuClient(std::string platform_name,
|
||||
std::unique_ptr<tpu_driver::TpuDriver> driver,
|
||||
std::vector<std::shared_ptr<PjRtDevice>> devices,
|
||||
int host_id);
|
||||
int task_id);
|
||||
virtual ~PyTpuClient() = default;
|
||||
|
||||
PyTpuClient(const PyTpuClient&) = delete;
|
||||
@ -115,7 +115,7 @@ class PyTpuClient {
|
||||
const std::map<int, std::shared_ptr<PjRtDevice>>& id_to_device() const {
|
||||
return id_to_device_;
|
||||
}
|
||||
int host_id() const { return host_id_; }
|
||||
int task_id() const { return task_id_; }
|
||||
const std::string& platform_name() const { return platform_name_; }
|
||||
|
||||
StatusOr<Shape> ChooseCompactLayoutForShape(Shape subshape) {
|
||||
@ -140,7 +140,7 @@ class PyTpuClient {
|
||||
std::map<int, std::shared_ptr<PjRtDevice>> id_to_device_;
|
||||
// Local devices indexed by local device ordinal.
|
||||
std::vector<std::shared_ptr<PjRtDevice>> local_devices_;
|
||||
int host_id_;
|
||||
int task_id_;
|
||||
|
||||
// A thread pool for scheduling core executions in parallel.
|
||||
std::unique_ptr<tensorflow::thread::ThreadPool> pool_;
|
||||
|
@ -32,7 +32,8 @@ PYBIND11_MODULE(tpu_client_extension, m) {
|
||||
.def("local_device_count", &PyTpuClient::local_device_count)
|
||||
.def("devices", &PyTpuClient::devices)
|
||||
.def("local_devices", &PyTpuClient::local_devices)
|
||||
.def("host_id", &PyTpuClient::host_id)
|
||||
.def("host_id", &PyTpuClient::task_id)
|
||||
.def("task_id", &PyTpuClient::task_id)
|
||||
.def("get_default_device_assignment",
|
||||
[](PyTpuClient* client, int num_replicas, int num_partitions)
|
||||
-> StatusOr<
|
||||
@ -213,7 +214,7 @@ PYBIND11_MODULE(tpu_client_extension, m) {
|
||||
.def("__repr__", [](const TpuDevice& device) {
|
||||
return absl::StrFormat(
|
||||
"TpuDevice(id=%i, host_id=%i, coords=(%i,%i,%i), core_on_chip=%i)",
|
||||
device.id(), device.host_id(), device.coords()[0],
|
||||
device.id(), device.task_id(), device.coords()[0],
|
||||
device.coords()[1], device.coords()[2], device.core_on_chip());
|
||||
});
|
||||
} // NOLINT(readability/fn_size)
|
||||
|
@ -55,6 +55,8 @@ limitations under the License.
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/python/lib/core/bfloat16.h"
|
||||
|
||||
// TODO(phawkins): remove host_id properties after JAX is update to avoid them.
|
||||
|
||||
namespace xla {
|
||||
namespace {
|
||||
|
||||
@ -124,9 +126,12 @@ PYBIND11_MODULE(xla_extension, m) {
|
||||
"id", &PjRtDevice::id,
|
||||
"Integer ID of this device.\n\nUnique across all available devices "
|
||||
"of this type, including remote devices on multi-host platforms.")
|
||||
.def_property_readonly("host_id", &PjRtDevice::host_id,
|
||||
"Integer ID of this device's host.\n\n"
|
||||
"This is always 0 except on multi-host platforms.")
|
||||
.def_property_readonly("host_id", &PjRtDevice::task_id,
|
||||
"Integer ID of this device's task.\n\n"
|
||||
"This is always 0 except on multi-task platforms.")
|
||||
.def_property_readonly("task_id", &PjRtDevice::task_id,
|
||||
"Integer ID of this device's task.\n\n"
|
||||
"This is always 0 except on multi-task platforms.")
|
||||
.def_property_readonly("platform",
|
||||
[](const PjRtDevice& device) {
|
||||
return device.client()->platform_name();
|
||||
@ -174,7 +179,8 @@ PYBIND11_MODULE(xla_extension, m) {
|
||||
|
||||
py::class_<PjRtTpuDevice, PjRtDevice, ClientAndPtr<PjRtTpuDevice>>(
|
||||
m, "TpuDevice")
|
||||
.def_property_readonly("host_id", &PjRtTpuDevice::host_id)
|
||||
.def_property_readonly("host_id", &PjRtTpuDevice::task_id)
|
||||
.def_property_readonly("task_id", &PjRtTpuDevice::task_id)
|
||||
.def_property_readonly(
|
||||
"coords",
|
||||
[](const PjRtTpuDevice& device) -> pybind11::tuple {
|
||||
@ -187,7 +193,7 @@ PYBIND11_MODULE(xla_extension, m) {
|
||||
.def("__repr__", [](const PjRtTpuDevice& device) {
|
||||
return absl::StrFormat(
|
||||
"TpuDevice(id=%i, host=%i, coords=(%s), core_on_chip=%i)",
|
||||
device.id(), device.host_id(), absl::StrJoin(device.coords(), ","),
|
||||
device.id(), device.task_id(), absl::StrJoin(device.coords(), ","),
|
||||
device.core_on_chip());
|
||||
});
|
||||
|
||||
@ -217,7 +223,8 @@ PYBIND11_MODULE(xla_extension, m) {
|
||||
.def("devices", &PyClient::Devices)
|
||||
.def("local_devices", &PyClient::LocalDevices)
|
||||
.def("live_buffers", &PyClient::LiveBuffers)
|
||||
.def("host_id", &PyClient::host_id)
|
||||
.def("host_id", &PyClient::task_id)
|
||||
.def("task_id", &PyClient::task_id)
|
||||
.def("get_default_device_assignment",
|
||||
&PyClient::GetDefaultDeviceAssignment)
|
||||
// TODO(skye): delete after all callers can handle 2D output
|
||||
|
Loading…
x
Reference in New Issue
Block a user