diff --git a/tensorflow/compiler/xla/service/computation_placer.cc b/tensorflow/compiler/xla/service/computation_placer.cc index 8b464a734ad..f9aaa1a676e 100644 --- a/tensorflow/compiler/xla/service/computation_placer.cc +++ b/tensorflow/compiler/xla/service/computation_placer.cc @@ -40,26 +40,33 @@ using absl::StrCat; namespace xla { -StatusOr DeviceAssignment::ReplicaIdForDevice( +StatusOr> DeviceAssignment::LogicalIdsForDevice( GlobalDeviceId device_id) const { - absl::optional replica_id; - for (int64 r = 0; r < replica_count(); ++r) { - for (int64 c = 0; c < computation_count(); ++c) { + absl::optional> logical_ids; + for (int r = 0; r < replica_count(); ++r) { + for (int c = 0; c < computation_count(); ++c) { if ((*this)(r, c) == device_id.value()) { - if (replica_id.has_value()) { + if (logical_ids.has_value()) { return InternalError( "Device %d appears twice in DeviceAssignment: %s", device_id.value(), ToString()); } - replica_id = r; + logical_ids.emplace(r, c); } } } - if (!replica_id.has_value()) { + if (logical_ids.has_value()) { + return *logical_ids; + } else { return InternalError("Device %d doesn't appear in DeviceAssignment: %s", device_id.value(), ToString()); } - return *replica_id; +} + +StatusOr DeviceAssignment::ReplicaIdForDevice( + GlobalDeviceId device_id) const { + TF_ASSIGN_OR_RETURN(auto logical_ids, LogicalIdsForDevice(device_id)); + return logical_ids.first; } Status DeviceAssignment::Serialize(DeviceAssignmentProto* proto) const { diff --git a/tensorflow/compiler/xla/service/computation_placer.h b/tensorflow/compiler/xla/service/computation_placer.h index 72a5f935524..e6b591b7e23 100644 --- a/tensorflow/compiler/xla/service/computation_placer.h +++ b/tensorflow/compiler/xla/service/computation_placer.h @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include #include #include "tensorflow/compiler/xla/array2d.h" @@ -48,6 +49,9 @@ class DeviceAssignment : public Array2D { int replica_count() const { return height(); } int computation_count() const { return width(); } + // Finds the (replica ID, computation ID) pair for the given device. + StatusOr> LogicalIdsForDevice( + GlobalDeviceId device_id) const; // Finds the replica ID for the given device. StatusOr ReplicaIdForDevice(GlobalDeviceId device_id) const;