Fix colocation in function inlining

PiperOrigin-RevId: 332305597
Change-Id: I28c18160bc82d86cc771f0ba6e051ab1442cfe8f
This commit is contained in:
Rachel Lim 2020-09-17 13:37:54 -07:00 committed by TensorFlower Gardener
parent 5d80f5900e
commit 2826431775
4 changed files with 69 additions and 10 deletions

View File

@ -587,6 +587,10 @@ Status InlineFunctionBody(const FunctionLibraryDefinition& flib_def, Graph* g,
//
// If 'x' is a node in fbody->graph and its copy in 'g' is 'y', we
// remember 'y' in node_map[x->id()].
std::unordered_set<string> fn_nodes;
for (Node* n : fbody->graph->op_nodes()) {
fn_nodes.insert(n->name());
}
std::vector<Node*> node_map(fbody->graph->num_node_ids());
for (Node* n : fbody->graph->op_nodes()) {
NodeDef ndef = n->def();
@ -605,6 +609,8 @@ Status InlineFunctionBody(const FunctionLibraryDefinition& flib_def, Graph* g,
const string prefix = strings::StrCat(caller->name(), "/");
TF_RETURN_IF_ERROR(AddPrefixAndSuffixToNode(prefix, /*suffix=*/"", &ndef,
options.uniquify_frame_names));
TF_RETURN_IF_ERROR(
MaybeAddPrefixToColocationConstraints(fn_nodes, prefix, &ndef));
Status added_node;
Node* clone = g->AddNode(ndef, &added_node);

View File

@ -795,6 +795,8 @@ bool IsValidControlInputName(StringPiece sp) {
}
}
const StringPiece kColocationGroupPrefixStringPiece(kColocationGroupPrefix);
} // namespace
Status ValidateOpInput(const string& input_name, bool* is_control_input) {
@ -924,17 +926,27 @@ Status AddPrefixAndSuffixToNode(StringPiece prefix, StringPiece suffix,
attr.set_s(frame_name);
}
// Update colocation constraints.
constexpr char kClassAttr[] = "_class";
auto class_attr = node_def->mutable_attr()->find(kClassAttr);
if (class_attr != node_def->mutable_attr()->end()) {
AttrValue new_value;
new_value.mutable_list()->add_s(
strings::StrCat(prefix, class_attr->second.s()));
node_def->mutable_attr()->erase(kClassAttr);
node_def->mutable_attr()->insert({kClassAttr, new_value});
}
return Status::OK();
}
Status MaybeAddPrefixToColocationConstraints(
const std::unordered_set<string>& match, StringPiece prefix,
NodeDef* node_def) {
auto attr = node_def->mutable_attr()->find(kColocationAttrName);
if (attr == node_def->mutable_attr()->end()) {
return Status::OK();
}
auto constraints_list = attr->second.mutable_list();
auto constraints_size = constraints_list->s_size();
for (size_t i = 0; i < constraints_size; ++i) {
StringPiece original(constraints_list->s(i));
if (absl::ConsumePrefix(&original, kColocationGroupPrefixStringPiece)) {
if (match.find(string(original)) != match.end()) {
(*constraints_list->mutable_s(i)) =
strings::StrCat(kColocationGroupPrefix, prefix, original);
}
}
}
return Status::OK();
}

View File

@ -17,6 +17,7 @@ limitations under the License.
#define TENSORFLOW_CORE_FRAMEWORK_NODE_DEF_UTIL_H_
#include <string>
#include <unordered_set>
#include <vector>
#include "tensorflow/core/framework/attr_value_util.h"
@ -391,6 +392,13 @@ Status AttachDef(const Status& status, const NodeDef& node_def,
Status AddPrefixAndSuffixToNode(StringPiece prefix, StringPiece suffix,
NodeDef* node_def,
bool uniquify_frame_name = true);
// Appends the given prefix to the colocation group name if the name exists
// in `to_match`.
Status MaybeAddPrefixToColocationConstraints(
const std::unordered_set<string>& match, StringPiece prefix,
NodeDef* node_def);
} // namespace tensorflow
#endif // TENSORFLOW_CORE_FRAMEWORK_NODE_DEF_UTIL_H_

View File

@ -615,6 +615,39 @@ TEST(AddPrefixAndSuffixToNode, Enter) {
EXPECT_EQ("prefix/test_frame/suffix", frame_name);
}
TEST(MaybeAddPrefixToColocationConstraints, Basic) {
NodeDef node_def;
node_def.set_name("Identity");
node_def.set_op("Identity");
AddNodeAttr(kColocationAttrName,
{strings::StrCat(kColocationGroupPrefix, "Node1"),
strings::StrCat(kColocationGroupPrefix, "Node2"),
strings::StrCat(kColocationGroupPrefix, "Node3")},
&node_def);
std::unordered_set<string> match;
match.insert("Node1");
match.insert("Node3");
TF_ASSERT_OK(MaybeAddPrefixToColocationConstraints(match, "fn/", &node_def));
std::vector<string> coloc_constraints;
TF_ASSERT_OK(GetNodeAttr(node_def, kColocationAttrName, &coloc_constraints));
EXPECT_EQ(
coloc_constraints,
std::vector<string>({"loc:@fn/Node1", "loc:@Node2", "loc:@fn/Node3"}));
}
TEST(MaybeAddPrefixToColocationConstraints, NoConstraints) {
NodeDef node_def;
node_def.set_name("Identity");
node_def.set_op("Identity");
std::unordered_set<string> match;
match.insert("Node1");
match.insert("Node3");
TF_ASSERT_OK(MaybeAddPrefixToColocationConstraints(match, "fn/", &node_def));
EXPECT_FALSE(HasNodeAttr(node_def, kColocationAttrName));
}
TEST(FormatNodeForErrorTest, Node) {
Graph g(OpRegistry::Global());
Node* node;