Add DeviceAssignment::LogicalIdsForDevice method.

PiperOrigin-RevId: 344493673
Change-Id: I04283f13a442f84764aae57f91d821a581d25fff
This commit is contained in:
Chris Jones 2020-11-27 00:04:10 -08:00 committed by TensorFlower Gardener
parent d2c44948da
commit 1900d79a7e
2 changed files with 19 additions and 8 deletions

View File

@ -40,26 +40,33 @@ using absl::StrCat;
namespace xla {
StatusOr<int> DeviceAssignment::ReplicaIdForDevice(
StatusOr<std::pair<int, int>> DeviceAssignment::LogicalIdsForDevice(
GlobalDeviceId device_id) const {
absl::optional<int> replica_id;
for (int64 r = 0; r < replica_count(); ++r) {
for (int64 c = 0; c < computation_count(); ++c) {
absl::optional<std::pair<int, int>> logical_ids;
for (int r = 0; r < replica_count(); ++r) {
for (int c = 0; c < computation_count(); ++c) {
if ((*this)(r, c) == device_id.value()) {
if (replica_id.has_value()) {
if (logical_ids.has_value()) {
return InternalError(
"Device %d appears twice in DeviceAssignment: %s",
device_id.value(), ToString());
}
replica_id = r;
logical_ids.emplace(r, c);
}
}
}
if (!replica_id.has_value()) {
if (logical_ids.has_value()) {
return *logical_ids;
} else {
return InternalError("Device %d doesn't appear in DeviceAssignment: %s",
device_id.value(), ToString());
}
return *replica_id;
}
StatusOr<int> DeviceAssignment::ReplicaIdForDevice(
GlobalDeviceId device_id) const {
TF_ASSIGN_OR_RETURN(auto logical_ids, LogicalIdsForDevice(device_id));
return logical_ids.first;
}
Status DeviceAssignment::Serialize(DeviceAssignmentProto* proto) const {

View File

@ -18,6 +18,7 @@ limitations under the License.
#include <map>
#include <memory>
#include <utility>
#include <vector>
#include "tensorflow/compiler/xla/array2d.h"
@ -48,6 +49,9 @@ class DeviceAssignment : public Array2D<int> {
int replica_count() const { return height(); }
int computation_count() const { return width(); }
// Finds the (replica ID, computation ID) pair for the given device.
StatusOr<std::pair<int, int>> LogicalIdsForDevice(
GlobalDeviceId device_id) const;
// Finds the replica ID for the given device.
StatusOr<int> ReplicaIdForDevice(GlobalDeviceId device_id) const;