Introduce helper ParseFullOrLocalName
PiperOrigin-RevId: 329785171 Change-Id: I67a9502c0c7581aea07efd14c68aa013257499b3
This commit is contained in:
parent
67dfaf3a87
commit
6bac741e7a
@ -174,6 +174,11 @@ bool DeviceNameUtils::ParseFullName(StringPiece fullname, ParsedName* p) {
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool DeviceNameUtils::ParseFullOrLocalName(StringPiece fullname,
|
||||||
|
ParsedName* p) {
|
||||||
|
return ParseFullName(fullname, p) || ParseLocalName(fullname, p);
|
||||||
|
}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
void CompleteName(const DeviceNameUtils::ParsedName& parsed_basename,
|
void CompleteName(const DeviceNameUtils::ParsedName& parsed_basename,
|
||||||
|
@ -89,6 +89,11 @@ class DeviceNameUtils {
|
|||||||
bool has_id = false;
|
bool has_id = false;
|
||||||
int id = 0;
|
int id = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Parses the device name, first as a full name, then, if it fails, as a
|
||||||
|
// global one. Returns `false` if both attempts fail.
|
||||||
|
static bool ParseFullOrLocalName(StringPiece fullname, ParsedName* parsed);
|
||||||
|
|
||||||
// Parses "fullname" into "*parsed". Returns true iff succeeds.
|
// Parses "fullname" into "*parsed". Returns true iff succeeds.
|
||||||
// Legacy names like "/cpu:0" that don't contain "device",
|
// Legacy names like "/cpu:0" that don't contain "device",
|
||||||
// are parsed to mean their current counterparts "/device:CPU:0". More
|
// are parsed to mean their current counterparts "/device:CPU:0". More
|
||||||
|
@ -105,6 +105,8 @@ TEST(DeviceNameUtilsTest, Basic) {
|
|||||||
DeviceNameUtils::ParsedName p;
|
DeviceNameUtils::ParsedName p;
|
||||||
EXPECT_TRUE(DeviceNameUtils::ParseFullName(
|
EXPECT_TRUE(DeviceNameUtils::ParseFullName(
|
||||||
"/job:foo_bar/replica:1/task:2/device:GPU:3", &p));
|
"/job:foo_bar/replica:1/task:2/device:GPU:3", &p));
|
||||||
|
EXPECT_TRUE(DeviceNameUtils::ParseFullOrLocalName(
|
||||||
|
"/job:foo_bar/replica:1/task:2/device:GPU:3", &p));
|
||||||
EXPECT_TRUE(p.has_job);
|
EXPECT_TRUE(p.has_job);
|
||||||
EXPECT_TRUE(p.has_replica);
|
EXPECT_TRUE(p.has_replica);
|
||||||
EXPECT_TRUE(p.has_task);
|
EXPECT_TRUE(p.has_task);
|
||||||
@ -246,12 +248,14 @@ TEST(DeviceNameUtilsTest, Basic) {
|
|||||||
{
|
{
|
||||||
DeviceNameUtils::ParsedName p;
|
DeviceNameUtils::ParsedName p;
|
||||||
EXPECT_TRUE(DeviceNameUtils::ParseLocalName("CPU:10", &p));
|
EXPECT_TRUE(DeviceNameUtils::ParseLocalName("CPU:10", &p));
|
||||||
|
EXPECT_TRUE(DeviceNameUtils::ParseFullOrLocalName("CPU:10", &p));
|
||||||
EXPECT_EQ(p.type, "CPU");
|
EXPECT_EQ(p.type, "CPU");
|
||||||
EXPECT_EQ(p.id, 10);
|
EXPECT_EQ(p.id, 10);
|
||||||
EXPECT_FALSE(DeviceNameUtils::ParseLocalName("cpu:abc", &p));
|
EXPECT_FALSE(DeviceNameUtils::ParseLocalName("cpu:abc", &p));
|
||||||
EXPECT_FALSE(DeviceNameUtils::ParseLocalName("abc:", &p));
|
EXPECT_FALSE(DeviceNameUtils::ParseLocalName("abc:", &p));
|
||||||
EXPECT_FALSE(DeviceNameUtils::ParseLocalName("abc", &p));
|
EXPECT_FALSE(DeviceNameUtils::ParseLocalName("abc", &p));
|
||||||
EXPECT_FALSE(DeviceNameUtils::ParseLocalName("myspecialdevice", &p));
|
EXPECT_FALSE(DeviceNameUtils::ParseLocalName("myspecialdevice", &p));
|
||||||
|
EXPECT_FALSE(DeviceNameUtils::ParseFullOrLocalName("myspecialdevice", &p));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test that all parts are round-tripped correctly.
|
// Test that all parts are round-tripped correctly.
|
||||||
|
@ -359,10 +359,8 @@ PYBIND11_MODULE(_pywrap_tfe, m) {
|
|||||||
tensorflow::InputTFE_Context(ctx)));
|
tensorflow::InputTFE_Context(ctx)));
|
||||||
|
|
||||||
tensorflow::DeviceNameUtils::ParsedName input_device_name;
|
tensorflow::DeviceNameUtils::ParsedName input_device_name;
|
||||||
if (!tensorflow::DeviceNameUtils::ParseFullName(device_name,
|
if (!tensorflow::DeviceNameUtils::ParseFullOrLocalName(
|
||||||
&input_device_name) &&
|
device_name, &input_device_name)) {
|
||||||
!tensorflow::DeviceNameUtils::ParseLocalName(device_name,
|
|
||||||
&input_device_name)) {
|
|
||||||
tensorflow::ThrowValueError(
|
tensorflow::ThrowValueError(
|
||||||
absl::StrFormat("Failed parsing device name: '%s'", device_name)
|
absl::StrFormat("Failed parsing device name: '%s'", device_name)
|
||||||
.c_str());
|
.c_str());
|
||||||
|
Loading…
Reference in New Issue
Block a user