Add ImportGraphDefOptions::uniquify_prefix.
This option is necessary to mimic the Python import_graph_def method's behavior. PiperOrigin-RevId: 177986165
This commit is contained in:
parent
c72bb97541
commit
e72ecbdb7a
@ -77,6 +77,7 @@ class GraphConstructor {
|
||||
? in.prefix
|
||||
: in.prefix + "/"),
|
||||
uniquify_names(in.uniquify_names),
|
||||
uniquify_prefix(in.uniquify_prefix),
|
||||
input_map(in.input_map),
|
||||
skip_mapped_nodes(in.skip_mapped_nodes),
|
||||
control_dependencies(in.control_dependencies),
|
||||
@ -90,6 +91,7 @@ class GraphConstructor {
|
||||
|
||||
string prefix;
|
||||
bool uniquify_names;
|
||||
bool uniquify_prefix;
|
||||
std::map<TensorId, TensorId> input_map;
|
||||
bool skip_mapped_nodes;
|
||||
std::vector<string> control_dependencies;
|
||||
@ -144,6 +146,7 @@ class GraphConstructor {
|
||||
library_(library),
|
||||
g_(g),
|
||||
original_versions_(g->versions()),
|
||||
prefix_(opts.prefix),
|
||||
refiner_(refiner),
|
||||
return_tensors_(return_tensors),
|
||||
return_nodes_(return_nodes),
|
||||
@ -227,6 +230,9 @@ class GraphConstructor {
|
||||
Graph* g_;
|
||||
const VersionDef original_versions_;
|
||||
|
||||
// A copy of opts_.prefix, possibly uniquified.
|
||||
string prefix_;
|
||||
|
||||
ShapeRefiner* refiner_;
|
||||
|
||||
// May be null. Not owned.
|
||||
@ -348,7 +354,7 @@ Status GraphConstructor::EnsureNoNameCollisions() {
|
||||
}
|
||||
AddPrefixes(n->name(), &existing_prefixes_);
|
||||
}
|
||||
if (opts_.prefix.empty() && opts_.importing && !opts_.uniquify_names) {
|
||||
if (prefix_.empty() && opts_.importing && !opts_.uniquify_names) {
|
||||
for (const NodeDef* n : node_defs_) {
|
||||
const string& name = n->name();
|
||||
if (NameExistsInGraph(name)) {
|
||||
@ -356,19 +362,22 @@ Status GraphConstructor::EnsureNoNameCollisions() {
|
||||
"' already exists in the Graph");
|
||||
}
|
||||
}
|
||||
} else if (!opts_.prefix.empty()) {
|
||||
StringPiece prefix_no_slash(opts_.prefix);
|
||||
} else if (!prefix_.empty()) {
|
||||
StringPiece prefix_no_slash(prefix_);
|
||||
prefix_no_slash.remove_suffix(1);
|
||||
if (!IsValidNodeName(prefix_no_slash, false)) {
|
||||
return errors::InvalidArgument("Imported node name prefix '",
|
||||
opts_.prefix,
|
||||
return errors::InvalidArgument("Imported node name prefix '", prefix_,
|
||||
"' would lead to invalid node names");
|
||||
}
|
||||
if (NameExistsInGraph(prefix_no_slash)) {
|
||||
return errors::InvalidArgument("Import node name prefix '",
|
||||
prefix_no_slash,
|
||||
"' conflicts with "
|
||||
"name already used in the graph");
|
||||
if (opts_.uniquify_prefix) {
|
||||
prefix_ = strings::StrCat(FindUniqueName(prefix_no_slash), "/");
|
||||
} else {
|
||||
return errors::InvalidArgument("Import node name prefix '",
|
||||
prefix_no_slash,
|
||||
"' conflicts with "
|
||||
"name already used in the graph");
|
||||
}
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
@ -740,8 +749,8 @@ void GraphConstructor::AddControlDependencies(
|
||||
|
||||
void GraphConstructor::AddPrefixToNodeDef(
|
||||
const std::vector<bool>& input_already_exists, NodeDef* node_def) {
|
||||
if (opts_.prefix.empty()) return;
|
||||
node_def->set_name(strings::StrCat(opts_.prefix, node_def->name()));
|
||||
if (prefix_.empty()) return;
|
||||
node_def->set_name(strings::StrCat(prefix_, node_def->name()));
|
||||
// Update names of input nodes
|
||||
for (int i = 0; i < node_def->input_size(); ++i) {
|
||||
StringPiece input(node_def->input(i));
|
||||
@ -749,9 +758,9 @@ void GraphConstructor::AddPrefixToNodeDef(
|
||||
// imported).
|
||||
if (input_already_exists[i]) continue;
|
||||
if (input.Consume("^")) {
|
||||
node_def->set_input(i, strings::StrCat("^", opts_.prefix, input));
|
||||
node_def->set_input(i, strings::StrCat("^", prefix_, input));
|
||||
} else {
|
||||
node_def->set_input(i, strings::StrCat(opts_.prefix, input));
|
||||
node_def->set_input(i, strings::StrCat(prefix_, input));
|
||||
}
|
||||
}
|
||||
// Update names of colocation groups
|
||||
@ -761,8 +770,7 @@ void GraphConstructor::AddPrefixToNodeDef(
|
||||
for (int i = 0; i < list->s_size(); ++i) {
|
||||
StringPiece v(list->s(i));
|
||||
if (v.Consume(kColocationGroupPrefix)) {
|
||||
list->set_s(i,
|
||||
strings::StrCat(kColocationGroupPrefix, opts_.prefix, v));
|
||||
list->set_s(i, strings::StrCat(kColocationGroupPrefix, prefix_, v));
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -975,7 +983,7 @@ Status GraphConstructor::Convert() {
|
||||
|
||||
Node* node;
|
||||
if (opts_.importing) {
|
||||
if (!opts_.prefix.empty()) {
|
||||
if (!prefix_.empty()) {
|
||||
AddPrefixToNodeDef(input_already_exists, &imported_node_def);
|
||||
} else if (opts_.uniquify_names) {
|
||||
UniquifyNames(input_already_exists, &imported_node_def);
|
||||
|
@ -54,7 +54,10 @@ extern Status ConvertNodeDefsToGraph(const GraphConstructorOptions& opts,
|
||||
|
||||
// Options for calling ImportGraphDef().
|
||||
struct ImportGraphDefOptions {
|
||||
ImportGraphDefOptions() : uniquify_names(false), skip_mapped_nodes(false) {}
|
||||
ImportGraphDefOptions()
|
||||
: uniquify_names(false),
|
||||
uniquify_prefix(false),
|
||||
skip_mapped_nodes(false) {}
|
||||
|
||||
// Name prefix to use for nodes imported from the GraphDef. For example, if
|
||||
// prefix="animals" and GraphDef contains a node "bunny" then the node will be
|
||||
@ -68,6 +71,11 @@ struct ImportGraphDefOptions {
|
||||
// will guarantee all node names are unique.
|
||||
bool uniquify_names;
|
||||
|
||||
// If true, `prefix` will be modified if it already exists as a node name or
|
||||
// prefix in the graph. If false, a conflicting prefix will be treated as an
|
||||
// error. This option has no effect if `prefix` isn't specified.
|
||||
bool uniquify_prefix;
|
||||
|
||||
// Maps tensors in `gdef` to existing tensors in `g`. Inputs in `gdef`
|
||||
// corresponding to `input_map` keys will be remapped to the nodes in `g`
|
||||
// corresponding to the values.
|
||||
|
@ -1806,6 +1806,21 @@ TEST_F(GraphConstructorTest, ImportGraphDef_UniquifyNames) {
|
||||
EXPECT_EQ(results.return_nodes[1]->name(), "B_2");
|
||||
EXPECT_EQ(results.return_nodes[1]->def().input(0), "A_2:0");
|
||||
|
||||
// Import with an already-used prefix
|
||||
opts.prefix = "A";
|
||||
opts.uniquify_prefix = true;
|
||||
results = ImportGraphDefResults();
|
||||
ExpectOK(graph_def_str, opts, &refiner, &results);
|
||||
|
||||
ASSERT_EQ(results.return_nodes.size(), 2);
|
||||
EXPECT_EQ(results.return_nodes[0]->name(), "A_3/A");
|
||||
EXPECT_EQ(results.return_nodes[1]->name(), "A_3/B");
|
||||
EXPECT_EQ(results.return_nodes[1]->def().input(0), "A_3/A");
|
||||
|
||||
// Create B_3 node to keep the A/B numbering in sync
|
||||
opts = ImportGraphDefOptions();
|
||||
ExpectOK("node { name: 'B_3' op: 'TestInput' }");
|
||||
|
||||
// Import with existing de-duped node names
|
||||
opts = ImportGraphDefOptions();
|
||||
opts.uniquify_names = true;
|
||||
@ -1827,24 +1842,24 @@ TEST_F(GraphConstructorTest, ImportGraphDef_UniquifyNames) {
|
||||
opts = ImportGraphDefOptions();
|
||||
opts.uniquify_names = true;
|
||||
opts.return_nodes.push_back("A");
|
||||
opts.return_nodes.push_back("A_3");
|
||||
opts.return_nodes.push_back("A_4");
|
||||
opts.return_nodes.push_back("B");
|
||||
opts.return_nodes.push_back("B_3/B");
|
||||
opts.return_nodes.push_back("B_4/B");
|
||||
results = ImportGraphDefResults();
|
||||
ExpectOK(
|
||||
"node { name: 'A' op: 'TestInput' }"
|
||||
"node { name: 'A_3' op: 'TestInput' }"
|
||||
"node { name: 'A_4' op: 'TestInput' }"
|
||||
"node { name: 'B' op: 'TestOneInputTwoOutputs' input: ['A'] }"
|
||||
"node { name: 'B_3/B' op: 'TestOneInputTwoOutputs' input: ['A_3'] }",
|
||||
"node { name: 'B_4/B' op: 'TestOneInputTwoOutputs' input: ['A_4'] }",
|
||||
opts, &refiner, &results);
|
||||
|
||||
ASSERT_EQ(results.return_nodes.size(), 4);
|
||||
EXPECT_EQ(results.return_nodes[0]->name(), "A_4");
|
||||
EXPECT_EQ(results.return_nodes[1]->name(), "A_3");
|
||||
EXPECT_EQ(results.return_nodes[2]->name(), "B_4");
|
||||
EXPECT_EQ(results.return_nodes[2]->def().input(0), "A_4:0");
|
||||
EXPECT_EQ(results.return_nodes[3]->name(), "B_3/B");
|
||||
EXPECT_EQ(results.return_nodes[3]->def().input(0), "A_3");
|
||||
EXPECT_EQ(results.return_nodes[0]->name(), "A_5");
|
||||
EXPECT_EQ(results.return_nodes[1]->name(), "A_4");
|
||||
EXPECT_EQ(results.return_nodes[2]->name(), "B_5");
|
||||
EXPECT_EQ(results.return_nodes[2]->def().input(0), "A_5:0");
|
||||
EXPECT_EQ(results.return_nodes[3]->name(), "B_4/B");
|
||||
EXPECT_EQ(results.return_nodes[3]->def().input(0), "A_4");
|
||||
|
||||
// Create node with prefix and then import node with same name
|
||||
ExpectOK("node { name: 'foo/abc' op: 'ABC' }");
|
||||
@ -1895,8 +1910,8 @@ TEST_F(GraphConstructorTest, ImportGraphDef_UniquifyNames) {
|
||||
ExpectOK(graph_def_str, opts, &refiner, &results);
|
||||
|
||||
ASSERT_EQ(results.return_nodes.size(), 2);
|
||||
EXPECT_EQ(results.return_nodes[0]->name(), "A_5");
|
||||
EXPECT_EQ(results.return_nodes[1]->name(), "B_5");
|
||||
EXPECT_EQ(results.return_nodes[0]->name(), "A_6");
|
||||
EXPECT_EQ(results.return_nodes[1]->name(), "B_6");
|
||||
EXPECT_EQ(results.return_nodes[1]->def().input(0), "A:0");
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user