[Grappler] Add topological sort to new GraphView.

PiperOrigin-RevId: 249459895
This commit is contained in:
Andy Ly 2019-05-22 09:09:04 -07:00 committed by TensorFlower Gardener
parent 477447155b
commit d74bb6ad5f
5 changed files with 575 additions and 5 deletions

View File

@ -348,10 +348,12 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/grappler:op_types",
"//tensorflow/core/grappler:utils",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
],
)

View File

@ -16,8 +16,11 @@ limitations under the License.
#include "tensorflow/core/grappler/utils/graph_view.h"
#include "absl/container/flat_hash_set.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/graph/tensor_id.h"
#include "tensorflow/core/grappler/op_types.h"
#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/map_util.h"
@ -1360,6 +1363,238 @@ void MutableGraphView::RemoveNodesInternal(
}
}
namespace {
constexpr int kTopologicalSortDone = -1;
const char kMutableGraphViewSortTopologicallyError[] =
"MutableGraphView::SortTopologically error: ";
// TraversalState is an enum representing the state of a node when it is being
// traversed via DFS.
enum TraversalState : uint8_t { NOT_VISITED, PENDING, PROCESSING, PROCESSED };
// RecursionStackState is an enum representing the recursion stack state
// when using DFS iteratively. `ENTER` is the state representing entering into
// a recursive call, while `EXIT` is the state representing exiting a
// recursive call.
enum RecursionStackState : bool { ENTER, EXIT };
// RecursionStackEntry is a helper struct representing an instance of a
// recursive call in the iterative DFS simulating a recursive ordering.
struct RecursionStackEntry {
RecursionStackEntry(int node_index, RecursionStackState recursion_state)
: node_index(node_index), recursion_state(recursion_state) {}
const int node_index;
const RecursionStackState recursion_state;
};
// Edge is a helper struct representing an edge in the graph.
struct Edge {
Edge(int from, int to) : from(from), to(to) {}
const int from;
const int to;
};
} // namespace
Status MutableGraphView::SortTopologically(
bool ignore_cycles,
absl::Span<const TopologicalDependency> extra_dependencies) {
if (!mutation_.updated_nodes_.empty() || !mutation_.new_nodes_.empty()) {
// Cannot sort when there is an active mutation due to indices possibly
// being changed or invalidated.
return errors::InvalidArgument(kMutableGraphViewSortTopologicallyError,
"active mutation exists.");
}
const int num_nodes = nodes_.size();
// Group extra dependencies by `from` node.
absl::flat_hash_map<int, std::vector<int>> extra_dependencies_by_parent;
for (const auto& extra_dependency : extra_dependencies) {
if (extra_dependency.graph_view_ != this ||
extra_dependency.from_ == extra_dependency.to_ ||
extra_dependency.from_ < 0 || extra_dependency.from_ >= num_nodes ||
extra_dependency.to_ < 0 || extra_dependency.to_ >= num_nodes) {
return errors::InvalidArgument(kMutableGraphViewSortTopologicallyError,
"invalid extra dependencies.");
}
extra_dependencies_by_parent[extra_dependency.from_].push_back(
extra_dependency.to_);
}
// Reversed colored post-order DFS traversal. This does not fail on cycles,
// but there are no guarantees on ordering within a cycle.
std::vector<TraversalState> traversal_state(num_nodes, NOT_VISITED);
int curr_pos = num_nodes - 1;
std::vector<int> order(num_nodes);
std::vector<Edge> edges_in_cycle;
auto push_onto_stack = [this](
const int curr_index, const int fanout_index,
std::vector<RecursionStackEntry>* recursion_stack,
std::vector<TraversalState>* traversal_state,
std::vector<Edge>* edges_in_cycle) {
auto& fanout_traversal_state = (*traversal_state)[fanout_index];
if (fanout_traversal_state == PROCESSING) {
// Ignore NextIteration -> Merge cycles.
if (!IsNextIteration(graph_->node(curr_index)) ||
!IsMerge(graph_->node(fanout_index))) {
// Cycle detected.
edges_in_cycle->push_back({curr_index, fanout_index});
}
} else if (fanout_traversal_state == NOT_VISITED) {
// Unvisited node, simply add to stack for future traversal.
fanout_traversal_state = PENDING;
recursion_stack->push_back({fanout_index, ENTER});
}
};
auto process_fanouts = [this, &extra_dependencies_by_parent,
&push_onto_stack](
const int curr_index,
std::vector<RecursionStackEntry>* recursion_stack,
std::vector<TraversalState>* traversal_state,
std::vector<Edge>* edges_in_cycle) {
const auto& node_view = nodes_[curr_index];
// Regular fanouts.
for (const auto& regular_fanouts_port_i : node_view.GetRegularFanouts()) {
for (const auto& regular_fanout : regular_fanouts_port_i) {
push_onto_stack(curr_index, regular_fanout.node_index_, recursion_stack,
traversal_state, edges_in_cycle);
}
}
// Controlled fanouts.
for (const auto& controlled_fanout : node_view.GetControlledFanouts()) {
push_onto_stack(curr_index, controlled_fanout.node_index_,
recursion_stack, traversal_state, edges_in_cycle);
}
// Extra dependencies.
auto it = extra_dependencies_by_parent.find(curr_index);
if (it != extra_dependencies_by_parent.end()) {
for (const auto& extra_fanout : it->second) {
push_onto_stack(curr_index, extra_fanout, recursion_stack,
traversal_state, edges_in_cycle);
}
}
};
auto reversed_postorder_dfs =
[&process_fanouts](const MutableNodeView& root_node_view,
std::vector<int>* order,
std::vector<TraversalState>* traversal_state,
int* curr_pos, std::vector<Edge>* edges_in_cycle) {
std::vector<RecursionStackEntry> recursion_stack;
// Add the root to stack to start the traversal.
const int root_index = root_node_view.node_index_;
auto& root_traversal_state = (*traversal_state)[root_index];
if (root_traversal_state == NOT_VISITED) {
root_traversal_state = PENDING;
recursion_stack.push_back({root_index, ENTER});
}
while (!recursion_stack.empty()) {
auto curr_pair = recursion_stack.back();
recursion_stack.pop_back();
const int curr_index = curr_pair.node_index;
auto& curr_traversal_state = (*traversal_state)[curr_index];
if (curr_traversal_state == PROCESSED) {
// Node already processed which can be ignored.
continue;
} else if (curr_pair.recursion_state == EXIT) {
// Node from recursion stack where all fanouts were visited.
// Instead of adding node index to a vector, simply set what its
// index would be, so there will not be a need for inversion later
// on. The value set is in decending order so the reversed
// post-order is returned.
(*order)[curr_index] = *curr_pos;
curr_traversal_state = PROCESSED;
--(*curr_pos);
} else {
// Process current node and fanouts.
curr_traversal_state = PROCESSING;
recursion_stack.push_back({curr_index, EXIT});
process_fanouts(curr_index, &recursion_stack, traversal_state,
edges_in_cycle);
}
}
};
// Determine sources to start DFS (nodes with no inputs) and unique fanout
// nodes.
for (const auto& node : nodes_) {
if (node.NumRegularFanins() + node.NumControllingFanins() == 0) {
reversed_postorder_dfs(node, &order, &traversal_state, &curr_pos,
&edges_in_cycle);
}
}
if (!ignore_cycles && !edges_in_cycle.empty()) {
std::vector<string> edges_formatted;
edges_formatted.reserve(edges_in_cycle.size());
for (const auto& edge : edges_in_cycle) {
edges_formatted.push_back(
absl::StrCat("'", graph_->node(edge.from).name(), "' -> '",
graph_->node(edge.to).name(), "'"));
}
const string edges_str =
absl::StrCat("{", absl::StrJoin(edges_formatted, ", "), "}");
return errors::InvalidArgument(kMutableGraphViewSortTopologicallyError,
"detected edge(s) creating cycle(s) ",
edges_str, ".");
}
if (curr_pos != kTopologicalSortDone) {
// Not all nodes were processed.
if (!ignore_cycles) {
return errors::InvalidArgument(
kMutableGraphViewSortTopologicallyError,
"was not able to sort all nodes topologically.");
}
// Otherwise process all nodes regardless of cycles.
for (const auto& node : nodes_) {
reversed_postorder_dfs(node, &order, &traversal_state, &curr_pos,
&edges_in_cycle);
}
}
// Permute nodes by reversed post-order DFS.
std::vector<MutableNodeView> permuted_nodes(num_nodes);
for (int i = 0; i < num_nodes; ++i) {
permuted_nodes[order[i]] = std::move(nodes_[i]);
}
nodes_.swap(permuted_nodes);
// Fix up indices of MutableNodeViews.
for (MutableNodeView& node_view : nodes_) {
const int prev_node_index = node_view.node_index_;
if (prev_node_index != order[prev_node_index]) {
const string& node_name = graph_->node(prev_node_index).name();
node_view.node_index_ = order[prev_node_index];
node_index_by_name_.find(node_name)->second = node_view.node_index_;
}
for (MutableFanoutView& regular_fanin : node_view.regular_fanins_) {
regular_fanin.node_index_ = order[regular_fanin.node_index_];
}
for (MutableFanoutView& controlling_fanin : node_view.controlling_fanins_) {
controlling_fanin.node_index_ = order[controlling_fanin.node_index_];
}
for (std::vector<MutableFaninView>& regular_fanouts_port_i :
node_view.regular_fanouts_by_port_) {
for (MutableFaninView& regular_fanout : regular_fanouts_port_i) {
regular_fanout.node_index_ = order[regular_fanout.node_index_];
}
}
for (MutableFaninView& controlled_fanout : node_view.controlled_fanouts_) {
controlled_fanout.node_index_ = order[controlled_fanout.node_index_];
}
}
// Permute graph NodeDefs.
PermuteNodesInPlace(graph_, &order, /*invert_permutation=*/false);
return Status::OK();
}
inline Status MutableGraphView::ValidateInternal(
absl::flat_hash_map<absl::string_view, int>* node_names,
std::vector<RenamedOrOverwrittenNode>* renamed_nodes,
@ -1410,8 +1645,8 @@ Status MutableGraphView::ApplyMutationInternal() {
// Node name and associated fanouts.
absl::flat_hash_map<string, NodeViewFanouts> renamed_fanouts;
// Removed nodes where name was overwritten by a renamed node.
std::vector<bool> overwritten_name_removed_nodes;
overwritten_name_removed_nodes.resize(mutation_.updated_nodes_.size(), false);
std::vector<bool> overwritten_name_removed_nodes(
mutation_.updated_nodes_.size());
// Fix renaming of existing nodes by swapping fanouts and rehashing names.
// This will also overwrite removed or unmodified nodes.
FixRenamedNodes(&renamed_nodes, &renamed_fanouts,

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
@ -78,9 +79,16 @@ class FanoutView : public internal::NodeIndexAndPortIndex<NodeView, GraphView> {
class NodeView : public internal::NodeViewInternal<FaninView, FanoutView,
GraphView, true> {
public:
using NodeViewInternal::NodeViewInternal;
explicit NodeView(GraphView* graph_view, int node_index)
: NodeViewInternal(graph_view, node_index) {}
NodeView() : NodeViewInternal() {}
~NodeView() override = default;
NodeView(NodeView&&) = default;
NodeView& operator=(NodeView&&) = default;
const NodeDef* node() const override;
// Checks if a fanin exists for the node.
@ -200,9 +208,16 @@ class MutableNodeView
: public internal::NodeViewInternal<MutableFaninView, MutableFanoutView,
MutableGraphView, false> {
public:
using NodeViewInternal::NodeViewInternal;
explicit MutableNodeView(MutableGraphView* graph_view, int node_index)
: NodeViewInternal(graph_view, node_index) {}
MutableNodeView() : NodeViewInternal() {}
~MutableNodeView() override = default;
MutableNodeView(MutableNodeView&&) = default;
MutableNodeView& operator=(MutableNodeView&&) = default;
NodeDef* node() const override;
// Checks if a fanin exists for the node.
@ -364,6 +379,34 @@ class MutableGraphView
// Returns a Mutation (builder) that can be used to modify MutableGraphView.
Mutation* GetMutationBuilder();
// Helper class representing an extra dependency for topological sorting.
class TopologicalDependency {
public:
TopologicalDependency(const MutableNodeView* from_node,
const MutableNodeView* to_node) {
if (from_node->graph_view_ == to_node->graph_view_) {
graph_view_ = from_node->graph_view_;
from_ = from_node->node_index_;
to_ = to_node->node_index_;
}
}
private:
MutableGraphView* graph_view_ = nullptr;
int from_ = internal::kMissingIndex;
int to_ = internal::kMissingIndex;
friend class MutableGraphView;
};
// Sorts graph topologically in-place. If `ignore_cycles` is set, a
// topological like sorting will be performed when there are cycles. Otherwise
// if a cycle is detected or if the graph cannot be sorted, an error will be
// returned.
Status SortTopologically(
bool ignore_cycles,
absl::Span<const TopologicalDependency> extra_dependencies);
private:
bool AddUniqueNodeInternal(NodeDef* node);

View File

@ -131,8 +131,15 @@ class NodeViewInternal {
: graph_view_(graph_view),
node_index_(node_index),
attrs_(AttrSlice(graph_view->graph()->node(node_index))) {}
NodeViewInternal()
: graph_view_(nullptr), node_index_(kMissingIndex), attrs_(AttrSlice()) {}
virtual ~NodeViewInternal() {}
NodeViewInternal(NodeViewInternal&&) = default;
NodeViewInternal& operator=(NodeViewInternal&&) = default;
bool operator==(const NodeViewInternal& other) const {
return node_index_ == other.node_index_ && graph_view_ == other.graph_view_;
}

View File

@ -790,7 +790,7 @@ TYPED_TEST(TypedNodeViewTest, HasAttr) {
EXPECT_FALSE(c_node->HasAttr("attr"));
}
class MutationTest : public GrapplerTest {
class CompareGraphTest : public GrapplerTest {
public:
void CompareGraphViewWithGraph(MutableGraphView* graph_view,
const GraphDef& expected_graph) {
@ -953,6 +953,8 @@ class MutationTest : public GrapplerTest {
}
};
class MutationTest : public CompareGraphTest {};
constexpr char kDeviceCPU0[] = "/device:CPU:0";
constexpr char kDeviceGPU0[] = "/device:GPU:0";
@ -1995,6 +1997,270 @@ TEST_F(MutationTest, EmptyMutationUpdateIndexPersisting) {
CompareGraphViewWithGraph(&graph_view, test_graph());
}
class TopologicalSortTest : public CompareGraphTest {
protected:
void CompareGraphOrder(const MutableGraphView& graph_view,
absl::Span<const string> node_names) {
const int num_nodes = graph_view.NumNodes();
ASSERT_EQ(num_nodes, node_names.size());
for (int i = 0; i < num_nodes; ++i) {
EXPECT_EQ(graph_view.GetNode(i)->GetName(), node_names[i]);
}
}
void CompareGraphNodePrecedences(
const MutableGraphView& graph_view,
absl::Span<const std::pair<string, string>> node_precedences) {
for (const auto& node_precedence : node_precedences) {
auto* parent_node = graph_view.GetNode(node_precedence.first);
ASSERT_NE(parent_node, nullptr);
auto* child_node = graph_view.GetNode(node_precedence.second);
ASSERT_NE(child_node, nullptr);
EXPECT_TRUE(parent_node->node_index() < child_node->node_index());
}
}
};
TEST_F(TopologicalSortTest, ActiveMutationSort) {
auto test_graph = []() {
return GDef({NDef("a", kIdentity, {}, {{"T", DT_FLOAT}}, kDeviceGPU0),
NDef("b", kIdentity, {"a"}, {{"T", DT_FLOAT}}, kDeviceGPU1)},
/*funcs=*/{});
};
GraphDef graph = test_graph();
Status status;
MutableGraphView graph_view(&graph, &status);
TF_ASSERT_OK(status);
Mutation* mutation = graph_view.GetMutationBuilder();
mutation->AddNode({}, &status);
TF_ASSERT_OK(status);
for (bool ignore_cycles : {false, true}) {
status = graph_view.SortTopologically(ignore_cycles, {});
EXPECT_FALSE(status.ok());
EXPECT_EQ(
status.error_message(),
"MutableGraphView::SortTopologically error: active mutation exists.");
CompareGraphViewWithGraph(&graph_view, test_graph());
CompareGraphOrder(graph_view, {"a", "b"});
}
}
TEST_F(TopologicalSortTest, BadExtraDependenciesSort) {
auto test_graph = []() {
return GDef({NDef("a", kIdentity, {}, {{"T", DT_FLOAT}}, kDeviceGPU0),
NDef("b", kIdentity, {}, {{"T", DT_FLOAT}}, kDeviceGPU1)},
/*funcs=*/{});
};
GraphDef graph_1 = test_graph();
Status status;
MutableGraphView graph_view_1(&graph_1, &status);
TF_ASSERT_OK(status);
MutableNodeView* a_node_1 = graph_view_1.GetNode("a");
GraphDef graph_2 = test_graph();
MutableGraphView graph_view_2(&graph_2, &status);
TF_ASSERT_OK(status);
MutableNodeView* b_node_2 = graph_view_2.GetNode("b");
for (bool ignore_cycles : {false, true}) {
status =
graph_view_2.SortTopologically(ignore_cycles, {{a_node_1, b_node_2}});
EXPECT_FALSE(status.ok());
EXPECT_EQ(status.error_message(),
"MutableGraphView::SortTopologically error: invalid extra "
"dependencies.");
CompareGraphViewWithGraph(&graph_view_2, test_graph());
CompareGraphOrder(graph_view_2, {"a", "b"});
}
}
TEST_F(TopologicalSortTest, NoCyclesAllowed) {
auto test_graph = []() {
return GDef(
{NDef("a", kIdentity, {}, {{"T", DT_FLOAT}}, kDeviceGPU0),
NDef("b", kIdentity, {"a", "c"}, {{"T", DT_FLOAT}}, kDeviceGPU1),
NDef("c", kIdentity, {"b"}, {{"T", DT_FLOAT}}, kDeviceGPU1)},
/*funcs=*/{});
};
GraphDef graph = test_graph();
Status status;
MutableGraphView graph_view(&graph, &status);
TF_ASSERT_OK(status);
status = graph_view.SortTopologically(/*ignore_cycles=*/false, {});
EXPECT_FALSE(status.ok());
EXPECT_EQ(status.error_message(),
"MutableGraphView::SortTopologically error: detected edge(s) "
"creating cycle(s) {'c' -> 'b'}.");
CompareGraphViewWithGraph(&graph_view, test_graph());
CompareGraphOrder(graph_view, {"a", "b", "c"});
TF_EXPECT_OK(graph_view.SortTopologically(/*ignore_cycles=*/true, {}));
CompareGraphViewWithGraph(&graph_view, test_graph());
CompareGraphNodePrecedences(graph_view, {{"a", "b"}, {"a", "c"}});
}
TEST_F(TopologicalSortTest, NoNodesWithZeroFanins) {
auto test_graph = []() {
return GDef({NDef("a", kIdentity, {"b"}, {{"T", DT_FLOAT}}, kDeviceGPU0),
NDef("b", kIdentity, {"a"}, {{"T", DT_FLOAT}}, kDeviceGPU1)},
/*funcs=*/{});
};
GraphDef graph = test_graph();
Status status;
MutableGraphView graph_view(&graph, &status);
TF_ASSERT_OK(status);
status = graph_view.SortTopologically(/*ignore_cycles=*/false, {});
EXPECT_FALSE(status.ok());
EXPECT_EQ(status.error_message(),
"MutableGraphView::SortTopologically error: was not able to sort "
"all nodes topologically.");
CompareGraphViewWithGraph(&graph_view, test_graph());
CompareGraphOrder(graph_view, {"a", "b"});
TF_EXPECT_OK(graph_view.SortTopologically(/*ignore_cycles=*/true, {}));
CompareGraphViewWithGraph(&graph_view, test_graph());
}
TEST_F(TopologicalSortTest, DidNotReachAllNodes) {
auto test_graph = []() {
return GDef({NDef("c", kIdentity, {}, {{"T", DT_FLOAT}}, kDeviceGPU2),
NDef("a", kIdentity, {"b"}, {{"T", DT_FLOAT}}, kDeviceGPU0),
NDef("b", kIdentity, {"a"}, {{"T", DT_FLOAT}}, kDeviceGPU1)},
/*funcs=*/{});
};
GraphDef graph = test_graph();
Status status;
MutableGraphView graph_view(&graph, &status);
TF_ASSERT_OK(status);
status = graph_view.SortTopologically(/*ignore_cycles=*/false, {});
EXPECT_FALSE(status.ok());
EXPECT_EQ(status.error_message(),
"MutableGraphView::SortTopologically error: was not able to sort "
"all nodes topologically.");
CompareGraphViewWithGraph(&graph_view, test_graph());
CompareGraphOrder(graph_view, {"c", "a", "b"});
TF_EXPECT_OK(graph_view.SortTopologically(/*ignore_cycles=*/true, {}));
CompareGraphViewWithGraph(&graph_view, test_graph());
CompareGraphOrder(graph_view, {"a", "b", "c"});
}
TEST_F(TopologicalSortTest, NoLoopGraph) {
auto test_graph = []() {
return GDef({NDef("c", kIdentity, {"f"}), NDef("a", kIdentity, {"f", "e"}),
NDef("b", kIdentity, {"e", "d"}), NDef("d", kIdentity, {"c"}),
NDef("f", kIdentity, {}), NDef("e", kIdentity, {})},
/*funcs=*/{});
};
GraphDef graph = test_graph();
Status status;
MutableGraphView graph_view(&graph, &status);
TF_ASSERT_OK(status);
TF_EXPECT_OK(graph_view.SortTopologically(/*ignore_cycles=*/false, {}));
CompareGraphViewWithGraph(&graph_view, test_graph());
CompareGraphNodePrecedences(
graph_view,
{{"f", "a"}, {"f", "c"}, {"e", "a"}, {"e", "b"}, {"c", "d"}, {"d", "b"}});
}
TEST_F(TopologicalSortTest, ValidLoopGraph) {
// NextIteration -> Merge loop.
auto test_graph = []() {
return GDef({NDef("b", "Merge", {"a", "e"}), NDef("c", "Switch", {"b"}),
NDef("d", kIdentity, {"c"}), NDef("e", "NextIteration", {"d"}),
NDef("a", "Const", {})},
/*funcs=*/{});
};
GraphDef graph = test_graph();
Status status;
MutableGraphView graph_view(&graph, &status);
TF_ASSERT_OK(status);
TF_EXPECT_OK(graph_view.SortTopologically(/*ignore_cycles=*/false, {}));
CompareGraphViewWithGraph(&graph_view, test_graph());
CompareGraphOrder(graph_view, {"a", "b", "c", "d", "e"});
}
TEST_F(TopologicalSortTest, DuplicateFanins) {
auto test_graph = []() {
return GDef(
{NDef("b", kIdentity, {"a", "a", "^a"}), NDef("a", "Const", {})},
/*funcs=*/{});
};
GraphDef graph = test_graph();
Status status;
MutableGraphView graph_view(&graph, &status);
TF_ASSERT_OK(status);
TF_EXPECT_OK(graph_view.SortTopologically(/*ignore_cycles=*/false, {}));
CompareGraphViewWithGraph(&graph_view, test_graph());
CompareGraphOrder(graph_view, {"a", "b"});
}
TEST_F(TopologicalSortTest, DiamondDependencyNotACycle) {
auto test_graph = []() {
return GDef({NDef("e", kIdentity, {"b", "c", "d"}),
NDef("b", kIdentity, {"a"}), NDef("a", "Const", {}),
NDef("d", kIdentity, {"a"}), NDef("c", kIdentity, {"a"})},
/*funcs=*/{});
};
GraphDef graph = test_graph();
Status status;
MutableGraphView graph_view(&graph, &status);
TF_ASSERT_OK(status);
TF_EXPECT_OK(graph_view.SortTopologically(/*ignore_cycles=*/false, {}));
CompareGraphViewWithGraph(&graph_view, test_graph());
CompareGraphNodePrecedences(
graph_view,
{{"a", "b"}, {"a", "c"}, {"a", "d"}, {"b", "e"}, {"c", "e"}, {"d", "e"}});
}
TEST_F(TopologicalSortTest, ExtraDependencies) {
auto test_graph = []() {
return GDef({NDef("c", kIdentity, {"f"}), NDef("a", kIdentity, {"f", "e"}),
NDef("b", kIdentity, {"e", "d"}), NDef("d", kIdentity, {"c"}),
NDef("f", kIdentity, {}), NDef("e", kIdentity, {})},
/*funcs=*/{});
};
GraphDef graph = test_graph();
Status status;
MutableGraphView graph_view(&graph, &status);
TF_ASSERT_OK(status);
auto* e_node = graph_view.GetNode("e");
ASSERT_NE(e_node, nullptr);
auto* f_node = graph_view.GetNode("f");
ASSERT_NE(f_node, nullptr);
TF_EXPECT_OK(
graph_view.SortTopologically(/*ignore_cycles=*/true, {{e_node, f_node}}));
CompareGraphViewWithGraph(&graph_view, test_graph());
CompareGraphNodePrecedences(graph_view, {{"f", "a"},
{"f", "c"},
{"e", "a"},
{"e", "b"},
{"c", "d"},
{"d", "b"},
{"e", "f"}});
}
#define RUN_NUM_NODE_NUM_EDGE_BENCHMARK(name) \
BENCHMARK(name) \
->ArgPair(10, 2) \
@ -2541,6 +2807,23 @@ RUN_NUM_FANIN_NUM_FANOUT_BENCHMARK(BM_GraphViewHasControlledFanoutLast);
RUN_NUM_FANIN_NUM_FANOUT_BENCHMARK(BM_MutableGraphViewHasControlledFanoutFirst);
RUN_NUM_FANIN_NUM_FANOUT_BENCHMARK(BM_MutableGraphViewHasControlledFanoutLast);
static void BM_SortTopologically(int iters, int size) {
testing::StopTiming();
GraphDef graph = test::CreateRandomGraph(size);
Status status;
MutableGraphView graph_view(&graph, &status);
TF_ASSERT_OK(status);
testing::StartTiming();
for (int i = 0; i < iters; i++) {
TF_EXPECT_OK(graph_view.SortTopologically(/*ignore_cycles=*/false, {}));
}
testing::StopTiming();
}
RUN_NUM_NODE_BENCHMARK(BM_SortTopologically);
} // namespace
} // namespace utils
} // namespace grappler