From d74bb6ad5fbb5e41a068db13d5c8578d3b1ffc15 Mon Sep 17 00:00:00 2001 From: Andy Ly Date: Wed, 22 May 2019 09:09:04 -0700 Subject: [PATCH] [Grappler] Add topological sort to new GraphView. PiperOrigin-RevId: 249459895 --- tensorflow/core/grappler/utils/BUILD | 2 + tensorflow/core/grappler/utils/graph_view.cc | 239 ++++++++++++++- tensorflow/core/grappler/utils/graph_view.h | 47 ++- .../core/grappler/utils/graph_view_internal.h | 7 + .../core/grappler/utils/graph_view_test.cc | 285 +++++++++++++++++- 5 files changed, 575 insertions(+), 5 deletions(-) diff --git a/tensorflow/core/grappler/utils/BUILD b/tensorflow/core/grappler/utils/BUILD index 8b8a3b4ef5d..e6a873ee474 100644 --- a/tensorflow/core/grappler/utils/BUILD +++ b/tensorflow/core/grappler/utils/BUILD @@ -348,10 +348,12 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", + "//tensorflow/core/grappler:op_types", "//tensorflow/core/grappler:utils", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", ], ) diff --git a/tensorflow/core/grappler/utils/graph_view.cc b/tensorflow/core/grappler/utils/graph_view.cc index 3d823d2daf9..07a40396d58 100644 --- a/tensorflow/core/grappler/utils/graph_view.cc +++ b/tensorflow/core/grappler/utils/graph_view.cc @@ -16,8 +16,11 @@ limitations under the License. #include "tensorflow/core/grappler/utils/graph_view.h" #include "absl/container/flat_hash_set.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/graph/tensor_id.h" +#include "tensorflow/core/grappler/op_types.h" #include "tensorflow/core/grappler/utils.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/map_util.h" @@ -1360,6 +1363,238 @@ void MutableGraphView::RemoveNodesInternal( } } +namespace { +constexpr int kTopologicalSortDone = -1; + +const char kMutableGraphViewSortTopologicallyError[] = + "MutableGraphView::SortTopologically error: "; + +// TraversalState is an enum representing the state of a node when it is being +// traversed via DFS. +enum TraversalState : uint8_t { NOT_VISITED, PENDING, PROCESSING, PROCESSED }; + +// RecursionStackState is an enum representing the recursion stack state +// when using DFS iteratively. `ENTER` is the state representing entering into +// a recursive call, while `EXIT` is the state representing exiting a +// recursive call. +enum RecursionStackState : bool { ENTER, EXIT }; + +// RecursionStackEntry is a helper struct representing an instance of a +// recursive call in the iterative DFS simulating a recursive ordering. +struct RecursionStackEntry { + RecursionStackEntry(int node_index, RecursionStackState recursion_state) + : node_index(node_index), recursion_state(recursion_state) {} + + const int node_index; + const RecursionStackState recursion_state; +}; + +// Edge is a helper struct representing an edge in the graph. +struct Edge { + Edge(int from, int to) : from(from), to(to) {} + + const int from; + const int to; +}; +} // namespace + +Status MutableGraphView::SortTopologically( + bool ignore_cycles, + absl::Span extra_dependencies) { + if (!mutation_.updated_nodes_.empty() || !mutation_.new_nodes_.empty()) { + // Cannot sort when there is an active mutation due to indices possibly + // being changed or invalidated. + return errors::InvalidArgument(kMutableGraphViewSortTopologicallyError, + "active mutation exists."); + } + + const int num_nodes = nodes_.size(); + + // Group extra dependencies by `from` node. + absl::flat_hash_map> extra_dependencies_by_parent; + for (const auto& extra_dependency : extra_dependencies) { + if (extra_dependency.graph_view_ != this || + extra_dependency.from_ == extra_dependency.to_ || + extra_dependency.from_ < 0 || extra_dependency.from_ >= num_nodes || + extra_dependency.to_ < 0 || extra_dependency.to_ >= num_nodes) { + return errors::InvalidArgument(kMutableGraphViewSortTopologicallyError, + "invalid extra dependencies."); + } + extra_dependencies_by_parent[extra_dependency.from_].push_back( + extra_dependency.to_); + } + + // Reversed colored post-order DFS traversal. This does not fail on cycles, + // but there are no guarantees on ordering within a cycle. + std::vector traversal_state(num_nodes, NOT_VISITED); + int curr_pos = num_nodes - 1; + std::vector order(num_nodes); + std::vector edges_in_cycle; + + auto push_onto_stack = [this]( + const int curr_index, const int fanout_index, + std::vector* recursion_stack, + std::vector* traversal_state, + std::vector* edges_in_cycle) { + auto& fanout_traversal_state = (*traversal_state)[fanout_index]; + if (fanout_traversal_state == PROCESSING) { + // Ignore NextIteration -> Merge cycles. + if (!IsNextIteration(graph_->node(curr_index)) || + !IsMerge(graph_->node(fanout_index))) { + // Cycle detected. + edges_in_cycle->push_back({curr_index, fanout_index}); + } + } else if (fanout_traversal_state == NOT_VISITED) { + // Unvisited node, simply add to stack for future traversal. + fanout_traversal_state = PENDING; + recursion_stack->push_back({fanout_index, ENTER}); + } + }; + + auto process_fanouts = [this, &extra_dependencies_by_parent, + &push_onto_stack]( + const int curr_index, + std::vector* recursion_stack, + std::vector* traversal_state, + std::vector* edges_in_cycle) { + const auto& node_view = nodes_[curr_index]; + // Regular fanouts. + for (const auto& regular_fanouts_port_i : node_view.GetRegularFanouts()) { + for (const auto& regular_fanout : regular_fanouts_port_i) { + push_onto_stack(curr_index, regular_fanout.node_index_, recursion_stack, + traversal_state, edges_in_cycle); + } + } + // Controlled fanouts. + for (const auto& controlled_fanout : node_view.GetControlledFanouts()) { + push_onto_stack(curr_index, controlled_fanout.node_index_, + recursion_stack, traversal_state, edges_in_cycle); + } + // Extra dependencies. + auto it = extra_dependencies_by_parent.find(curr_index); + if (it != extra_dependencies_by_parent.end()) { + for (const auto& extra_fanout : it->second) { + push_onto_stack(curr_index, extra_fanout, recursion_stack, + traversal_state, edges_in_cycle); + } + } + }; + + auto reversed_postorder_dfs = + [&process_fanouts](const MutableNodeView& root_node_view, + std::vector* order, + std::vector* traversal_state, + int* curr_pos, std::vector* edges_in_cycle) { + std::vector recursion_stack; + // Add the root to stack to start the traversal. + const int root_index = root_node_view.node_index_; + auto& root_traversal_state = (*traversal_state)[root_index]; + if (root_traversal_state == NOT_VISITED) { + root_traversal_state = PENDING; + recursion_stack.push_back({root_index, ENTER}); + } + while (!recursion_stack.empty()) { + auto curr_pair = recursion_stack.back(); + recursion_stack.pop_back(); + const int curr_index = curr_pair.node_index; + auto& curr_traversal_state = (*traversal_state)[curr_index]; + if (curr_traversal_state == PROCESSED) { + // Node already processed which can be ignored. + continue; + } else if (curr_pair.recursion_state == EXIT) { + // Node from recursion stack where all fanouts were visited. + // Instead of adding node index to a vector, simply set what its + // index would be, so there will not be a need for inversion later + // on. The value set is in decending order so the reversed + // post-order is returned. + (*order)[curr_index] = *curr_pos; + curr_traversal_state = PROCESSED; + --(*curr_pos); + } else { + // Process current node and fanouts. + curr_traversal_state = PROCESSING; + recursion_stack.push_back({curr_index, EXIT}); + process_fanouts(curr_index, &recursion_stack, traversal_state, + edges_in_cycle); + } + } + }; + + // Determine sources to start DFS (nodes with no inputs) and unique fanout + // nodes. + for (const auto& node : nodes_) { + if (node.NumRegularFanins() + node.NumControllingFanins() == 0) { + reversed_postorder_dfs(node, &order, &traversal_state, &curr_pos, + &edges_in_cycle); + } + } + + if (!ignore_cycles && !edges_in_cycle.empty()) { + std::vector edges_formatted; + edges_formatted.reserve(edges_in_cycle.size()); + for (const auto& edge : edges_in_cycle) { + edges_formatted.push_back( + absl::StrCat("'", graph_->node(edge.from).name(), "' -> '", + graph_->node(edge.to).name(), "'")); + } + const string edges_str = + absl::StrCat("{", absl::StrJoin(edges_formatted, ", "), "}"); + return errors::InvalidArgument(kMutableGraphViewSortTopologicallyError, + "detected edge(s) creating cycle(s) ", + edges_str, "."); + } + if (curr_pos != kTopologicalSortDone) { + // Not all nodes were processed. + if (!ignore_cycles) { + return errors::InvalidArgument( + kMutableGraphViewSortTopologicallyError, + "was not able to sort all nodes topologically."); + } + // Otherwise process all nodes regardless of cycles. + for (const auto& node : nodes_) { + reversed_postorder_dfs(node, &order, &traversal_state, &curr_pos, + &edges_in_cycle); + } + } + + // Permute nodes by reversed post-order DFS. + std::vector permuted_nodes(num_nodes); + for (int i = 0; i < num_nodes; ++i) { + permuted_nodes[order[i]] = std::move(nodes_[i]); + } + nodes_.swap(permuted_nodes); + + // Fix up indices of MutableNodeViews. + for (MutableNodeView& node_view : nodes_) { + const int prev_node_index = node_view.node_index_; + if (prev_node_index != order[prev_node_index]) { + const string& node_name = graph_->node(prev_node_index).name(); + node_view.node_index_ = order[prev_node_index]; + node_index_by_name_.find(node_name)->second = node_view.node_index_; + } + for (MutableFanoutView& regular_fanin : node_view.regular_fanins_) { + regular_fanin.node_index_ = order[regular_fanin.node_index_]; + } + for (MutableFanoutView& controlling_fanin : node_view.controlling_fanins_) { + controlling_fanin.node_index_ = order[controlling_fanin.node_index_]; + } + for (std::vector& regular_fanouts_port_i : + node_view.regular_fanouts_by_port_) { + for (MutableFaninView& regular_fanout : regular_fanouts_port_i) { + regular_fanout.node_index_ = order[regular_fanout.node_index_]; + } + } + for (MutableFaninView& controlled_fanout : node_view.controlled_fanouts_) { + controlled_fanout.node_index_ = order[controlled_fanout.node_index_]; + } + } + + // Permute graph NodeDefs. + PermuteNodesInPlace(graph_, &order, /*invert_permutation=*/false); + + return Status::OK(); +} + inline Status MutableGraphView::ValidateInternal( absl::flat_hash_map* node_names, std::vector* renamed_nodes, @@ -1410,8 +1645,8 @@ Status MutableGraphView::ApplyMutationInternal() { // Node name and associated fanouts. absl::flat_hash_map renamed_fanouts; // Removed nodes where name was overwritten by a renamed node. - std::vector overwritten_name_removed_nodes; - overwritten_name_removed_nodes.resize(mutation_.updated_nodes_.size(), false); + std::vector overwritten_name_removed_nodes( + mutation_.updated_nodes_.size()); // Fix renaming of existing nodes by swapping fanouts and rehashing names. // This will also overwrite removed or unmodified nodes. FixRenamedNodes(&renamed_nodes, &renamed_fanouts, diff --git a/tensorflow/core/grappler/utils/graph_view.h b/tensorflow/core/grappler/utils/graph_view.h index 18f7c4ab560..9f61e811169 100644 --- a/tensorflow/core/grappler/utils/graph_view.h +++ b/tensorflow/core/grappler/utils/graph_view.h @@ -21,6 +21,7 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/node_def.pb.h" @@ -78,9 +79,16 @@ class FanoutView : public internal::NodeIndexAndPortIndex { class NodeView : public internal::NodeViewInternal { public: - using NodeViewInternal::NodeViewInternal; + explicit NodeView(GraphView* graph_view, int node_index) + : NodeViewInternal(graph_view, node_index) {} + + NodeView() : NodeViewInternal() {} + ~NodeView() override = default; + NodeView(NodeView&&) = default; + NodeView& operator=(NodeView&&) = default; + const NodeDef* node() const override; // Checks if a fanin exists for the node. @@ -200,9 +208,16 @@ class MutableNodeView : public internal::NodeViewInternal { public: - using NodeViewInternal::NodeViewInternal; + explicit MutableNodeView(MutableGraphView* graph_view, int node_index) + : NodeViewInternal(graph_view, node_index) {} + + MutableNodeView() : NodeViewInternal() {} + ~MutableNodeView() override = default; + MutableNodeView(MutableNodeView&&) = default; + MutableNodeView& operator=(MutableNodeView&&) = default; + NodeDef* node() const override; // Checks if a fanin exists for the node. @@ -364,6 +379,34 @@ class MutableGraphView // Returns a Mutation (builder) that can be used to modify MutableGraphView. Mutation* GetMutationBuilder(); + // Helper class representing an extra dependency for topological sorting. + class TopologicalDependency { + public: + TopologicalDependency(const MutableNodeView* from_node, + const MutableNodeView* to_node) { + if (from_node->graph_view_ == to_node->graph_view_) { + graph_view_ = from_node->graph_view_; + from_ = from_node->node_index_; + to_ = to_node->node_index_; + } + } + + private: + MutableGraphView* graph_view_ = nullptr; + int from_ = internal::kMissingIndex; + int to_ = internal::kMissingIndex; + + friend class MutableGraphView; + }; + + // Sorts graph topologically in-place. If `ignore_cycles` is set, a + // topological like sorting will be performed when there are cycles. Otherwise + // if a cycle is detected or if the graph cannot be sorted, an error will be + // returned. + Status SortTopologically( + bool ignore_cycles, + absl::Span extra_dependencies); + private: bool AddUniqueNodeInternal(NodeDef* node); diff --git a/tensorflow/core/grappler/utils/graph_view_internal.h b/tensorflow/core/grappler/utils/graph_view_internal.h index 22e15917fb4..b1756a465fe 100644 --- a/tensorflow/core/grappler/utils/graph_view_internal.h +++ b/tensorflow/core/grappler/utils/graph_view_internal.h @@ -131,8 +131,15 @@ class NodeViewInternal { : graph_view_(graph_view), node_index_(node_index), attrs_(AttrSlice(graph_view->graph()->node(node_index))) {} + + NodeViewInternal() + : graph_view_(nullptr), node_index_(kMissingIndex), attrs_(AttrSlice()) {} + virtual ~NodeViewInternal() {} + NodeViewInternal(NodeViewInternal&&) = default; + NodeViewInternal& operator=(NodeViewInternal&&) = default; + bool operator==(const NodeViewInternal& other) const { return node_index_ == other.node_index_ && graph_view_ == other.graph_view_; } diff --git a/tensorflow/core/grappler/utils/graph_view_test.cc b/tensorflow/core/grappler/utils/graph_view_test.cc index c8952ceac92..ba2c9c31bb9 100644 --- a/tensorflow/core/grappler/utils/graph_view_test.cc +++ b/tensorflow/core/grappler/utils/graph_view_test.cc @@ -790,7 +790,7 @@ TYPED_TEST(TypedNodeViewTest, HasAttr) { EXPECT_FALSE(c_node->HasAttr("attr")); } -class MutationTest : public GrapplerTest { +class CompareGraphTest : public GrapplerTest { public: void CompareGraphViewWithGraph(MutableGraphView* graph_view, const GraphDef& expected_graph) { @@ -953,6 +953,8 @@ class MutationTest : public GrapplerTest { } }; +class MutationTest : public CompareGraphTest {}; + constexpr char kDeviceCPU0[] = "/device:CPU:0"; constexpr char kDeviceGPU0[] = "/device:GPU:0"; @@ -1995,6 +1997,270 @@ TEST_F(MutationTest, EmptyMutationUpdateIndexPersisting) { CompareGraphViewWithGraph(&graph_view, test_graph()); } +class TopologicalSortTest : public CompareGraphTest { + protected: + void CompareGraphOrder(const MutableGraphView& graph_view, + absl::Span node_names) { + const int num_nodes = graph_view.NumNodes(); + ASSERT_EQ(num_nodes, node_names.size()); + for (int i = 0; i < num_nodes; ++i) { + EXPECT_EQ(graph_view.GetNode(i)->GetName(), node_names[i]); + } + } + + void CompareGraphNodePrecedences( + const MutableGraphView& graph_view, + absl::Span> node_precedences) { + for (const auto& node_precedence : node_precedences) { + auto* parent_node = graph_view.GetNode(node_precedence.first); + ASSERT_NE(parent_node, nullptr); + auto* child_node = graph_view.GetNode(node_precedence.second); + ASSERT_NE(child_node, nullptr); + EXPECT_TRUE(parent_node->node_index() < child_node->node_index()); + } + } +}; + +TEST_F(TopologicalSortTest, ActiveMutationSort) { + auto test_graph = []() { + return GDef({NDef("a", kIdentity, {}, {{"T", DT_FLOAT}}, kDeviceGPU0), + NDef("b", kIdentity, {"a"}, {{"T", DT_FLOAT}}, kDeviceGPU1)}, + /*funcs=*/{}); + }; + + GraphDef graph = test_graph(); + Status status; + MutableGraphView graph_view(&graph, &status); + TF_ASSERT_OK(status); + + Mutation* mutation = graph_view.GetMutationBuilder(); + mutation->AddNode({}, &status); + TF_ASSERT_OK(status); + + for (bool ignore_cycles : {false, true}) { + status = graph_view.SortTopologically(ignore_cycles, {}); + EXPECT_FALSE(status.ok()); + EXPECT_EQ( + status.error_message(), + "MutableGraphView::SortTopologically error: active mutation exists."); + CompareGraphViewWithGraph(&graph_view, test_graph()); + CompareGraphOrder(graph_view, {"a", "b"}); + } +} + +TEST_F(TopologicalSortTest, BadExtraDependenciesSort) { + auto test_graph = []() { + return GDef({NDef("a", kIdentity, {}, {{"T", DT_FLOAT}}, kDeviceGPU0), + NDef("b", kIdentity, {}, {{"T", DT_FLOAT}}, kDeviceGPU1)}, + /*funcs=*/{}); + }; + + GraphDef graph_1 = test_graph(); + Status status; + MutableGraphView graph_view_1(&graph_1, &status); + TF_ASSERT_OK(status); + MutableNodeView* a_node_1 = graph_view_1.GetNode("a"); + + GraphDef graph_2 = test_graph(); + MutableGraphView graph_view_2(&graph_2, &status); + TF_ASSERT_OK(status); + MutableNodeView* b_node_2 = graph_view_2.GetNode("b"); + + for (bool ignore_cycles : {false, true}) { + status = + graph_view_2.SortTopologically(ignore_cycles, {{a_node_1, b_node_2}}); + EXPECT_FALSE(status.ok()); + EXPECT_EQ(status.error_message(), + "MutableGraphView::SortTopologically error: invalid extra " + "dependencies."); + CompareGraphViewWithGraph(&graph_view_2, test_graph()); + CompareGraphOrder(graph_view_2, {"a", "b"}); + } +} + +TEST_F(TopologicalSortTest, NoCyclesAllowed) { + auto test_graph = []() { + return GDef( + {NDef("a", kIdentity, {}, {{"T", DT_FLOAT}}, kDeviceGPU0), + NDef("b", kIdentity, {"a", "c"}, {{"T", DT_FLOAT}}, kDeviceGPU1), + NDef("c", kIdentity, {"b"}, {{"T", DT_FLOAT}}, kDeviceGPU1)}, + /*funcs=*/{}); + }; + + GraphDef graph = test_graph(); + Status status; + MutableGraphView graph_view(&graph, &status); + TF_ASSERT_OK(status); + + status = graph_view.SortTopologically(/*ignore_cycles=*/false, {}); + EXPECT_FALSE(status.ok()); + EXPECT_EQ(status.error_message(), + "MutableGraphView::SortTopologically error: detected edge(s) " + "creating cycle(s) {'c' -> 'b'}."); + CompareGraphViewWithGraph(&graph_view, test_graph()); + CompareGraphOrder(graph_view, {"a", "b", "c"}); + + TF_EXPECT_OK(graph_view.SortTopologically(/*ignore_cycles=*/true, {})); + CompareGraphViewWithGraph(&graph_view, test_graph()); + CompareGraphNodePrecedences(graph_view, {{"a", "b"}, {"a", "c"}}); +} + +TEST_F(TopologicalSortTest, NoNodesWithZeroFanins) { + auto test_graph = []() { + return GDef({NDef("a", kIdentity, {"b"}, {{"T", DT_FLOAT}}, kDeviceGPU0), + NDef("b", kIdentity, {"a"}, {{"T", DT_FLOAT}}, kDeviceGPU1)}, + /*funcs=*/{}); + }; + + GraphDef graph = test_graph(); + Status status; + MutableGraphView graph_view(&graph, &status); + TF_ASSERT_OK(status); + + status = graph_view.SortTopologically(/*ignore_cycles=*/false, {}); + EXPECT_FALSE(status.ok()); + EXPECT_EQ(status.error_message(), + "MutableGraphView::SortTopologically error: was not able to sort " + "all nodes topologically."); + CompareGraphViewWithGraph(&graph_view, test_graph()); + CompareGraphOrder(graph_view, {"a", "b"}); + + TF_EXPECT_OK(graph_view.SortTopologically(/*ignore_cycles=*/true, {})); + CompareGraphViewWithGraph(&graph_view, test_graph()); +} + +TEST_F(TopologicalSortTest, DidNotReachAllNodes) { + auto test_graph = []() { + return GDef({NDef("c", kIdentity, {}, {{"T", DT_FLOAT}}, kDeviceGPU2), + NDef("a", kIdentity, {"b"}, {{"T", DT_FLOAT}}, kDeviceGPU0), + NDef("b", kIdentity, {"a"}, {{"T", DT_FLOAT}}, kDeviceGPU1)}, + /*funcs=*/{}); + }; + + GraphDef graph = test_graph(); + Status status; + MutableGraphView graph_view(&graph, &status); + TF_ASSERT_OK(status); + + status = graph_view.SortTopologically(/*ignore_cycles=*/false, {}); + EXPECT_FALSE(status.ok()); + EXPECT_EQ(status.error_message(), + "MutableGraphView::SortTopologically error: was not able to sort " + "all nodes topologically."); + CompareGraphViewWithGraph(&graph_view, test_graph()); + CompareGraphOrder(graph_view, {"c", "a", "b"}); + + TF_EXPECT_OK(graph_view.SortTopologically(/*ignore_cycles=*/true, {})); + CompareGraphViewWithGraph(&graph_view, test_graph()); + CompareGraphOrder(graph_view, {"a", "b", "c"}); +} + +TEST_F(TopologicalSortTest, NoLoopGraph) { + auto test_graph = []() { + return GDef({NDef("c", kIdentity, {"f"}), NDef("a", kIdentity, {"f", "e"}), + NDef("b", kIdentity, {"e", "d"}), NDef("d", kIdentity, {"c"}), + NDef("f", kIdentity, {}), NDef("e", kIdentity, {})}, + /*funcs=*/{}); + }; + + GraphDef graph = test_graph(); + Status status; + MutableGraphView graph_view(&graph, &status); + TF_ASSERT_OK(status); + + TF_EXPECT_OK(graph_view.SortTopologically(/*ignore_cycles=*/false, {})); + CompareGraphViewWithGraph(&graph_view, test_graph()); + CompareGraphNodePrecedences( + graph_view, + {{"f", "a"}, {"f", "c"}, {"e", "a"}, {"e", "b"}, {"c", "d"}, {"d", "b"}}); +} + +TEST_F(TopologicalSortTest, ValidLoopGraph) { + // NextIteration -> Merge loop. + auto test_graph = []() { + return GDef({NDef("b", "Merge", {"a", "e"}), NDef("c", "Switch", {"b"}), + NDef("d", kIdentity, {"c"}), NDef("e", "NextIteration", {"d"}), + NDef("a", "Const", {})}, + /*funcs=*/{}); + }; + + GraphDef graph = test_graph(); + Status status; + MutableGraphView graph_view(&graph, &status); + TF_ASSERT_OK(status); + + TF_EXPECT_OK(graph_view.SortTopologically(/*ignore_cycles=*/false, {})); + CompareGraphViewWithGraph(&graph_view, test_graph()); + CompareGraphOrder(graph_view, {"a", "b", "c", "d", "e"}); +} + +TEST_F(TopologicalSortTest, DuplicateFanins) { + auto test_graph = []() { + return GDef( + {NDef("b", kIdentity, {"a", "a", "^a"}), NDef("a", "Const", {})}, + /*funcs=*/{}); + }; + + GraphDef graph = test_graph(); + Status status; + MutableGraphView graph_view(&graph, &status); + TF_ASSERT_OK(status); + + TF_EXPECT_OK(graph_view.SortTopologically(/*ignore_cycles=*/false, {})); + CompareGraphViewWithGraph(&graph_view, test_graph()); + CompareGraphOrder(graph_view, {"a", "b"}); +} + +TEST_F(TopologicalSortTest, DiamondDependencyNotACycle) { + auto test_graph = []() { + return GDef({NDef("e", kIdentity, {"b", "c", "d"}), + NDef("b", kIdentity, {"a"}), NDef("a", "Const", {}), + NDef("d", kIdentity, {"a"}), NDef("c", kIdentity, {"a"})}, + /*funcs=*/{}); + }; + + GraphDef graph = test_graph(); + Status status; + MutableGraphView graph_view(&graph, &status); + TF_ASSERT_OK(status); + + TF_EXPECT_OK(graph_view.SortTopologically(/*ignore_cycles=*/false, {})); + CompareGraphViewWithGraph(&graph_view, test_graph()); + CompareGraphNodePrecedences( + graph_view, + {{"a", "b"}, {"a", "c"}, {"a", "d"}, {"b", "e"}, {"c", "e"}, {"d", "e"}}); +} + +TEST_F(TopologicalSortTest, ExtraDependencies) { + auto test_graph = []() { + return GDef({NDef("c", kIdentity, {"f"}), NDef("a", kIdentity, {"f", "e"}), + NDef("b", kIdentity, {"e", "d"}), NDef("d", kIdentity, {"c"}), + NDef("f", kIdentity, {}), NDef("e", kIdentity, {})}, + /*funcs=*/{}); + }; + + GraphDef graph = test_graph(); + Status status; + MutableGraphView graph_view(&graph, &status); + TF_ASSERT_OK(status); + + auto* e_node = graph_view.GetNode("e"); + ASSERT_NE(e_node, nullptr); + auto* f_node = graph_view.GetNode("f"); + ASSERT_NE(f_node, nullptr); + + TF_EXPECT_OK( + graph_view.SortTopologically(/*ignore_cycles=*/true, {{e_node, f_node}})); + CompareGraphViewWithGraph(&graph_view, test_graph()); + CompareGraphNodePrecedences(graph_view, {{"f", "a"}, + {"f", "c"}, + {"e", "a"}, + {"e", "b"}, + {"c", "d"}, + {"d", "b"}, + {"e", "f"}}); +} + #define RUN_NUM_NODE_NUM_EDGE_BENCHMARK(name) \ BENCHMARK(name) \ ->ArgPair(10, 2) \ @@ -2541,6 +2807,23 @@ RUN_NUM_FANIN_NUM_FANOUT_BENCHMARK(BM_GraphViewHasControlledFanoutLast); RUN_NUM_FANIN_NUM_FANOUT_BENCHMARK(BM_MutableGraphViewHasControlledFanoutFirst); RUN_NUM_FANIN_NUM_FANOUT_BENCHMARK(BM_MutableGraphViewHasControlledFanoutLast); +static void BM_SortTopologically(int iters, int size) { + testing::StopTiming(); + + GraphDef graph = test::CreateRandomGraph(size); + Status status; + MutableGraphView graph_view(&graph, &status); + TF_ASSERT_OK(status); + + testing::StartTiming(); + for (int i = 0; i < iters; i++) { + TF_EXPECT_OK(graph_view.SortTopologically(/*ignore_cycles=*/false, {})); + } + testing::StopTiming(); +} + +RUN_NUM_NODE_BENCHMARK(BM_SortTopologically); + } // namespace } // namespace utils } // namespace grappler