During contraction of an edge from "a" to "b" there is a choice which node should be used to represent the union of the nodes. Using the node with the largest degree minimizes the number of operation that should be performed.

PiperOrigin-RevId: 301277365
Change-Id: I21b6dba8e72627d66628f4096006208c2f0b8c2b
This commit is contained in:
A. Unique TensorFlower 2020-03-16 18:06:46 -07:00 committed by TensorFlower Gardener
parent 6923207925
commit 41f5a034c8
6 changed files with 64 additions and 20 deletions

View File

@ -18,6 +18,7 @@ cc_library(
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
"@com_google_absl//absl/types:span",
],
)

View File

@ -368,14 +368,20 @@ bool GraphCycles::CanContractEdge(int32 a, int32 b) {
return !reachable;
}
bool GraphCycles::ContractEdge(int32 a, int32 b) {
absl::optional<int32> GraphCycles::ContractEdge(int32 a, int32 b) {
CHECK(HasEdge(a, b));
RemoveEdge(a, b);
if (IsReachableNonConst(a, b)) {
// Restore the graph to its original state.
InsertEdge(a, b);
return false;
return absl::nullopt;
}
if (rep_->nodes_[b]->in.Size() + rep_->nodes_[b]->out.Size() >
rep_->nodes_[a]->in.Size() + rep_->nodes_[a]->out.Size()) {
// Swap "a" and "b" to minimize copying.
std::swap(a, b);
}
Node* nb = rep_->nodes_[b];
@ -399,7 +405,8 @@ bool GraphCycles::ContractEdge(int32 a, int32 b) {
InsertEdge(y, a);
}
return true;
// Note, if the swap happened it might be what originally was called "b".
return a;
}
absl::Span<const int32> GraphCycles::Successors(int32 node) const {

View File

@ -40,6 +40,7 @@ limitations under the License.
// FindPath() is linear in the size of the graph.
// The current implementation uses O(|V|+|E|) space.
#include "absl/types/optional.h"
#include "absl/types/span.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
@ -80,11 +81,11 @@ class GraphCycles {
// Return whether there is an edge directly from source_node to dest_node.
bool HasEdge(int32 source_node, int32 dest_node) const;
// Contracts the edge from 'a' to node 'b', merging nodes 'a' and 'b'. 'b' is
// removed from the graph, and edges to/from 'b' are replaced with edges
// to/from 'a'. If contracting the edge would create a cycle, does nothing
// and returns false.
bool ContractEdge(int32 a, int32 b);
// Contracts the edge from 'a' to node 'b', merging nodes 'a' and 'b'. One of
// the nodes is removed from the graph, and edges to/from it are added to
// the remaining one, which is returned. If contracting the edge would create
// a cycle, does nothing and return no value.
absl::optional<int32> ContractEdge(int32 a, int32 b);
// Return true if can contract edge, otherwise return false.
bool CanContractEdge(int32 a, int32 b);

View File

@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
#include <optional>
#include <random>
#include <unordered_set>
#include <vector>
@ -479,19 +480,21 @@ TEST_F(GraphCyclesTest, ContractEdge) {
ASSERT_TRUE(AddEdge(2, 4));
ASSERT_TRUE(AddEdge(3, 4));
EXPECT_FALSE(g_.ContractEdge(1, 3));
EXPECT_FALSE(g_.ContractEdge(1, 3).has_value());
CHECK(g_.CheckInvariants());
EXPECT_TRUE(g_.HasEdge(1, 3));
EXPECT_TRUE(g_.ContractEdge(1, 2));
// Node (2) has more edges.
EXPECT_EQ(g_.ContractEdge(1, 2).value(), 2);
CHECK(g_.CheckInvariants());
EXPECT_TRUE(g_.HasEdge(1, 3));
EXPECT_TRUE(g_.HasEdge(1, 4));
EXPECT_TRUE(g_.HasEdge(2, 3));
EXPECT_TRUE(g_.HasEdge(2, 4));
EXPECT_TRUE(g_.HasEdge(3, 4));
EXPECT_TRUE(g_.ContractEdge(1, 3));
// Node (2) has more edges.
EXPECT_EQ(g_.ContractEdge(2, 3).value(), 2);
CHECK(g_.CheckInvariants());
EXPECT_TRUE(g_.HasEdge(1, 4));
EXPECT_TRUE(g_.HasEdge(2, 4));
}
TEST_F(GraphCyclesTest, CanContractEdge) {
@ -527,3 +530,26 @@ static void BM_StressTest(int iters, int num_nodes) {
}
}
BENCHMARK(BM_StressTest)->Range(2048, 1048576);
static void BM_ContractEdge(int iters, int num_nodes) {
while (iters-- > 0) {
tensorflow::testing::StopTiming();
tensorflow::GraphCycles g;
std::vector<int32> nodes;
nodes.reserve(num_nodes);
for (int i = 0; i < num_nodes; i++) {
nodes.push_back(g.NewNode());
}
// All edges point toward the last one.
for (int i = 0; i < num_nodes - 1; ++i) {
g.InsertEdge(nodes[i], nodes[num_nodes - 1]);
}
tensorflow::testing::StartTiming();
int node = num_nodes - 1;
for (int i = 0; i < num_nodes - 1; ++i) {
node = g.ContractEdge(nodes[i], node).value();
}
}
}
BENCHMARK(BM_ContractEdge)->Arg(1000)->Arg(10000);

View File

@ -161,6 +161,11 @@ class MarkForCompilationPassImpl {
// The ID of the cluster as represented in `cycles_graph_`.
int cycles_graph_node_id() const { return cycles_graph_node_id_; }
// Sets the ID of the cluster as represented in `cycles_graph_`.
void set_cycles_graph_node_id(int cycles_graph_node_id) {
cycles_graph_node_id_ = cycles_graph_node_id;
}
// The size of the cluster excluding constant and identity nodes.
int effective_cluster_size() const { return effective_cluster_size_; }
@ -381,14 +386,16 @@ class MarkForCompilationPassImpl {
// R, B} cluster.
string DescribePotentialCycle(int from, int to);
// Merge the clusters `cluster_from` and `cluster_to`. After this step the
// larger combined cluster is represented by `cluster_from`'s ID in
// `cycles_graph_`.
// Merge the clusters `cluster_from` and `cluster_to`. After this step the
// larger combined cluster is represented by `cluster_from`, but can have
// `cycles_graph_`'s ID of either `cluster_from` or `cluster_to` depending on
// which way will require less operations.
bool MergeClusters(Cluster* cluster_from, Cluster* cluster_to) {
int from = cluster_from->cycles_graph_node_id();
int to = cluster_to->cycles_graph_node_id();
if (!cycles_graph_.ContractEdge(from, to)) {
auto optional_merged_node = cycles_graph_.ContractEdge(from, to);
if (!optional_merged_node.has_value()) {
VLOG(3) << "Could not contract " << cluster_from->DebugString(*graph_)
<< " -> " << cluster_to->DebugString(*graph_)
<< " because contracting the edge would create a cycle via "
@ -398,6 +405,8 @@ class MarkForCompilationPassImpl {
// Merge the clusters.
cluster_from->Merge(cluster_to);
// Update `cycle_graph_`'s ID.
cluster_from->set_cycles_graph_node_id(optional_merged_node.value());
// Merge the UnionFind<Cluster*>.
cluster_for_node_[from].Merge(&cluster_for_node_[to]);

View File

@ -50,7 +50,7 @@ TEST(CreateCycleDetectionGraph, ConnectivityThroughEnterExitRegion) {
GraphCycles cycles;
TF_ASSERT_OK(CreateCycleDetectionGraph(root.graph(), &cycles).status());
EXPECT_FALSE(cycles.ContractEdge(a.node()->id(), b.node()->id()));
EXPECT_FALSE(cycles.CanContractEdge(a.node()->id(), b.node()->id()));
}
TEST(CreateCycleDetectionGraph, ConnectivityThroughMultipleEnterExitRegions) {
@ -69,7 +69,7 @@ TEST(CreateCycleDetectionGraph, ConnectivityThroughMultipleEnterExitRegions) {
GraphCycles cycles;
TF_ASSERT_OK(CreateCycleDetectionGraph(root.graph(), &cycles).status());
EXPECT_FALSE(cycles.ContractEdge(a.node()->id(), b.node()->id()));
EXPECT_FALSE(cycles.CanContractEdge(a.node()->id(), b.node()->id()));
}
TEST(CreateCycleDetectionGraph, ReachingEnterExit) {