Add a helper function to re-assign colocation in a graph.

PiperOrigin-RevId: 191679495
This commit is contained in:
A. Unique TensorFlower 2018-04-04 18:20:36 -07:00 committed by TensorFlower Gardener
parent 8abde65d3c
commit 2194f66f0a
6 changed files with 373 additions and 1 deletions

View File

@ -508,6 +508,7 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/grappler:grappler_item",
"//tensorflow/core/grappler/utils:colocation",
"//tensorflow/core/grappler/utils:topological_sort",
],
)

View File

@ -1,4 +1,4 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -28,6 +28,7 @@ limitations under the License.
#include "tensorflow/core/grappler/optimizers/loop_optimizer.h"
#include "tensorflow/core/grappler/optimizers/memory_optimizer.h"
#include "tensorflow/core/grappler/optimizers/model_pruner.h"
#include "tensorflow/core/grappler/utils/colocation.h"
#include "tensorflow/core/grappler/utils/topological_sort.h"
#include "tensorflow/core/lib/core/status.h"
@ -221,6 +222,7 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
if (already_optimized) {
TF_RETURN_IF_ERROR(TopologicalSort(optimized_graph));
ReassignColocation(optimized_graph);
// Make sure that the optimizers preserved the graph version and library.
DCHECK_GE(optimized_graph->library().function_size(),
item.graph.library().function_size());

View File

@ -181,3 +181,28 @@ tf_cc_test(
"//tensorflow/core:testlib",
],
)
cc_library(
name = "colocation",
srcs = ["colocation.cc"],
hdrs = ["colocation.h"],
visibility = ["//visibility:public"],
deps = [
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/grappler:utils",
],
)
tf_cc_test(
name = "colocation_test",
size = "small",
srcs = ["colocation_test.cc"],
deps = [
":colocation",
"//tensorflow/core:framework",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
],
)

View File

@ -0,0 +1,122 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/grappler/utils/colocation.h"
#include <cstring>
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/grappler/utils.h"
namespace tensorflow {
namespace grappler {
namespace {
// Find root node of the colocation group.
// The map is mapping from one node name to its parent. node_name is the
// starting node to search. By iteratively following the path from child to
// parent, we can find the root node for the colocation group that node_name
// belongs to.
string GetColocationGroupRoot(std::unordered_map<string, string>* map,
const string& node_name) {
if (map->find(node_name) == map->end()) {
// If node_name is not in the map, we create a new root node which points
// to itself.
map->insert({node_name, node_name});
return node_name;
}
string cur = node_name;
while ((*map)[cur] != cur) {
// Backtracing the map until we reach the root node.
cur = (*map)[cur];
}
return cur;
}
// Merge two colocation groups into one.
// left and right is the root node of two colocation groups respectively.
void MergeColocationGroup(std::unordered_map<string, string>* map,
const string& left, const string& right) {
// Do nothing if left or right node is not in the map.
if (map->find(left) == map->end() || map->find(right) == map->end()) {
return;
}
if (left != right) {
// Make the right node a child of the left node, which merges the two
// groups.
map->at(right) = left;
}
}
} // namespace
// Use of disjoint set algorithm to build the colocation groups from the input
// graph. The core data structure in use is a hash map from one node to its
// parent node. Whenever we see two nodes colocate with each other, we merge
// their colocation groups together. After we traverse all colocation pairs
// in the graph, we will have several disjoint sets. Then we pick the root node
// of each disjoint set as the representative node, and let all other nodes in
// the group colocate with the representative node.
void ReassignColocation(GraphDef* graph) {
constexpr char kClassAttr[] = "_class";
constexpr char kColocPrefix[] = "loc:@";
// A hashmap that maps from a node name to its parent node name.
std::unordered_map<string, string> coloc_groups;
NodeMap node_map(graph);
for (const auto& node : graph->node()) {
auto iter = node.attr().find(kClassAttr);
if (iter != node.attr().end() && iter->second.has_list()) {
for (const auto& str : iter->second.list().s()) {
size_t pos = str.find(kColocPrefix);
if (pos == 0) {
// After we find a colocation, update the colocation groups.
string colocate_node = str.substr(pos + strlen(kColocPrefix));
MergeColocationGroup(
&coloc_groups, GetColocationGroupRoot(&coloc_groups, node.name()),
GetColocationGroupRoot(&coloc_groups, colocate_node));
}
}
}
}
// We use the root node of each colocation groups as its representative
// node. For each node in one group, colocate with the representative node
// if the node is in the graph.
for (const auto& pair : coloc_groups) {
if (pair.first != pair.second) {
// This is a child node.
NodeDef* node = node_map.GetNode(pair.first);
if (node) {
// Colocate this node with the root node.
AttrValue new_value;
new_value.mutable_list()->add_s(
kColocPrefix + GetColocationGroupRoot(&coloc_groups, pair.first));
node->mutable_attr()->erase(kClassAttr);
node->mutable_attr()->insert({kClassAttr, new_value});
}
} else {
// This is a root node. Clear the _class attribute.
NodeDef* node = node_map.GetNode(pair.first);
if (node) { // root node should always exist in the graph as guaranteed
// by order of merging. Just put check here to ensure safety.
node->mutable_attr()->erase(kClassAttr);
}
}
}
}
} // namespace grappler
} // namespace tensorflow

