Add GraphTopologyView to efficiently traverse node-to-node connections.

+ Remove SimpleGraphView from topological sorting.

PiperOrigin-RevId: 226932668
This commit is contained in:
Eugene Zhulenev 2018-12-26 10:43:22 -08:00 committed by TensorFlower Gardener
parent 87af907e8c
commit 9585116b80
9 changed files with 662 additions and 116 deletions

View File

@ -62,6 +62,36 @@ tf_cuda_library(
],
)
cc_library(
name = "graph_topology_view",
srcs = ["graph_topology_view.cc"],
hdrs = ["graph_topology_view.h"],
visibility = ["//visibility:public"],
deps = [
":graph_view",
"//tensorflow/core:graph",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
"@com_google_absl//absl/types:span",
],
)
tf_cc_test(
name = "graph_topology_view_test",
srcs = ["graph_topology_view_test.cc"],
deps = [
":graph_topology_view",
":graph_view",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
],
)
cc_library(
name = "graph_view",
srcs = ["graph_view.cc"],

View File

@ -425,9 +425,11 @@ NodeDef MakeConstNodeDefFromShape(InferenceContext* ic,
// information is refined.
class TopoQueue {
public:
explicit TopoQueue(const std::unordered_map<const NodeDef*, int>& topo_order)
: topo_order_(topo_order) {}
explicit TopoQueue(const std::vector<const NodeDef*>& topo_order)
: topo_order_(TopoOrder(topo_order)) {}
void push(const NodeDef* n) { queue_.emplace(n, topo_order_.at(n)); }
const NodeDef* pop() {
CHECK(!empty());
auto it = queue_.begin();
@ -448,7 +450,18 @@ class TopoQueue {
return lhs.second < rhs.second;
}
};
const std::unordered_map<const NodeDef*, int>& topo_order_;
const std::unordered_map<const NodeDef*, int> TopoOrder(
const std::vector<const NodeDef*>& topo_order) const {
std::unordered_map<const NodeDef*, int> map;
map.reserve(topo_order.size());
for (int i = 0; i < topo_order.size(); ++i) {
map.emplace(topo_order[i], i);
}
return map;
}
const std::unordered_map<const NodeDef*, int> topo_order_;
std::set<NodeAndId, OrderByIdAscending> queue_;
};
@ -1970,7 +1983,7 @@ Status GraphProperties::InferStatically(bool assume_valid_feeds,
}
std::unordered_map<const NodeDef*, const NodeDef*> resource_handles;
std::vector<std::pair<const NodeDef*, const NodeDef*>> extra_deps;
std::vector<TopologicalDependency> extra_deps;
for (const auto& resource : resources) {
for (const NodeDef* src : resource.second.first) {
resource_handles[src] = resource.first;
@ -1982,8 +1995,8 @@ Status GraphProperties::InferStatically(bool assume_valid_feeds,
}
}
std::unordered_map<const NodeDef*, int> topo_order;
Status s = ComputeTopologicalOrder(item_.graph, &topo_order, &extra_deps);
std::vector<const NodeDef*> topo_order;
Status s = ComputeTopologicalOrder(item_.graph, extra_deps, &topo_order);
if (!s.ok()) {
if (extra_deps.empty()) {
return s;
@ -1992,8 +2005,7 @@ Status GraphProperties::InferStatically(bool assume_valid_feeds,
// order. This will make the shape inference less precise but since this
// isn't common it's not worth to figure out where to break the loop and
// do a proper relaxation.
TF_RETURN_IF_ERROR(
ComputeTopologicalOrder(item_.graph, &topo_order, nullptr));
TF_RETURN_IF_ERROR(ComputeTopologicalOrder(item_.graph, &topo_order));
}
}

View File

@ -0,0 +1,163 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/grappler/graph_topology_view.h"
#include <algorithm>
#include "absl/container/flat_hash_map.h"
#include "absl/container/inlined_vector.h"
#include "absl/strings/string_view.h"
#include "absl/types/optional.h"
#include "absl/types/span.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
namespace tensorflow {
namespace grappler {
namespace {
template <typename T>
inline void SortAndRemoveDuplicates(T* v) {
std::sort(v->begin(), v->end());
v->erase(std::unique(v->begin(), v->end()), v->end());
}
} // namespace
Status GraphTopologyView::InitializeFromGraph(
const GraphDef& graph,
const absl::Span<const GraphView::Edge> ephemeral_edges) {
if (graph_ != nullptr) {
return errors::InvalidArgument("GraphTopologyView is already initialized.");
}
graph_ = &graph;
num_nodes_ = graph.node_size();
index_to_node_name_.resize(num_nodes_);
node_name_to_index_.rehash(num_nodes_);
fanins_.resize(num_nodes_);
fanouts_.resize(num_nodes_);
// Build map from name to index and vice versa.
for (int node_idx = 0; node_idx < num_nodes_; ++node_idx) {
const NodeDef& node = graph.node(node_idx);
node_name_to_index_.emplace(node.name(), node_idx);
index_to_node_name_.emplace_back(node.name());
}
// 1. Add ephemeral edges to the adjacency lists.
for (const GraphView::Edge& edge : ephemeral_edges) {
const auto src = node_name_to_index_.find(edge.src.node->name());
if (src == node_name_to_index_.end()) {
return errors::InvalidArgument("Non-existent src node: ",
edge.src.node->name());
}
const auto dst = node_name_to_index_.find(edge.dst.node->name());
if (dst == node_name_to_index_.end()) {
return errors::InvalidArgument("Non-existent dst node: ",
edge.dst.node->name());
}
const int src_idx = src->second;
const int dst_idx = dst->second;
fanins_[dst_idx].push_back(src_idx);
fanouts_[src_idx].push_back(dst_idx);
}
// 2. Add graph edges to the adjacency lists.
for (int node_idx = 0; node_idx < num_nodes_; ++node_idx) {
const NodeDef& node = graph.node(node_idx);
fanins_[node_idx].reserve(node.input_size());
for (const string& input : node.input()) {
TensorId tensor = ParseTensorName(input);
const auto it = node_name_to_index_.find(tensor.node());
if (it == node_name_to_index_.end()) {
return errors::InvalidArgument("Non-existent input ", input,
" for node ", node.name());
}
const int input_idx = it->second;
fanins_[node_idx].push_back(input_idx);
fanouts_[input_idx].push_back(node_idx);
}
// Dedup the input list while it's still hot in cache.
SortAndRemoveDuplicates(&fanins_[node_idx]);
}
// Dedup outputs for all the graph nodes.
for (int node_idx = 0; node_idx < num_nodes_; ++node_idx) {
SortAndRemoveDuplicates(&fanouts_[node_idx]);
}
return Status::OK();
}
Status GraphTopologyView::InitializeFromGraph(const GraphDef& graph) {
return InitializeFromGraph(graph, absl::Span<GraphView::Edge>());
}
bool GraphTopologyView::HasNode(const absl::string_view node_name) const {
DCHECK(is_initialized()) << "GraphTopologyView is not initialized";
const auto it = node_name_to_index_.find(node_name);
return it != node_name_to_index_.end();
}
const NodeDef* GraphTopologyView::GetNode(
const absl::string_view node_name) const {
DCHECK(is_initialized()) << "GraphTopologyView is not initialized";
const auto it = node_name_to_index_.find(node_name);
return it == node_name_to_index_.end() ? nullptr : &graph_->node(it->second);
}
const NodeDef* GraphTopologyView::GetNode(int node_idx) const {
DCHECK(is_initialized()) << "GraphTopologyView is not initialized";
DCHECK(node_idx >= 0 && node_idx < num_nodes_) << "node_idx is out of range";
return &graph_->node(node_idx);
}
const absl::optional<int> GraphTopologyView::GetNodeIndex(
const absl::string_view node_name) const {
DCHECK(is_initialized()) << "GraphTopologyView is not initialized";
const auto it = node_name_to_index_.find(node_name);
DCHECK(it != node_name_to_index_.end()) << "Node doesn't exist in a graph";
return it == node_name_to_index_.end() ? absl::nullopt
: absl::make_optional(it->second);
}
const absl::optional<int> GraphTopologyView::GetNodeIndex(
const NodeDef& node) const {
return GetNodeIndex(node.name());
}
const absl::InlinedVector<int, 4>& GraphTopologyView::GetFanin(
int node_idx) const {
DCHECK(is_initialized()) << "GraphTopologyView is not initialized";
const bool is_valid_node_idx = node_idx >= 0 && node_idx < num_nodes_;
DCHECK(is_valid_node_idx) << "node_idx is out of range";
return is_valid_node_idx ? fanins_[node_idx] : empty_fanin_;
}
const absl::InlinedVector<int, 2>& GraphTopologyView::GetFanout(
int node_idx) const {
DCHECK(is_initialized()) << "GraphTopologyView is not initialized";
const bool is_valid_node_idx = node_idx >= 0 && node_idx < num_nodes_;
DCHECK(is_valid_node_idx) << "node_idx is out of range";
return is_valid_node_idx ? fanouts_[node_idx] : empty_fanout_;
}
} // end namespace grappler
} // end namespace tensorflow

View File

@ -0,0 +1,105 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_GRAPPLER_GRAPH_TOPOLOGY_VIEW_H_
#define TENSORFLOW_CORE_GRAPPLER_GRAPH_TOPOLOGY_VIEW_H_
#include "absl/container/flat_hash_map.h"
#include "absl/container/inlined_vector.h"
#include "absl/strings/string_view.h"
#include "absl/types/optional.h"
#include "absl/types/span.h"
#include "tensorflow/core/graph/tensor_id.h"
#include "tensorflow/core/grappler/graph_view.h"
namespace tensorflow {
namespace grappler {
// GraphTopologyView is a helper class to simplify `node-to-node` connectivity
// traversals. Regular `GraphView` simplifies `tensor-to-tensor` traversals:
// connections between output tensors and inputs of a consumer nodes. For the
// topology view we are focused on nodes connected to nodes, and it's irrelevant
// if this connection is formed by one or multiple individual tensors.
//
// Example:
// a = Placeholder(..)
// b = Placeholder(..)
// c = AddN([a, a, b])
//
// GraphView edges: [a:0 -> c:0, a:0 -> c:1, b:0 -> c:3]
// GraphTopologyView edges: [a -> c, b -> c]
//
// GraphView is used for exploring single node fanins and fanouts, and
// GraphTopologyView is focused on efficient full graph traversals (computing
// graph node properties from transitive fanouts, etc...).
class GraphTopologyView {
public:
GraphTopologyView() = default;
// Initialize graph topology view from the graph. It's possible to pass
// additional edges that do not exist in a graph, but must be respected when
// computing graph topology. Example: Tensorflow runtime allows concurrent
// execution of dequeue/enqueue ops from the same queue resource, but we might
// want to enforce ordering between them for the purpose of graph analysis.
Status InitializeFromGraph(const GraphDef& graph,
absl::Span<const GraphView::Edge> ephemeral_edges);
Status InitializeFromGraph(const GraphDef& graph);
bool is_initialized() const { return graph_ != nullptr; }
int num_nodes() const { return num_nodes_; }
const GraphDef* graph() const { return graph_; }
// Returns true iff the node exists in the underlying graph.
bool HasNode(absl::string_view node_name) const;
// Finds a node by name or returns `nullptr` if it's not in the graph.
const NodeDef* GetNode(absl::string_view node_name) const;
// Returns a node corresponding to the given node index.
const NodeDef* GetNode(int node_idx) const;
// Returns a node index for the given node name, if the name exists in the
// underlying graph. Otherwise returns empty optional.
const absl::optional<int> GetNodeIndex(absl::string_view node_name) const;
// Returns a node index for the given node, if the node belongs to the
// underlying graph. Otherwise returns empty optional.
const absl::optional<int> GetNodeIndex(const NodeDef& node) const;
// Returns all the node indexes that are in the direct fanin of the given
// node. If the `node_idx` is outside of [0, num_nodes_) returns empty vector.
const absl::InlinedVector<int, 4>& GetFanin(int node_idx) const;
// Returns all the node indexes that are in the direct fanout of the given
// node. If the `node_idx` is outside of [0, num_nodes_) returns empty vector.
const absl::InlinedVector<int, 2>& GetFanout(int node_idx) const;
private:
// WARN: `graph_` must outlive this object and graph nodes must not be
// destructed, because node names captured with absl::string_view.
const GraphDef* graph_ = nullptr; // do not own
int num_nodes_ = 0;
std::vector<absl::string_view> index_to_node_name_;
absl::flat_hash_map<absl::string_view, int> node_name_to_index_;
std::vector<absl::InlinedVector<int, 4>> fanins_; // node_idx->input nodes
std::vector<absl::InlinedVector<int, 2>> fanouts_; // node_idx->output nodes
// We need a valid reference to return from GetFanin/GetFanout if the
// `node_idx` argument is outside of the [0, num_nodes_) range.
absl::InlinedVector<int, 4> empty_fanin_;
absl::InlinedVector<int, 2> empty_fanout_;
};
} // end namespace grappler
} // end namespace tensorflow
#endif // TENSORFLOW_CORE_GRAPPLER_GRAPH_TOPOLOGY_VIEW_H_

View File

@ -0,0 +1,117 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/grappler/graph_topology_view.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace grappler {
class GraphTopologyViewTest : public ::testing::Test {
protected:
using NodeConfig = std::pair<string, std::vector<string>>;
static GraphDef CreateGraph(const std::vector<NodeConfig>& nodes) {
GraphDef graph;
for (const NodeConfig& node : nodes) {
const auto& node_name = node.first;
const auto& node_inputs = node.second;
NodeDef node_def;
node_def.set_name(node_name);
for (const string& input : node_inputs) {
node_def.add_input(input);
}
*graph.add_node() = std::move(node_def);
}
return graph;
}
};
TEST_F(GraphTopologyViewTest, SimpleGraph) {
const GraphDef graph = CreateGraph({
{"a", {}}, // idx: 0
{"b", {}}, // idx: 1
{"c", {"a", "b"}}, // idx: 2
{"d", {"a", "c"}}, // idx: 3
});
GraphTopologyView graph_view;
TF_CHECK_OK(graph_view.InitializeFromGraph(graph));
EXPECT_TRUE(graph_view.is_initialized());
const NodeDef* a_by_name = graph_view.GetNode("a");
const NodeDef* a_by_idx = graph_view.GetNode(0);
ASSERT_TRUE(a_by_name);
ASSERT_TRUE(a_by_idx);
EXPECT_EQ(a_by_name, a_by_idx);
const NodeDef* b_by_name = graph_view.GetNode("b");
const NodeDef* b_by_idx = graph_view.GetNode(1);
ASSERT_TRUE(b_by_name);
ASSERT_TRUE(b_by_idx);
EXPECT_EQ(b_by_name, b_by_idx);
const absl::optional<int> b_idx = graph_view.GetNodeIndex(*b_by_name);
ASSERT_TRUE(b_idx.has_value());
EXPECT_EQ(b_idx.value(), 1);
const absl::optional<int> c_idx = graph_view.GetNodeIndex("c");
ASSERT_TRUE(c_idx.has_value());
EXPECT_EQ(c_idx.value(), 2);
using Fanin = absl::InlinedVector<int, 4>;
EXPECT_EQ(graph_view.GetFanin(0), Fanin());
EXPECT_EQ(graph_view.GetFanin(1), Fanin());
EXPECT_EQ(graph_view.GetFanin(2), Fanin({0, 1}));
EXPECT_EQ(graph_view.GetFanin(3), Fanin({0, 2}));
using Fanout = absl::InlinedVector<int, 2>;
EXPECT_EQ(graph_view.GetFanout(0), Fanout({2, 3}));
EXPECT_EQ(graph_view.GetFanout(1), Fanout({2}));
EXPECT_EQ(graph_view.GetFanout(2), Fanout({3}));
EXPECT_EQ(graph_view.GetFanout(3), Fanout());
}
TEST_F(GraphTopologyViewTest, GraphWithALoop) {
const GraphDef graph = CreateGraph({
{"a", {}}, // idx: 0
{"b", {}}, // idx: 1
{"c", {"a", "b", "d"}}, // idx: 2 <<<--- 'c' and 'd' have a loop
{"d", {"a", "c"}}, // idx: 3
});
GraphTopologyView graph_view;
TF_CHECK_OK(graph_view.InitializeFromGraph(graph));
EXPECT_TRUE(graph_view.is_initialized());
using Fanin = absl::InlinedVector<int, 4>;
EXPECT_EQ(graph_view.GetFanin(2), Fanin({0, 1, 3}));
EXPECT_EQ(graph_view.GetFanin(3), Fanin({0, 2}));
using Fanout = absl::InlinedVector<int, 2>;
EXPECT_EQ(graph_view.GetFanout(2), Fanout({3}));
EXPECT_EQ(graph_view.GetFanout(3), Fanout({2}));
}
} // namespace grappler
} // namespace tensorflow

View File

@ -48,8 +48,11 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/grappler:graph_topology_view",
"//tensorflow/core/grappler:graph_view",
"//tensorflow/core/grappler:op_types",
"//tensorflow/core/grappler:utils",
"@com_google_absl//absl/types:span",
],
)
@ -58,10 +61,11 @@ tf_cc_test(
srcs = ["topological_sort_test.cc"],
deps = [
":topological_sort",
"//tensorflow/core:lib_proto_parsing",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"@com_google_absl//absl/strings",
],
)

View File

@ -14,10 +14,15 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/grappler/utils/topological_sort.h"
#include <algorithm>
#include <deque>
#include <unordered_map>
#include "absl/types/span.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/grappler/graph_topology_view.h"
#include "tensorflow/core/grappler/graph_view.h"
#include "tensorflow/core/grappler/op_types.h"
#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/lib/core/status.h"
@ -25,27 +30,46 @@ limitations under the License.
namespace tensorflow {
namespace grappler {
namespace {
std::vector<GraphView::Edge> MakeEphemeralEdges(
const absl::Span<const TopologicalDependency> extra_dependencies) {
std::vector<GraphView::Edge> ephemeral_edges;
ephemeral_edges.reserve(extra_dependencies.size());
for (const auto& dep : extra_dependencies) {
ephemeral_edges.emplace_back(
GraphView::OutputPort(dep.from, Graph::kControlSlot),
GraphView::InputPort(dep.to, Graph::kControlSlot));
}
return ephemeral_edges;
}
// Kahn's algorithm is implemented.
// For details, see https://en.wikipedia.org/wiki/Topological_sorting
Status ComputeTopologicalOrder(
const GraphDef& graph, std::vector<int>* ready_nodes,
const std::vector<std::pair<const NodeDef*, const NodeDef*>>*
extra_dependencies) {
SimpleGraphView graph_view;
TF_RETURN_IF_ERROR(graph_view.Initialize(graph, extra_dependencies));
const GraphDef& graph,
const absl::Span<const TopologicalDependency> extra_dependencies,
std::vector<int>* ready_nodes) {
GraphTopologyView graph_view;
TF_RETURN_IF_ERROR(graph_view.InitializeFromGraph(
graph, MakeEphemeralEdges(extra_dependencies)));
ready_nodes->reserve(graph_view.num_nodes());
// Keep track of how many inputs are ready for the given node.
std::vector<int> num_ready_inputs(graph.node_size(), 0);
// We'll push index of ready nodes to this output vector.
ready_nodes->reserve(graph.node_size());
int front = 0;
int back = 0;
std::vector<int> num_ready_inputs(graph_view.num_nodes(), 0);
for (int i = 0; i < graph_view.num_nodes(); i++) {
if (graph_view.inputs(i).empty()) {
for (int i = 0; i < graph.node_size(); i++) {
if (graph_view.GetFanin(i).empty()) {
ready_nodes->push_back(i);
back++;
}
if (IsMerge(graph.node(i))) {
for (int input : graph_view.inputs(i)) {
for (int input : graph_view.GetFanin(i)) {
if (IsNextIteration(graph.node(input))) {
num_ready_inputs[i]++;
}
@ -55,9 +79,9 @@ Status ComputeTopologicalOrder(
while (front != back) {
int ready_node = (*ready_nodes)[front];
for (int fanout : graph_view.outputs(ready_node)) {
for (int fanout : graph_view.GetFanout(ready_node)) {
++num_ready_inputs[fanout];
if (num_ready_inputs[fanout] == graph_view.inputs(fanout).size()) {
if (num_ready_inputs[fanout] == graph_view.GetFanin(fanout).size()) {
ready_nodes->push_back(fanout);
++back;
}
@ -72,23 +96,32 @@ Status ComputeTopologicalOrder(
return Status::OK();
}
} // namespace
Status ComputeTopologicalOrder(
const GraphDef& graph, std::unordered_map<const NodeDef*, int>* topo_order,
const std::vector<std::pair<const NodeDef*, const NodeDef*>>*
extra_dependencies) {
const GraphDef& graph,
const absl::Span<const TopologicalDependency> extra_dependencies,
std::vector<const NodeDef*>* topo_order) {
std::vector<int> ready_nodes;
TF_RETURN_IF_ERROR(
ComputeTopologicalOrder(graph, &ready_nodes, extra_dependencies));
topo_order->reserve(graph.node_size());
for (int i = 0; i < ready_nodes.size(); ++i) {
(*topo_order)[&graph.node(ready_nodes[i])] = i;
ComputeTopologicalOrder(graph, extra_dependencies, &ready_nodes));
topo_order->reserve(ready_nodes.size());
for (int ready_node_idx : ready_nodes) {
topo_order->emplace_back(&graph.node(ready_node_idx));
}
return Status::OK();
}
Status ComputeTopologicalOrder(const GraphDef& graph,
std::vector<const NodeDef*>* topo_order) {
return ComputeTopologicalOrder(graph, {}, topo_order);
}
Status ReversedTopologicalSort(GraphDef* graph) {
std::vector<int> ready_nodes;
TF_RETURN_IF_ERROR(ComputeTopologicalOrder(*graph, &ready_nodes, nullptr));
TF_RETURN_IF_ERROR(ComputeTopologicalOrder(*graph, {}, &ready_nodes));
std::reverse(ready_nodes.begin(), ready_nodes.end());
PermuteNodesInPlace(graph, &ready_nodes, /*invert_permutation=*/true);
return Status::OK();
@ -96,7 +129,7 @@ Status ReversedTopologicalSort(GraphDef* graph) {
Status TopologicalSort(GraphDef* graph) {
std::vector<int> ready_nodes;
TF_RETURN_IF_ERROR(ComputeTopologicalOrder(*graph, &ready_nodes, nullptr));
TF_RETURN_IF_ERROR(ComputeTopologicalOrder(*graph, {}, &ready_nodes));
PermuteNodesInPlace(graph, &ready_nodes, /*invert_permutation=*/true);
return Status::OK();
}

View File

@ -16,22 +16,40 @@ limitations under the License.
#ifndef TENSORFLOW_CORE_GRAPPLER_UTILS_TOPOLOGICAL_SORT_H_
#define TENSORFLOW_CORE_GRAPPLER_UTILS_TOPOLOGICAL_SORT_H_
#include "absl/types/span.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/lib/core/status.h"
namespace tensorflow {
namespace grappler {
// Compute a topological ordering for the graph nodes.
Status ComputeTopologicalOrder(
const GraphDef& graph, std::unordered_map<const NodeDef*, int>* topo_order,
const std::vector<std::pair<const NodeDef*, const NodeDef*>>*
extra_dependencies);
// TODO(ezhulenev, b/121379902): We should be consistent with GraphTopologyView
// and use `GraphView::Edge` to pass extra dependencies.
struct TopologicalDependency {
TopologicalDependency(const NodeDef* from, const NodeDef* to)
: from(from), to(to) {}
const NodeDef* from;
const NodeDef* to;
};
// Sort a graph in topological order.
// Computes a topological ordering for the graph nodes and outputs nodes in the
// topological order to the `topo_order` output argument.
//
// It's possible to pass additional edges that do not exists in a graph, but
// must be respected when computing graph topological order. Example: Tensorflow
// runtime allows concurrent execution of dequeue/enqueue ops from the same
// queue resource, but we might want to enforce ordering between them.
Status ComputeTopologicalOrder(
const GraphDef& graph,
absl::Span<const TopologicalDependency> extra_dependencies,
std::vector<const NodeDef*>* topo_order);
Status ComputeTopologicalOrder(const GraphDef& graph,
std::vector<const NodeDef*>* topo_order);
// Sorts a graph in topological order.
Status TopologicalSort(GraphDef* graph);
// Sort a graph in topological order and reverse it.
// Sorts a graph in topological order and reverse it.
Status ReversedTopologicalSort(GraphDef* graph);
} // namespace grappler

View File

@ -14,79 +14,94 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/grappler/utils/topological_sort.h"
#include "absl/strings/str_cat.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/random/philox_random.h"
#include "tensorflow/core/lib/random/simple_philox.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"
namespace tensorflow {
namespace grappler {
namespace {
class TopologicalSortTest : public ::testing::Test {
protected:
static NodeDef CreateNode(const string& name,
const std::vector<string>& inputs) {
return CreateNode(name, "", inputs);
}
static NodeDef CreateNode(const string& name, const string& op,
const std::vector<string>& inputs) {
NodeDef node;
node.set_name(name);
if (!op.empty()) {
node.set_op(op);
struct NodeConfig {
NodeConfig(string name, std::vector<string> inputs)
: name(std::move(name)), inputs(std::move(inputs)) {}
NodeConfig(string name, string op, std::vector<string> inputs)
: name(std::move(name)), op(std::move(op)), inputs(std::move(inputs)) {}
string name;
string op;
std::vector<string> inputs;
};
static GraphDef CreateGraph(const std::vector<NodeConfig>& nodes) {
GraphDef graph;
for (const NodeConfig& node : nodes) {
NodeDef node_def;
node_def.set_name(node.name);
node_def.set_op(node.op);
for (const string& input : node.inputs) {
node_def.add_input(input);
}
*graph.add_node() = std::move(node_def);
}
for (const string& input : inputs) {
node.add_input(input);
}
return node;
return graph;
}
};
TEST_F(TopologicalSortTest, NoLoop) {
GraphDef graph;
*graph.add_node() = CreateNode("2", {"5"});
*graph.add_node() = CreateNode("0", {"5", "4"});
*graph.add_node() = CreateNode("1", {"4", "3"});
*graph.add_node() = CreateNode("3", {"2"});
*graph.add_node() = CreateNode("5", {});
*graph.add_node() = CreateNode("4", {});
GraphDef graph = CreateGraph({
{"2", {"5"}}, //
{"0", {"5", "4"}}, //
{"1", {"4", "3"}}, //
{"3", {"2"}}, //
{"5", {}}, //
{"4", {}} //
});
std::unordered_map<const NodeDef*, int> topo_order;
TF_EXPECT_OK(ComputeTopologicalOrder(graph, &topo_order, nullptr));
std::vector<const NodeDef*> topo_order;
TF_EXPECT_OK(ComputeTopologicalOrder(graph, &topo_order));
const std::vector<string> order = {"5", "4", "2", "0", "3", "1"};
for (const auto& topo : topo_order) {
const string& node_name = topo.first->name();
const int topo_order = topo.second;
std::cout << "Node " << node_name << " at order " << topo_order
<< std::endl;
EXPECT_EQ(node_name, order[topo_order]);
ASSERT_EQ(topo_order.size(), order.size());
for (int i = 0; i < topo_order.size(); ++i) {
const NodeDef* node = topo_order[i];
EXPECT_EQ(node->name(), order[i]);
}
TF_EXPECT_OK(TopologicalSort(&graph));
for (int i = 0; i < order.size(); i++) {
for (int i = 0; i < topo_order.size(); i++) {
EXPECT_EQ(graph.node(i).name(), order[i]);
}
}
TEST_F(TopologicalSortTest, WithLoop) {
GraphDef graph;
// Create a loop
*graph.add_node() = CreateNode("2", "Merge", {"1", "5"});
*graph.add_node() = CreateNode("3", "Switch", {"2"});
*graph.add_node() = CreateNode("4", "Identity", {"3"});
*graph.add_node() = CreateNode("5", "NextIteration", {"4"});
*graph.add_node() = CreateNode("1", {});
GraphDef graph = CreateGraph({
// Graph with a loop.
{"2", "Merge", {"1", "5"}}, //
{"3", "Switch", {"2"}}, //
{"4", "Identity", {"3"}}, //
{"5", "NextIteration", {"4"}}, //
{"1", {}} //
});
std::unordered_map<const NodeDef*, int> topo_order;
TF_EXPECT_OK(ComputeTopologicalOrder(graph, &topo_order, nullptr));
std::vector<const NodeDef*> topo_order;
TF_EXPECT_OK(ComputeTopologicalOrder(graph, &topo_order));
const std::vector<string> order = {"1", "2", "3", "4", "5"};
for (const auto& topo : topo_order) {
const string& node_name = topo.first->name();
const int topo_order = topo.second;
EXPECT_EQ(node_name, order[topo_order]);
ASSERT_EQ(topo_order.size(), order.size());
for (int i = 0; i < topo_order.size(); ++i) {
const NodeDef* node = topo_order[i];
EXPECT_EQ(node->name(), order[i]);
}
TF_EXPECT_OK(TopologicalSort(&graph));
@ -96,12 +111,13 @@ TEST_F(TopologicalSortTest, WithLoop) {
}
TEST_F(TopologicalSortTest, WithIllegalLoop) {
GraphDef graph;
// A loop without Merge and NextIteration is illegal and the original node
// order and graph will be preserved.
*graph.add_node() = CreateNode("2", {"1", "3"});
*graph.add_node() = CreateNode("3", {"2"});
*graph.add_node() = CreateNode("1", {});
GraphDef graph = CreateGraph({
{"2", {"1", "3"}}, //
{"3", {"2"}}, //
{"1", {}} //
});
EXPECT_FALSE(TopologicalSort(&graph).ok());
std::vector<string> order = {"2", "3", "1"};
@ -111,9 +127,10 @@ TEST_F(TopologicalSortTest, WithIllegalLoop) {
}
TEST_F(TopologicalSortTest, DuplicatedInputs) {
GraphDef graph;
*graph.add_node() = CreateNode("2", {"1", "1"});
*graph.add_node() = CreateNode("1", {});
GraphDef graph = CreateGraph({
{"2", {"1", "1"}}, //
{"1", {}} //
});
TF_EXPECT_OK(TopologicalSort(&graph));
std::vector<string> order = {"1", "2"};
@ -123,12 +140,13 @@ TEST_F(TopologicalSortTest, DuplicatedInputs) {
}
TEST_F(TopologicalSortTest, Idempotent) {
GraphDef graph;
*graph.add_node() = CreateNode("1", {});
*graph.add_node() = CreateNode("2", {});
*graph.add_node() = CreateNode("3", {"1", "2"});
*graph.add_node() = CreateNode("4", {"1", "3"});
*graph.add_node() = CreateNode("5", {"2", "3"});
GraphDef graph = CreateGraph({
{"1", {}}, //
{"2", {}}, //
{"3", {"1", "2"}}, //
{"4", {"1", "3"}}, //
{"5", {"2", "3"}} //
});
TF_EXPECT_OK(TopologicalSort(&graph));
std::vector<string> order = {"1", "2", "3", "4", "5"};
@ -136,7 +154,7 @@ TEST_F(TopologicalSortTest, Idempotent) {
EXPECT_EQ(graph.node(i).name(), order[i]);
}
// Run topo sort again to verify that it is idenpotent.
// Run topo sort again to verify that it is idempotent.
TF_EXPECT_OK(TopologicalSort(&graph));
for (int i = 0; i < order.size(); i++) {
EXPECT_EQ(graph.node(i).name(), order[i]);
@ -144,35 +162,81 @@ TEST_F(TopologicalSortTest, Idempotent) {
}
TEST_F(TopologicalSortTest, ExtraDependencies) {
GraphDef graph;
*graph.add_node() = CreateNode("2", {"5"});
*graph.add_node() = CreateNode("0", {"5", "4"});
*graph.add_node() = CreateNode("1", {"4", "3"});
*graph.add_node() = CreateNode("3", {"2"});
*graph.add_node() = CreateNode("5", {});
*graph.add_node() = CreateNode("4", {});
GraphDef graph = CreateGraph({
{"2", {"5"}}, //
{"0", {"5", "4"}}, //
{"1", {"4", "3"}}, //
{"3", {"2"}}, //
{"5", {}}, //
{"4", {}} //
});
// Add an edge from 4 to 5.
std::vector<std::pair<const NodeDef*, const NodeDef*>> extra_dependencies;
extra_dependencies.emplace_back(&graph.node(5), &graph.node(4));
std::vector<TopologicalDependency> extra_dependencies;
extra_dependencies.push_back({&graph.node(5), &graph.node(4)});
std::unordered_map<const NodeDef*, int> topo_order;
TF_EXPECT_OK(
ComputeTopologicalOrder(graph, &topo_order, &extra_dependencies));
std::vector<const NodeDef*> topo_order;
TF_EXPECT_OK(ComputeTopologicalOrder(graph, extra_dependencies, &topo_order));
const std::vector<string> order = {"4", "5", "2", "0", "3", "1"};
for (const auto& topo : topo_order) {
const string& node_name = topo.first->name();
const int topo_order = topo.second;
EXPECT_EQ(node_name, order[topo_order]);
const std::vector<string> valid_order_1 = {"4", "5", "2", "0", "3", "1"};
const std::vector<string> valid_order_2 = {"4", "5", "0", "2", "3", "1"};
ASSERT_EQ(topo_order.size(), valid_order_1.size());
std::vector<string> computed_order(6, "");
for (int i = 0; i < topo_order.size(); ++i) {
const NodeDef* node = topo_order[i];
computed_order[i] = node->name();
}
EXPECT_TRUE(computed_order == valid_order_1 ||
computed_order == valid_order_2);
// Add an edge from 0 to 4. This will create a loop
extra_dependencies.emplace_back(&graph.node(1), &graph.node(5));
// Add an edge from `0` to `4`. This will create a loop.
extra_dependencies.push_back({&graph.node(1), &graph.node(5)});
EXPECT_FALSE(
ComputeTopologicalOrder(graph, &topo_order, &extra_dependencies).ok());
ComputeTopologicalOrder(graph, extra_dependencies, &topo_order).ok());
}
} // namespace
static void BM_ComputeTopologicalOrder(int iters, int size) {
testing::StopTiming();
random::PhiloxRandom philox(0x12345);
random::SimplePhilox rnd(&philox);
string prefix = "long_node_name_prefix_to_measure_string_copy_overhead";
GraphDef graph;
for (int i = 0; i < size; ++i) {
const string name = absl::StrCat(prefix, i);
const uint32 num_inputs = rnd.Uniform(std::min(i, 5));
NodeDef node;
node.set_name(name);
for (int n = 0; n < num_inputs; ++n) {
const uint32 input_node = rnd.Uniform(i);
node.add_input(absl::StrCat(prefix, input_node));
}
*graph.add_node() = std::move(node);
}
testing::StartTiming();
std::vector<const NodeDef*> topo_order;
for (int i = 0; i < iters; i++) {
topo_order.clear();
Status st = ComputeTopologicalOrder(graph, &topo_order);
CHECK(st.ok()) << "Failed to compute topological order";
}
testing::StopTiming();
}
BENCHMARK(BM_ComputeTopologicalOrder)
->Arg(10)
->Arg(100)
->Arg(1000)
->Arg(10000)
->Arg(25000)
->Arg(50000)
->Arg(100000);
} // namespace grappler
} // namespace tensorflow