From 55f1184704b7eaa30238b5f2f3173ad11869daa4 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Tue, 26 Jan 2021 17:14:02 -0800 Subject: [PATCH] [PJRT] Rename PjRtDevice::host_id and PjRtClient::host_id to ...::task_id instead. Fixes a TODO in PJRT. In general, we may have multiple tasks per physical host, so "task_id" is more accurate. PiperOrigin-RevId: 353984838 Change-Id: I59340c15181374a3db488b238ed81836a94003d0 --- tensorflow/compiler/xla/pjrt/cpu_device.cc | 2 +- .../compiler/xla/pjrt/interpreter_device.cc | 2 +- tensorflow/compiler/xla/pjrt/pjrt_client.h | 13 +++++------ .../xla/pjrt/pjrt_stream_executor_client.cc | 8 +++---- .../xla/pjrt/pjrt_stream_executor_client.h | 14 ++++++------ tensorflow/compiler/xla/pjrt/tpu_client.cc | 14 ++++++------ tensorflow/compiler/xla/pjrt/tpu_client.h | 6 ++--- .../xla/python/outfeed_receiver_test.cc | 2 +- tensorflow/compiler/xla/python/py_client.h | 2 +- .../python/tpu_driver/client/tpu_client.cc | 22 +++++++++---------- .../xla/python/tpu_driver/client/tpu_client.h | 12 +++++----- .../tpu_driver/client/tpu_client_extension.cc | 5 +++-- tensorflow/compiler/xla/python/xla.cc | 19 +++++++++++----- 13 files changed, 64 insertions(+), 57 deletions(-) diff --git a/tensorflow/compiler/xla/pjrt/cpu_device.cc b/tensorflow/compiler/xla/pjrt/cpu_device.cc index 72da2d2b0dd..930b10179be 100644 --- a/tensorflow/compiler/xla/pjrt/cpu_device.cc +++ b/tensorflow/compiler/xla/pjrt/cpu_device.cc @@ -58,7 +58,7 @@ StatusOr> GetCpuClient(bool asynchronous) { } return std::unique_ptr(std::make_unique( - kCpuName, client, std::move(devices), /*host_id=*/0, + kCpuName, client, std::move(devices), /*task_id=*/0, /*allocator=*/nullptr, /*host_memory_allocator=*/nullptr, /*should_stage_host_to_device_transfers=*/false, /*gpu_run_options=*/nullptr)); diff --git a/tensorflow/compiler/xla/pjrt/interpreter_device.cc b/tensorflow/compiler/xla/pjrt/interpreter_device.cc index 818740ca105..599ea9296c7 100644 --- a/tensorflow/compiler/xla/pjrt/interpreter_device.cc +++ b/tensorflow/compiler/xla/pjrt/interpreter_device.cc @@ -52,7 +52,7 @@ StatusOr> GetInterpreterClient() { devices.push_back(std::move(device)); return std::unique_ptr(std::make_unique( - "interpreter", client, std::move(devices), /*host_id=*/0, + "interpreter", client, std::move(devices), /*task_id=*/0, /*allocator=*/nullptr, /*host_memory_allocator=*/nullptr, /*should_stage_host_to_device_transfers=*/false, /*gpu_run_options=*/nullptr)); diff --git a/tensorflow/compiler/xla/pjrt/pjrt_client.h b/tensorflow/compiler/xla/pjrt/pjrt_client.h index d0afcf356ff..34d0ac25ae3 100644 --- a/tensorflow/compiler/xla/pjrt/pjrt_client.h +++ b/tensorflow/compiler/xla/pjrt/pjrt_client.h @@ -72,11 +72,11 @@ class PjRtDevice { // hosts' devices. This is the ID that should be used in a DeviceAssignment. virtual int id() const = 0; - // The task ID of this device according to TpuTopology. This is not the same - // as PjRtClient::host_id() in a multi-task setting, where each client can see - // devices from all tasks, but only a subset of them are addressable and have - // the same task_id as the client. - virtual int host_id() const = 0; + // The task ID of this device according to TpuTopology. This is not always + // identical to PjRtClient::task_id() in a multi-task setting, where each + // client can see devices from all tasks, but only a subset of them are + // addressable and have the same task_id as the client. + virtual int task_id() const = 0; // Opaque hardware ID, e.g., the CUDA device number, useful for identifying // which GPU when interacting with non-JAX code. In general, not guaranteed to @@ -140,9 +140,8 @@ class PjRtClient { public: virtual ~PjRtClient() = default; - // TODO(zhangqiaorjc): Rename to task_id. // Return the task id of this client. In single-task setting, always 0. - virtual int host_id() const = 0; + virtual int task_id() const = 0; // Return the number of devices in the entire computation. In multi-headed // client setting, some are addressable by this client, some are not. In a diff --git a/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc b/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc index cfe2303aade..d7f1eb42c63 100644 --- a/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc +++ b/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc @@ -202,7 +202,7 @@ static int DefaultThreadPoolSize() { PjRtStreamExecutorClient::PjRtStreamExecutorClient( std::string platform_name, LocalClient* client, - std::vector> devices, int host_id, + std::vector> devices, int task_id, std::unique_ptr allocator, std::unique_ptr host_memory_allocator, bool should_stage_host_to_device_transfers, @@ -212,7 +212,7 @@ PjRtStreamExecutorClient::PjRtStreamExecutorClient( client_(client), host_memory_allocator_(std::move(host_memory_allocator)), owned_devices_(std::move(devices)), - host_id_(host_id), + task_id_(task_id), owned_allocator_(std::move(allocator)), should_stage_host_to_device_transfers_( should_stage_host_to_device_transfers), @@ -1815,7 +1815,7 @@ PjRtStreamExecutorExecutable::ExecuteHelper( (*device_assignment)(0, 0) = device->id(); } - CHECK_EQ(device->host_id(), client_->host_id()); + CHECK_EQ(device->task_id(), client_->task_id()); int device_ordinal = tensorflow::down_cast(device) ->local_device_state() ->device_ordinal(); @@ -2093,7 +2093,7 @@ StatusOr> PjRtStreamExecutorClient::Compile( for (int partition = 0; partition < num_partitions; ++partition) { int device_id = (*device_assignment)(replica, partition); TF_ASSIGN_OR_RETURN(PjRtDevice * device, LookupDevice(device_id)); - if (device->host_id() != host_id()) { + if (device->task_id() != task_id()) { VLOG(3) << "Non-local device: " << device_id; continue; } diff --git a/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h b/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h index d87f77f402a..d01112a0296 100644 --- a/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h +++ b/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h @@ -58,12 +58,12 @@ class PjRtStreamExecutorDevice : public PjRtDevice { public: explicit PjRtStreamExecutorDevice( int id, std::unique_ptr local_device_state, - std::string device_kind, int host_id = 0) + std::string device_kind, int task_id = 0) : id_(id), device_ordinal_( local_device_state ? local_device_state->device_ordinal() : -1), local_device_state_(std::move(local_device_state)), - host_id_(host_id), + task_id_(task_id), device_kind_(std::move(device_kind)) {} ~PjRtStreamExecutorDevice() override {} @@ -73,7 +73,7 @@ class PjRtStreamExecutorDevice : public PjRtDevice { client_ = client; } - int host_id() const override { return host_id_; } + int task_id() const override { return task_id_; } // Return `platform_id` from client. PjRtPlatformId platform_id() const; @@ -113,7 +113,7 @@ class PjRtStreamExecutorDevice : public PjRtDevice { const int id_; const int device_ordinal_; // -1 means not local. const std::unique_ptr local_device_state_; - const int host_id_; + const int task_id_; const std::string device_kind_; PjRtClient* client_ = nullptr; }; @@ -124,13 +124,13 @@ class PjRtStreamExecutorClient : public PjRtClient { explicit PjRtStreamExecutorClient( std::string platform_name, LocalClient* client, std::vector> devices, - int host_id, std::unique_ptr allocator, + int task_id, std::unique_ptr allocator, std::unique_ptr host_memory_allocator, bool should_stage_host_to_device_transfers, std::unique_ptr gpu_run_options); ~PjRtStreamExecutorClient() override = default; - int host_id() const override { return host_id_; } + int task_id() const override { return task_id_; } int device_count() const override { return devices_.size(); } int addressable_device_count() const override { @@ -255,7 +255,7 @@ class PjRtStreamExecutorClient : public PjRtClient { std::map id_to_device_; // Local devices indexed by local device ordinal. std::vector addressable_devices_; - int host_id_; + int task_id_; se::DeviceMemoryAllocator* allocator_; std::unique_ptr owned_allocator_; diff --git a/tensorflow/compiler/xla/pjrt/tpu_client.cc b/tensorflow/compiler/xla/pjrt/tpu_client.cc index 068bd73a0d0..617b97e7359 100644 --- a/tensorflow/compiler/xla/pjrt/tpu_client.cc +++ b/tensorflow/compiler/xla/pjrt/tpu_client.cc @@ -77,7 +77,7 @@ class PjRtTpuClient : public PjRtStreamExecutorClient { public: PjRtTpuClient(LocalClient* client, std::vector> devices, - int host_id); + int task_id); StatusOr GetDefaultDeviceAssignment( int num_replicas, int num_partitions) const override; @@ -90,8 +90,8 @@ class PjRtTpuClient : public PjRtStreamExecutorClient { PjRtTpuClient::PjRtTpuClient( LocalClient* client, - std::vector> devices, int host_id) - : PjRtStreamExecutorClient(kTpuName, client, std::move(devices), host_id, + std::vector> devices, int task_id) + : PjRtStreamExecutorClient(kTpuName, client, std::move(devices), task_id, /*allocator=*/nullptr, /*host_memory_allocator=*/nullptr, /*should_stage_host_to_device_transfers=*/false, @@ -155,7 +155,7 @@ StatusOr>> GetTpuDevices( auto it = core_id_to_device_ordinal.find(core.Id()); int device_ordinal = (it != core_id_to_device_ordinal.end()) ? it->second : -1; - int host_id = topology.IdForHost(core.host_coordinates()); + int task_id = topology.IdForHost(core.host_coordinates()); const tf_tpu::TpuDimensionsExternal coords = core.chip_coordinates(); std::array coords_array = {coords.x, coords.y, coords.z}; std::unique_ptr local_device_state; @@ -163,7 +163,7 @@ StatusOr>> GetTpuDevices( local_device_state = std::move(local_device_states[device_ordinal]); } auto device = absl::make_unique( - core, std::move(local_device_state), host_id, coords_array, + core, std::move(local_device_state), task_id, coords_array, std::string(tf_tpu::TpuVersionEnumToString(topology.version()))); devices.push_back(std::move(device)); } @@ -214,10 +214,10 @@ StatusOr> GetTpuClient( TF_ASSIGN_OR_RETURN(auto devices, GetTpuDevices(client, std::move(local_device_states))); - int host_id = platform->GetTpuHostLocation().Id(); + int task_id = platform->GetTpuHostLocation().Id(); return std::shared_ptr( - absl::make_unique(client, std::move(devices), host_id)); + absl::make_unique(client, std::move(devices), task_id)); } } // namespace xla diff --git a/tensorflow/compiler/xla/pjrt/tpu_client.h b/tensorflow/compiler/xla/pjrt/tpu_client.h index d9847bb85de..ead33b0fe16 100644 --- a/tensorflow/compiler/xla/pjrt/tpu_client.h +++ b/tensorflow/compiler/xla/pjrt/tpu_client.h @@ -30,10 +30,10 @@ class PjRtTpuDevice : public PjRtStreamExecutorDevice { public: PjRtTpuDevice(const tensorflow::tpu::TpuCoreLocationExternal core, std::unique_ptr local_device_state, - int host_id, const std::array& coords, + int task_id, const std::array& coords, std::string device_kind) : PjRtStreamExecutorDevice(core.Id(), std::move(local_device_state), - std::move(device_kind), host_id), + std::move(device_kind), task_id), core_(core), coords_(coords) {} @@ -42,7 +42,7 @@ class PjRtTpuDevice : public PjRtStreamExecutorDevice { const tensorflow::tpu::TpuCoreLocationExternal core() const { return core_; } std::string DebugString() const override { - return absl::StrFormat("TPU_%i(host=%i,(%i,%i,%i,%i))", id(), host_id(), + return absl::StrFormat("TPU_%i(host=%i,(%i,%i,%i,%i))", id(), task_id(), coords_[0], coords_[1], coords_[2], core_.index()); } diff --git a/tensorflow/compiler/xla/python/outfeed_receiver_test.cc b/tensorflow/compiler/xla/python/outfeed_receiver_test.cc index 8e60553d290..b47be147236 100644 --- a/tensorflow/compiler/xla/python/outfeed_receiver_test.cc +++ b/tensorflow/compiler/xla/python/outfeed_receiver_test.cc @@ -98,7 +98,7 @@ StatusOr> GetCpuClientWithNonLocalDevice() { devices.push_back(absl::make_unique(1, nullptr)); return std::unique_ptr(std::make_unique( - kCpuName, client, std::move(devices), /*host_id=*/0, + kCpuName, client, std::move(devices), /*task_id=*/0, /*allocator=*/nullptr, /*host_memory_allocator=*/nullptr, /*should_stage_host_to_device_transfers=*/false, /*gpu_run_options=*/nullptr)); diff --git a/tensorflow/compiler/xla/python/py_client.h b/tensorflow/compiler/xla/python/py_client.h index 61a3bcb7619..d959756ce69 100644 --- a/tensorflow/compiler/xla/python/py_client.h +++ b/tensorflow/compiler/xla/python/py_client.h @@ -101,7 +101,7 @@ class PyClient : public std::enable_shared_from_this { return pjrt_client_->addressable_device_count(); } int device_count() const { return pjrt_client_->device_count(); } - int host_id() const { return pjrt_client_->host_id(); } + int task_id() const { return pjrt_client_->task_id(); } std::vector> Devices(); std::vector> LocalDevices(); diff --git a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.cc b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.cc index 4656e9dfa7a..0f10463fdd0 100644 --- a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.cc +++ b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.cc @@ -35,15 +35,15 @@ limitations under the License. namespace xla { -TpuDevice::TpuDevice(int id, int host_id, const std::array& coords, +TpuDevice::TpuDevice(int id, int task_id, const std::array& coords, int core_on_chip) : id_(id), - host_id_(host_id), + task_id_(task_id), coords_(coords), core_on_chip_(core_on_chip) {} std::string TpuDevice::DebugString() const { - return absl::StrFormat("TPU_%i(host=%i,(%i,%i,%i,%i))", id(), host_id(), + return absl::StrFormat("TPU_%i(host=%i,(%i,%i,%i,%i))", id(), task_id(), coords_[0], coords_[1], coords_[2], core_on_chip_); } @@ -53,10 +53,10 @@ TpuDevice::GetTpuDevices(const tpu_driver::SystemInfo& system_info) { for (const auto& chip : system_info.tpu_chip()) { auto& coord = chip.chip_coord(); std::array coords_array = {coord.x(), coord.y(), coord.z()}; - int host_id = chip.host_id(); + int task_id = chip.host_id(); for (const auto& core : chip.core()) { auto device = std::make_shared( - core.id(), host_id, coords_array, core.core_on_chip_index()); + core.id(), task_id, coords_array, core.core_on_chip_index()); devices.push_back(device); } } @@ -89,17 +89,17 @@ StatusOr> PyTpuClient::Get( PyTpuClient::PyTpuClient(std::string platform_name, std::unique_ptr driver, std::vector> devices, - int host_id) + int task_id) : platform_name_(std::move(platform_name)), driver_(std::move(driver)), devices_(std::move(devices)), - host_id_(host_id) { + task_id_(task_id) { for (const std::shared_ptr& device : devices_) { CHECK(id_to_device_.insert({device->id(), device}).second) << "Duplicate device id: " << device->id(); - if (device->host_id() == host_id_) { - LOG(INFO) << "Detected local device, host id: " << host_id_ + if (device->task_id() == task_id_) { + LOG(INFO) << "Detected local device, host id: " << task_id_ << ". device id: " << device->id(); local_devices_.push_back(device); } else { @@ -522,7 +522,7 @@ PyTpuExecutable::PyTpuExecutable( for (int partition = 0; partition < num_partitions; ++partition) { int device_id = device_assignment_(replica, partition); std::shared_ptr device = LookupDevice(*client_, device_id); - if (device->host_id() != client_->host_id()) { + if (device->task_id() != client_->task_id()) { VLOG(3) << "Non-local device: " << device_id; continue; } @@ -547,7 +547,7 @@ PyTpuExecutable::ExecuteResult PyTpuExecutable::ExecuteHelper( int partition, const RunId& run_id) { const int device_id = device_assignment_(replica, partition); std::shared_ptr device = LookupDevice(*client_, device_id); - CHECK_EQ(device->host_id(), client_->host_id()); + CHECK_EQ(device->task_id(), client_->task_id()); tensorflow::profiler::TraceMe traceme("PyTpuExecutable::Execute"); VLOG(3) << "Replica " << replica << ", partition " << partition << " mapped to device id for execution: " << device_id; diff --git a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.h b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.h index 73ac6dc40f4..decd4f852a6 100644 --- a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.h +++ b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.h @@ -42,7 +42,7 @@ constexpr char kTpuPlatform[] = "tpu"; class TpuDevice : public PjRtDevice { public: - TpuDevice(int id, int host_id, const std::array& coords, + TpuDevice(int id, int task_id, const std::array& coords, int core_on_chip); const std::array& coords() const { return coords_; } @@ -59,7 +59,7 @@ class TpuDevice : public PjRtDevice { int id() const override { return id_; } - int host_id() const override { return host_id_; } + int task_id() const override { return task_id_; } int local_hardware_id() const override { return -1; } @@ -75,7 +75,7 @@ class TpuDevice : public PjRtDevice { private: const int id_; - const int host_id_; + const int task_id_; const std::array coords_; const std::string device_kind_ = "Cloud TPU"; // Index of the core of the same chip. @@ -92,7 +92,7 @@ class PyTpuClient { explicit PyTpuClient(std::string platform_name, std::unique_ptr driver, std::vector> devices, - int host_id); + int task_id); virtual ~PyTpuClient() = default; PyTpuClient(const PyTpuClient&) = delete; @@ -115,7 +115,7 @@ class PyTpuClient { const std::map>& id_to_device() const { return id_to_device_; } - int host_id() const { return host_id_; } + int task_id() const { return task_id_; } const std::string& platform_name() const { return platform_name_; } StatusOr ChooseCompactLayoutForShape(Shape subshape) { @@ -140,7 +140,7 @@ class PyTpuClient { std::map> id_to_device_; // Local devices indexed by local device ordinal. std::vector> local_devices_; - int host_id_; + int task_id_; // A thread pool for scheduling core executions in parallel. std::unique_ptr pool_; diff --git a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client_extension.cc b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client_extension.cc index 92ee1825e56..24559cc117e 100644 --- a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client_extension.cc +++ b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client_extension.cc @@ -32,7 +32,8 @@ PYBIND11_MODULE(tpu_client_extension, m) { .def("local_device_count", &PyTpuClient::local_device_count) .def("devices", &PyTpuClient::devices) .def("local_devices", &PyTpuClient::local_devices) - .def("host_id", &PyTpuClient::host_id) + .def("host_id", &PyTpuClient::task_id) + .def("task_id", &PyTpuClient::task_id) .def("get_default_device_assignment", [](PyTpuClient* client, int num_replicas, int num_partitions) -> StatusOr< @@ -213,7 +214,7 @@ PYBIND11_MODULE(tpu_client_extension, m) { .def("__repr__", [](const TpuDevice& device) { return absl::StrFormat( "TpuDevice(id=%i, host_id=%i, coords=(%i,%i,%i), core_on_chip=%i)", - device.id(), device.host_id(), device.coords()[0], + device.id(), device.task_id(), device.coords()[0], device.coords()[1], device.coords()[2], device.core_on_chip()); }); } // NOLINT(readability/fn_size) diff --git a/tensorflow/compiler/xla/python/xla.cc b/tensorflow/compiler/xla/python/xla.cc index e409e68e94e..530132f3e25 100644 --- a/tensorflow/compiler/xla/python/xla.cc +++ b/tensorflow/compiler/xla/python/xla.cc @@ -55,6 +55,8 @@ limitations under the License. #include "tensorflow/core/platform/errors.h" #include "tensorflow/python/lib/core/bfloat16.h" +// TODO(phawkins): remove host_id properties after JAX is update to avoid them. + namespace xla { namespace { @@ -124,9 +126,12 @@ PYBIND11_MODULE(xla_extension, m) { "id", &PjRtDevice::id, "Integer ID of this device.\n\nUnique across all available devices " "of this type, including remote devices on multi-host platforms.") - .def_property_readonly("host_id", &PjRtDevice::host_id, - "Integer ID of this device's host.\n\n" - "This is always 0 except on multi-host platforms.") + .def_property_readonly("host_id", &PjRtDevice::task_id, + "Integer ID of this device's task.\n\n" + "This is always 0 except on multi-task platforms.") + .def_property_readonly("task_id", &PjRtDevice::task_id, + "Integer ID of this device's task.\n\n" + "This is always 0 except on multi-task platforms.") .def_property_readonly("platform", [](const PjRtDevice& device) { return device.client()->platform_name(); @@ -174,7 +179,8 @@ PYBIND11_MODULE(xla_extension, m) { py::class_>( m, "TpuDevice") - .def_property_readonly("host_id", &PjRtTpuDevice::host_id) + .def_property_readonly("host_id", &PjRtTpuDevice::task_id) + .def_property_readonly("task_id", &PjRtTpuDevice::task_id) .def_property_readonly( "coords", [](const PjRtTpuDevice& device) -> pybind11::tuple { @@ -187,7 +193,7 @@ PYBIND11_MODULE(xla_extension, m) { .def("__repr__", [](const PjRtTpuDevice& device) { return absl::StrFormat( "TpuDevice(id=%i, host=%i, coords=(%s), core_on_chip=%i)", - device.id(), device.host_id(), absl::StrJoin(device.coords(), ","), + device.id(), device.task_id(), absl::StrJoin(device.coords(), ","), device.core_on_chip()); }); @@ -217,7 +223,8 @@ PYBIND11_MODULE(xla_extension, m) { .def("devices", &PyClient::Devices) .def("local_devices", &PyClient::LocalDevices) .def("live_buffers", &PyClient::LiveBuffers) - .def("host_id", &PyClient::host_id) + .def("host_id", &PyClient::task_id) + .def("task_id", &PyClient::task_id) .def("get_default_device_assignment", &PyClient::GetDefaultDeviceAssignment) // TODO(skye): delete after all callers can handle 2D output