View File

@ -0,0 +1,39 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_GRAPPLER_UTILS_COLOCATION_H_
#define TENSORFLOW_CORE_GRAPPLER_UTILS_COLOCATION_H_
#include <unordered_map>
#include "tensorflow/core/framework/graph.pb.h"
namespace tensorflow {
namespace grappler {
// Evaluates the colocation relation in the graph and rewrites the new
// colocation relation in the graph. We scan the graph nodes sequentially, and
// builds a disjoint-sets of nodes (within each disjoint-set the nodes are
// colocated with each other). We then select the root node of each set as a
// representative node, and then colocate each node within the set (should also
// exist in graph) with the representative node.
// Note that there is current one situation this function can't handle:
// Node A colocates with X, node B colocates with Y, X colocates with Y but
// X, Y are removed from graph. In this case we can't know A colocates with B.
void ReassignColocation(GraphDef* graph);
} // namespace grappler
} // namespace tensorflow
#endif // TENSORFLOW_CORE_GRAPPLER_UTILS_COLOCATION_H_

View File

@ -0,0 +1,183 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/grappler/utils/colocation.h"
#include "tensorflow/core/framework/function_testlib.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace grappler {
class ColocationTest : public ::testing::Test {};
bool VerifyNodeHasColocation(const NodeDef& ndef, const string& coloc) {
if (ndef.attr().empty()) {
return false;
}
if (ndef.attr().find("_class") == ndef.attr().end()) {
return false;
}
return ndef.attr().at("_class").list().s(0) == coloc;
}
TEST(ColocationTest, ReassignColocation_SingleNode) {
// Node A colocates with B, but node B is not in the graph.
// A
// |
// |
// [B]
NodeDef ndef;
const Status status =
NodeDefBuilder("A", "Const").Attr("_class", {"loc:@B"}).Finalize(&ndef);
TF_EXPECT_OK(status);
GraphDef gdef = test::function::GDef({ndef});
EXPECT_EQ(1, gdef.node_size());
EXPECT_EQ(1, gdef.node(0).attr_size());
ReassignColocation(&gdef);
// Validates that node A's colocation info is cleared.
EXPECT_EQ(1, gdef.node_size());
EXPECT_EQ(0, gdef.node(0).attr_size());
}
TEST(ColocationTest, ReassignColocation_MultiNode_SingleGroup) {
// Node A, B, C colocate with X. D colocates with C. E colocates with D.
// Node X is not in the graph.
// A B C---D---E
// | | |
// | | |
// +--[X]--+
// After re-assign of colocation, A, B, C, D should colocate with E.
// A B C D
// | | | |
// | | | |
// +---+-E-+---+
NodeDef ndef_a, ndef_b, ndef_c, ndef_d, ndef_e;
Status status =
NodeDefBuilder("A", "Const").Attr("_class", {"loc:@X"}).Finalize(&ndef_a);
TF_EXPECT_OK(status);
status =
NodeDefBuilder("B", "Const").Attr("_class", {"loc:@X"}).Finalize(&ndef_b);
TF_EXPECT_OK(status);
status =
NodeDefBuilder("C", "Const").Attr("_class", {"loc:@X"}).Finalize(&ndef_c);
TF_EXPECT_OK(status);
status =
NodeDefBuilder("D", "Const").Attr("_class", {"loc:@C"}).Finalize(&ndef_d);
TF_EXPECT_OK(status);
status =
NodeDefBuilder("E", "Const").Attr("_class", {"loc:@D"}).Finalize(&ndef_e);
TF_EXPECT_OK(status);
GraphDef gdef =
test::function::GDef({ndef_a, ndef_b, ndef_c, ndef_d, ndef_e});
EXPECT_EQ(5, gdef.node_size());
EXPECT_TRUE(VerifyNodeHasColocation(gdef.node(0), "loc:@X")); // A
EXPECT_TRUE(VerifyNodeHasColocation(gdef.node(1), "loc:@X")); // B
EXPECT_TRUE(VerifyNodeHasColocation(gdef.node(2), "loc:@X")); // C
EXPECT_TRUE(VerifyNodeHasColocation(gdef.node(3), "loc:@C")); // D
EXPECT_TRUE(VerifyNodeHasColocation(gdef.node(4), "loc:@D")); // E
ReassignColocation(&gdef);
EXPECT_EQ(5, gdef.node_size());
EXPECT_TRUE(VerifyNodeHasColocation(gdef.node(0), "loc:@E")); // A
EXPECT_TRUE(VerifyNodeHasColocation(gdef.node(1), "loc:@E")); // B
EXPECT_TRUE(VerifyNodeHasColocation(gdef.node(2), "loc:@E")); // C
EXPECT_TRUE(VerifyNodeHasColocation(gdef.node(3), "loc:@E")); // D
EXPECT_EQ(0, gdef.node(4).attr_size()); // E
}
TEST(ColocationTest, ReassignColocation_MultiNode_MultiGroup) {
// Before re-assign:
// Node A, B, C colocate with X. D colocates with C. E colocates with D.
// Node U, V colocates with W. Node X, W are not in the graph:
// A B C---D---E
// | | |
// | | |
// +--[X]--+
//
// U V
// | |
// | |
// +--[W]--+
//
// After re-assign:
// A, B, C, D should colocate with E. U should colocate with V.
// A B C D
// | | | |
// | | | |
// +---+-E-+---+
//
// U
// |
// |
// V
NodeDef ndef_a, ndef_b, ndef_c, ndef_d, ndef_e, ndef_u, ndef_v;
Status status =
NodeDefBuilder("A", "Const").Attr("_class", {"loc:@X"}).Finalize(&ndef_a);
TF_EXPECT_OK(status);
status =
NodeDefBuilder("B", "Const").Attr("_class", {"loc:@X"}).Finalize(&ndef_b);
TF_EXPECT_OK(status);
status =
NodeDefBuilder("C", "Const").Attr("_class", {"loc:@X"}).Finalize(&ndef_c);
TF_EXPECT_OK(status);
status =
NodeDefBuilder("D", "Const").Attr("_class", {"loc:@C"}).Finalize(&ndef_d);
TF_EXPECT_OK(status);
status =
NodeDefBuilder("E", "Const").Attr("_class", {"loc:@D"}).Finalize(&ndef_e);
TF_EXPECT_OK(status);
status =
NodeDefBuilder("U", "Const").Attr("_class", {"loc:@W"}).Finalize(&ndef_u);
TF_EXPECT_OK(status);
status =
NodeDefBuilder("V", "Const").Attr("_class", {"loc:@W"}).Finalize(&ndef_v);
TF_EXPECT_OK(status);
GraphDef gdef = test::function::GDef(
{ndef_a, ndef_b, ndef_c, ndef_d, ndef_e, ndef_u, ndef_v});
EXPECT_EQ(7, gdef.node_size());
EXPECT_TRUE(VerifyNodeHasColocation(gdef.node(0), "loc:@X")); // A
EXPECT_TRUE(VerifyNodeHasColocation(gdef.node(1), "loc:@X")); // B
EXPECT_TRUE(VerifyNodeHasColocation(gdef.node(2), "loc:@X")); // C
EXPECT_TRUE(VerifyNodeHasColocation(gdef.node(3), "loc:@C")); // D
EXPECT_TRUE(VerifyNodeHasColocation(gdef.node(4), "loc:@D")); // E
EXPECT_TRUE(VerifyNodeHasColocation(gdef.node(5), "loc:@W")); // U
EXPECT_TRUE(VerifyNodeHasColocation(gdef.node(6), "loc:@W")); // V
ReassignColocation(&gdef);
EXPECT_EQ(7, gdef.node_size());
EXPECT_TRUE(VerifyNodeHasColocation(gdef.node(0), "loc:@E")); // A
EXPECT_TRUE(VerifyNodeHasColocation(gdef.node(1), "loc:@E")); // B
EXPECT_TRUE(VerifyNodeHasColocation(gdef.node(2), "loc:@E")); // C
EXPECT_TRUE(VerifyNodeHasColocation(gdef.node(3), "loc:@E")); // D
EXPECT_EQ(0, gdef.node(4).attr_size()); // E
EXPECT_TRUE(VerifyNodeHasColocation(gdef.node(5), "loc:@V")); // U
EXPECT_EQ(0, gdef.node(6).attr_size()); // V
}
} // namespace grappler
} // namespace tensorflow