Supports lookup devices by fullname either in the canonical form or the
legacy form. This makes DeviceSet behaves the same as DeviceMgr's FindDevice method. PiperOrigin-RevId: 163300346
This commit is contained in:
parent
631a364cd1
commit
dd1f0cdddb
tensorflow/core
@ -30,22 +30,9 @@ DeviceMgr::DeviceMgr(const std::vector<Device*>& devices)
|
||||
devices_.push_back(d);
|
||||
|
||||
// Register under the (1) full name, (2) canonical name, and (3) local name.
|
||||
string full_name = d->name();
|
||||
device_map_[CopyToBackingStore(full_name)] = d;
|
||||
|
||||
// TODO(b/62909072): Upgrade device_map_ to a better data structure.
|
||||
DeviceNameUtils::ParsedName parsed_name = d->parsed_name();
|
||||
if (parsed_name.has_job && parsed_name.has_replica &&
|
||||
parsed_name.has_task && parsed_name.has_type && parsed_name.has_id) {
|
||||
string canonical_name = DeviceNameUtils::FullName(
|
||||
parsed_name.job, parsed_name.replica, parsed_name.task,
|
||||
parsed_name.type, parsed_name.id);
|
||||
device_map_[CopyToBackingStore(canonical_name)] = d;
|
||||
|
||||
string legacy_name = DeviceNameUtils::LegacyName(
|
||||
parsed_name.job, parsed_name.replica, parsed_name.task,
|
||||
parsed_name.type, parsed_name.id);
|
||||
device_map_[CopyToBackingStore(legacy_name)] = d;
|
||||
for (const string& name :
|
||||
DeviceNameUtils::GetNamesForDeviceMappings(d->parsed_name())) {
|
||||
device_map_[CopyToBackingStore(name)] = d;
|
||||
}
|
||||
string lname = DeviceNameUtils::LocalName(d->name());
|
||||
device_map_[CopyToBackingStore(lname)] = d;
|
||||
|
@ -32,7 +32,10 @@ DeviceSet::~DeviceSet() {}
|
||||
|
||||
void DeviceSet::AddDevice(Device* device) {
|
||||
devices_.push_back(device);
|
||||
device_by_name_.insert({device->name(), device});
|
||||
for (const string& name :
|
||||
DeviceNameUtils::GetNamesForDeviceMappings(device->parsed_name())) {
|
||||
device_by_name_.insert({name, device});
|
||||
}
|
||||
}
|
||||
|
||||
void DeviceSet::FindMatchingDevices(const DeviceNameUtils::ParsedName& spec,
|
||||
|
@ -387,4 +387,16 @@ bool DeviceNameUtils::SplitDeviceName(StringPiece name, string* task,
|
||||
return false;
|
||||
}
|
||||
|
||||
std::vector<string> DeviceNameUtils::GetNamesForDeviceMappings(
|
||||
const ParsedName& pn) {
|
||||
if (pn.has_job && pn.has_replica && pn.has_task && pn.has_type && pn.has_id) {
|
||||
return {
|
||||
DeviceNameUtils::FullName(pn.job, pn.replica, pn.task, pn.type, pn.id),
|
||||
DeviceNameUtils::LegacyName(pn.job, pn.replica, pn.task, pn.type,
|
||||
pn.id)};
|
||||
} else {
|
||||
return {};
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -152,6 +152,11 @@ class DeviceNameUtils {
|
||||
static bool SplitDeviceName(StringPiece name, string* task, string* device);
|
||||
|
||||
static string ParsedNameToString(const ParsedName& pn);
|
||||
|
||||
// Returns canonical and legacy full names for the given parsed
|
||||
// device name 'pn'. The returned string names are often useful to
|
||||
// lookup devices from a mapping.
|
||||
static std::vector<string> GetNamesForDeviceMappings(const ParsedName& pn);
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -443,6 +443,16 @@ TEST(DeviceNameUtilsTest, MergeDevNamesAllowSoftPlacement) {
|
||||
MergeDevNamesHelperAllowSoftPlacement("/gpu:1", "/gpu:2", "/gpu:*");
|
||||
}
|
||||
|
||||
TEST(DeviceNameUtilsTest, GetNamesForDeviceMappings) {
|
||||
DeviceNameUtils::ParsedName p = Name("/job:foo/replica:10/task:0/gpu:1");
|
||||
EXPECT_EQ(str_util::Join(DeviceNameUtils::GetNamesForDeviceMappings(p), ","),
|
||||
"/job:foo/replica:10/task:0/device:GPU:1,"
|
||||
"/job:foo/replica:10/task:0/gpu:1");
|
||||
p.has_task = false;
|
||||
EXPECT_EQ(str_util::Join(DeviceNameUtils::GetNamesForDeviceMappings(p), ","),
|
||||
"");
|
||||
}
|
||||
|
||||
static void BM_ParseFullName(int iters) {
|
||||
DeviceNameUtils::ParsedName p;
|
||||
while (iters--) {
|
||||
|
Loading…
Reference in New Issue
Block a user