[Grappler] Add topological sort to new GraphView.
PiperOrigin-RevId: 249459895
This commit is contained in:
parent
477447155b
commit
d74bb6ad5f
@ -348,10 +348,12 @@ cc_library(
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/grappler:op_types",
|
||||
"//tensorflow/core/grappler:utils",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -16,8 +16,11 @@ limitations under the License.
|
||||
#include "tensorflow/core/grappler/utils/graph_view.h"
|
||||
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/str_join.h"
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
#include "tensorflow/core/graph/tensor_id.h"
|
||||
#include "tensorflow/core/grappler/op_types.h"
|
||||
#include "tensorflow/core/grappler/utils.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/gtl/map_util.h"
|
||||
@ -1360,6 +1363,238 @@ void MutableGraphView::RemoveNodesInternal(
|
||||
}
|
||||
}
|
||||
|
||||
namespace {
|
||||
constexpr int kTopologicalSortDone = -1;
|
||||
|
||||
const char kMutableGraphViewSortTopologicallyError[] =
|
||||
"MutableGraphView::SortTopologically error: ";
|
||||
|
||||
// TraversalState is an enum representing the state of a node when it is being
|
||||
// traversed via DFS.
|
||||
enum TraversalState : uint8_t { NOT_VISITED, PENDING, PROCESSING, PROCESSED };
|
||||
|
||||
// RecursionStackState is an enum representing the recursion stack state
|
||||
// when using DFS iteratively. `ENTER` is the state representing entering into
|
||||
// a recursive call, while `EXIT` is the state representing exiting a
|
||||
// recursive call.
|
||||
enum RecursionStackState : bool { ENTER, EXIT };
|
||||
|
||||
// RecursionStackEntry is a helper struct representing an instance of a
|
||||
// recursive call in the iterative DFS simulating a recursive ordering.
|
||||
struct RecursionStackEntry {
|
||||
RecursionStackEntry(int node_index, RecursionStackState recursion_state)
|
||||
: node_index(node_index), recursion_state(recursion_state) {}
|
||||
|
||||
const int node_index;
|
||||
const RecursionStackState recursion_state;
|
||||
};
|
||||
|
||||
// Edge is a helper struct representing an edge in the graph.
|
||||
struct Edge {
|
||||
Edge(int from, int to) : from(from), to(to) {}
|
||||
|
||||
const int from;
|
||||
const int to;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
Status MutableGraphView::SortTopologically(
|
||||
bool ignore_cycles,
|
||||
absl::Span<const TopologicalDependency> extra_dependencies) {
|
||||
if (!mutation_.updated_nodes_.empty() || !mutation_.new_nodes_.empty()) {
|
||||
// Cannot sort when there is an active mutation due to indices possibly
|
||||
// being changed or invalidated.
|
||||
return errors::InvalidArgument(kMutableGraphViewSortTopologicallyError,
|
||||
"active mutation exists.");
|
||||
}
|
||||
|
||||
const int num_nodes = nodes_.size();
|
||||
|
||||
// Group extra dependencies by `from` node.
|
||||
absl::flat_hash_map<int, std::vector<int>> extra_dependencies_by_parent;
|
||||
for (const auto& extra_dependency : extra_dependencies) {
|
||||
if (extra_dependency.graph_view_ != this ||
|
||||
extra_dependency.from_ == extra_dependency.to_ ||
|
||||
extra_dependency.from_ < 0 || extra_dependency.from_ >= num_nodes ||
|
||||
extra_dependency.to_ < 0 || extra_dependency.to_ >= num_nodes) {
|
||||
return errors::InvalidArgument(kMutableGraphViewSortTopologicallyError,
|
||||
"invalid extra dependencies.");
|
||||
}
|
||||
extra_dependencies_by_parent[extra_dependency.from_].push_back(
|
||||
extra_dependency.to_);
|
||||
}
|
||||
|
||||
// Reversed colored post-order DFS traversal. This does not fail on cycles,
|
||||
// but there are no guarantees on ordering within a cycle.
|
||||
std::vector<TraversalState> traversal_state(num_nodes, NOT_VISITED);
|
||||
int curr_pos = num_nodes - 1;
|
||||
std::vector<int> order(num_nodes);
|
||||
std::vector<Edge> edges_in_cycle;
|
||||
|
||||
auto push_onto_stack = [this](
|
||||
const int curr_index, const int fanout_index,
|
||||
std::vector<RecursionStackEntry>* recursion_stack,
|
||||
std::vector<TraversalState>* traversal_state,
|
||||
std::vector<Edge>* edges_in_cycle) {
|
||||
auto& fanout_traversal_state = (*traversal_state)[fanout_index];
|
||||
if (fanout_traversal_state == PROCESSING) {
|
||||
// Ignore NextIteration -> Merge cycles.
|
||||
if (!IsNextIteration(graph_->node(curr_index)) ||
|
||||
!IsMerge(graph_->node(fanout_index))) {
|
||||
// Cycle detected.
|
||||
edges_in_cycle->push_back({curr_index, fanout_index});
|
||||
}
|
||||
} else if (fanout_traversal_state == NOT_VISITED) {
|
||||
// Unvisited node, simply add to stack for future traversal.
|
||||
fanout_traversal_state = PENDING;
|
||||
recursion_stack->push_back({fanout_index, ENTER});
|
||||
}
|
||||
};
|
||||
|
||||
auto process_fanouts = [this, &extra_dependencies_by_parent,
|
||||
&push_onto_stack](
|
||||
const int curr_index,
|
||||
std::vector<RecursionStackEntry>* recursion_stack,
|
||||
std::vector<TraversalState>* traversal_state,
|
||||
std::vector<Edge>* edges_in_cycle) {
|
||||
const auto& node_view = nodes_[curr_index];
|
||||
// Regular fanouts.
|
||||
for (const auto& regular_fanouts_port_i : node_view.GetRegularFanouts()) {
|
||||
for (const auto& regular_fanout : regular_fanouts_port_i) {
|
||||
push_onto_stack(curr_index, regular_fanout.node_index_, recursion_stack,
|
||||
traversal_state, edges_in_cycle);
|
||||
}
|
||||
}
|
||||
// Controlled fanouts.
|
||||
for (const auto& controlled_fanout : node_view.GetControlledFanouts()) {
|
||||
push_onto_stack(curr_index, controlled_fanout.node_index_,
|
||||
recursion_stack, traversal_state, edges_in_cycle);
|
||||
}
|
||||
// Extra dependencies.
|
||||
auto it = extra_dependencies_by_parent.find(curr_index);
|
||||
if (it != extra_dependencies_by_parent.end()) {
|
||||
for (const auto& extra_fanout : it->second) {
|
||||
push_onto_stack(curr_index, extra_fanout, recursion_stack,
|
||||
traversal_state, edges_in_cycle);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
auto reversed_postorder_dfs =
|
||||
[&process_fanouts](const MutableNodeView& root_node_view,
|
||||
std::vector<int>* order,
|
||||
std::vector<TraversalState>* traversal_state,
|
||||
int* curr_pos, std::vector<Edge>* edges_in_cycle) {
|
||||
std::vector<RecursionStackEntry> recursion_stack;
|
||||
// Add the root to stack to start the traversal.
|
||||
const int root_index = root_node_view.node_index_;
|
||||
auto& root_traversal_state = (*traversal_state)[root_index];
|
||||
if (root_traversal_state == NOT_VISITED) {
|
||||
root_traversal_state = PENDING;
|
||||
recursion_stack.push_back({root_index, ENTER});
|
||||
}
|
||||
while (!recursion_stack.empty()) {
|
||||
auto curr_pair = recursion_stack.back();
|
||||
recursion_stack.pop_back();
|
||||
const int curr_index = curr_pair.node_index;
|
||||
auto& curr_traversal_state = (*traversal_state)[curr_index];
|
||||
if (curr_traversal_state == PROCESSED) {
|
||||
// Node already processed which can be ignored.
|
||||
continue;
|
||||
} else if (curr_pair.recursion_state == EXIT) {
|
||||
// Node from recursion stack where all fanouts were visited.
|
||||
// Instead of adding node index to a vector, simply set what its
|
||||
// index would be, so there will not be a need for inversion later
|
||||
// on. The value set is in decending order so the reversed
|
||||
// post-order is returned.
|
||||
(*order)[curr_index] = *curr_pos;
|
||||
curr_traversal_state = PROCESSED;
|
||||
--(*curr_pos);
|
||||
} else {
|
||||
// Process current node and fanouts.
|
||||
curr_traversal_state = PROCESSING;
|
||||
recursion_stack.push_back({curr_index, EXIT});
|
||||
process_fanouts(curr_index, &recursion_stack, traversal_state,
|
||||
edges_in_cycle);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Determine sources to start DFS (nodes with no inputs) and unique fanout
|
||||
// nodes.
|
||||
for (const auto& node : nodes_) {
|
||||
if (node.NumRegularFanins() + node.NumControllingFanins() == 0) {
|
||||
reversed_postorder_dfs(node, &order, &traversal_state, &curr_pos,
|
||||
&edges_in_cycle);
|
||||
}
|
||||
}
|
||||
|
||||
if (!ignore_cycles && !edges_in_cycle.empty()) {
|
||||
std::vector<string> edges_formatted;
|
||||
edges_formatted.reserve(edges_in_cycle.size());
|
||||
for (const auto& edge : edges_in_cycle) {
|
||||
edges_formatted.push_back(
|
||||
absl::StrCat("'", graph_->node(edge.from).name(), "' -> '",
|
||||
graph_->node(edge.to).name(), "'"));
|
||||
}
|
||||
const string edges_str =
|
||||
absl::StrCat("{", absl::StrJoin(edges_formatted, ", "), "}");
|
||||
return errors::InvalidArgument(kMutableGraphViewSortTopologicallyError,
|
||||
"detected edge(s) creating cycle(s) ",
|
||||
edges_str, ".");
|
||||
}
|
||||
if (curr_pos != kTopologicalSortDone) {
|
||||
// Not all nodes were processed.
|
||||
if (!ignore_cycles) {
|
||||
return errors::InvalidArgument(
|
||||
kMutableGraphViewSortTopologicallyError,
|
||||
"was not able to sort all nodes topologically.");
|
||||
}
|
||||
// Otherwise process all nodes regardless of cycles.
|
||||
for (const auto& node : nodes_) {
|
||||
reversed_postorder_dfs(node, &order, &traversal_state, &curr_pos,
|
||||
&edges_in_cycle);
|
||||
}
|
||||
}
|
||||
|
||||
// Permute nodes by reversed post-order DFS.
|
||||
std::vector<MutableNodeView> permuted_nodes(num_nodes);
|
||||
for (int i = 0; i < num_nodes; ++i) {
|
||||
permuted_nodes[order[i]] = std::move(nodes_[i]);
|
||||
}
|
||||
nodes_.swap(permuted_nodes);
|
||||
|
||||
// Fix up indices of MutableNodeViews.
|
||||
for (MutableNodeView& node_view : nodes_) {
|
||||
const int prev_node_index = node_view.node_index_;
|
||||
if (prev_node_index != order[prev_node_index]) {
|
||||
const string& node_name = graph_->node(prev_node_index).name();
|
||||
node_view.node_index_ = order[prev_node_index];
|
||||
node_index_by_name_.find(node_name)->second = node_view.node_index_;
|
||||
}
|
||||
for (MutableFanoutView& regular_fanin : node_view.regular_fanins_) {
|
||||
regular_fanin.node_index_ = order[regular_fanin.node_index_];
|
||||
}
|
||||
for (MutableFanoutView& controlling_fanin : node_view.controlling_fanins_) {
|
||||
controlling_fanin.node_index_ = order[controlling_fanin.node_index_];
|
||||
}
|
||||
for (std::vector<MutableFaninView>& regular_fanouts_port_i :
|
||||
node_view.regular_fanouts_by_port_) {
|
||||
for (MutableFaninView& regular_fanout : regular_fanouts_port_i) {
|
||||
regular_fanout.node_index_ = order[regular_fanout.node_index_];
|
||||
}
|
||||
}
|
||||
for (MutableFaninView& controlled_fanout : node_view.controlled_fanouts_) {
|
||||
controlled_fanout.node_index_ = order[controlled_fanout.node_index_];
|
||||
}
|
||||
}
|
||||
|
||||
// Permute graph NodeDefs.
|
||||
PermuteNodesInPlace(graph_, &order, /*invert_permutation=*/false);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
inline Status MutableGraphView::ValidateInternal(
|
||||
absl::flat_hash_map<absl::string_view, int>* node_names,
|
||||
std::vector<RenamedOrOverwrittenNode>* renamed_nodes,
|
||||
@ -1410,8 +1645,8 @@ Status MutableGraphView::ApplyMutationInternal() {
|
||||
// Node name and associated fanouts.
|
||||
absl::flat_hash_map<string, NodeViewFanouts> renamed_fanouts;
|
||||
// Removed nodes where name was overwritten by a renamed node.
|
||||
std::vector<bool> overwritten_name_removed_nodes;
|
||||
overwritten_name_removed_nodes.resize(mutation_.updated_nodes_.size(), false);
|
||||
std::vector<bool> overwritten_name_removed_nodes(
|
||||
mutation_.updated_nodes_.size());
|
||||
// Fix renaming of existing nodes by swapping fanouts and rehashing names.
|
||||
// This will also overwrite removed or unmodified nodes.
|
||||
FixRenamedNodes(&renamed_nodes, &renamed_fanouts,
|
||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
@ -78,9 +79,16 @@ class FanoutView : public internal::NodeIndexAndPortIndex<NodeView, GraphView> {
|
||||
class NodeView : public internal::NodeViewInternal<FaninView, FanoutView,
|
||||
GraphView, true> {
|
||||
public:
|
||||
using NodeViewInternal::NodeViewInternal;
|
||||
explicit NodeView(GraphView* graph_view, int node_index)
|
||||
: NodeViewInternal(graph_view, node_index) {}
|
||||
|
||||
NodeView() : NodeViewInternal() {}
|
||||
|
||||
~NodeView() override = default;
|
||||
|
||||
NodeView(NodeView&&) = default;
|
||||
NodeView& operator=(NodeView&&) = default;
|
||||
|
||||
const NodeDef* node() const override;
|
||||
|
||||
// Checks if a fanin exists for the node.
|
||||
@ -200,9 +208,16 @@ class MutableNodeView
|
||||
: public internal::NodeViewInternal<MutableFaninView, MutableFanoutView,
|
||||
MutableGraphView, false> {
|
||||
public:
|
||||
using NodeViewInternal::NodeViewInternal;
|
||||
explicit MutableNodeView(MutableGraphView* graph_view, int node_index)
|
||||
: NodeViewInternal(graph_view, node_index) {}
|
||||
|
||||
MutableNodeView() : NodeViewInternal() {}
|
||||
|
||||
~MutableNodeView() override = default;
|
||||
|
||||
MutableNodeView(MutableNodeView&&) = default;
|
||||
MutableNodeView& operator=(MutableNodeView&&) = default;
|
||||
|
||||
NodeDef* node() const override;
|
||||
|
||||
// Checks if a fanin exists for the node.
|
||||
@ -364,6 +379,34 @@ class MutableGraphView
|
||||
// Returns a Mutation (builder) that can be used to modify MutableGraphView.
|
||||
Mutation* GetMutationBuilder();
|
||||
|
||||
// Helper class representing an extra dependency for topological sorting.
|
||||
class TopologicalDependency {
|
||||
public:
|
||||
TopologicalDependency(const MutableNodeView* from_node,
|
||||
const MutableNodeView* to_node) {
|
||||
if (from_node->graph_view_ == to_node->graph_view_) {
|
||||
graph_view_ = from_node->graph_view_;
|
||||
from_ = from_node->node_index_;
|
||||
to_ = to_node->node_index_;
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
MutableGraphView* graph_view_ = nullptr;
|
||||
int from_ = internal::kMissingIndex;
|
||||
int to_ = internal::kMissingIndex;
|
||||
|
||||
friend class MutableGraphView;
|
||||
};
|
||||
|
||||
// Sorts graph topologically in-place. If `ignore_cycles` is set, a
|
||||
// topological like sorting will be performed when there are cycles. Otherwise
|
||||
// if a cycle is detected or if the graph cannot be sorted, an error will be
|
||||
// returned.
|
||||
Status SortTopologically(
|
||||
bool ignore_cycles,
|
||||
absl::Span<const TopologicalDependency> extra_dependencies);
|
||||
|
||||
private:
|
||||
bool AddUniqueNodeInternal(NodeDef* node);
|
||||
|
||||
|
@ -131,8 +131,15 @@ class NodeViewInternal {
|
||||
: graph_view_(graph_view),
|
||||
node_index_(node_index),
|
||||
attrs_(AttrSlice(graph_view->graph()->node(node_index))) {}
|
||||
|
||||
NodeViewInternal()
|
||||
: graph_view_(nullptr), node_index_(kMissingIndex), attrs_(AttrSlice()) {}
|
||||
|
||||
virtual ~NodeViewInternal() {}
|
||||
|
||||
NodeViewInternal(NodeViewInternal&&) = default;
|
||||
NodeViewInternal& operator=(NodeViewInternal&&) = default;
|
||||
|
||||
bool operator==(const NodeViewInternal& other) const {
|
||||
return node_index_ == other.node_index_ && graph_view_ == other.graph_view_;
|
||||
}
|
||||
|
@ -790,7 +790,7 @@ TYPED_TEST(TypedNodeViewTest, HasAttr) {
|
||||
EXPECT_FALSE(c_node->HasAttr("attr"));
|
||||
}
|
||||
|
||||
class MutationTest : public GrapplerTest {
|
||||
class CompareGraphTest : public GrapplerTest {
|
||||
public:
|
||||
void CompareGraphViewWithGraph(MutableGraphView* graph_view,
|
||||
const GraphDef& expected_graph) {
|
||||
@ -953,6 +953,8 @@ class MutationTest : public GrapplerTest {
|
||||
}
|
||||
};
|
||||
|
||||
class MutationTest : public CompareGraphTest {};
|
||||
|
||||
constexpr char kDeviceCPU0[] = "/device:CPU:0";
|
||||
constexpr char kDeviceGPU0[] = "/device:GPU:0";
|
||||
|
||||
@ -1995,6 +1997,270 @@ TEST_F(MutationTest, EmptyMutationUpdateIndexPersisting) {
|
||||
CompareGraphViewWithGraph(&graph_view, test_graph());
|
||||
}
|
||||
|
||||
class TopologicalSortTest : public CompareGraphTest {
|
||||
protected:
|
||||
void CompareGraphOrder(const MutableGraphView& graph_view,
|
||||
absl::Span<const string> node_names) {
|
||||
const int num_nodes = graph_view.NumNodes();
|
||||
ASSERT_EQ(num_nodes, node_names.size());
|
||||
for (int i = 0; i < num_nodes; ++i) {
|
||||
EXPECT_EQ(graph_view.GetNode(i)->GetName(), node_names[i]);
|
||||
}
|
||||
}
|
||||
|
||||
void CompareGraphNodePrecedences(
|
||||
const MutableGraphView& graph_view,
|
||||
absl::Span<const std::pair<string, string>> node_precedences) {
|
||||
for (const auto& node_precedence : node_precedences) {
|
||||
auto* parent_node = graph_view.GetNode(node_precedence.first);
|
||||
ASSERT_NE(parent_node, nullptr);
|
||||
auto* child_node = graph_view.GetNode(node_precedence.second);
|
||||
ASSERT_NE(child_node, nullptr);
|
||||
EXPECT_TRUE(parent_node->node_index() < child_node->node_index());
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(TopologicalSortTest, ActiveMutationSort) {
|
||||
auto test_graph = []() {
|
||||
return GDef({NDef("a", kIdentity, {}, {{"T", DT_FLOAT}}, kDeviceGPU0),
|
||||
NDef("b", kIdentity, {"a"}, {{"T", DT_FLOAT}}, kDeviceGPU1)},
|
||||
/*funcs=*/{});
|
||||
};
|
||||
|
||||
GraphDef graph = test_graph();
|
||||
Status status;
|
||||
MutableGraphView graph_view(&graph, &status);
|
||||
TF_ASSERT_OK(status);
|
||||
|
||||
Mutation* mutation = graph_view.GetMutationBuilder();
|
||||
mutation->AddNode({}, &status);
|
||||
TF_ASSERT_OK(status);
|
||||
|
||||
for (bool ignore_cycles : {false, true}) {
|
||||
status = graph_view.SortTopologically(ignore_cycles, {});
|
||||
EXPECT_FALSE(status.ok());
|
||||
EXPECT_EQ(
|
||||
status.error_message(),
|
||||
"MutableGraphView::SortTopologically error: active mutation exists.");
|
||||
CompareGraphViewWithGraph(&graph_view, test_graph());
|
||||
CompareGraphOrder(graph_view, {"a", "b"});
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(TopologicalSortTest, BadExtraDependenciesSort) {
|
||||
auto test_graph = []() {
|
||||
return GDef({NDef("a", kIdentity, {}, {{"T", DT_FLOAT}}, kDeviceGPU0),
|
||||
NDef("b", kIdentity, {}, {{"T", DT_FLOAT}}, kDeviceGPU1)},
|
||||
/*funcs=*/{});
|
||||
};
|
||||
|
||||
GraphDef graph_1 = test_graph();
|
||||
Status status;
|
||||
MutableGraphView graph_view_1(&graph_1, &status);
|
||||
TF_ASSERT_OK(status);
|
||||
MutableNodeView* a_node_1 = graph_view_1.GetNode("a");
|
||||
|
||||
GraphDef graph_2 = test_graph();
|
||||
MutableGraphView graph_view_2(&graph_2, &status);
|
||||
TF_ASSERT_OK(status);
|
||||
MutableNodeView* b_node_2 = graph_view_2.GetNode("b");
|
||||
|
||||
for (bool ignore_cycles : {false, true}) {
|
||||
status =
|
||||
graph_view_2.SortTopologically(ignore_cycles, {{a_node_1, b_node_2}});
|
||||
EXPECT_FALSE(status.ok());
|
||||
EXPECT_EQ(status.error_message(),
|
||||
"MutableGraphView::SortTopologically error: invalid extra "
|
||||
"dependencies.");
|
||||
CompareGraphViewWithGraph(&graph_view_2, test_graph());
|
||||
CompareGraphOrder(graph_view_2, {"a", "b"});
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(TopologicalSortTest, NoCyclesAllowed) {
|
||||
auto test_graph = []() {
|
||||
return GDef(
|
||||
{NDef("a", kIdentity, {}, {{"T", DT_FLOAT}}, kDeviceGPU0),
|
||||
NDef("b", kIdentity, {"a", "c"}, {{"T", DT_FLOAT}}, kDeviceGPU1),
|
||||
NDef("c", kIdentity, {"b"}, {{"T", DT_FLOAT}}, kDeviceGPU1)},
|
||||
/*funcs=*/{});
|
||||
};
|
||||
|
||||
GraphDef graph = test_graph();
|
||||
Status status;
|
||||
MutableGraphView graph_view(&graph, &status);
|
||||
TF_ASSERT_OK(status);
|
||||
|
||||
status = graph_view.SortTopologically(/*ignore_cycles=*/false, {});
|
||||
EXPECT_FALSE(status.ok());
|
||||
EXPECT_EQ(status.error_message(),
|
||||
"MutableGraphView::SortTopologically error: detected edge(s) "
|
||||
"creating cycle(s) {'c' -> 'b'}.");
|
||||
CompareGraphViewWithGraph(&graph_view, test_graph());
|
||||
CompareGraphOrder(graph_view, {"a", "b", "c"});
|
||||
|
||||
TF_EXPECT_OK(graph_view.SortTopologically(/*ignore_cycles=*/true, {}));
|
||||
CompareGraphViewWithGraph(&graph_view, test_graph());
|
||||
CompareGraphNodePrecedences(graph_view, {{"a", "b"}, {"a", "c"}});
|
||||
}
|
||||
|
||||
TEST_F(TopologicalSortTest, NoNodesWithZeroFanins) {
|
||||
auto test_graph = []() {
|
||||
return GDef({NDef("a", kIdentity, {"b"}, {{"T", DT_FLOAT}}, kDeviceGPU0),
|
||||
NDef("b", kIdentity, {"a"}, {{"T", DT_FLOAT}}, kDeviceGPU1)},
|
||||
/*funcs=*/{});
|
||||
};
|
||||
|
||||
GraphDef graph = test_graph();
|
||||
Status status;
|
||||
MutableGraphView graph_view(&graph, &status);
|
||||
TF_ASSERT_OK(status);
|
||||
|
||||
status = graph_view.SortTopologically(/*ignore_cycles=*/false, {});
|
||||
EXPECT_FALSE(status.ok());
|
||||
EXPECT_EQ(status.error_message(),
|
||||
"MutableGraphView::SortTopologically error: was not able to sort "
|
||||
"all nodes topologically.");
|
||||
CompareGraphViewWithGraph(&graph_view, test_graph());
|
||||
CompareGraphOrder(graph_view, {"a", "b"});
|
||||
|
||||
TF_EXPECT_OK(graph_view.SortTopologically(/*ignore_cycles=*/true, {}));
|
||||
CompareGraphViewWithGraph(&graph_view, test_graph());
|
||||
}
|
||||
|
||||
TEST_F(TopologicalSortTest, DidNotReachAllNodes) {
|
||||
auto test_graph = []() {
|
||||
return GDef({NDef("c", kIdentity, {}, {{"T", DT_FLOAT}}, kDeviceGPU2),
|
||||
NDef("a", kIdentity, {"b"}, {{"T", DT_FLOAT}}, kDeviceGPU0),
|
||||
NDef("b", kIdentity, {"a"}, {{"T", DT_FLOAT}}, kDeviceGPU1)},
|
||||
/*funcs=*/{});
|
||||
};
|
||||
|
||||
GraphDef graph = test_graph();
|
||||
Status status;
|
||||
MutableGraphView graph_view(&graph, &status);
|
||||
TF_ASSERT_OK(status);
|
||||
|
||||
status = graph_view.SortTopologically(/*ignore_cycles=*/false, {});
|
||||
EXPECT_FALSE(status.ok());
|
||||
EXPECT_EQ(status.error_message(),
|
||||
"MutableGraphView::SortTopologically error: was not able to sort "
|
||||
"all nodes topologically.");
|
||||
CompareGraphViewWithGraph(&graph_view, test_graph());
|
||||
CompareGraphOrder(graph_view, {"c", "a", "b"});
|
||||
|
||||
TF_EXPECT_OK(graph_view.SortTopologically(/*ignore_cycles=*/true, {}));
|
||||
CompareGraphViewWithGraph(&graph_view, test_graph());
|
||||
CompareGraphOrder(graph_view, {"a", "b", "c"});
|
||||
}
|
||||
|
||||
TEST_F(TopologicalSortTest, NoLoopGraph) {
|
||||
auto test_graph = []() {
|
||||
return GDef({NDef("c", kIdentity, {"f"}), NDef("a", kIdentity, {"f", "e"}),
|
||||
NDef("b", kIdentity, {"e", "d"}), NDef("d", kIdentity, {"c"}),
|
||||
NDef("f", kIdentity, {}), NDef("e", kIdentity, {})},
|
||||
/*funcs=*/{});
|
||||
};
|
||||
|
||||
GraphDef graph = test_graph();
|
||||
Status status;
|
||||
MutableGraphView graph_view(&graph, &status);
|
||||
TF_ASSERT_OK(status);
|
||||
|
||||
TF_EXPECT_OK(graph_view.SortTopologically(/*ignore_cycles=*/false, {}));
|
||||
CompareGraphViewWithGraph(&graph_view, test_graph());
|
||||
CompareGraphNodePrecedences(
|
||||
graph_view,
|
||||
{{"f", "a"}, {"f", "c"}, {"e", "a"}, {"e", "b"}, {"c", "d"}, {"d", "b"}});
|
||||
}
|
||||
|
||||
TEST_F(TopologicalSortTest, ValidLoopGraph) {
|
||||
// NextIteration -> Merge loop.
|
||||
auto test_graph = []() {
|
||||
return GDef({NDef("b", "Merge", {"a", "e"}), NDef("c", "Switch", {"b"}),
|
||||
NDef("d", kIdentity, {"c"}), NDef("e", "NextIteration", {"d"}),
|
||||
NDef("a", "Const", {})},
|
||||
/*funcs=*/{});
|
||||
};
|
||||
|
||||
GraphDef graph = test_graph();
|
||||
Status status;
|
||||
MutableGraphView graph_view(&graph, &status);
|
||||
TF_ASSERT_OK(status);
|
||||
|
||||
TF_EXPECT_OK(graph_view.SortTopologically(/*ignore_cycles=*/false, {}));
|
||||
CompareGraphViewWithGraph(&graph_view, test_graph());
|
||||
CompareGraphOrder(graph_view, {"a", "b", "c", "d", "e"});
|
||||
}
|
||||
|
||||
TEST_F(TopologicalSortTest, DuplicateFanins) {
|
||||
auto test_graph = []() {
|
||||
return GDef(
|
||||
{NDef("b", kIdentity, {"a", "a", "^a"}), NDef("a", "Const", {})},
|
||||
/*funcs=*/{});
|
||||
};
|
||||
|
||||
GraphDef graph = test_graph();
|
||||
Status status;
|
||||
MutableGraphView graph_view(&graph, &status);
|
||||
TF_ASSERT_OK(status);
|
||||
|
||||
TF_EXPECT_OK(graph_view.SortTopologically(/*ignore_cycles=*/false, {}));
|
||||
CompareGraphViewWithGraph(&graph_view, test_graph());
|
||||
CompareGraphOrder(graph_view, {"a", "b"});
|
||||
}
|
||||
|
||||
TEST_F(TopologicalSortTest, DiamondDependencyNotACycle) {
|
||||
auto test_graph = []() {
|
||||
return GDef({NDef("e", kIdentity, {"b", "c", "d"}),
|
||||
NDef("b", kIdentity, {"a"}), NDef("a", "Const", {}),
|
||||
NDef("d", kIdentity, {"a"}), NDef("c", kIdentity, {"a"})},
|
||||
/*funcs=*/{});
|
||||
};
|
||||
|
||||
GraphDef graph = test_graph();
|
||||
Status status;
|
||||
MutableGraphView graph_view(&graph, &status);
|
||||
TF_ASSERT_OK(status);
|
||||
|
||||
TF_EXPECT_OK(graph_view.SortTopologically(/*ignore_cycles=*/false, {}));
|
||||
CompareGraphViewWithGraph(&graph_view, test_graph());
|
||||
CompareGraphNodePrecedences(
|
||||
graph_view,
|
||||
{{"a", "b"}, {"a", "c"}, {"a", "d"}, {"b", "e"}, {"c", "e"}, {"d", "e"}});
|
||||
}
|
||||
|
||||
TEST_F(TopologicalSortTest, ExtraDependencies) {
|
||||
auto test_graph = []() {
|
||||
return GDef({NDef("c", kIdentity, {"f"}), NDef("a", kIdentity, {"f", "e"}),
|
||||
NDef("b", kIdentity, {"e", "d"}), NDef("d", kIdentity, {"c"}),
|
||||
NDef("f", kIdentity, {}), NDef("e", kIdentity, {})},
|
||||
/*funcs=*/{});
|
||||
};
|
||||
|
||||
GraphDef graph = test_graph();
|
||||
Status status;
|
||||
MutableGraphView graph_view(&graph, &status);
|
||||
TF_ASSERT_OK(status);
|
||||
|
||||
auto* e_node = graph_view.GetNode("e");
|
||||
ASSERT_NE(e_node, nullptr);
|
||||
auto* f_node = graph_view.GetNode("f");
|
||||
ASSERT_NE(f_node, nullptr);
|
||||
|
||||
TF_EXPECT_OK(
|
||||
graph_view.SortTopologically(/*ignore_cycles=*/true, {{e_node, f_node}}));
|
||||
CompareGraphViewWithGraph(&graph_view, test_graph());
|
||||
CompareGraphNodePrecedences(graph_view, {{"f", "a"},
|
||||
{"f", "c"},
|
||||
{"e", "a"},
|
||||
{"e", "b"},
|
||||
{"c", "d"},
|
||||
{"d", "b"},
|
||||
{"e", "f"}});
|
||||
}
|
||||
|
||||
#define RUN_NUM_NODE_NUM_EDGE_BENCHMARK(name) \
|
||||
BENCHMARK(name) \
|
||||
->ArgPair(10, 2) \
|
||||
@ -2541,6 +2807,23 @@ RUN_NUM_FANIN_NUM_FANOUT_BENCHMARK(BM_GraphViewHasControlledFanoutLast);
|
||||
RUN_NUM_FANIN_NUM_FANOUT_BENCHMARK(BM_MutableGraphViewHasControlledFanoutFirst);
|
||||
RUN_NUM_FANIN_NUM_FANOUT_BENCHMARK(BM_MutableGraphViewHasControlledFanoutLast);
|
||||
|
||||
static void BM_SortTopologically(int iters, int size) {
|
||||
testing::StopTiming();
|
||||
|
||||
GraphDef graph = test::CreateRandomGraph(size);
|
||||
Status status;
|
||||
MutableGraphView graph_view(&graph, &status);
|
||||
TF_ASSERT_OK(status);
|
||||
|
||||
testing::StartTiming();
|
||||
for (int i = 0; i < iters; i++) {
|
||||
TF_EXPECT_OK(graph_view.SortTopologically(/*ignore_cycles=*/false, {}));
|
||||
}
|
||||
testing::StopTiming();
|
||||
}
|
||||
|
||||
RUN_NUM_NODE_BENCHMARK(BM_SortTopologically);
|
||||
|
||||
} // namespace
|
||||
} // namespace utils
|
||||
} // namespace grappler
|
||||
|
Loading…
x
Reference in New Issue
Block a user