During contraction of an edge from "a" to "b" there is a choice which node should be used to represent the union of the nodes. Using the node with the largest degree minimizes the number of operation that should be performed.
PiperOrigin-RevId: 301277365 Change-Id: I21b6dba8e72627d66628f4096006208c2f0b8c2b
This commit is contained in:
parent
6923207925
commit
41f5a034c8
@ -18,6 +18,7 @@ cc_library(
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/container:inlined_vector",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
@ -368,14 +368,20 @@ bool GraphCycles::CanContractEdge(int32 a, int32 b) {
|
||||
return !reachable;
|
||||
}
|
||||
|
||||
bool GraphCycles::ContractEdge(int32 a, int32 b) {
|
||||
absl::optional<int32> GraphCycles::ContractEdge(int32 a, int32 b) {
|
||||
CHECK(HasEdge(a, b));
|
||||
RemoveEdge(a, b);
|
||||
|
||||
if (IsReachableNonConst(a, b)) {
|
||||
// Restore the graph to its original state.
|
||||
InsertEdge(a, b);
|
||||
return false;
|
||||
return absl::nullopt;
|
||||
}
|
||||
|
||||
if (rep_->nodes_[b]->in.Size() + rep_->nodes_[b]->out.Size() >
|
||||
rep_->nodes_[a]->in.Size() + rep_->nodes_[a]->out.Size()) {
|
||||
// Swap "a" and "b" to minimize copying.
|
||||
std::swap(a, b);
|
||||
}
|
||||
|
||||
Node* nb = rep_->nodes_[b];
|
||||
@ -399,7 +405,8 @@ bool GraphCycles::ContractEdge(int32 a, int32 b) {
|
||||
InsertEdge(y, a);
|
||||
}
|
||||
|
||||
return true;
|
||||
// Note, if the swap happened it might be what originally was called "b".
|
||||
return a;
|
||||
}
|
||||
|
||||
absl::Span<const int32> GraphCycles::Successors(int32 node) const {
|
||||
|
@ -40,6 +40,7 @@ limitations under the License.
|
||||
// FindPath() is linear in the size of the graph.
|
||||
// The current implementation uses O(|V|+|E|) space.
|
||||
|
||||
#include "absl/types/optional.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
@ -80,11 +81,11 @@ class GraphCycles {
|
||||
// Return whether there is an edge directly from source_node to dest_node.
|
||||
bool HasEdge(int32 source_node, int32 dest_node) const;
|
||||
|
||||
// Contracts the edge from 'a' to node 'b', merging nodes 'a' and 'b'. 'b' is
|
||||
// removed from the graph, and edges to/from 'b' are replaced with edges
|
||||
// to/from 'a'. If contracting the edge would create a cycle, does nothing
|
||||
// and returns false.
|
||||
bool ContractEdge(int32 a, int32 b);
|
||||
// Contracts the edge from 'a' to node 'b', merging nodes 'a' and 'b'. One of
|
||||
// the nodes is removed from the graph, and edges to/from it are added to
|
||||
// the remaining one, which is returned. If contracting the edge would create
|
||||
// a cycle, does nothing and return no value.
|
||||
absl::optional<int32> ContractEdge(int32 a, int32 b);
|
||||
|
||||
// Return true if can contract edge, otherwise return false.
|
||||
bool CanContractEdge(int32 a, int32 b);
|
||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
|
||||
|
||||
#include <optional>
|
||||
#include <random>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
@ -479,19 +480,21 @@ TEST_F(GraphCyclesTest, ContractEdge) {
|
||||
ASSERT_TRUE(AddEdge(2, 4));
|
||||
ASSERT_TRUE(AddEdge(3, 4));
|
||||
|
||||
EXPECT_FALSE(g_.ContractEdge(1, 3));
|
||||
EXPECT_FALSE(g_.ContractEdge(1, 3).has_value());
|
||||
CHECK(g_.CheckInvariants());
|
||||
EXPECT_TRUE(g_.HasEdge(1, 3));
|
||||
|
||||
EXPECT_TRUE(g_.ContractEdge(1, 2));
|
||||
// Node (2) has more edges.
|
||||
EXPECT_EQ(g_.ContractEdge(1, 2).value(), 2);
|
||||
CHECK(g_.CheckInvariants());
|
||||
EXPECT_TRUE(g_.HasEdge(1, 3));
|
||||
EXPECT_TRUE(g_.HasEdge(1, 4));
|
||||
EXPECT_TRUE(g_.HasEdge(2, 3));
|
||||
EXPECT_TRUE(g_.HasEdge(2, 4));
|
||||
EXPECT_TRUE(g_.HasEdge(3, 4));
|
||||
|
||||
EXPECT_TRUE(g_.ContractEdge(1, 3));
|
||||
// Node (2) has more edges.
|
||||
EXPECT_EQ(g_.ContractEdge(2, 3).value(), 2);
|
||||
CHECK(g_.CheckInvariants());
|
||||
EXPECT_TRUE(g_.HasEdge(1, 4));
|
||||
EXPECT_TRUE(g_.HasEdge(2, 4));
|
||||
}
|
||||
|
||||
TEST_F(GraphCyclesTest, CanContractEdge) {
|
||||
@ -527,3 +530,26 @@ static void BM_StressTest(int iters, int num_nodes) {
|
||||
}
|
||||
}
|
||||
BENCHMARK(BM_StressTest)->Range(2048, 1048576);
|
||||
|
||||
static void BM_ContractEdge(int iters, int num_nodes) {
|
||||
while (iters-- > 0) {
|
||||
tensorflow::testing::StopTiming();
|
||||
tensorflow::GraphCycles g;
|
||||
std::vector<int32> nodes;
|
||||
nodes.reserve(num_nodes);
|
||||
for (int i = 0; i < num_nodes; i++) {
|
||||
nodes.push_back(g.NewNode());
|
||||
}
|
||||
// All edges point toward the last one.
|
||||
for (int i = 0; i < num_nodes - 1; ++i) {
|
||||
g.InsertEdge(nodes[i], nodes[num_nodes - 1]);
|
||||
}
|
||||
|
||||
tensorflow::testing::StartTiming();
|
||||
int node = num_nodes - 1;
|
||||
for (int i = 0; i < num_nodes - 1; ++i) {
|
||||
node = g.ContractEdge(nodes[i], node).value();
|
||||
}
|
||||
}
|
||||
}
|
||||
BENCHMARK(BM_ContractEdge)->Arg(1000)->Arg(10000);
|
||||
|
@ -161,6 +161,11 @@ class MarkForCompilationPassImpl {
|
||||
// The ID of the cluster as represented in `cycles_graph_`.
|
||||
int cycles_graph_node_id() const { return cycles_graph_node_id_; }
|
||||
|
||||
// Sets the ID of the cluster as represented in `cycles_graph_`.
|
||||
void set_cycles_graph_node_id(int cycles_graph_node_id) {
|
||||
cycles_graph_node_id_ = cycles_graph_node_id;
|
||||
}
|
||||
|
||||
// The size of the cluster excluding constant and identity nodes.
|
||||
int effective_cluster_size() const { return effective_cluster_size_; }
|
||||
|
||||
@ -381,14 +386,16 @@ class MarkForCompilationPassImpl {
|
||||
// R, B} cluster.
|
||||
string DescribePotentialCycle(int from, int to);
|
||||
|
||||
// Merge the clusters `cluster_from` and `cluster_to`. After this step the
|
||||
// larger combined cluster is represented by `cluster_from`'s ID in
|
||||
// `cycles_graph_`.
|
||||
// Merge the clusters `cluster_from` and `cluster_to`. After this step the
|
||||
// larger combined cluster is represented by `cluster_from`, but can have
|
||||
// `cycles_graph_`'s ID of either `cluster_from` or `cluster_to` depending on
|
||||
// which way will require less operations.
|
||||
bool MergeClusters(Cluster* cluster_from, Cluster* cluster_to) {
|
||||
int from = cluster_from->cycles_graph_node_id();
|
||||
int to = cluster_to->cycles_graph_node_id();
|
||||
|
||||
if (!cycles_graph_.ContractEdge(from, to)) {
|
||||
auto optional_merged_node = cycles_graph_.ContractEdge(from, to);
|
||||
if (!optional_merged_node.has_value()) {
|
||||
VLOG(3) << "Could not contract " << cluster_from->DebugString(*graph_)
|
||||
<< " -> " << cluster_to->DebugString(*graph_)
|
||||
<< " because contracting the edge would create a cycle via "
|
||||
@ -398,6 +405,8 @@ class MarkForCompilationPassImpl {
|
||||
|
||||
// Merge the clusters.
|
||||
cluster_from->Merge(cluster_to);
|
||||
// Update `cycle_graph_`'s ID.
|
||||
cluster_from->set_cycles_graph_node_id(optional_merged_node.value());
|
||||
|
||||
// Merge the UnionFind<Cluster*>.
|
||||
cluster_for_node_[from].Merge(&cluster_for_node_[to]);
|
||||
|
@ -50,7 +50,7 @@ TEST(CreateCycleDetectionGraph, ConnectivityThroughEnterExitRegion) {
|
||||
|
||||
GraphCycles cycles;
|
||||
TF_ASSERT_OK(CreateCycleDetectionGraph(root.graph(), &cycles).status());
|
||||
EXPECT_FALSE(cycles.ContractEdge(a.node()->id(), b.node()->id()));
|
||||
EXPECT_FALSE(cycles.CanContractEdge(a.node()->id(), b.node()->id()));
|
||||
}
|
||||
|
||||
TEST(CreateCycleDetectionGraph, ConnectivityThroughMultipleEnterExitRegions) {
|
||||
@ -69,7 +69,7 @@ TEST(CreateCycleDetectionGraph, ConnectivityThroughMultipleEnterExitRegions) {
|
||||
|
||||
GraphCycles cycles;
|
||||
TF_ASSERT_OK(CreateCycleDetectionGraph(root.graph(), &cycles).status());
|
||||
EXPECT_FALSE(cycles.ContractEdge(a.node()->id(), b.node()->id()));
|
||||
EXPECT_FALSE(cycles.CanContractEdge(a.node()->id(), b.node()->id()));
|
||||
}
|
||||
|
||||
TEST(CreateCycleDetectionGraph, ReachingEnterExit) {
|
||||
|
Loading…
Reference in New Issue
Block a user