Add ImportGraphDefOptions::uniquify_prefix.

This option is necessary to mimic the Python import_graph_def method's
behavior.

PiperOrigin-RevId: 177986165
This commit is contained in:
Skye Wanderman-Milne 2017-12-05 11:38:19 -08:00 committed by TensorFlower Gardener
parent c72bb97541
commit e72ecbdb7a
3 changed files with 60 additions and 29 deletions

View File

@ -77,6 +77,7 @@ class GraphConstructor {
? in.prefix
: in.prefix + "/"),
uniquify_names(in.uniquify_names),
uniquify_prefix(in.uniquify_prefix),
input_map(in.input_map),
skip_mapped_nodes(in.skip_mapped_nodes),
control_dependencies(in.control_dependencies),
@ -90,6 +91,7 @@ class GraphConstructor {
string prefix;
bool uniquify_names;
bool uniquify_prefix;
std::map<TensorId, TensorId> input_map;
bool skip_mapped_nodes;
std::vector<string> control_dependencies;
@ -144,6 +146,7 @@ class GraphConstructor {
library_(library),
g_(g),
original_versions_(g->versions()),
prefix_(opts.prefix),
refiner_(refiner),
return_tensors_(return_tensors),
return_nodes_(return_nodes),
@ -227,6 +230,9 @@ class GraphConstructor {
Graph* g_;
const VersionDef original_versions_;
// A copy of opts_.prefix, possibly uniquified.
string prefix_;
ShapeRefiner* refiner_;
// May be null. Not owned.
@ -348,7 +354,7 @@ Status GraphConstructor::EnsureNoNameCollisions() {
}
AddPrefixes(n->name(), &existing_prefixes_);
}
if (opts_.prefix.empty() && opts_.importing && !opts_.uniquify_names) {
if (prefix_.empty() && opts_.importing && !opts_.uniquify_names) {
for (const NodeDef* n : node_defs_) {
const string& name = n->name();
if (NameExistsInGraph(name)) {
@ -356,19 +362,22 @@ Status GraphConstructor::EnsureNoNameCollisions() {
"' already exists in the Graph");
}
}
} else if (!opts_.prefix.empty()) {
StringPiece prefix_no_slash(opts_.prefix);
} else if (!prefix_.empty()) {
StringPiece prefix_no_slash(prefix_);
prefix_no_slash.remove_suffix(1);
if (!IsValidNodeName(prefix_no_slash, false)) {
return errors::InvalidArgument("Imported node name prefix '",
opts_.prefix,
return errors::InvalidArgument("Imported node name prefix '", prefix_,
"' would lead to invalid node names");
}
if (NameExistsInGraph(prefix_no_slash)) {
return errors::InvalidArgument("Import node name prefix '",
prefix_no_slash,
"' conflicts with "
"name already used in the graph");
if (opts_.uniquify_prefix) {
prefix_ = strings::StrCat(FindUniqueName(prefix_no_slash), "/");
} else {
return errors::InvalidArgument("Import node name prefix '",
prefix_no_slash,
"' conflicts with "
"name already used in the graph");
}
}
}
return Status::OK();
@ -740,8 +749,8 @@ void GraphConstructor::AddControlDependencies(
void GraphConstructor::AddPrefixToNodeDef(
const std::vector<bool>& input_already_exists, NodeDef* node_def) {
if (opts_.prefix.empty()) return;
node_def->set_name(strings::StrCat(opts_.prefix, node_def->name()));
if (prefix_.empty()) return;
node_def->set_name(strings::StrCat(prefix_, node_def->name()));
// Update names of input nodes
for (int i = 0; i < node_def->input_size(); ++i) {
StringPiece input(node_def->input(i));
@ -749,9 +758,9 @@ void GraphConstructor::AddPrefixToNodeDef(
// imported).
if (input_already_exists[i]) continue;
if (input.Consume("^")) {
node_def->set_input(i, strings::StrCat("^", opts_.prefix, input));
node_def->set_input(i, strings::StrCat("^", prefix_, input));
} else {
node_def->set_input(i, strings::StrCat(opts_.prefix, input));
node_def->set_input(i, strings::StrCat(prefix_, input));
}
}
// Update names of colocation groups
@ -761,8 +770,7 @@ void GraphConstructor::AddPrefixToNodeDef(
for (int i = 0; i < list->s_size(); ++i) {
StringPiece v(list->s(i));
if (v.Consume(kColocationGroupPrefix)) {
list->set_s(i,
strings::StrCat(kColocationGroupPrefix, opts_.prefix, v));
list->set_s(i, strings::StrCat(kColocationGroupPrefix, prefix_, v));
}
}
}
@ -975,7 +983,7 @@ Status GraphConstructor::Convert() {
Node* node;
if (opts_.importing) {
if (!opts_.prefix.empty()) {
if (!prefix_.empty()) {
AddPrefixToNodeDef(input_already_exists, &imported_node_def);
} else if (opts_.uniquify_names) {
UniquifyNames(input_already_exists, &imported_node_def);

View File

@ -54,7 +54,10 @@ extern Status ConvertNodeDefsToGraph(const GraphConstructorOptions& opts,
// Options for calling ImportGraphDef().
struct ImportGraphDefOptions {
ImportGraphDefOptions() : uniquify_names(false), skip_mapped_nodes(false) {}
ImportGraphDefOptions()
: uniquify_names(false),
uniquify_prefix(false),
skip_mapped_nodes(false) {}
// Name prefix to use for nodes imported from the GraphDef. For example, if
// prefix="animals" and GraphDef contains a node "bunny" then the node will be
@ -68,6 +71,11 @@ struct ImportGraphDefOptions {
// will guarantee all node names are unique.
bool uniquify_names;
// If true, `prefix` will be modified if it already exists as a node name or
// prefix in the graph. If false, a conflicting prefix will be treated as an
// error. This option has no effect if `prefix` isn't specified.
bool uniquify_prefix;
// Maps tensors in `gdef` to existing tensors in `g`. Inputs in `gdef`
// corresponding to `input_map` keys will be remapped to the nodes in `g`
// corresponding to the values.

View File

@ -1806,6 +1806,21 @@ TEST_F(GraphConstructorTest, ImportGraphDef_UniquifyNames) {
EXPECT_EQ(results.return_nodes[1]->name(), "B_2");
EXPECT_EQ(results.return_nodes[1]->def().input(0), "A_2:0");
// Import with an already-used prefix
opts.prefix = "A";
opts.uniquify_prefix = true;
results = ImportGraphDefResults();
ExpectOK(graph_def_str, opts, &refiner, &results);
ASSERT_EQ(results.return_nodes.size(), 2);
EXPECT_EQ(results.return_nodes[0]->name(), "A_3/A");
EXPECT_EQ(results.return_nodes[1]->name(), "A_3/B");
EXPECT_EQ(results.return_nodes[1]->def().input(0), "A_3/A");
// Create B_3 node to keep the A/B numbering in sync
opts = ImportGraphDefOptions();
ExpectOK("node { name: 'B_3' op: 'TestInput' }");
// Import with existing de-duped node names
opts = ImportGraphDefOptions();
opts.uniquify_names = true;
@ -1827,24 +1842,24 @@ TEST_F(GraphConstructorTest, ImportGraphDef_UniquifyNames) {
opts = ImportGraphDefOptions();
opts.uniquify_names = true;
opts.return_nodes.push_back("A");
opts.return_nodes.push_back("A_3");
opts.return_nodes.push_back("A_4");
opts.return_nodes.push_back("B");
opts.return_nodes.push_back("B_3/B");
opts.return_nodes.push_back("B_4/B");
results = ImportGraphDefResults();
ExpectOK(
"node { name: 'A' op: 'TestInput' }"
"node { name: 'A_3' op: 'TestInput' }"
"node { name: 'A_4' op: 'TestInput' }"
"node { name: 'B' op: 'TestOneInputTwoOutputs' input: ['A'] }"
"node { name: 'B_3/B' op: 'TestOneInputTwoOutputs' input: ['A_3'] }",
"node { name: 'B_4/B' op: 'TestOneInputTwoOutputs' input: ['A_4'] }",
opts, &refiner, &results);
ASSERT_EQ(results.return_nodes.size(), 4);
EXPECT_EQ(results.return_nodes[0]->name(), "A_4");
EXPECT_EQ(results.return_nodes[1]->name(), "A_3");
EXPECT_EQ(results.return_nodes[2]->name(), "B_4");
EXPECT_EQ(results.return_nodes[2]->def().input(0), "A_4:0");
EXPECT_EQ(results.return_nodes[3]->name(), "B_3/B");
EXPECT_EQ(results.return_nodes[3]->def().input(0), "A_3");
EXPECT_EQ(results.return_nodes[0]->name(), "A_5");
EXPECT_EQ(results.return_nodes[1]->name(), "A_4");
EXPECT_EQ(results.return_nodes[2]->name(), "B_5");
EXPECT_EQ(results.return_nodes[2]->def().input(0), "A_5:0");
EXPECT_EQ(results.return_nodes[3]->name(), "B_4/B");
EXPECT_EQ(results.return_nodes[3]->def().input(0), "A_4");
// Create node with prefix and then import node with same name
ExpectOK("node { name: 'foo/abc' op: 'ABC' }");
@ -1895,8 +1910,8 @@ TEST_F(GraphConstructorTest, ImportGraphDef_UniquifyNames) {
ExpectOK(graph_def_str, opts, &refiner, &results);
ASSERT_EQ(results.return_nodes.size(), 2);
EXPECT_EQ(results.return_nodes[0]->name(), "A_5");
EXPECT_EQ(results.return_nodes[1]->name(), "B_5");
EXPECT_EQ(results.return_nodes[0]->name(), "A_6");
EXPECT_EQ(results.return_nodes[1]->name(), "B_6");
EXPECT_EQ(results.return_nodes[1]->def().input(0), "A:0");
}