Fix colocation in function inlining
PiperOrigin-RevId: 332305597 Change-Id: I28c18160bc82d86cc771f0ba6e051ab1442cfe8f
This commit is contained in:
parent
5d80f5900e
commit
2826431775
@ -587,6 +587,10 @@ Status InlineFunctionBody(const FunctionLibraryDefinition& flib_def, Graph* g,
|
||||
//
|
||||
// If 'x' is a node in fbody->graph and its copy in 'g' is 'y', we
|
||||
// remember 'y' in node_map[x->id()].
|
||||
std::unordered_set<string> fn_nodes;
|
||||
for (Node* n : fbody->graph->op_nodes()) {
|
||||
fn_nodes.insert(n->name());
|
||||
}
|
||||
std::vector<Node*> node_map(fbody->graph->num_node_ids());
|
||||
for (Node* n : fbody->graph->op_nodes()) {
|
||||
NodeDef ndef = n->def();
|
||||
@ -605,6 +609,8 @@ Status InlineFunctionBody(const FunctionLibraryDefinition& flib_def, Graph* g,
|
||||
const string prefix = strings::StrCat(caller->name(), "/");
|
||||
TF_RETURN_IF_ERROR(AddPrefixAndSuffixToNode(prefix, /*suffix=*/"", &ndef,
|
||||
options.uniquify_frame_names));
|
||||
TF_RETURN_IF_ERROR(
|
||||
MaybeAddPrefixToColocationConstraints(fn_nodes, prefix, &ndef));
|
||||
|
||||
Status added_node;
|
||||
Node* clone = g->AddNode(ndef, &added_node);
|
||||
|
@ -795,6 +795,8 @@ bool IsValidControlInputName(StringPiece sp) {
|
||||
}
|
||||
}
|
||||
|
||||
const StringPiece kColocationGroupPrefixStringPiece(kColocationGroupPrefix);
|
||||
|
||||
} // namespace
|
||||
|
||||
Status ValidateOpInput(const string& input_name, bool* is_control_input) {
|
||||
@ -924,17 +926,27 @@ Status AddPrefixAndSuffixToNode(StringPiece prefix, StringPiece suffix,
|
||||
attr.set_s(frame_name);
|
||||
}
|
||||
|
||||
// Update colocation constraints.
|
||||
constexpr char kClassAttr[] = "_class";
|
||||
auto class_attr = node_def->mutable_attr()->find(kClassAttr);
|
||||
if (class_attr != node_def->mutable_attr()->end()) {
|
||||
AttrValue new_value;
|
||||
new_value.mutable_list()->add_s(
|
||||
strings::StrCat(prefix, class_attr->second.s()));
|
||||
node_def->mutable_attr()->erase(kClassAttr);
|
||||
node_def->mutable_attr()->insert({kClassAttr, new_value});
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status MaybeAddPrefixToColocationConstraints(
|
||||
const std::unordered_set<string>& match, StringPiece prefix,
|
||||
NodeDef* node_def) {
|
||||
auto attr = node_def->mutable_attr()->find(kColocationAttrName);
|
||||
if (attr == node_def->mutable_attr()->end()) {
|
||||
return Status::OK();
|
||||
}
|
||||
auto constraints_list = attr->second.mutable_list();
|
||||
auto constraints_size = constraints_list->s_size();
|
||||
for (size_t i = 0; i < constraints_size; ++i) {
|
||||
StringPiece original(constraints_list->s(i));
|
||||
if (absl::ConsumePrefix(&original, kColocationGroupPrefixStringPiece)) {
|
||||
if (match.find(string(original)) != match.end()) {
|
||||
(*constraints_list->mutable_s(i)) =
|
||||
strings::StrCat(kColocationGroupPrefix, prefix, original);
|
||||
}
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
||||
#define TENSORFLOW_CORE_FRAMEWORK_NODE_DEF_UTIL_H_
|
||||
|
||||
#include <string>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/framework/attr_value_util.h"
|
||||
@ -391,6 +392,13 @@ Status AttachDef(const Status& status, const NodeDef& node_def,
|
||||
Status AddPrefixAndSuffixToNode(StringPiece prefix, StringPiece suffix,
|
||||
NodeDef* node_def,
|
||||
bool uniquify_frame_name = true);
|
||||
|
||||
// Appends the given prefix to the colocation group name if the name exists
|
||||
// in `to_match`.
|
||||
Status MaybeAddPrefixToColocationConstraints(
|
||||
const std::unordered_set<string>& match, StringPiece prefix,
|
||||
NodeDef* node_def);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_FRAMEWORK_NODE_DEF_UTIL_H_
|
||||
|
@ -615,6 +615,39 @@ TEST(AddPrefixAndSuffixToNode, Enter) {
|
||||
EXPECT_EQ("prefix/test_frame/suffix", frame_name);
|
||||
}
|
||||
|
||||
TEST(MaybeAddPrefixToColocationConstraints, Basic) {
|
||||
NodeDef node_def;
|
||||
node_def.set_name("Identity");
|
||||
node_def.set_op("Identity");
|
||||
AddNodeAttr(kColocationAttrName,
|
||||
{strings::StrCat(kColocationGroupPrefix, "Node1"),
|
||||
strings::StrCat(kColocationGroupPrefix, "Node2"),
|
||||
strings::StrCat(kColocationGroupPrefix, "Node3")},
|
||||
&node_def);
|
||||
|
||||
std::unordered_set<string> match;
|
||||
match.insert("Node1");
|
||||
match.insert("Node3");
|
||||
TF_ASSERT_OK(MaybeAddPrefixToColocationConstraints(match, "fn/", &node_def));
|
||||
std::vector<string> coloc_constraints;
|
||||
TF_ASSERT_OK(GetNodeAttr(node_def, kColocationAttrName, &coloc_constraints));
|
||||
EXPECT_EQ(
|
||||
coloc_constraints,
|
||||
std::vector<string>({"loc:@fn/Node1", "loc:@Node2", "loc:@fn/Node3"}));
|
||||
}
|
||||
|
||||
TEST(MaybeAddPrefixToColocationConstraints, NoConstraints) {
|
||||
NodeDef node_def;
|
||||
node_def.set_name("Identity");
|
||||
node_def.set_op("Identity");
|
||||
|
||||
std::unordered_set<string> match;
|
||||
match.insert("Node1");
|
||||
match.insert("Node3");
|
||||
TF_ASSERT_OK(MaybeAddPrefixToColocationConstraints(match, "fn/", &node_def));
|
||||
EXPECT_FALSE(HasNodeAttr(node_def, kColocationAttrName));
|
||||
}
|
||||
|
||||
TEST(FormatNodeForErrorTest, Node) {
|
||||
Graph g(OpRegistry::Global());
|
||||
Node* node;
|
||||
|
Loading…
Reference in New Issue
Block a user