import_graph_def: support "absolute" names with the C API enabled.
Passing a name with a trailing '/' to import_graph_def causes that name to be used as-is (i.e. it is not appended to the existing name scope and not de-duped with any existing name scopes. This is in order to re-use an existing name scope). This didn't work with the C API enabled because it was set to always have the C API uniquify the prefix. The fix is to not uniquify the prefix, since calling name_scope in import_graph_def already has the logic to uniquify the prefix if necessary. I'm not sure why I thought we needed the C API to do this to being with. In addition, this changes the graph_constructor.cc logic to uniquify names if the prefix cannot be guaranteed unique (see the new test case in graph_constructor_test.cc for why/when this is necessary). PiperOrigin-RevId: 185215326
This commit is contained in:
parent
816f59e6ab
commit
5b71a126c4
@ -374,15 +374,8 @@ Status GraphConstructor::EnsureNoNameCollisions() {
|
|||||||
return errors::InvalidArgument("Imported node name prefix '", prefix_,
|
return errors::InvalidArgument("Imported node name prefix '", prefix_,
|
||||||
"' would lead to invalid node names");
|
"' would lead to invalid node names");
|
||||||
}
|
}
|
||||||
if (NameExistsInGraph(prefix_no_slash)) {
|
if (NameExistsInGraph(prefix_no_slash) && opts_.uniquify_prefix) {
|
||||||
if (opts_.uniquify_prefix) {
|
prefix_ = strings::StrCat(FindUniqueName(prefix_no_slash), "/");
|
||||||
prefix_ = strings::StrCat(FindUniqueName(prefix_no_slash), "/");
|
|
||||||
} else {
|
|
||||||
return errors::InvalidArgument("Import node name prefix '",
|
|
||||||
prefix_no_slash,
|
|
||||||
"' conflicts with "
|
|
||||||
"name already used in the graph");
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
@ -990,7 +983,10 @@ Status GraphConstructor::Convert() {
|
|||||||
if (opts_.importing) {
|
if (opts_.importing) {
|
||||||
if (!prefix_.empty()) {
|
if (!prefix_.empty()) {
|
||||||
AddPrefixToNodeDef(input_already_exists, &imported_node_def);
|
AddPrefixToNodeDef(input_already_exists, &imported_node_def);
|
||||||
} else if (opts_.uniquify_names) {
|
}
|
||||||
|
// Note: no need to uniquify names if the prefix already guarantees
|
||||||
|
// uniqueness
|
||||||
|
if (opts_.uniquify_names && (prefix_.empty() || !opts_.uniquify_prefix)) {
|
||||||
UniquifyNames(input_already_exists, &imported_node_def);
|
UniquifyNames(input_already_exists, &imported_node_def);
|
||||||
}
|
}
|
||||||
TF_RETURN_IF_ERROR(ModifyNodeDefForImport(&imported_node_def));
|
TF_RETURN_IF_ERROR(ModifyNodeDefForImport(&imported_node_def));
|
||||||
|
@ -1834,7 +1834,7 @@ TEST_F(GraphConstructorTest, ImportGraphDef_UniquifyNames) {
|
|||||||
EXPECT_EQ(results.return_nodes[1]->name(), "B_2");
|
EXPECT_EQ(results.return_nodes[1]->name(), "B_2");
|
||||||
EXPECT_EQ(results.return_nodes[1]->def().input(0), "A_2:0");
|
EXPECT_EQ(results.return_nodes[1]->def().input(0), "A_2:0");
|
||||||
|
|
||||||
// Import with an already-used prefix
|
// Import with an already-used prefix and uniquify_prefix = true
|
||||||
opts.prefix = "A";
|
opts.prefix = "A";
|
||||||
opts.uniquify_prefix = true;
|
opts.uniquify_prefix = true;
|
||||||
results = ImportGraphDefResults();
|
results = ImportGraphDefResults();
|
||||||
@ -1846,9 +1846,27 @@ TEST_F(GraphConstructorTest, ImportGraphDef_UniquifyNames) {
|
|||||||
EXPECT_EQ(results.return_nodes[1]->def().input(0), "A_3/A");
|
EXPECT_EQ(results.return_nodes[1]->def().input(0), "A_3/A");
|
||||||
|
|
||||||
// Create B_3 node to keep the A/B numbering in sync
|
// Create B_3 node to keep the A/B numbering in sync
|
||||||
opts = ImportGraphDefOptions();
|
|
||||||
ExpectOK("node { name: 'B_3' op: 'TestInput' }");
|
ExpectOK("node { name: 'B_3' op: 'TestInput' }");
|
||||||
|
|
||||||
|
// Import with an already-used prefix and uniquify_prefix = false
|
||||||
|
opts.uniquify_prefix = false;
|
||||||
|
results = ImportGraphDefResults();
|
||||||
|
ExpectOK(graph_def_str, opts, &refiner, &results);
|
||||||
|
|
||||||
|
ASSERT_EQ(results.return_nodes.size(), 2);
|
||||||
|
EXPECT_EQ(results.return_nodes[0]->name(), "A/A");
|
||||||
|
EXPECT_EQ(results.return_nodes[1]->name(), "A/B");
|
||||||
|
EXPECT_EQ(results.return_nodes[1]->def().input(0), "A/A");
|
||||||
|
|
||||||
|
// Repeat the same import
|
||||||
|
results = ImportGraphDefResults();
|
||||||
|
ExpectOK(graph_def_str, opts, &refiner, &results);
|
||||||
|
|
||||||
|
ASSERT_EQ(results.return_nodes.size(), 2);
|
||||||
|
EXPECT_EQ(results.return_nodes[0]->name(), "A/A_1");
|
||||||
|
EXPECT_EQ(results.return_nodes[1]->name(), "A/B_1");
|
||||||
|
EXPECT_EQ(results.return_nodes[1]->def().input(0), "A/A_1:0");
|
||||||
|
|
||||||
// Import with existing de-duped node names
|
// Import with existing de-duped node names
|
||||||
opts = ImportGraphDefOptions();
|
opts = ImportGraphDefOptions();
|
||||||
opts.uniquify_names = true;
|
opts.uniquify_names = true;
|
||||||
|
@ -270,7 +270,6 @@ def _PopulateTFImportGraphDefOptions(options, prefix, input_map,
|
|||||||
"""Populates the TF_ImportGraphDefOptions `options`."""
|
"""Populates the TF_ImportGraphDefOptions `options`."""
|
||||||
c_api.TF_ImportGraphDefOptionsSetPrefix(options, prefix)
|
c_api.TF_ImportGraphDefOptionsSetPrefix(options, prefix)
|
||||||
c_api.TF_ImportGraphDefOptionsSetUniquifyNames(options, True)
|
c_api.TF_ImportGraphDefOptionsSetUniquifyNames(options, True)
|
||||||
c_api.TF_ImportGraphDefOptionsSetUniquifyPrefix(options, True)
|
|
||||||
|
|
||||||
for input_src, input_dst in input_map.items():
|
for input_src, input_dst in input_map.items():
|
||||||
input_src = compat.as_str(input_src)
|
input_src = compat.as_str(input_src)
|
||||||
|
@ -154,6 +154,25 @@ class ImportGraphDefTest(test.TestCase):
|
|||||||
self.assertEqual(b3.name, "A_3/B")
|
self.assertEqual(b3.name, "A_3/B")
|
||||||
self.assertEqual(list(b3.inputs), [a3.outputs[0]])
|
self.assertEqual(list(b3.inputs), [a3.outputs[0]])
|
||||||
|
|
||||||
|
# Import with an already-used name but with a '/' to indicate an
|
||||||
|
# "absolute" name scope (see the Graph.name_scope docstring).
|
||||||
|
a_a, a_b = importer.import_graph_def(
|
||||||
|
graph_def,
|
||||||
|
return_elements=["A", "B"],
|
||||||
|
name="A/")
|
||||||
|
self.assertEqual(a_a.name, "A/A")
|
||||||
|
self.assertEqual(a_b.name, "A/B")
|
||||||
|
self.assertEqual(list(a_b.inputs), [a_a.outputs[0]])
|
||||||
|
|
||||||
|
# Repeat the same import.
|
||||||
|
a_a1, a_b1 = importer.import_graph_def(
|
||||||
|
graph_def,
|
||||||
|
return_elements=["A", "B"],
|
||||||
|
name="A/")
|
||||||
|
self.assertEqual(a_a1.name, "A/A_1")
|
||||||
|
self.assertEqual(a_b1.name, "A/B_1")
|
||||||
|
self.assertEqual(list(a_b1.inputs), [a_a1.outputs[0]])
|
||||||
|
|
||||||
# Import with existing de-duped node names
|
# Import with existing de-duped node names
|
||||||
a1_1, b1_1 = importer.import_graph_def(
|
a1_1, b1_1 = importer.import_graph_def(
|
||||||
self._MakeGraphDef("""
|
self._MakeGraphDef("""
|
||||||
|
Loading…
Reference in New Issue
Block a user