Add a workaround to the cycle detection algorithm for certain unusual graphs

Our cycle detection graph has special logic to deal with frames and
NextIteration nodes and this logic does not do the right thing in certain cases.
The motivation for the current scheme is mentioned in CL 137756517 but I think
it goes wrong for graphs like:

digraph {
  Enter_0 [label="Enter('frame')"]
  Enter_1 [label="Enter('frame')"]
  Exit_0 [label="Exit"]
  Exit_1 [label="Exit"]

  SRC -> Enter_0
  Enter_0 -> Exit_0
  Exit_0 -> A
  A -> Enter_1
  Enter_1 -> Exit_1
}

for which we try to create the cycle detection graph:

digraph {
  SRC -> "FakeNode(frame)"
  "FakeNode(frame)" -> A
  A -> "FakeNode(frame)"
  "FakeNode(frame)" -> Exit
}

which, of course, is cyclic.

I'm not sure how frequent this is in practice, so for now I'll add a workaround
that avoids auto-clustering in this situation.  But we may want to eventually
implement a more principled fix, possibly based on some from of SCC detection.

PiperOrigin-RevId: 237154773
This commit is contained in:
Sanjoy Das 2019-03-06 17:43:27 -08:00 committed by TensorFlower Gardener
parent 10b1a5c578
commit 58e052bd77
7 changed files with 60 additions and 11 deletions

View File

