Add a helper function to re-assign colocation in a graph.
PiperOrigin-RevId: 191679495
This commit is contained in:
parent
8abde65d3c
commit
2194f66f0a
@ -508,6 +508,7 @@ cc_library(
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/grappler:grappler_item",
|
||||
"//tensorflow/core/grappler/utils:colocation",
|
||||
"//tensorflow/core/grappler/utils:topological_sort",
|
||||
],
|
||||
)
|
||||
|
@ -1,4 +1,4 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@ -28,6 +28,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/grappler/optimizers/loop_optimizer.h"
|
||||
#include "tensorflow/core/grappler/optimizers/memory_optimizer.h"
|
||||
#include "tensorflow/core/grappler/optimizers/model_pruner.h"
|
||||
#include "tensorflow/core/grappler/utils/colocation.h"
|
||||
#include "tensorflow/core/grappler/utils/topological_sort.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
|
||||
@ -221,6 +222,7 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
|
||||
|
||||
if (already_optimized) {
|
||||
TF_RETURN_IF_ERROR(TopologicalSort(optimized_graph));
|
||||
ReassignColocation(optimized_graph);
|
||||
// Make sure that the optimizers preserved the graph version and library.
|
||||
DCHECK_GE(optimized_graph->library().function_size(),
|
||||
item.graph.library().function_size());
|
||||
|
@ -181,3 +181,28 @@ tf_cc_test(
|
||||
"//tensorflow/core:testlib",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "colocation",
|
||||
srcs = ["colocation.cc"],
|
||||
hdrs = ["colocation.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/grappler:utils",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "colocation_test",
|
||||
size = "small",
|
||||
srcs = ["colocation_test.cc"],
|
||||
deps = [
|
||||
":colocation",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
],
|
||||
)
|
||||
|
122
tensorflow/core/grappler/utils/colocation.cc
Normal file
122
tensorflow/core/grappler/utils/colocation.cc
Normal file
@ -0,0 +1,122 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/grappler/utils/colocation.h"
|
||||
|
||||
#include <cstring>
|
||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/grappler/utils.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
|
||||
namespace {
|
||||
|
||||
// Find root node of the colocation group.
|
||||
// The map is mapping from one node name to its parent. node_name is the
|
||||
// starting node to search. By iteratively following the path from child to
|
||||
// parent, we can find the root node for the colocation group that node_name
|
||||
// belongs to.
|
||||
string GetColocationGroupRoot(std::unordered_map<string, string>* map,
|
||||
const string& node_name) {
|
||||
if (map->find(node_name) == map->end()) {
|
||||
// If node_name is not in the map, we create a new root node which points
|
||||
// to itself.
|
||||
map->insert({node_name, node_name});
|
||||
return node_name;
|
||||
}
|
||||
string cur = node_name;
|
||||
while ((*map)[cur] != cur) {
|
||||
// Backtracing the map until we reach the root node.
|
||||
cur = (*map)[cur];
|
||||
}
|
||||
return cur;
|
||||
}
|
||||
|
||||
// Merge two colocation groups into one.
|
||||
// left and right is the root node of two colocation groups respectively.
|
||||
void MergeColocationGroup(std::unordered_map<string, string>* map,
|
||||
const string& left, const string& right) {
|
||||
// Do nothing if left or right node is not in the map.
|
||||
if (map->find(left) == map->end() || map->find(right) == map->end()) {
|
||||
return;
|
||||
}
|
||||
if (left != right) {
|
||||
// Make the right node a child of the left node, which merges the two
|
||||
// groups.
|
||||
map->at(right) = left;
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
// Use of disjoint set algorithm to build the colocation groups from the input
|
||||
// graph. The core data structure in use is a hash map from one node to its
|
||||
// parent node. Whenever we see two nodes colocate with each other, we merge
|
||||
// their colocation groups together. After we traverse all colocation pairs
|
||||
// in the graph, we will have several disjoint sets. Then we pick the root node
|
||||
// of each disjoint set as the representative node, and let all other nodes in
|
||||
// the group colocate with the representative node.
|
||||
void ReassignColocation(GraphDef* graph) {
|
||||
constexpr char kClassAttr[] = "_class";
|
||||
constexpr char kColocPrefix[] = "loc:@";
|
||||
|
||||
// A hashmap that maps from a node name to its parent node name.
|
||||
std::unordered_map<string, string> coloc_groups;
|
||||
NodeMap node_map(graph);
|
||||
for (const auto& node : graph->node()) {
|
||||
auto iter = node.attr().find(kClassAttr);
|
||||
if (iter != node.attr().end() && iter->second.has_list()) {
|
||||
for (const auto& str : iter->second.list().s()) {
|
||||
size_t pos = str.find(kColocPrefix);
|
||||
if (pos == 0) {
|
||||
// After we find a colocation, update the colocation groups.
|
||||
string colocate_node = str.substr(pos + strlen(kColocPrefix));
|
||||
MergeColocationGroup(
|
||||
&coloc_groups, GetColocationGroupRoot(&coloc_groups, node.name()),
|
||||
GetColocationGroupRoot(&coloc_groups, colocate_node));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// We use the root node of each colocation groups as its representative
|
||||
// node. For each node in one group, colocate with the representative node
|
||||
// if the node is in the graph.
|
||||
for (const auto& pair : coloc_groups) {
|
||||
if (pair.first != pair.second) {
|
||||
// This is a child node.
|
||||
NodeDef* node = node_map.GetNode(pair.first);
|
||||
if (node) {
|
||||
// Colocate this node with the root node.
|
||||
AttrValue new_value;
|
||||
new_value.mutable_list()->add_s(
|
||||
kColocPrefix + GetColocationGroupRoot(&coloc_groups, pair.first));
|
||||
node->mutable_attr()->erase(kClassAttr);
|
||||
node->mutable_attr()->insert({kClassAttr, new_value});
|
||||
}
|
||||
} else {
|
||||
// This is a root node. Clear the _class attribute.
|
||||
NodeDef* node = node_map.GetNode(pair.first);
|
||||
if (node) { // root node should always exist in the graph as guaranteed
|
||||
// by order of merging. Just put check here to ensure safety.
|
||||
node->mutable_attr()->erase(kClassAttr);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace grappler
|
||||
} // namespace tensorflow
|
39
tensorflow/core/grappler/utils/colocation.h
Normal file
39
tensorflow/core/grappler/utils/colocation.h
Normal file
@ -0,0 +1,39 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_GRAPPLER_UTILS_COLOCATION_H_
|
||||
#define TENSORFLOW_CORE_GRAPPLER_UTILS_COLOCATION_H_
|
||||
|
||||
#include <unordered_map>
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
|
||||
// Evaluates the colocation relation in the graph and rewrites the new
|
||||
// colocation relation in the graph. We scan the graph nodes sequentially, and
|
||||
// builds a disjoint-sets of nodes (within each disjoint-set the nodes are
|
||||
// colocated with each other). We then select the root node of each set as a
|
||||
// representative node, and then colocate each node within the set (should also
|
||||
// exist in graph) with the representative node.
|
||||
// Note that there is current one situation this function can't handle:
|
||||
// Node A colocates with X, node B colocates with Y, X colocates with Y but
|
||||
// X, Y are removed from graph. In this case we can't know A colocates with B.
|
||||
void ReassignColocation(GraphDef* graph);
|
||||
|
||||
} // namespace grappler
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_GRAPPLER_UTILS_COLOCATION_H_
|
183
tensorflow/core/grappler/utils/colocation_test.cc
Normal file
183
tensorflow/core/grappler/utils/colocation_test.cc
Normal file
@ -0,0 +1,183 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/grappler/utils/colocation.h"
|
||||
|
||||
#include "tensorflow/core/framework/function_testlib.h"
|
||||
#include "tensorflow/core/framework/node_def_builder.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
|
||||
class ColocationTest : public ::testing::Test {};
|
||||
|
||||
bool VerifyNodeHasColocation(const NodeDef& ndef, const string& coloc) {
|
||||
if (ndef.attr().empty()) {
|
||||
return false;
|
||||
}
|
||||
if (ndef.attr().find("_class") == ndef.attr().end()) {
|
||||
return false;
|
||||
}
|
||||
return ndef.attr().at("_class").list().s(0) == coloc;
|
||||
}
|
||||
|
||||
TEST(ColocationTest, ReassignColocation_SingleNode) {
|
||||
// Node A colocates with B, but node B is not in the graph.
|
||||
// A
|
||||
// |
|
||||
// |
|
||||
// [B]
|
||||
|
||||
NodeDef ndef;
|
||||
const Status status =
|
||||
NodeDefBuilder("A", "Const").Attr("_class", {"loc:@B"}).Finalize(&ndef);
|
||||
TF_EXPECT_OK(status);
|
||||
GraphDef gdef = test::function::GDef({ndef});
|
||||
|
||||
EXPECT_EQ(1, gdef.node_size());
|
||||
EXPECT_EQ(1, gdef.node(0).attr_size());
|
||||
|
||||
ReassignColocation(&gdef);
|
||||
|
||||
// Validates that node A's colocation info is cleared.
|
||||
EXPECT_EQ(1, gdef.node_size());
|
||||
EXPECT_EQ(0, gdef.node(0).attr_size());
|
||||
}
|
||||
|
||||
TEST(ColocationTest, ReassignColocation_MultiNode_SingleGroup) {
|
||||
// Node A, B, C colocate with X. D colocates with C. E colocates with D.
|
||||
// Node X is not in the graph.
|
||||
// A B C---D---E
|
||||
// | | |
|
||||
// | | |
|
||||
// +--[X]--+
|
||||
// After re-assign of colocation, A, B, C, D should colocate with E.
|
||||
// A B C D
|
||||
// | | | |
|
||||
// | | | |
|
||||
// +---+-E-+---+
|
||||
|
||||
NodeDef ndef_a, ndef_b, ndef_c, ndef_d, ndef_e;
|
||||
Status status =
|
||||
NodeDefBuilder("A", "Const").Attr("_class", {"loc:@X"}).Finalize(&ndef_a);
|
||||
TF_EXPECT_OK(status);
|
||||
status =
|
||||
NodeDefBuilder("B", "Const").Attr("_class", {"loc:@X"}).Finalize(&ndef_b);
|
||||
TF_EXPECT_OK(status);
|
||||
status =
|
||||
NodeDefBuilder("C", "Const").Attr("_class", {"loc:@X"}).Finalize(&ndef_c);
|
||||
TF_EXPECT_OK(status);
|
||||
status =
|
||||
NodeDefBuilder("D", "Const").Attr("_class", {"loc:@C"}).Finalize(&ndef_d);
|
||||
TF_EXPECT_OK(status);
|
||||
status =
|
||||
NodeDefBuilder("E", "Const").Attr("_class", {"loc:@D"}).Finalize(&ndef_e);
|
||||
TF_EXPECT_OK(status);
|
||||
GraphDef gdef =
|
||||
test::function::GDef({ndef_a, ndef_b, ndef_c, ndef_d, ndef_e});
|
||||
|
||||
EXPECT_EQ(5, gdef.node_size());
|
||||
EXPECT_TRUE(VerifyNodeHasColocation(gdef.node(0), "loc:@X")); // A
|
||||
EXPECT_TRUE(VerifyNodeHasColocation(gdef.node(1), "loc:@X")); // B
|
||||
EXPECT_TRUE(VerifyNodeHasColocation(gdef.node(2), "loc:@X")); // C
|
||||
EXPECT_TRUE(VerifyNodeHasColocation(gdef.node(3), "loc:@C")); // D
|
||||
EXPECT_TRUE(VerifyNodeHasColocation(gdef.node(4), "loc:@D")); // E
|
||||
|
||||
ReassignColocation(&gdef);
|
||||
|
||||
EXPECT_EQ(5, gdef.node_size());
|
||||
EXPECT_TRUE(VerifyNodeHasColocation(gdef.node(0), "loc:@E")); // A
|
||||
EXPECT_TRUE(VerifyNodeHasColocation(gdef.node(1), "loc:@E")); // B
|
||||
EXPECT_TRUE(VerifyNodeHasColocation(gdef.node(2), "loc:@E")); // C
|
||||
EXPECT_TRUE(VerifyNodeHasColocation(gdef.node(3), "loc:@E")); // D
|
||||
EXPECT_EQ(0, gdef.node(4).attr_size()); // E
|
||||
}
|
||||
|
||||
TEST(ColocationTest, ReassignColocation_MultiNode_MultiGroup) {
|
||||
// Before re-assign:
|
||||
// Node A, B, C colocate with X. D colocates with C. E colocates with D.
|
||||
// Node U, V colocates with W. Node X, W are not in the graph:
|
||||
// A B C---D---E
|
||||
// | | |
|
||||
// | | |
|
||||
// +--[X]--+
|
||||
//
|
||||
// U V
|
||||
// | |
|
||||
// | |
|
||||
// +--[W]--+
|
||||
//
|
||||
// After re-assign:
|
||||
// A, B, C, D should colocate with E. U should colocate with V.
|
||||
// A B C D
|
||||
// | | | |
|
||||
// | | | |
|
||||
// +---+-E-+---+
|
||||
//
|
||||
// U
|
||||
// |
|
||||
// |
|
||||
// V
|
||||
|
||||
NodeDef ndef_a, ndef_b, ndef_c, ndef_d, ndef_e, ndef_u, ndef_v;
|
||||
Status status =
|
||||
NodeDefBuilder("A", "Const").Attr("_class", {"loc:@X"}).Finalize(&ndef_a);
|
||||
TF_EXPECT_OK(status);
|
||||
status =
|
||||
NodeDefBuilder("B", "Const").Attr("_class", {"loc:@X"}).Finalize(&ndef_b);
|
||||
TF_EXPECT_OK(status);
|
||||
status =
|
||||
NodeDefBuilder("C", "Const").Attr("_class", {"loc:@X"}).Finalize(&ndef_c);
|
||||
TF_EXPECT_OK(status);
|
||||
status =
|
||||
NodeDefBuilder("D", "Const").Attr("_class", {"loc:@C"}).Finalize(&ndef_d);
|
||||
TF_EXPECT_OK(status);
|
||||
status =
|
||||
NodeDefBuilder("E", "Const").Attr("_class", {"loc:@D"}).Finalize(&ndef_e);
|
||||
TF_EXPECT_OK(status);
|
||||
status =
|
||||
NodeDefBuilder("U", "Const").Attr("_class", {"loc:@W"}).Finalize(&ndef_u);
|
||||
TF_EXPECT_OK(status);
|
||||
status =
|
||||
NodeDefBuilder("V", "Const").Attr("_class", {"loc:@W"}).Finalize(&ndef_v);
|
||||
TF_EXPECT_OK(status);
|
||||
GraphDef gdef = test::function::GDef(
|
||||
{ndef_a, ndef_b, ndef_c, ndef_d, ndef_e, ndef_u, ndef_v});
|
||||
|
||||
EXPECT_EQ(7, gdef.node_size());
|
||||
EXPECT_TRUE(VerifyNodeHasColocation(gdef.node(0), "loc:@X")); // A
|
||||
EXPECT_TRUE(VerifyNodeHasColocation(gdef.node(1), "loc:@X")); // B
|
||||
EXPECT_TRUE(VerifyNodeHasColocation(gdef.node(2), "loc:@X")); // C
|
||||
EXPECT_TRUE(VerifyNodeHasColocation(gdef.node(3), "loc:@C")); // D
|
||||
EXPECT_TRUE(VerifyNodeHasColocation(gdef.node(4), "loc:@D")); // E
|
||||
EXPECT_TRUE(VerifyNodeHasColocation(gdef.node(5), "loc:@W")); // U
|
||||
EXPECT_TRUE(VerifyNodeHasColocation(gdef.node(6), "loc:@W")); // V
|
||||
|
||||
ReassignColocation(&gdef);
|
||||
|
||||
EXPECT_EQ(7, gdef.node_size());
|
||||
EXPECT_TRUE(VerifyNodeHasColocation(gdef.node(0), "loc:@E")); // A
|
||||
EXPECT_TRUE(VerifyNodeHasColocation(gdef.node(1), "loc:@E")); // B
|
||||
EXPECT_TRUE(VerifyNodeHasColocation(gdef.node(2), "loc:@E")); // C
|
||||
EXPECT_TRUE(VerifyNodeHasColocation(gdef.node(3), "loc:@E")); // D
|
||||
EXPECT_EQ(0, gdef.node(4).attr_size()); // E
|
||||
EXPECT_TRUE(VerifyNodeHasColocation(gdef.node(5), "loc:@V")); // U
|
||||
EXPECT_EQ(0, gdef.node(6).attr_size()); // V
|
||||
}
|
||||
|
||||
} // namespace grappler
|
||||
} // namespace tensorflow
|
Loading…
x
Reference in New Issue
Block a user