Introduce helper ParseFullOrLocalName

PiperOrigin-RevId: 329785171
Change-Id: I67a9502c0c7581aea07efd14c68aa013257499b3
This commit is contained in:
George Karpenkov 2020-09-02 14:11:33 -07:00 committed by TensorFlower Gardener
parent 67dfaf3a87
commit 6bac741e7a
4 changed files with 16 additions and 4 deletions

View File

@ -174,6 +174,11 @@ bool DeviceNameUtils::ParseFullName(StringPiece fullname, ParsedName* p) {
return true;
}
bool DeviceNameUtils::ParseFullOrLocalName(StringPiece fullname,
ParsedName* p) {
return ParseFullName(fullname, p) || ParseLocalName(fullname, p);
}
namespace {
void CompleteName(const DeviceNameUtils::ParsedName& parsed_basename,

View File

@ -89,6 +89,11 @@ class DeviceNameUtils {
bool has_id = false;
int id = 0;
};
// Parses the device name, first as a full name, then, if it fails, as a
// global one. Returns `false` if both attempts fail.
static bool ParseFullOrLocalName(StringPiece fullname, ParsedName* parsed);
// Parses "fullname" into "*parsed". Returns true iff succeeds.
// Legacy names like "/cpu:0" that don't contain "device",
// are parsed to mean their current counterparts "/device:CPU:0". More

View File

@ -105,6 +105,8 @@ TEST(DeviceNameUtilsTest, Basic) {
DeviceNameUtils::ParsedName p;
EXPECT_TRUE(DeviceNameUtils::ParseFullName(
"/job:foo_bar/replica:1/task:2/device:GPU:3", &p));
EXPECT_TRUE(DeviceNameUtils::ParseFullOrLocalName(
"/job:foo_bar/replica:1/task:2/device:GPU:3", &p));
EXPECT_TRUE(p.has_job);
EXPECT_TRUE(p.has_replica);
EXPECT_TRUE(p.has_task);
@ -246,12 +248,14 @@ TEST(DeviceNameUtilsTest, Basic) {
{
DeviceNameUtils::ParsedName p;
EXPECT_TRUE(DeviceNameUtils::ParseLocalName("CPU:10", &p));
EXPECT_TRUE(DeviceNameUtils::ParseFullOrLocalName("CPU:10", &p));
EXPECT_EQ(p.type, "CPU");
EXPECT_EQ(p.id, 10);
EXPECT_FALSE(DeviceNameUtils::ParseLocalName("cpu:abc", &p));
EXPECT_FALSE(DeviceNameUtils::ParseLocalName("abc:", &p));
EXPECT_FALSE(DeviceNameUtils::ParseLocalName("abc", &p));
EXPECT_FALSE(DeviceNameUtils::ParseLocalName("myspecialdevice", &p));
EXPECT_FALSE(DeviceNameUtils::ParseFullOrLocalName("myspecialdevice", &p));
}
// Test that all parts are round-tripped correctly.

View File

@ -359,10 +359,8 @@ PYBIND11_MODULE(_pywrap_tfe, m) {
tensorflow::InputTFE_Context(ctx)));
tensorflow::DeviceNameUtils::ParsedName input_device_name;
if (!tensorflow::DeviceNameUtils::ParseFullName(device_name,
&input_device_name) &&
!tensorflow::DeviceNameUtils::ParseLocalName(device_name,
&input_device_name)) {
if (!tensorflow::DeviceNameUtils::ParseFullOrLocalName(
device_name, &input_device_name)) {
tensorflow::ThrowValueError(
absl::StrFormat("Failed parsing device name: '%s'", device_name)
.c_str());