From e72ecbdb7a84c5cc0801e85a9c38f6fd181ceef6 Mon Sep 17 00:00:00 2001 From: Skye Wanderman-Milne Date: Tue, 5 Dec 2017 11:38:19 -0800 Subject: [PATCH] Add ImportGraphDefOptions::uniquify_prefix. This option is necessary to mimic the Python import_graph_def method's behavior. PiperOrigin-RevId: 177986165 --- tensorflow/core/graph/graph_constructor.cc | 40 +++++++++++-------- tensorflow/core/graph/graph_constructor.h | 10 ++++- .../core/graph/graph_constructor_test.cc | 39 ++++++++++++------ 3 files changed, 60 insertions(+), 29 deletions(-) diff --git a/tensorflow/core/graph/graph_constructor.cc b/tensorflow/core/graph/graph_constructor.cc index 0fb61fd9af2..6e72d739189 100644 --- a/tensorflow/core/graph/graph_constructor.cc +++ b/tensorflow/core/graph/graph_constructor.cc @@ -77,6 +77,7 @@ class GraphConstructor { ? in.prefix : in.prefix + "/"), uniquify_names(in.uniquify_names), + uniquify_prefix(in.uniquify_prefix), input_map(in.input_map), skip_mapped_nodes(in.skip_mapped_nodes), control_dependencies(in.control_dependencies), @@ -90,6 +91,7 @@ class GraphConstructor { string prefix; bool uniquify_names; + bool uniquify_prefix; std::map input_map; bool skip_mapped_nodes; std::vector control_dependencies; @@ -144,6 +146,7 @@ class GraphConstructor { library_(library), g_(g), original_versions_(g->versions()), + prefix_(opts.prefix), refiner_(refiner), return_tensors_(return_tensors), return_nodes_(return_nodes), @@ -227,6 +230,9 @@ class GraphConstructor { Graph* g_; const VersionDef original_versions_; + // A copy of opts_.prefix, possibly uniquified. + string prefix_; + ShapeRefiner* refiner_; // May be null. Not owned. @@ -348,7 +354,7 @@ Status GraphConstructor::EnsureNoNameCollisions() { } AddPrefixes(n->name(), &existing_prefixes_); } - if (opts_.prefix.empty() && opts_.importing && !opts_.uniquify_names) { + if (prefix_.empty() && opts_.importing && !opts_.uniquify_names) { for (const NodeDef* n : node_defs_) { const string& name = n->name(); if (NameExistsInGraph(name)) { @@ -356,19 +362,22 @@ Status GraphConstructor::EnsureNoNameCollisions() { "' already exists in the Graph"); } } - } else if (!opts_.prefix.empty()) { - StringPiece prefix_no_slash(opts_.prefix); + } else if (!prefix_.empty()) { + StringPiece prefix_no_slash(prefix_); prefix_no_slash.remove_suffix(1); if (!IsValidNodeName(prefix_no_slash, false)) { - return errors::InvalidArgument("Imported node name prefix '", - opts_.prefix, + return errors::InvalidArgument("Imported node name prefix '", prefix_, "' would lead to invalid node names"); } if (NameExistsInGraph(prefix_no_slash)) { - return errors::InvalidArgument("Import node name prefix '", - prefix_no_slash, - "' conflicts with " - "name already used in the graph"); + if (opts_.uniquify_prefix) { + prefix_ = strings::StrCat(FindUniqueName(prefix_no_slash), "/"); + } else { + return errors::InvalidArgument("Import node name prefix '", + prefix_no_slash, + "' conflicts with " + "name already used in the graph"); + } } } return Status::OK(); @@ -740,8 +749,8 @@ void GraphConstructor::AddControlDependencies( void GraphConstructor::AddPrefixToNodeDef( const std::vector& input_already_exists, NodeDef* node_def) { - if (opts_.prefix.empty()) return; - node_def->set_name(strings::StrCat(opts_.prefix, node_def->name())); + if (prefix_.empty()) return; + node_def->set_name(strings::StrCat(prefix_, node_def->name())); // Update names of input nodes for (int i = 0; i < node_def->input_size(); ++i) { StringPiece input(node_def->input(i)); @@ -749,9 +758,9 @@ void GraphConstructor::AddPrefixToNodeDef( // imported). if (input_already_exists[i]) continue; if (input.Consume("^")) { - node_def->set_input(i, strings::StrCat("^", opts_.prefix, input)); + node_def->set_input(i, strings::StrCat("^", prefix_, input)); } else { - node_def->set_input(i, strings::StrCat(opts_.prefix, input)); + node_def->set_input(i, strings::StrCat(prefix_, input)); } } // Update names of colocation groups @@ -761,8 +770,7 @@ void GraphConstructor::AddPrefixToNodeDef( for (int i = 0; i < list->s_size(); ++i) { StringPiece v(list->s(i)); if (v.Consume(kColocationGroupPrefix)) { - list->set_s(i, - strings::StrCat(kColocationGroupPrefix, opts_.prefix, v)); + list->set_s(i, strings::StrCat(kColocationGroupPrefix, prefix_, v)); } } } @@ -975,7 +983,7 @@ Status GraphConstructor::Convert() { Node* node; if (opts_.importing) { - if (!opts_.prefix.empty()) { + if (!prefix_.empty()) { AddPrefixToNodeDef(input_already_exists, &imported_node_def); } else if (opts_.uniquify_names) { UniquifyNames(input_already_exists, &imported_node_def); diff --git a/tensorflow/core/graph/graph_constructor.h b/tensorflow/core/graph/graph_constructor.h index 4b418b86229..b4dd2ba51a6 100644 --- a/tensorflow/core/graph/graph_constructor.h +++ b/tensorflow/core/graph/graph_constructor.h @@ -54,7 +54,10 @@ extern Status ConvertNodeDefsToGraph(const GraphConstructorOptions& opts, // Options for calling ImportGraphDef(). struct ImportGraphDefOptions { - ImportGraphDefOptions() : uniquify_names(false), skip_mapped_nodes(false) {} + ImportGraphDefOptions() + : uniquify_names(false), + uniquify_prefix(false), + skip_mapped_nodes(false) {} // Name prefix to use for nodes imported from the GraphDef. For example, if // prefix="animals" and GraphDef contains a node "bunny" then the node will be @@ -68,6 +71,11 @@ struct ImportGraphDefOptions { // will guarantee all node names are unique. bool uniquify_names; + // If true, `prefix` will be modified if it already exists as a node name or + // prefix in the graph. If false, a conflicting prefix will be treated as an + // error. This option has no effect if `prefix` isn't specified. + bool uniquify_prefix; + // Maps tensors in `gdef` to existing tensors in `g`. Inputs in `gdef` // corresponding to `input_map` keys will be remapped to the nodes in `g` // corresponding to the values. diff --git a/tensorflow/core/graph/graph_constructor_test.cc b/tensorflow/core/graph/graph_constructor_test.cc index 83aba6c9be3..9be3de23881 100644 --- a/tensorflow/core/graph/graph_constructor_test.cc +++ b/tensorflow/core/graph/graph_constructor_test.cc @@ -1806,6 +1806,21 @@ TEST_F(GraphConstructorTest, ImportGraphDef_UniquifyNames) { EXPECT_EQ(results.return_nodes[1]->name(), "B_2"); EXPECT_EQ(results.return_nodes[1]->def().input(0), "A_2:0"); + // Import with an already-used prefix + opts.prefix = "A"; + opts.uniquify_prefix = true; + results = ImportGraphDefResults(); + ExpectOK(graph_def_str, opts, &refiner, &results); + + ASSERT_EQ(results.return_nodes.size(), 2); + EXPECT_EQ(results.return_nodes[0]->name(), "A_3/A"); + EXPECT_EQ(results.return_nodes[1]->name(), "A_3/B"); + EXPECT_EQ(results.return_nodes[1]->def().input(0), "A_3/A"); + + // Create B_3 node to keep the A/B numbering in sync + opts = ImportGraphDefOptions(); + ExpectOK("node { name: 'B_3' op: 'TestInput' }"); + // Import with existing de-duped node names opts = ImportGraphDefOptions(); opts.uniquify_names = true; @@ -1827,24 +1842,24 @@ TEST_F(GraphConstructorTest, ImportGraphDef_UniquifyNames) { opts = ImportGraphDefOptions(); opts.uniquify_names = true; opts.return_nodes.push_back("A"); - opts.return_nodes.push_back("A_3"); + opts.return_nodes.push_back("A_4"); opts.return_nodes.push_back("B"); - opts.return_nodes.push_back("B_3/B"); + opts.return_nodes.push_back("B_4/B"); results = ImportGraphDefResults(); ExpectOK( "node { name: 'A' op: 'TestInput' }" - "node { name: 'A_3' op: 'TestInput' }" + "node { name: 'A_4' op: 'TestInput' }" "node { name: 'B' op: 'TestOneInputTwoOutputs' input: ['A'] }" - "node { name: 'B_3/B' op: 'TestOneInputTwoOutputs' input: ['A_3'] }", + "node { name: 'B_4/B' op: 'TestOneInputTwoOutputs' input: ['A_4'] }", opts, &refiner, &results); ASSERT_EQ(results.return_nodes.size(), 4); - EXPECT_EQ(results.return_nodes[0]->name(), "A_4"); - EXPECT_EQ(results.return_nodes[1]->name(), "A_3"); - EXPECT_EQ(results.return_nodes[2]->name(), "B_4"); - EXPECT_EQ(results.return_nodes[2]->def().input(0), "A_4:0"); - EXPECT_EQ(results.return_nodes[3]->name(), "B_3/B"); - EXPECT_EQ(results.return_nodes[3]->def().input(0), "A_3"); + EXPECT_EQ(results.return_nodes[0]->name(), "A_5"); + EXPECT_EQ(results.return_nodes[1]->name(), "A_4"); + EXPECT_EQ(results.return_nodes[2]->name(), "B_5"); + EXPECT_EQ(results.return_nodes[2]->def().input(0), "A_5:0"); + EXPECT_EQ(results.return_nodes[3]->name(), "B_4/B"); + EXPECT_EQ(results.return_nodes[3]->def().input(0), "A_4"); // Create node with prefix and then import node with same name ExpectOK("node { name: 'foo/abc' op: 'ABC' }"); @@ -1895,8 +1910,8 @@ TEST_F(GraphConstructorTest, ImportGraphDef_UniquifyNames) { ExpectOK(graph_def_str, opts, &refiner, &results); ASSERT_EQ(results.return_nodes.size(), 2); - EXPECT_EQ(results.return_nodes[0]->name(), "A_5"); - EXPECT_EQ(results.return_nodes[1]->name(), "B_5"); + EXPECT_EQ(results.return_nodes[0]->name(), "A_6"); + EXPECT_EQ(results.return_nodes[1]->name(), "B_6"); EXPECT_EQ(results.return_nodes[1]->def().input(0), "A:0"); }