diff --git a/tensorflow/core/util/device_name_utils.cc b/tensorflow/core/util/device_name_utils.cc index 8688a11870e..14dab634416 100644 --- a/tensorflow/core/util/device_name_utils.cc +++ b/tensorflow/core/util/device_name_utils.cc @@ -174,6 +174,11 @@ bool DeviceNameUtils::ParseFullName(StringPiece fullname, ParsedName* p) { return true; } +bool DeviceNameUtils::ParseFullOrLocalName(StringPiece fullname, + ParsedName* p) { + return ParseFullName(fullname, p) || ParseLocalName(fullname, p); +} + namespace { void CompleteName(const DeviceNameUtils::ParsedName& parsed_basename, diff --git a/tensorflow/core/util/device_name_utils.h b/tensorflow/core/util/device_name_utils.h index a1515ba8508..3de7544a05e 100644 --- a/tensorflow/core/util/device_name_utils.h +++ b/tensorflow/core/util/device_name_utils.h @@ -89,6 +89,11 @@ class DeviceNameUtils { bool has_id = false; int id = 0; }; + + // Parses the device name, first as a full name, then, if it fails, as a + // global one. Returns `false` if both attempts fail. + static bool ParseFullOrLocalName(StringPiece fullname, ParsedName* parsed); + // Parses "fullname" into "*parsed". Returns true iff succeeds. // Legacy names like "/cpu:0" that don't contain "device", // are parsed to mean their current counterparts "/device:CPU:0". More diff --git a/tensorflow/core/util/device_name_utils_test.cc b/tensorflow/core/util/device_name_utils_test.cc index 729d1ec3ae8..065fcfbf2ce 100644 --- a/tensorflow/core/util/device_name_utils_test.cc +++ b/tensorflow/core/util/device_name_utils_test.cc @@ -105,6 +105,8 @@ TEST(DeviceNameUtilsTest, Basic) { DeviceNameUtils::ParsedName p; EXPECT_TRUE(DeviceNameUtils::ParseFullName( "/job:foo_bar/replica:1/task:2/device:GPU:3", &p)); + EXPECT_TRUE(DeviceNameUtils::ParseFullOrLocalName( + "/job:foo_bar/replica:1/task:2/device:GPU:3", &p)); EXPECT_TRUE(p.has_job); EXPECT_TRUE(p.has_replica); EXPECT_TRUE(p.has_task); @@ -246,12 +248,14 @@ TEST(DeviceNameUtilsTest, Basic) { { DeviceNameUtils::ParsedName p; EXPECT_TRUE(DeviceNameUtils::ParseLocalName("CPU:10", &p)); + EXPECT_TRUE(DeviceNameUtils::ParseFullOrLocalName("CPU:10", &p)); EXPECT_EQ(p.type, "CPU"); EXPECT_EQ(p.id, 10); EXPECT_FALSE(DeviceNameUtils::ParseLocalName("cpu:abc", &p)); EXPECT_FALSE(DeviceNameUtils::ParseLocalName("abc:", &p)); EXPECT_FALSE(DeviceNameUtils::ParseLocalName("abc", &p)); EXPECT_FALSE(DeviceNameUtils::ParseLocalName("myspecialdevice", &p)); + EXPECT_FALSE(DeviceNameUtils::ParseFullOrLocalName("myspecialdevice", &p)); } // Test that all parts are round-tripped correctly. diff --git a/tensorflow/python/tfe_wrapper.cc b/tensorflow/python/tfe_wrapper.cc index d2ee9a4c57e..2ea7fd7008d 100644 --- a/tensorflow/python/tfe_wrapper.cc +++ b/tensorflow/python/tfe_wrapper.cc @@ -359,10 +359,8 @@ PYBIND11_MODULE(_pywrap_tfe, m) { tensorflow::InputTFE_Context(ctx))); tensorflow::DeviceNameUtils::ParsedName input_device_name; - if (!tensorflow::DeviceNameUtils::ParseFullName(device_name, - &input_device_name) && - !tensorflow::DeviceNameUtils::ParseLocalName(device_name, - &input_device_name)) { + if (!tensorflow::DeviceNameUtils::ParseFullOrLocalName( + device_name, &input_device_name)) { tensorflow::ThrowValueError( absl::StrFormat("Failed parsing device name: '%s'", device_name) .c_str());