Introduce helper ParseFullOrLocalName
PiperOrigin-RevId: 329785171 Change-Id: I67a9502c0c7581aea07efd14c68aa013257499b3
This commit is contained in:
parent
67dfaf3a87
commit
6bac741e7a
@ -174,6 +174,11 @@ bool DeviceNameUtils::ParseFullName(StringPiece fullname, ParsedName* p) {
|
||||
return true;
|
||||
}
|
||||
|
||||
bool DeviceNameUtils::ParseFullOrLocalName(StringPiece fullname,
|
||||
ParsedName* p) {
|
||||
return ParseFullName(fullname, p) || ParseLocalName(fullname, p);
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
void CompleteName(const DeviceNameUtils::ParsedName& parsed_basename,
|
||||
|
@ -89,6 +89,11 @@ class DeviceNameUtils {
|
||||
bool has_id = false;
|
||||
int id = 0;
|
||||
};
|
||||
|
||||
// Parses the device name, first as a full name, then, if it fails, as a
|
||||
// global one. Returns `false` if both attempts fail.
|
||||
static bool ParseFullOrLocalName(StringPiece fullname, ParsedName* parsed);
|
||||
|
||||
// Parses "fullname" into "*parsed". Returns true iff succeeds.
|
||||
// Legacy names like "/cpu:0" that don't contain "device",
|
||||
// are parsed to mean their current counterparts "/device:CPU:0". More
|
||||
|
@ -105,6 +105,8 @@ TEST(DeviceNameUtilsTest, Basic) {
|
||||
DeviceNameUtils::ParsedName p;
|
||||
EXPECT_TRUE(DeviceNameUtils::ParseFullName(
|
||||
"/job:foo_bar/replica:1/task:2/device:GPU:3", &p));
|
||||
EXPECT_TRUE(DeviceNameUtils::ParseFullOrLocalName(
|
||||
"/job:foo_bar/replica:1/task:2/device:GPU:3", &p));
|
||||
EXPECT_TRUE(p.has_job);
|
||||
EXPECT_TRUE(p.has_replica);
|
||||
EXPECT_TRUE(p.has_task);
|
||||
@ -246,12 +248,14 @@ TEST(DeviceNameUtilsTest, Basic) {
|
||||
{
|
||||
DeviceNameUtils::ParsedName p;
|
||||
EXPECT_TRUE(DeviceNameUtils::ParseLocalName("CPU:10", &p));
|
||||
EXPECT_TRUE(DeviceNameUtils::ParseFullOrLocalName("CPU:10", &p));
|
||||
EXPECT_EQ(p.type, "CPU");
|
||||
EXPECT_EQ(p.id, 10);
|
||||
EXPECT_FALSE(DeviceNameUtils::ParseLocalName("cpu:abc", &p));
|
||||
EXPECT_FALSE(DeviceNameUtils::ParseLocalName("abc:", &p));
|
||||
EXPECT_FALSE(DeviceNameUtils::ParseLocalName("abc", &p));
|
||||
EXPECT_FALSE(DeviceNameUtils::ParseLocalName("myspecialdevice", &p));
|
||||
EXPECT_FALSE(DeviceNameUtils::ParseFullOrLocalName("myspecialdevice", &p));
|
||||
}
|
||||
|
||||
// Test that all parts are round-tripped correctly.
|
||||
|
@ -359,10 +359,8 @@ PYBIND11_MODULE(_pywrap_tfe, m) {
|
||||
tensorflow::InputTFE_Context(ctx)));
|
||||
|
||||
tensorflow::DeviceNameUtils::ParsedName input_device_name;
|
||||
if (!tensorflow::DeviceNameUtils::ParseFullName(device_name,
|
||||
&input_device_name) &&
|
||||
!tensorflow::DeviceNameUtils::ParseLocalName(device_name,
|
||||
&input_device_name)) {
|
||||
if (!tensorflow::DeviceNameUtils::ParseFullOrLocalName(
|
||||
device_name, &input_device_name)) {
|
||||
tensorflow::ThrowValueError(
|
||||
absl::StrFormat("Failed parsing device name: '%s'", device_name)
|
||||
.c_str());
|
||||
|
Loading…
Reference in New Issue
Block a user