Add DeviceAssignment::LogicalIdsForDevice
method.
PiperOrigin-RevId: 344493673 Change-Id: I04283f13a442f84764aae57f91d821a581d25fff
This commit is contained in:
parent
d2c44948da
commit
1900d79a7e
@ -40,26 +40,33 @@ using absl::StrCat;
|
||||
|
||||
namespace xla {
|
||||
|
||||
StatusOr<int> DeviceAssignment::ReplicaIdForDevice(
|
||||
StatusOr<std::pair<int, int>> DeviceAssignment::LogicalIdsForDevice(
|
||||
GlobalDeviceId device_id) const {
|
||||
absl::optional<int> replica_id;
|
||||
for (int64 r = 0; r < replica_count(); ++r) {
|
||||
for (int64 c = 0; c < computation_count(); ++c) {
|
||||
absl::optional<std::pair<int, int>> logical_ids;
|
||||
for (int r = 0; r < replica_count(); ++r) {
|
||||
for (int c = 0; c < computation_count(); ++c) {
|
||||
if ((*this)(r, c) == device_id.value()) {
|
||||
if (replica_id.has_value()) {
|
||||
if (logical_ids.has_value()) {
|
||||
return InternalError(
|
||||
"Device %d appears twice in DeviceAssignment: %s",
|
||||
device_id.value(), ToString());
|
||||
}
|
||||
replica_id = r;
|
||||
logical_ids.emplace(r, c);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!replica_id.has_value()) {
|
||||
if (logical_ids.has_value()) {
|
||||
return *logical_ids;
|
||||
} else {
|
||||
return InternalError("Device %d doesn't appear in DeviceAssignment: %s",
|
||||
device_id.value(), ToString());
|
||||
}
|
||||
return *replica_id;
|
||||
}
|
||||
|
||||
StatusOr<int> DeviceAssignment::ReplicaIdForDevice(
|
||||
GlobalDeviceId device_id) const {
|
||||
TF_ASSIGN_OR_RETURN(auto logical_ids, LogicalIdsForDevice(device_id));
|
||||
return logical_ids.first;
|
||||
}
|
||||
|
||||
Status DeviceAssignment::Serialize(DeviceAssignmentProto* proto) const {
|
||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/xla/array2d.h"
|
||||
@ -48,6 +49,9 @@ class DeviceAssignment : public Array2D<int> {
|
||||
int replica_count() const { return height(); }
|
||||
int computation_count() const { return width(); }
|
||||
|
||||
// Finds the (replica ID, computation ID) pair for the given device.
|
||||
StatusOr<std::pair<int, int>> LogicalIdsForDevice(
|
||||
GlobalDeviceId device_id) const;
|
||||
// Finds the replica ID for the given device.
|
||||
StatusOr<int> ReplicaIdForDevice(GlobalDeviceId device_id) const;
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user