[XLA][NFC] Add a struct to describe the LogicalID of a device.

- Add a DeviceAssignment::LogicalID struct with replica_id and computation_id members
  to describe the logical-id of a device (instead of using a pair of ints).
- Change LogicalIdsForDevice to LogicalIdForDevice since the function now returns a
  single logical id.
- This now allows the return value to be defined in TF_ASSIGN_OR_RETURN().

PiperOrigin-RevId: 358956636
Change-Id: I7430588bc88c8b9d066c49651ee418fc33e08d28
This commit is contained in:
Rahul Joshi 2021-02-22 19:36:55 -08:00 committed by TensorFlower Gardener
parent b64965ce3d
commit ac3abfbaa1
5 changed files with 29 additions and 23 deletions

View File

@ -120,11 +120,10 @@ StatusOr<std::vector<GlobalDeviceId>> GetParticipatingDevices(
int replica_count = device_assignment.replica_count();
int partition_count = device_assignment.computation_count();
std::pair<int, int> logical_ids;
TF_ASSIGN_OR_RETURN(logical_ids,
device_assignment.LogicalIdsForDevice(device_id));
int current_replica_id = logical_ids.first;
int current_partition_id = logical_ids.second;
TF_ASSIGN_OR_RETURN(const DeviceAssignment::LogicalID logical_id,
device_assignment.LogicalIdForDevice(device_id));
int current_replica_id = logical_id.replica_id;
int current_partition_id = logical_id.computation_id;
std::vector<GlobalDeviceId> participants;
switch (group_mode) {

View File

@ -42,23 +42,23 @@ using absl::StrCat;
namespace xla {
StatusOr<std::pair<int, int>> DeviceAssignment::LogicalIdsForDevice(
StatusOr<DeviceAssignment::LogicalID> DeviceAssignment::LogicalIdForDevice(
GlobalDeviceId device_id) const {
absl::optional<std::pair<int, int>> logical_ids;
absl::optional<DeviceAssignment::LogicalID> logical_id;
for (int r = 0; r < replica_count(); ++r) {
for (int c = 0; c < computation_count(); ++c) {
if ((*this)(r, c) == device_id.value()) {
if (logical_ids.has_value()) {
if (logical_id.has_value()) {
return InternalError(
"Device %d appears twice in DeviceAssignment: %s",
device_id.value(), ToString());
}
logical_ids.emplace(r, c);
logical_id.emplace(DeviceAssignment::LogicalID{r, c});
}
}
}
if (logical_ids.has_value()) {
return *logical_ids;
if (logical_id.has_value()) {
return *logical_id;
} else {
return InternalError("Device %d doesn't appear in DeviceAssignment: %s",
device_id.value(), ToString());
@ -67,8 +67,9 @@ StatusOr<std::pair<int, int>> DeviceAssignment::LogicalIdsForDevice(
StatusOr<int> DeviceAssignment::ReplicaIdForDevice(
GlobalDeviceId device_id) const {
TF_ASSIGN_OR_RETURN(auto logical_ids, LogicalIdsForDevice(device_id));
return logical_ids.first;
TF_ASSIGN_OR_RETURN(const LogicalID logical_id,
LogicalIdForDevice(device_id));
return logical_id.replica_id;
}
Status DeviceAssignment::Serialize(DeviceAssignmentProto* proto) const {

View File

@ -50,9 +50,14 @@ class DeviceAssignment : public Array2D<int> {
int replica_count() const { return height(); }
int computation_count() const { return width(); }
// The logical ID of a device is its (replica ID, computation ID) pair.
struct LogicalID {
int replica_id;
int computation_id;
};
// Finds the (replica ID, computation ID) pair for the given device.
StatusOr<std::pair<int, int>> LogicalIdsForDevice(
GlobalDeviceId device_id) const;
StatusOr<LogicalID> LogicalIdForDevice(GlobalDeviceId device_id) const;
// Finds the replica ID for the given device.
StatusOr<int> ReplicaIdForDevice(GlobalDeviceId device_id) const;

View File

@ -139,12 +139,12 @@ Status NcclCollectiveThunk::ExecuteOnStream(const ExecuteParams& params) {
params.run_id, std::move(participants), local_participants.size(),
config().collective_op_kind, config().op_id);
if (VLOG_IS_ON(2)) {
std::pair<int, int> logical_ids;
TF_ASSIGN_OR_RETURN(
logical_ids, params.device_assn->LogicalIdsForDevice(global_device_id));
DeviceAssignment::LogicalID logical_id,
params.device_assn->LogicalIdForDevice(global_device_id));
VLOG(2) << "global device " << global_device_id << ", (r"
<< logical_ids.first << ", p" << logical_ids.second << ") key "
<< rendezvous_key.ToString() << "\n";
<< logical_id.replica_id << ", p" << logical_id.computation_id
<< ") key " << rendezvous_key.ToString() << "\n";
}
int device_ordinal = params.stream->parent()->device_ordinal();

View File

@ -24,11 +24,12 @@ Status ReplicaOrPartitionIdThunk::ExecuteOnStream(const ExecuteParams& params) {
auto dest_addr = params.buffer_allocations->GetDeviceAddress(dest_);
TF_ASSIGN_OR_RETURN(GlobalDeviceId global_device_id,
TF_ASSIGN_OR_RETURN(const GlobalDeviceId global_device_id,
params.GetGlobalDeviceId());
TF_ASSIGN_OR_RETURN(auto logical_ids, params.device_assn->LogicalIdsForDevice(
global_device_id));
int id = kind() == Kind::kReplicaId ? logical_ids.first : logical_ids.second;
TF_ASSIGN_OR_RETURN(const DeviceAssignment::LogicalID logical_id,
params.device_assn->LogicalIdForDevice(global_device_id));
int id = kind() == Kind::kReplicaId ? logical_id.replica_id
: logical_id.computation_id;
params.stream->ThenMemset32(&dest_addr, id, /*size=*/4);
return Status::OK();
}