@ -560,11 +560,13 @@ cc_library(
":resource_operation_safety_analysis", ":resource_operation_safety_analysis",
"//tensorflow/compiler/jit/graphcycles", "//tensorflow/compiler/jit/graphcycles",
"//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu",
"//tensorflow/core:framework", "//tensorflow/core:framework",
"//tensorflow/core:framework_bounds_check", "//tensorflow/core:framework_bounds_check",
"//tensorflow/core:graph", "//tensorflow/core:graph",
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",
"//tensorflow/stream_executor/lib",
"@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:flat_hash_set",
@ -692,6 +694,7 @@ tf_cc_test(
"//tensorflow/compiler/jit/kernels:xla_ops", "//tensorflow/compiler/jit/kernels:xla_ops",
"//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla/kernels:xla_ops", "//tensorflow/compiler/tf2xla/kernels:xla_ops",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu",
"//tensorflow/core:framework", "//tensorflow/core:framework",
"//tensorflow/core:framework_internal", "//tensorflow/core:framework_internal",
@ -715,6 +718,7 @@ cc_library(
":union_find", ":union_find",
":xla_cluster_util", ":xla_cluster_util",
"//tensorflow/compiler/jit/graphcycles", "//tensorflow/compiler/jit/graphcycles",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/core:core_cpu_base", "//tensorflow/core:core_cpu_base",
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",
"//tensorflow/core/grappler:grappler_item", "//tensorflow/core/grappler:grappler_item",

View File

@ -768,7 +768,8 @@ Status DeadnessAnalysisImpl::GetInputPreds(
auto it = predicate_map_.find(InputEdgeToTensorId(in_edge)); auto it = predicate_map_.find(InputEdgeToTensorId(in_edge));
if (it == predicate_map_.end()) { if (it == predicate_map_.end()) {
GraphCycles graph_cycles; GraphCycles graph_cycles;
TF_RETURN_IF_ERROR(CreateCycleDetectionGraph(&graph_, &graph_cycles)); TF_RETURN_IF_ERROR(
CreateCycleDetectionGraph(&graph_, &graph_cycles).status());
// If we didn't return with an error above then the graph is probably // If we didn't return with an error above then the graph is probably
// fine and we have a bug in deadness analysis. // fine and we have a bug in deadness analysis.

View File

@ -1098,7 +1098,11 @@ Status MarkForCompilationPass::RunImpl(
} }
GraphCycles cycles; GraphCycles cycles;
TF_RETURN_IF_ERROR(CreateCycleDetectionGraph(graph, &cycles)); TF_ASSIGN_OR_RETURN(bool cycle_detection_graph_ok,
CreateCycleDetectionGraph(graph, &cycles));
if (!cycle_detection_graph_ok) {
return Status::OK();
}
TF_RETURN_IF_ERROR(AdjustCycleDetectionGraphForResourceOps( TF_RETURN_IF_ERROR(AdjustCycleDetectionGraphForResourceOps(
graph, options.flib_def, IgnoreResourceOpForSafetyAnalysis, &cycles)); graph, options.flib_def, IgnoreResourceOpForSafetyAnalysis, &cycles));

View File

@ -111,7 +111,8 @@ bool HasForwardedRefInput(const Node& node) {
return false; return false;
} }
Status CreateCycleDetectionGraph(const Graph* graph, GraphCycles* cycles) { xla::StatusOr<bool> CreateCycleDetectionGraph(const Graph* graph,
GraphCycles* cycles) {
for (int i = 0; i < graph->num_node_ids(); ++i) { for (int i = 0; i < graph->num_node_ids(); ++i) {
// We rely on the node IDs in the cycle detection graph being consecutive // We rely on the node IDs in the cycle detection graph being consecutive
// integers starting from 0. // integers starting from 0.
@ -174,9 +175,11 @@ Status CreateCycleDetectionGraph(const Graph* graph, GraphCycles* cycles) {
} }
if (!cycles->InsertEdge(src, dst)) { if (!cycles->InsertEdge(src, dst)) {
return errors::Internal( // TODO(b/127521408): We can probably handle this situation with a more
"Cycle detected when adding ", src_type, "->", dst_type, // sophisticated SCC based algorithm, but for now we bail out.
" edge: ", DescribeCycle(cycles, *graph, src, dst)); VLOG(1) << "Cycle detected when adding " << src_type << "->" << dst_type
<< " edge: " << DescribeCycle(cycles, *graph, src, dst);
return false;
} }
// Drop the original edge. // Drop the original edge.
continue; continue;
@ -194,7 +197,8 @@ Status CreateCycleDetectionGraph(const Graph* graph, GraphCycles* cycles) {
DescribeCycle(cycles, *graph, edge->src()->id(), edge->dst()->id())); DescribeCycle(cycles, *graph, edge->src()->id(), edge->dst()->id()));
} }
} }
return Status::OK();
return true;
} }
absl::optional<absl::string_view> GetXlaClusterForNode(const Node& node) { absl::optional<absl::string_view> GetXlaClusterForNode(const Node& node) {

View File

@ -20,8 +20,10 @@ limitations under the License.
#include "absl/types/optional.h" #include "absl/types/optional.h"
#include "tensorflow/compiler/jit/graphcycles/graphcycles.h" #include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/core/common_runtime/optimization_registry.h" #include "tensorflow/core/common_runtime/optimization_registry.h"
#include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/stream_executor/lib/statusor.h"
namespace tensorflow { namespace tensorflow {
@ -53,7 +55,11 @@ bool HasForwardedRefInput(const Node& node);
// Creates a graph representation to enable cycle detection when clustering. // Creates a graph representation to enable cycle detection when clustering.
// This representation handles loops in graph by disconnecting each loop from // This representation handles loops in graph by disconnecting each loop from
// the enclosing graph. // the enclosing graph.
Status CreateCycleDetectionGraph(const Graph* graph, GraphCycles* cycles); //
// Returns true for success and false for valid graphs that we can't handle yet
// (b/127521408).
xla::StatusOr<bool> CreateCycleDetectionGraph(const Graph* graph,
GraphCycles* cycles);
// Returns the XLA cluster in which `node` is placed if it is in an XLA cluster, // Returns the XLA cluster in which `node` is placed if it is in an XLA cluster,
// otherwise returns nullopt. // otherwise returns nullopt.

View File

@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/framework/ops.h"
#include "tensorflow/cc/ops/control_flow_ops_internal.h" #include "tensorflow/cc/ops/control_flow_ops_internal.h"
#include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/core/framework/function_testlib.h" #include "tensorflow/core/framework/function_testlib.h"
#include "tensorflow/core/framework/graph_to_functiondef.h" #include "tensorflow/core/framework/graph_to_functiondef.h"
#include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/algorithm.h"
@ -44,7 +45,7 @@ TEST(CreateCycleDetectionGraph, ConnectivityThroughEnterExitRegion) {
FixupSourceAndSinkEdges(root.graph()); FixupSourceAndSinkEdges(root.graph());
GraphCycles cycles; GraphCycles cycles;
TF_ASSERT_OK(CreateCycleDetectionGraph(root.graph(), &cycles)); TF_ASSERT_OK(CreateCycleDetectionGraph(root.graph(), &cycles).status());
EXPECT_FALSE(cycles.ContractEdge(a.node()->id(), b.node()->id())); EXPECT_FALSE(cycles.ContractEdge(a.node()->id(), b.node()->id()));
} }
@ -63,10 +64,33 @@ TEST(CreateCycleDetectionGraph, ConnectivityThroughMultipleEnterExitRegions) {
FixupSourceAndSinkEdges(root.graph()); FixupSourceAndSinkEdges(root.graph());
GraphCycles cycles; GraphCycles cycles;
TF_ASSERT_OK(CreateCycleDetectionGraph(root.graph(), &cycles)); TF_ASSERT_OK(CreateCycleDetectionGraph(root.graph(), &cycles).status());
EXPECT_FALSE(cycles.ContractEdge(a.node()->id(), b.node()->id())); EXPECT_FALSE(cycles.ContractEdge(a.node()->id(), b.node()->id()));
} }
TEST(CreateCycleDetectionGraph, ReachingEnterExit) {
// TODO(b/127521408): We can lift this limitation with some work.
Scope root = Scope::NewRootScope().ExitOnError();
Output a = ops::Const(root.WithOpName("a"), Input::Initializer(0.0));
Output enter_0 =
ops::internal::Enter(root.WithOpName("enter_0"), a, "frame_0");
Output exit_0 = ops::internal::Exit(root.WithOpName("exit_0"), enter_0);
Output add = ops::Add(root.WithOpName("add"), exit_0, exit_0);
Output enter_1 =
ops::internal::Enter(root.WithOpName("enter_1"), add, "frame_0");
Output exit_1 = ops::internal::Exit(root.WithOpName("exit_1"), enter_1);
FixupSourceAndSinkEdges(root.graph());
GraphCycles cycles;
TF_ASSERT_OK_AND_ASSIGN(bool ok,
CreateCycleDetectionGraph(root.graph(), &cycles));
EXPECT_FALSE(ok);
}
void CheckPickDeviceResult(absl::string_view expected_result, void CheckPickDeviceResult(absl::string_view expected_result,
bool allow_mixing_unknown_and_cpu, bool allow_mixing_unknown_and_cpu,
absl::Span<const absl::string_view> inputs) { absl::Span<const absl::string_view> inputs) {

View File

@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/compiler/jit/graphcycles/graphcycles.h" #include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
#include "tensorflow/compiler/jit/union_find.h" #include "tensorflow/compiler/jit/union_find.h"
#include "tensorflow/compiler/jit/xla_cluster_util.h" #include "tensorflow/compiler/jit/xla_cluster_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/core/common_runtime/shape_refiner.h" #include "tensorflow/core/common_runtime/shape_refiner.h"
#include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/graph_constructor.h"
@ -208,7 +209,12 @@ Status XlaFusionOptimizer::Optimize(grappler::Cluster* cluster,
} }
GraphCycles cycles; GraphCycles cycles;
TF_RETURN_IF_ERROR(CreateCycleDetectionGraph(&graph, &cycles)); TF_ASSIGN_OR_RETURN(bool cycle_detection_graph_ok,
CreateCycleDetectionGraph(&graph, &cycles));
if (!cycle_detection_graph_ok) {
return Status::OK();
}
TF_RETURN_IF_ERROR(AdjustCycleDetectionGraphForResourceOps( TF_RETURN_IF_ERROR(AdjustCycleDetectionGraphForResourceOps(
&graph, &graph.flib_def(), /*resource_ops_to_ignore=*/{}, &cycles)); &graph, &graph.flib_def(), /*resource_ops_to_ignore=*/{}, &cycles));