diff --git a/tensorflow/core/grappler/BUILD b/tensorflow/core/grappler/BUILD index 6de12192ba8..4d0f02f4d6d 100644 --- a/tensorflow/core/grappler/BUILD +++ b/tensorflow/core/grappler/BUILD @@ -62,6 +62,36 @@ tf_cuda_library( ], ) +cc_library( + name = "graph_topology_view", + srcs = ["graph_topology_view.cc"], + hdrs = ["graph_topology_view.h"], + visibility = ["//visibility:public"], + deps = [ + ":graph_view", + "//tensorflow/core:graph", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:span", + ], +) + +tf_cc_test( + name = "graph_topology_view_test", + srcs = ["graph_topology_view_test.cc"], + deps = [ + ":graph_topology_view", + ":graph_view", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + cc_library( name = "graph_view", srcs = ["graph_view.cc"], diff --git a/tensorflow/core/grappler/costs/graph_properties.cc b/tensorflow/core/grappler/costs/graph_properties.cc index d6999798964..cfbd340f08f 100644 --- a/tensorflow/core/grappler/costs/graph_properties.cc +++ b/tensorflow/core/grappler/costs/graph_properties.cc @@ -425,9 +425,11 @@ NodeDef MakeConstNodeDefFromShape(InferenceContext* ic, // information is refined. class TopoQueue { public: - explicit TopoQueue(const std::unordered_map& topo_order) - : topo_order_(topo_order) {} + explicit TopoQueue(const std::vector& topo_order) + : topo_order_(TopoOrder(topo_order)) {} + void push(const NodeDef* n) { queue_.emplace(n, topo_order_.at(n)); } + const NodeDef* pop() { CHECK(!empty()); auto it = queue_.begin(); @@ -448,7 +450,18 @@ class TopoQueue { return lhs.second < rhs.second; } }; - const std::unordered_map& topo_order_; + + const std::unordered_map TopoOrder( + const std::vector& topo_order) const { + std::unordered_map map; + map.reserve(topo_order.size()); + for (int i = 0; i < topo_order.size(); ++i) { + map.emplace(topo_order[i], i); + } + return map; + } + + const std::unordered_map topo_order_; std::set queue_; }; @@ -1970,7 +1983,7 @@ Status GraphProperties::InferStatically(bool assume_valid_feeds, } std::unordered_map resource_handles; - std::vector> extra_deps; + std::vector extra_deps; for (const auto& resource : resources) { for (const NodeDef* src : resource.second.first) { resource_handles[src] = resource.first; @@ -1982,8 +1995,8 @@ Status GraphProperties::InferStatically(bool assume_valid_feeds, } } - std::unordered_map topo_order; - Status s = ComputeTopologicalOrder(item_.graph, &topo_order, &extra_deps); + std::vector topo_order; + Status s = ComputeTopologicalOrder(item_.graph, extra_deps, &topo_order); if (!s.ok()) { if (extra_deps.empty()) { return s; @@ -1992,8 +2005,7 @@ Status GraphProperties::InferStatically(bool assume_valid_feeds, // order. This will make the shape inference less precise but since this // isn't common it's not worth to figure out where to break the loop and // do a proper relaxation. - TF_RETURN_IF_ERROR( - ComputeTopologicalOrder(item_.graph, &topo_order, nullptr)); + TF_RETURN_IF_ERROR(ComputeTopologicalOrder(item_.graph, &topo_order)); } } diff --git a/tensorflow/core/grappler/graph_topology_view.cc b/tensorflow/core/grappler/graph_topology_view.cc new file mode 100644 index 00000000000..38ccfbaeb88 --- /dev/null +++ b/tensorflow/core/grappler/graph_topology_view.cc @@ -0,0 +1,163 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/graph_topology_view.h" + +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/inlined_vector.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "absl/types/span.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" + +namespace tensorflow { +namespace grappler { + +namespace { + +template +inline void SortAndRemoveDuplicates(T* v) { + std::sort(v->begin(), v->end()); + v->erase(std::unique(v->begin(), v->end()), v->end()); +} + +} // namespace + +Status GraphTopologyView::InitializeFromGraph( + const GraphDef& graph, + const absl::Span ephemeral_edges) { + if (graph_ != nullptr) { + return errors::InvalidArgument("GraphTopologyView is already initialized."); + } + + graph_ = &graph; + num_nodes_ = graph.node_size(); + index_to_node_name_.resize(num_nodes_); + node_name_to_index_.rehash(num_nodes_); + fanins_.resize(num_nodes_); + fanouts_.resize(num_nodes_); + + // Build map from name to index and vice versa. + for (int node_idx = 0; node_idx < num_nodes_; ++node_idx) { + const NodeDef& node = graph.node(node_idx); + node_name_to_index_.emplace(node.name(), node_idx); + index_to_node_name_.emplace_back(node.name()); + } + + // 1. Add ephemeral edges to the adjacency lists. + for (const GraphView::Edge& edge : ephemeral_edges) { + const auto src = node_name_to_index_.find(edge.src.node->name()); + if (src == node_name_to_index_.end()) { + return errors::InvalidArgument("Non-existent src node: ", + edge.src.node->name()); + } + const auto dst = node_name_to_index_.find(edge.dst.node->name()); + if (dst == node_name_to_index_.end()) { + return errors::InvalidArgument("Non-existent dst node: ", + edge.dst.node->name()); + } + const int src_idx = src->second; + const int dst_idx = dst->second; + fanins_[dst_idx].push_back(src_idx); + fanouts_[src_idx].push_back(dst_idx); + } + + // 2. Add graph edges to the adjacency lists. + for (int node_idx = 0; node_idx < num_nodes_; ++node_idx) { + const NodeDef& node = graph.node(node_idx); + fanins_[node_idx].reserve(node.input_size()); + + for (const string& input : node.input()) { + TensorId tensor = ParseTensorName(input); + const auto it = node_name_to_index_.find(tensor.node()); + if (it == node_name_to_index_.end()) { + return errors::InvalidArgument("Non-existent input ", input, + " for node ", node.name()); + } + const int input_idx = it->second; + fanins_[node_idx].push_back(input_idx); + fanouts_[input_idx].push_back(node_idx); + } + + // Dedup the input list while it's still hot in cache. + SortAndRemoveDuplicates(&fanins_[node_idx]); + } + + // Dedup outputs for all the graph nodes. + for (int node_idx = 0; node_idx < num_nodes_; ++node_idx) { + SortAndRemoveDuplicates(&fanouts_[node_idx]); + } + + return Status::OK(); +} + +Status GraphTopologyView::InitializeFromGraph(const GraphDef& graph) { + return InitializeFromGraph(graph, absl::Span()); +} + +bool GraphTopologyView::HasNode(const absl::string_view node_name) const { + DCHECK(is_initialized()) << "GraphTopologyView is not initialized"; + const auto it = node_name_to_index_.find(node_name); + return it != node_name_to_index_.end(); +} + +const NodeDef* GraphTopologyView::GetNode( + const absl::string_view node_name) const { + DCHECK(is_initialized()) << "GraphTopologyView is not initialized"; + const auto it = node_name_to_index_.find(node_name); + return it == node_name_to_index_.end() ? nullptr : &graph_->node(it->second); +} + +const NodeDef* GraphTopologyView::GetNode(int node_idx) const { + DCHECK(is_initialized()) << "GraphTopologyView is not initialized"; + DCHECK(node_idx >= 0 && node_idx < num_nodes_) << "node_idx is out of range"; + return &graph_->node(node_idx); +} + +const absl::optional GraphTopologyView::GetNodeIndex( + const absl::string_view node_name) const { + DCHECK(is_initialized()) << "GraphTopologyView is not initialized"; + const auto it = node_name_to_index_.find(node_name); + DCHECK(it != node_name_to_index_.end()) << "Node doesn't exist in a graph"; + return it == node_name_to_index_.end() ? absl::nullopt + : absl::make_optional(it->second); +} + +const absl::optional GraphTopologyView::GetNodeIndex( + const NodeDef& node) const { + return GetNodeIndex(node.name()); +} + +const absl::InlinedVector& GraphTopologyView::GetFanin( + int node_idx) const { + DCHECK(is_initialized()) << "GraphTopologyView is not initialized"; + const bool is_valid_node_idx = node_idx >= 0 && node_idx < num_nodes_; + DCHECK(is_valid_node_idx) << "node_idx is out of range"; + return is_valid_node_idx ? fanins_[node_idx] : empty_fanin_; +} + +const absl::InlinedVector& GraphTopologyView::GetFanout( + int node_idx) const { + DCHECK(is_initialized()) << "GraphTopologyView is not initialized"; + const bool is_valid_node_idx = node_idx >= 0 && node_idx < num_nodes_; + DCHECK(is_valid_node_idx) << "node_idx is out of range"; + return is_valid_node_idx ? fanouts_[node_idx] : empty_fanout_; +} + +} // end namespace grappler +} // end namespace tensorflow diff --git a/tensorflow/core/grappler/graph_topology_view.h b/tensorflow/core/grappler/graph_topology_view.h new file mode 100644 index 00000000000..1c222df4b60 --- /dev/null +++ b/tensorflow/core/grappler/graph_topology_view.h @@ -0,0 +1,105 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_GRAPH_TOPOLOGY_VIEW_H_ +#define TENSORFLOW_CORE_GRAPPLER_GRAPH_TOPOLOGY_VIEW_H_ + +#include "absl/container/flat_hash_map.h" +#include "absl/container/inlined_vector.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "absl/types/span.h" +#include "tensorflow/core/graph/tensor_id.h" +#include "tensorflow/core/grappler/graph_view.h" + +namespace tensorflow { +namespace grappler { + +// GraphTopologyView is a helper class to simplify `node-to-node` connectivity +// traversals. Regular `GraphView` simplifies `tensor-to-tensor` traversals: +// connections between output tensors and inputs of a consumer nodes. For the +// topology view we are focused on nodes connected to nodes, and it's irrelevant +// if this connection is formed by one or multiple individual tensors. +// +// Example: +// a = Placeholder(..) +// b = Placeholder(..) +// c = AddN([a, a, b]) +// +// GraphView edges: [a:0 -> c:0, a:0 -> c:1, b:0 -> c:3] +// GraphTopologyView edges: [a -> c, b -> c] +// +// GraphView is used for exploring single node fanins and fanouts, and +// GraphTopologyView is focused on efficient full graph traversals (computing +// graph node properties from transitive fanouts, etc...). +class GraphTopologyView { + public: + GraphTopologyView() = default; + + // Initialize graph topology view from the graph. It's possible to pass + // additional edges that do not exist in a graph, but must be respected when + // computing graph topology. Example: Tensorflow runtime allows concurrent + // execution of dequeue/enqueue ops from the same queue resource, but we might + // want to enforce ordering between them for the purpose of graph analysis. + Status InitializeFromGraph(const GraphDef& graph, + absl::Span ephemeral_edges); + Status InitializeFromGraph(const GraphDef& graph); + + bool is_initialized() const { return graph_ != nullptr; } + int num_nodes() const { return num_nodes_; } + const GraphDef* graph() const { return graph_; } + + // Returns true iff the node exists in the underlying graph. + bool HasNode(absl::string_view node_name) const; + + // Finds a node by name or returns `nullptr` if it's not in the graph. + const NodeDef* GetNode(absl::string_view node_name) const; + // Returns a node corresponding to the given node index. + const NodeDef* GetNode(int node_idx) const; + + // Returns a node index for the given node name, if the name exists in the + // underlying graph. Otherwise returns empty optional. + const absl::optional GetNodeIndex(absl::string_view node_name) const; + // Returns a node index for the given node, if the node belongs to the + // underlying graph. Otherwise returns empty optional. + const absl::optional GetNodeIndex(const NodeDef& node) const; + + // Returns all the node indexes that are in the direct fanin of the given + // node. If the `node_idx` is outside of [0, num_nodes_) returns empty vector. + const absl::InlinedVector& GetFanin(int node_idx) const; + // Returns all the node indexes that are in the direct fanout of the given + // node. If the `node_idx` is outside of [0, num_nodes_) returns empty vector. + const absl::InlinedVector& GetFanout(int node_idx) const; + + private: + // WARN: `graph_` must outlive this object and graph nodes must not be + // destructed, because node names captured with absl::string_view. + const GraphDef* graph_ = nullptr; // do not own + int num_nodes_ = 0; + std::vector index_to_node_name_; + absl::flat_hash_map node_name_to_index_; + std::vector> fanins_; // node_idx->input nodes + std::vector> fanouts_; // node_idx->output nodes + + // We need a valid reference to return from GetFanin/GetFanout if the + // `node_idx` argument is outside of the [0, num_nodes_) range. + absl::InlinedVector empty_fanin_; + absl::InlinedVector empty_fanout_; +}; + +} // end namespace grappler +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_GRAPH_TOPOLOGY_VIEW_H_ diff --git a/tensorflow/core/grappler/graph_topology_view_test.cc b/tensorflow/core/grappler/graph_topology_view_test.cc new file mode 100644 index 00000000000..36d3a2017cc --- /dev/null +++ b/tensorflow/core/grappler/graph_topology_view_test.cc @@ -0,0 +1,117 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/graph_topology_view.h" + +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace grappler { + +class GraphTopologyViewTest : public ::testing::Test { + protected: + using NodeConfig = std::pair>; + + static GraphDef CreateGraph(const std::vector& nodes) { + GraphDef graph; + + for (const NodeConfig& node : nodes) { + const auto& node_name = node.first; + const auto& node_inputs = node.second; + + NodeDef node_def; + node_def.set_name(node_name); + for (const string& input : node_inputs) { + node_def.add_input(input); + } + + *graph.add_node() = std::move(node_def); + } + + return graph; + } +}; + +TEST_F(GraphTopologyViewTest, SimpleGraph) { + const GraphDef graph = CreateGraph({ + {"a", {}}, // idx: 0 + {"b", {}}, // idx: 1 + {"c", {"a", "b"}}, // idx: 2 + {"d", {"a", "c"}}, // idx: 3 + }); + + GraphTopologyView graph_view; + TF_CHECK_OK(graph_view.InitializeFromGraph(graph)); + + EXPECT_TRUE(graph_view.is_initialized()); + + const NodeDef* a_by_name = graph_view.GetNode("a"); + const NodeDef* a_by_idx = graph_view.GetNode(0); + ASSERT_TRUE(a_by_name); + ASSERT_TRUE(a_by_idx); + EXPECT_EQ(a_by_name, a_by_idx); + + const NodeDef* b_by_name = graph_view.GetNode("b"); + const NodeDef* b_by_idx = graph_view.GetNode(1); + ASSERT_TRUE(b_by_name); + ASSERT_TRUE(b_by_idx); + EXPECT_EQ(b_by_name, b_by_idx); + + const absl::optional b_idx = graph_view.GetNodeIndex(*b_by_name); + ASSERT_TRUE(b_idx.has_value()); + EXPECT_EQ(b_idx.value(), 1); + + const absl::optional c_idx = graph_view.GetNodeIndex("c"); + ASSERT_TRUE(c_idx.has_value()); + EXPECT_EQ(c_idx.value(), 2); + + using Fanin = absl::InlinedVector; + EXPECT_EQ(graph_view.GetFanin(0), Fanin()); + EXPECT_EQ(graph_view.GetFanin(1), Fanin()); + EXPECT_EQ(graph_view.GetFanin(2), Fanin({0, 1})); + EXPECT_EQ(graph_view.GetFanin(3), Fanin({0, 2})); + + using Fanout = absl::InlinedVector; + EXPECT_EQ(graph_view.GetFanout(0), Fanout({2, 3})); + EXPECT_EQ(graph_view.GetFanout(1), Fanout({2})); + EXPECT_EQ(graph_view.GetFanout(2), Fanout({3})); + EXPECT_EQ(graph_view.GetFanout(3), Fanout()); +} + +TEST_F(GraphTopologyViewTest, GraphWithALoop) { + const GraphDef graph = CreateGraph({ + {"a", {}}, // idx: 0 + {"b", {}}, // idx: 1 + {"c", {"a", "b", "d"}}, // idx: 2 <<<--- 'c' and 'd' have a loop + {"d", {"a", "c"}}, // idx: 3 + }); + + GraphTopologyView graph_view; + TF_CHECK_OK(graph_view.InitializeFromGraph(graph)); + EXPECT_TRUE(graph_view.is_initialized()); + + using Fanin = absl::InlinedVector; + EXPECT_EQ(graph_view.GetFanin(2), Fanin({0, 1, 3})); + EXPECT_EQ(graph_view.GetFanin(3), Fanin({0, 2})); + + using Fanout = absl::InlinedVector; + EXPECT_EQ(graph_view.GetFanout(2), Fanout({3})); + EXPECT_EQ(graph_view.GetFanout(3), Fanout({2})); +} + +} // namespace grappler +} // namespace tensorflow diff --git a/tensorflow/core/grappler/utils/BUILD b/tensorflow/core/grappler/utils/BUILD index c0f19d3828a..b6e70361f33 100644 --- a/tensorflow/core/grappler/utils/BUILD +++ b/tensorflow/core/grappler/utils/BUILD @@ -48,8 +48,11 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", + "//tensorflow/core/grappler:graph_topology_view", + "//tensorflow/core/grappler:graph_view", "//tensorflow/core/grappler:op_types", "//tensorflow/core/grappler:utils", + "@com_google_absl//absl/types:span", ], ) @@ -58,10 +61,11 @@ tf_cc_test( srcs = ["topological_sort_test.cc"], deps = [ ":topological_sort", - "//tensorflow/core:lib_proto_parsing", + "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", "//tensorflow/core:test_main", + "@com_google_absl//absl/strings", ], ) diff --git a/tensorflow/core/grappler/utils/topological_sort.cc b/tensorflow/core/grappler/utils/topological_sort.cc index 63ca92c69e1..a6d0f5037bb 100644 --- a/tensorflow/core/grappler/utils/topological_sort.cc +++ b/tensorflow/core/grappler/utils/topological_sort.cc @@ -14,10 +14,15 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/grappler/utils/topological_sort.h" + #include #include #include + +#include "absl/types/span.h" #include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/grappler/graph_topology_view.h" +#include "tensorflow/core/grappler/graph_view.h" #include "tensorflow/core/grappler/op_types.h" #include "tensorflow/core/grappler/utils.h" #include "tensorflow/core/lib/core/status.h" @@ -25,27 +30,46 @@ limitations under the License. namespace tensorflow { namespace grappler { +namespace { + +std::vector MakeEphemeralEdges( + const absl::Span extra_dependencies) { + std::vector ephemeral_edges; + ephemeral_edges.reserve(extra_dependencies.size()); + for (const auto& dep : extra_dependencies) { + ephemeral_edges.emplace_back( + GraphView::OutputPort(dep.from, Graph::kControlSlot), + GraphView::InputPort(dep.to, Graph::kControlSlot)); + } + return ephemeral_edges; +} + // Kahn's algorithm is implemented. // For details, see https://en.wikipedia.org/wiki/Topological_sorting Status ComputeTopologicalOrder( - const GraphDef& graph, std::vector* ready_nodes, - const std::vector>* - extra_dependencies) { - SimpleGraphView graph_view; - TF_RETURN_IF_ERROR(graph_view.Initialize(graph, extra_dependencies)); + const GraphDef& graph, + const absl::Span extra_dependencies, + std::vector* ready_nodes) { + GraphTopologyView graph_view; + TF_RETURN_IF_ERROR(graph_view.InitializeFromGraph( + graph, MakeEphemeralEdges(extra_dependencies))); - ready_nodes->reserve(graph_view.num_nodes()); + // Keep track of how many inputs are ready for the given node. + std::vector num_ready_inputs(graph.node_size(), 0); + + // We'll push index of ready nodes to this output vector. + ready_nodes->reserve(graph.node_size()); int front = 0; int back = 0; - std::vector num_ready_inputs(graph_view.num_nodes(), 0); - for (int i = 0; i < graph_view.num_nodes(); i++) { - if (graph_view.inputs(i).empty()) { + + for (int i = 0; i < graph.node_size(); i++) { + if (graph_view.GetFanin(i).empty()) { ready_nodes->push_back(i); back++; } if (IsMerge(graph.node(i))) { - for (int input : graph_view.inputs(i)) { + for (int input : graph_view.GetFanin(i)) { if (IsNextIteration(graph.node(input))) { num_ready_inputs[i]++; } @@ -55,9 +79,9 @@ Status ComputeTopologicalOrder( while (front != back) { int ready_node = (*ready_nodes)[front]; - for (int fanout : graph_view.outputs(ready_node)) { + for (int fanout : graph_view.GetFanout(ready_node)) { ++num_ready_inputs[fanout]; - if (num_ready_inputs[fanout] == graph_view.inputs(fanout).size()) { + if (num_ready_inputs[fanout] == graph_view.GetFanin(fanout).size()) { ready_nodes->push_back(fanout); ++back; } @@ -72,23 +96,32 @@ Status ComputeTopologicalOrder( return Status::OK(); } +} // namespace + Status ComputeTopologicalOrder( - const GraphDef& graph, std::unordered_map* topo_order, - const std::vector>* - extra_dependencies) { + const GraphDef& graph, + const absl::Span extra_dependencies, + std::vector* topo_order) { std::vector ready_nodes; TF_RETURN_IF_ERROR( - ComputeTopologicalOrder(graph, &ready_nodes, extra_dependencies)); - topo_order->reserve(graph.node_size()); - for (int i = 0; i < ready_nodes.size(); ++i) { - (*topo_order)[&graph.node(ready_nodes[i])] = i; + ComputeTopologicalOrder(graph, extra_dependencies, &ready_nodes)); + + topo_order->reserve(ready_nodes.size()); + for (int ready_node_idx : ready_nodes) { + topo_order->emplace_back(&graph.node(ready_node_idx)); } + return Status::OK(); } +Status ComputeTopologicalOrder(const GraphDef& graph, + std::vector* topo_order) { + return ComputeTopologicalOrder(graph, {}, topo_order); +} + Status ReversedTopologicalSort(GraphDef* graph) { std::vector ready_nodes; - TF_RETURN_IF_ERROR(ComputeTopologicalOrder(*graph, &ready_nodes, nullptr)); + TF_RETURN_IF_ERROR(ComputeTopologicalOrder(*graph, {}, &ready_nodes)); std::reverse(ready_nodes.begin(), ready_nodes.end()); PermuteNodesInPlace(graph, &ready_nodes, /*invert_permutation=*/true); return Status::OK(); @@ -96,7 +129,7 @@ Status ReversedTopologicalSort(GraphDef* graph) { Status TopologicalSort(GraphDef* graph) { std::vector ready_nodes; - TF_RETURN_IF_ERROR(ComputeTopologicalOrder(*graph, &ready_nodes, nullptr)); + TF_RETURN_IF_ERROR(ComputeTopologicalOrder(*graph, {}, &ready_nodes)); PermuteNodesInPlace(graph, &ready_nodes, /*invert_permutation=*/true); return Status::OK(); } diff --git a/tensorflow/core/grappler/utils/topological_sort.h b/tensorflow/core/grappler/utils/topological_sort.h index b8cf897a321..dd4208dfff3 100644 --- a/tensorflow/core/grappler/utils/topological_sort.h +++ b/tensorflow/core/grappler/utils/topological_sort.h @@ -16,22 +16,40 @@ limitations under the License. #ifndef TENSORFLOW_CORE_GRAPPLER_UTILS_TOPOLOGICAL_SORT_H_ #define TENSORFLOW_CORE_GRAPPLER_UTILS_TOPOLOGICAL_SORT_H_ +#include "absl/types/span.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/lib/core/status.h" namespace tensorflow { namespace grappler { -// Compute a topological ordering for the graph nodes. -Status ComputeTopologicalOrder( - const GraphDef& graph, std::unordered_map* topo_order, - const std::vector>* - extra_dependencies); +// TODO(ezhulenev, b/121379902): We should be consistent with GraphTopologyView +// and use `GraphView::Edge` to pass extra dependencies. +struct TopologicalDependency { + TopologicalDependency(const NodeDef* from, const NodeDef* to) + : from(from), to(to) {} + const NodeDef* from; + const NodeDef* to; +}; -// Sort a graph in topological order. +// Computes a topological ordering for the graph nodes and outputs nodes in the +// topological order to the `topo_order` output argument. +// +// It's possible to pass additional edges that do not exists in a graph, but +// must be respected when computing graph topological order. Example: Tensorflow +// runtime allows concurrent execution of dequeue/enqueue ops from the same +// queue resource, but we might want to enforce ordering between them. +Status ComputeTopologicalOrder( + const GraphDef& graph, + absl::Span extra_dependencies, + std::vector* topo_order); +Status ComputeTopologicalOrder(const GraphDef& graph, + std::vector* topo_order); + +// Sorts a graph in topological order. Status TopologicalSort(GraphDef* graph); -// Sort a graph in topological order and reverse it. +// Sorts a graph in topological order and reverse it. Status ReversedTopologicalSort(GraphDef* graph); } // namespace grappler diff --git a/tensorflow/core/grappler/utils/topological_sort_test.cc b/tensorflow/core/grappler/utils/topological_sort_test.cc index 48b7eb50bd9..3868183c62d 100644 --- a/tensorflow/core/grappler/utils/topological_sort_test.cc +++ b/tensorflow/core/grappler/utils/topological_sort_test.cc @@ -14,79 +14,94 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/grappler/utils/topological_sort.h" +#include "absl/strings/str_cat.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/random/philox_random.h" +#include "tensorflow/core/lib/random/simple_philox.h" #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/test_benchmark.h" namespace tensorflow { namespace grappler { -namespace { class TopologicalSortTest : public ::testing::Test { protected: - static NodeDef CreateNode(const string& name, - const std::vector& inputs) { - return CreateNode(name, "", inputs); - } - static NodeDef CreateNode(const string& name, const string& op, - const std::vector& inputs) { - NodeDef node; - node.set_name(name); - if (!op.empty()) { - node.set_op(op); + struct NodeConfig { + NodeConfig(string name, std::vector inputs) + : name(std::move(name)), inputs(std::move(inputs)) {} + NodeConfig(string name, string op, std::vector inputs) + : name(std::move(name)), op(std::move(op)), inputs(std::move(inputs)) {} + + string name; + string op; + std::vector inputs; + }; + + static GraphDef CreateGraph(const std::vector& nodes) { + GraphDef graph; + + for (const NodeConfig& node : nodes) { + NodeDef node_def; + node_def.set_name(node.name); + node_def.set_op(node.op); + for (const string& input : node.inputs) { + node_def.add_input(input); + } + *graph.add_node() = std::move(node_def); } - for (const string& input : inputs) { - node.add_input(input); - } - return node; + + return graph; } }; TEST_F(TopologicalSortTest, NoLoop) { - GraphDef graph; - *graph.add_node() = CreateNode("2", {"5"}); - *graph.add_node() = CreateNode("0", {"5", "4"}); - *graph.add_node() = CreateNode("1", {"4", "3"}); - *graph.add_node() = CreateNode("3", {"2"}); - *graph.add_node() = CreateNode("5", {}); - *graph.add_node() = CreateNode("4", {}); + GraphDef graph = CreateGraph({ + {"2", {"5"}}, // + {"0", {"5", "4"}}, // + {"1", {"4", "3"}}, // + {"3", {"2"}}, // + {"5", {}}, // + {"4", {}} // + }); - std::unordered_map topo_order; - TF_EXPECT_OK(ComputeTopologicalOrder(graph, &topo_order, nullptr)); + std::vector topo_order; + TF_EXPECT_OK(ComputeTopologicalOrder(graph, &topo_order)); const std::vector order = {"5", "4", "2", "0", "3", "1"}; - for (const auto& topo : topo_order) { - const string& node_name = topo.first->name(); - const int topo_order = topo.second; - std::cout << "Node " << node_name << " at order " << topo_order - << std::endl; - EXPECT_EQ(node_name, order[topo_order]); + + ASSERT_EQ(topo_order.size(), order.size()); + for (int i = 0; i < topo_order.size(); ++i) { + const NodeDef* node = topo_order[i]; + EXPECT_EQ(node->name(), order[i]); } TF_EXPECT_OK(TopologicalSort(&graph)); - for (int i = 0; i < order.size(); i++) { + for (int i = 0; i < topo_order.size(); i++) { EXPECT_EQ(graph.node(i).name(), order[i]); } } TEST_F(TopologicalSortTest, WithLoop) { - GraphDef graph; - // Create a loop - *graph.add_node() = CreateNode("2", "Merge", {"1", "5"}); - *graph.add_node() = CreateNode("3", "Switch", {"2"}); - *graph.add_node() = CreateNode("4", "Identity", {"3"}); - *graph.add_node() = CreateNode("5", "NextIteration", {"4"}); - *graph.add_node() = CreateNode("1", {}); + GraphDef graph = CreateGraph({ + // Graph with a loop. + {"2", "Merge", {"1", "5"}}, // + {"3", "Switch", {"2"}}, // + {"4", "Identity", {"3"}}, // + {"5", "NextIteration", {"4"}}, // + {"1", {}} // + }); - std::unordered_map topo_order; - TF_EXPECT_OK(ComputeTopologicalOrder(graph, &topo_order, nullptr)); + std::vector topo_order; + TF_EXPECT_OK(ComputeTopologicalOrder(graph, &topo_order)); const std::vector order = {"1", "2", "3", "4", "5"}; - for (const auto& topo : topo_order) { - const string& node_name = topo.first->name(); - const int topo_order = topo.second; - EXPECT_EQ(node_name, order[topo_order]); + + ASSERT_EQ(topo_order.size(), order.size()); + for (int i = 0; i < topo_order.size(); ++i) { + const NodeDef* node = topo_order[i]; + EXPECT_EQ(node->name(), order[i]); } TF_EXPECT_OK(TopologicalSort(&graph)); @@ -96,12 +111,13 @@ TEST_F(TopologicalSortTest, WithLoop) { } TEST_F(TopologicalSortTest, WithIllegalLoop) { - GraphDef graph; // A loop without Merge and NextIteration is illegal and the original node // order and graph will be preserved. - *graph.add_node() = CreateNode("2", {"1", "3"}); - *graph.add_node() = CreateNode("3", {"2"}); - *graph.add_node() = CreateNode("1", {}); + GraphDef graph = CreateGraph({ + {"2", {"1", "3"}}, // + {"3", {"2"}}, // + {"1", {}} // + }); EXPECT_FALSE(TopologicalSort(&graph).ok()); std::vector order = {"2", "3", "1"}; @@ -111,9 +127,10 @@ TEST_F(TopologicalSortTest, WithIllegalLoop) { } TEST_F(TopologicalSortTest, DuplicatedInputs) { - GraphDef graph; - *graph.add_node() = CreateNode("2", {"1", "1"}); - *graph.add_node() = CreateNode("1", {}); + GraphDef graph = CreateGraph({ + {"2", {"1", "1"}}, // + {"1", {}} // + }); TF_EXPECT_OK(TopologicalSort(&graph)); std::vector order = {"1", "2"}; @@ -123,12 +140,13 @@ TEST_F(TopologicalSortTest, DuplicatedInputs) { } TEST_F(TopologicalSortTest, Idempotent) { - GraphDef graph; - *graph.add_node() = CreateNode("1", {}); - *graph.add_node() = CreateNode("2", {}); - *graph.add_node() = CreateNode("3", {"1", "2"}); - *graph.add_node() = CreateNode("4", {"1", "3"}); - *graph.add_node() = CreateNode("5", {"2", "3"}); + GraphDef graph = CreateGraph({ + {"1", {}}, // + {"2", {}}, // + {"3", {"1", "2"}}, // + {"4", {"1", "3"}}, // + {"5", {"2", "3"}} // + }); TF_EXPECT_OK(TopologicalSort(&graph)); std::vector order = {"1", "2", "3", "4", "5"}; @@ -136,7 +154,7 @@ TEST_F(TopologicalSortTest, Idempotent) { EXPECT_EQ(graph.node(i).name(), order[i]); } - // Run topo sort again to verify that it is idenpotent. + // Run topo sort again to verify that it is idempotent. TF_EXPECT_OK(TopologicalSort(&graph)); for (int i = 0; i < order.size(); i++) { EXPECT_EQ(graph.node(i).name(), order[i]); @@ -144,35 +162,81 @@ TEST_F(TopologicalSortTest, Idempotent) { } TEST_F(TopologicalSortTest, ExtraDependencies) { - GraphDef graph; - *graph.add_node() = CreateNode("2", {"5"}); - *graph.add_node() = CreateNode("0", {"5", "4"}); - *graph.add_node() = CreateNode("1", {"4", "3"}); - *graph.add_node() = CreateNode("3", {"2"}); - *graph.add_node() = CreateNode("5", {}); - *graph.add_node() = CreateNode("4", {}); + GraphDef graph = CreateGraph({ + {"2", {"5"}}, // + {"0", {"5", "4"}}, // + {"1", {"4", "3"}}, // + {"3", {"2"}}, // + {"5", {}}, // + {"4", {}} // + }); // Add an edge from 4 to 5. - std::vector> extra_dependencies; - extra_dependencies.emplace_back(&graph.node(5), &graph.node(4)); + std::vector extra_dependencies; + extra_dependencies.push_back({&graph.node(5), &graph.node(4)}); - std::unordered_map topo_order; - TF_EXPECT_OK( - ComputeTopologicalOrder(graph, &topo_order, &extra_dependencies)); + std::vector topo_order; + TF_EXPECT_OK(ComputeTopologicalOrder(graph, extra_dependencies, &topo_order)); - const std::vector order = {"4", "5", "2", "0", "3", "1"}; - for (const auto& topo : topo_order) { - const string& node_name = topo.first->name(); - const int topo_order = topo.second; - EXPECT_EQ(node_name, order[topo_order]); + const std::vector valid_order_1 = {"4", "5", "2", "0", "3", "1"}; + const std::vector valid_order_2 = {"4", "5", "0", "2", "3", "1"}; + + ASSERT_EQ(topo_order.size(), valid_order_1.size()); + + std::vector computed_order(6, ""); + for (int i = 0; i < topo_order.size(); ++i) { + const NodeDef* node = topo_order[i]; + computed_order[i] = node->name(); } + EXPECT_TRUE(computed_order == valid_order_1 || + computed_order == valid_order_2); - // Add an edge from 0 to 4. This will create a loop - extra_dependencies.emplace_back(&graph.node(1), &graph.node(5)); + // Add an edge from `0` to `4`. This will create a loop. + extra_dependencies.push_back({&graph.node(1), &graph.node(5)}); EXPECT_FALSE( - ComputeTopologicalOrder(graph, &topo_order, &extra_dependencies).ok()); + ComputeTopologicalOrder(graph, extra_dependencies, &topo_order).ok()); } -} // namespace +static void BM_ComputeTopologicalOrder(int iters, int size) { + testing::StopTiming(); + + random::PhiloxRandom philox(0x12345); + random::SimplePhilox rnd(&philox); + + string prefix = "long_node_name_prefix_to_measure_string_copy_overhead"; + + GraphDef graph; + for (int i = 0; i < size; ++i) { + const string name = absl::StrCat(prefix, i); + const uint32 num_inputs = rnd.Uniform(std::min(i, 5)); + + NodeDef node; + node.set_name(name); + for (int n = 0; n < num_inputs; ++n) { + const uint32 input_node = rnd.Uniform(i); + node.add_input(absl::StrCat(prefix, input_node)); + } + + *graph.add_node() = std::move(node); + } + + testing::StartTiming(); + std::vector topo_order; + for (int i = 0; i < iters; i++) { + topo_order.clear(); + Status st = ComputeTopologicalOrder(graph, &topo_order); + CHECK(st.ok()) << "Failed to compute topological order"; + } + testing::StopTiming(); +} +BENCHMARK(BM_ComputeTopologicalOrder) + ->Arg(10) + ->Arg(100) + ->Arg(1000) + ->Arg(10000) + ->Arg(25000) + ->Arg(50000) + ->Arg(100000); + } // namespace grappler } // namespace tensorflow