diff --git a/tensorflow/core/common_runtime/inline_function_utils.cc b/tensorflow/core/common_runtime/inline_function_utils.cc index 5a07573a430..362e4f2e0bc 100644 --- a/tensorflow/core/common_runtime/inline_function_utils.cc +++ b/tensorflow/core/common_runtime/inline_function_utils.cc @@ -587,6 +587,10 @@ Status InlineFunctionBody(const FunctionLibraryDefinition& flib_def, Graph* g, // // If 'x' is a node in fbody->graph and its copy in 'g' is 'y', we // remember 'y' in node_map[x->id()]. + std::unordered_set fn_nodes; + for (Node* n : fbody->graph->op_nodes()) { + fn_nodes.insert(n->name()); + } std::vector node_map(fbody->graph->num_node_ids()); for (Node* n : fbody->graph->op_nodes()) { NodeDef ndef = n->def(); @@ -605,6 +609,8 @@ Status InlineFunctionBody(const FunctionLibraryDefinition& flib_def, Graph* g, const string prefix = strings::StrCat(caller->name(), "/"); TF_RETURN_IF_ERROR(AddPrefixAndSuffixToNode(prefix, /*suffix=*/"", &ndef, options.uniquify_frame_names)); + TF_RETURN_IF_ERROR( + MaybeAddPrefixToColocationConstraints(fn_nodes, prefix, &ndef)); Status added_node; Node* clone = g->AddNode(ndef, &added_node); diff --git a/tensorflow/core/framework/node_def_util.cc b/tensorflow/core/framework/node_def_util.cc index be98c7cedfe..1146b02ed1c 100644 --- a/tensorflow/core/framework/node_def_util.cc +++ b/tensorflow/core/framework/node_def_util.cc @@ -795,6 +795,8 @@ bool IsValidControlInputName(StringPiece sp) { } } +const StringPiece kColocationGroupPrefixStringPiece(kColocationGroupPrefix); + } // namespace Status ValidateOpInput(const string& input_name, bool* is_control_input) { @@ -924,17 +926,27 @@ Status AddPrefixAndSuffixToNode(StringPiece prefix, StringPiece suffix, attr.set_s(frame_name); } - // Update colocation constraints. - constexpr char kClassAttr[] = "_class"; - auto class_attr = node_def->mutable_attr()->find(kClassAttr); - if (class_attr != node_def->mutable_attr()->end()) { - AttrValue new_value; - new_value.mutable_list()->add_s( - strings::StrCat(prefix, class_attr->second.s())); - node_def->mutable_attr()->erase(kClassAttr); - node_def->mutable_attr()->insert({kClassAttr, new_value}); - } + return Status::OK(); +} +Status MaybeAddPrefixToColocationConstraints( + const std::unordered_set& match, StringPiece prefix, + NodeDef* node_def) { + auto attr = node_def->mutable_attr()->find(kColocationAttrName); + if (attr == node_def->mutable_attr()->end()) { + return Status::OK(); + } + auto constraints_list = attr->second.mutable_list(); + auto constraints_size = constraints_list->s_size(); + for (size_t i = 0; i < constraints_size; ++i) { + StringPiece original(constraints_list->s(i)); + if (absl::ConsumePrefix(&original, kColocationGroupPrefixStringPiece)) { + if (match.find(string(original)) != match.end()) { + (*constraints_list->mutable_s(i)) = + strings::StrCat(kColocationGroupPrefix, prefix, original); + } + } + } return Status::OK(); } diff --git a/tensorflow/core/framework/node_def_util.h b/tensorflow/core/framework/node_def_util.h index d1a7c9aebba..d774d1cf414 100644 --- a/tensorflow/core/framework/node_def_util.h +++ b/tensorflow/core/framework/node_def_util.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_CORE_FRAMEWORK_NODE_DEF_UTIL_H_ #include +#include #include #include "tensorflow/core/framework/attr_value_util.h" @@ -391,6 +392,13 @@ Status AttachDef(const Status& status, const NodeDef& node_def, Status AddPrefixAndSuffixToNode(StringPiece prefix, StringPiece suffix, NodeDef* node_def, bool uniquify_frame_name = true); + +// Appends the given prefix to the colocation group name if the name exists +// in `to_match`. +Status MaybeAddPrefixToColocationConstraints( + const std::unordered_set& match, StringPiece prefix, + NodeDef* node_def); + } // namespace tensorflow #endif // TENSORFLOW_CORE_FRAMEWORK_NODE_DEF_UTIL_H_ diff --git a/tensorflow/core/framework/node_def_util_test.cc b/tensorflow/core/framework/node_def_util_test.cc index 2fc000d4e3c..b79b738353c 100644 --- a/tensorflow/core/framework/node_def_util_test.cc +++ b/tensorflow/core/framework/node_def_util_test.cc @@ -615,6 +615,39 @@ TEST(AddPrefixAndSuffixToNode, Enter) { EXPECT_EQ("prefix/test_frame/suffix", frame_name); } +TEST(MaybeAddPrefixToColocationConstraints, Basic) { + NodeDef node_def; + node_def.set_name("Identity"); + node_def.set_op("Identity"); + AddNodeAttr(kColocationAttrName, + {strings::StrCat(kColocationGroupPrefix, "Node1"), + strings::StrCat(kColocationGroupPrefix, "Node2"), + strings::StrCat(kColocationGroupPrefix, "Node3")}, + &node_def); + + std::unordered_set match; + match.insert("Node1"); + match.insert("Node3"); + TF_ASSERT_OK(MaybeAddPrefixToColocationConstraints(match, "fn/", &node_def)); + std::vector coloc_constraints; + TF_ASSERT_OK(GetNodeAttr(node_def, kColocationAttrName, &coloc_constraints)); + EXPECT_EQ( + coloc_constraints, + std::vector({"loc:@fn/Node1", "loc:@Node2", "loc:@fn/Node3"})); +} + +TEST(MaybeAddPrefixToColocationConstraints, NoConstraints) { + NodeDef node_def; + node_def.set_name("Identity"); + node_def.set_op("Identity"); + + std::unordered_set match; + match.insert("Node1"); + match.insert("Node3"); + TF_ASSERT_OK(MaybeAddPrefixToColocationConstraints(match, "fn/", &node_def)); + EXPECT_FALSE(HasNodeAttr(node_def, kColocationAttrName)); +} + TEST(FormatNodeForErrorTest, Node) { Graph g(OpRegistry::Global()); Node* node;