From 58e052bd777e29dcd1bd75bb83610c5fa56474b2 Mon Sep 17 00:00:00 2001 From: Sanjoy Das Date: Wed, 6 Mar 2019 17:43:27 -0800 Subject: [PATCH] Add a workaround to the cycle detection algorithm for certain unusual graphs Our cycle detection graph has special logic to deal with frames and NextIteration nodes and this logic does not do the right thing in certain cases. The motivation for the current scheme is mentioned in CL 137756517 but I think it goes wrong for graphs like: digraph { Enter_0 [label="Enter('frame')"] Enter_1 [label="Enter('frame')"] Exit_0 [label="Exit"] Exit_1 [label="Exit"] SRC -> Enter_0 Enter_0 -> Exit_0 Exit_0 -> A A -> Enter_1 Enter_1 -> Exit_1 } for which we try to create the cycle detection graph: digraph { SRC -> "FakeNode(frame)" "FakeNode(frame)" -> A A -> "FakeNode(frame)" "FakeNode(frame)" -> Exit } which, of course, is cyclic. I'm not sure how frequent this is in practice, so for now I'll add a workaround that avoids auto-clustering in this situation. But we may want to eventually implement a more principled fix, possibly based on some from of SCC detection. PiperOrigin-RevId: 237154773 --- tensorflow/compiler/jit/BUILD | 4 +++ tensorflow/compiler/jit/deadness_analysis.cc | 3 +- .../compiler/jit/mark_for_compilation_pass.cc | 6 +++- tensorflow/compiler/jit/xla_cluster_util.cc | 14 ++++++---- tensorflow/compiler/jit/xla_cluster_util.h | 8 +++++- .../compiler/jit/xla_cluster_util_test.cc | 28 +++++++++++++++++-- .../compiler/jit/xla_fusion_optimizer.cc | 8 +++++- 7 files changed, 60 insertions(+), 11 deletions(-) diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index d0db9cc3f50..3f46f713673 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -560,11 +560,13 @@ cc_library( ":resource_operation_safety_analysis", "//tensorflow/compiler/jit/graphcycles", "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", "//tensorflow/core:core_cpu", "//tensorflow/core:framework", "//tensorflow/core:framework_bounds_check", "//tensorflow/core:graph", "//tensorflow/core:protos_all_cc", + "//tensorflow/stream_executor/lib", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", @@ -692,6 +694,7 @@ tf_cc_test( "//tensorflow/compiler/jit/kernels:xla_ops", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla/kernels:xla_ops", + "//tensorflow/compiler/xla:status_macros", "//tensorflow/core:core_cpu", "//tensorflow/core:framework", "//tensorflow/core:framework_internal", @@ -715,6 +718,7 @@ cc_library( ":union_find", ":xla_cluster_util", "//tensorflow/compiler/jit/graphcycles", + "//tensorflow/compiler/xla:status_macros", "//tensorflow/core:core_cpu_base", "//tensorflow/core:protos_all_cc", "//tensorflow/core/grappler:grappler_item", diff --git a/tensorflow/compiler/jit/deadness_analysis.cc b/tensorflow/compiler/jit/deadness_analysis.cc index 4397eea9af2..4856301cef4 100644 --- a/tensorflow/compiler/jit/deadness_analysis.cc +++ b/tensorflow/compiler/jit/deadness_analysis.cc @@ -768,7 +768,8 @@ Status DeadnessAnalysisImpl::GetInputPreds( auto it = predicate_map_.find(InputEdgeToTensorId(in_edge)); if (it == predicate_map_.end()) { GraphCycles graph_cycles; - TF_RETURN_IF_ERROR(CreateCycleDetectionGraph(&graph_, &graph_cycles)); + TF_RETURN_IF_ERROR( + CreateCycleDetectionGraph(&graph_, &graph_cycles).status()); // If we didn't return with an error above then the graph is probably // fine and we have a bug in deadness analysis. diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index bbf5d77db14..c193386afcc 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -1098,7 +1098,11 @@ Status MarkForCompilationPass::RunImpl( } GraphCycles cycles; - TF_RETURN_IF_ERROR(CreateCycleDetectionGraph(graph, &cycles)); + TF_ASSIGN_OR_RETURN(bool cycle_detection_graph_ok, + CreateCycleDetectionGraph(graph, &cycles)); + if (!cycle_detection_graph_ok) { + return Status::OK(); + } TF_RETURN_IF_ERROR(AdjustCycleDetectionGraphForResourceOps( graph, options.flib_def, IgnoreResourceOpForSafetyAnalysis, &cycles)); diff --git a/tensorflow/compiler/jit/xla_cluster_util.cc b/tensorflow/compiler/jit/xla_cluster_util.cc index eaa7015768c..cb8ac06207e 100644 --- a/tensorflow/compiler/jit/xla_cluster_util.cc +++ b/tensorflow/compiler/jit/xla_cluster_util.cc @@ -111,7 +111,8 @@ bool HasForwardedRefInput(const Node& node) { return false; } -Status CreateCycleDetectionGraph(const Graph* graph, GraphCycles* cycles) { +xla::StatusOr CreateCycleDetectionGraph(const Graph* graph, + GraphCycles* cycles) { for (int i = 0; i < graph->num_node_ids(); ++i) { // We rely on the node IDs in the cycle detection graph being consecutive // integers starting from 0. @@ -174,9 +175,11 @@ Status CreateCycleDetectionGraph(const Graph* graph, GraphCycles* cycles) { } if (!cycles->InsertEdge(src, dst)) { - return errors::Internal( - "Cycle detected when adding ", src_type, "->", dst_type, - " edge: ", DescribeCycle(cycles, *graph, src, dst)); + // TODO(b/127521408): We can probably handle this situation with a more + // sophisticated SCC based algorithm, but for now we bail out. + VLOG(1) << "Cycle detected when adding " << src_type << "->" << dst_type + << " edge: " << DescribeCycle(cycles, *graph, src, dst); + return false; } // Drop the original edge. continue; @@ -194,7 +197,8 @@ Status CreateCycleDetectionGraph(const Graph* graph, GraphCycles* cycles) { DescribeCycle(cycles, *graph, edge->src()->id(), edge->dst()->id())); } } - return Status::OK(); + + return true; } absl::optional GetXlaClusterForNode(const Node& node) { diff --git a/tensorflow/compiler/jit/xla_cluster_util.h b/tensorflow/compiler/jit/xla_cluster_util.h index ddca0aaeabb..af01e1d3023 100644 --- a/tensorflow/compiler/jit/xla_cluster_util.h +++ b/tensorflow/compiler/jit/xla_cluster_util.h @@ -20,8 +20,10 @@ limitations under the License. #include "absl/types/optional.h" #include "tensorflow/compiler/jit/graphcycles/graphcycles.h" +#include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/core/common_runtime/optimization_registry.h" #include "tensorflow/core/graph/algorithm.h" +#include "tensorflow/stream_executor/lib/statusor.h" namespace tensorflow { @@ -53,7 +55,11 @@ bool HasForwardedRefInput(const Node& node); // Creates a graph representation to enable cycle detection when clustering. // This representation handles loops in graph by disconnecting each loop from // the enclosing graph. -Status CreateCycleDetectionGraph(const Graph* graph, GraphCycles* cycles); +// +// Returns true for success and false for valid graphs that we can't handle yet +// (b/127521408). +xla::StatusOr CreateCycleDetectionGraph(const Graph* graph, + GraphCycles* cycles); // Returns the XLA cluster in which `node` is placed if it is in an XLA cluster, // otherwise returns nullopt. diff --git a/tensorflow/compiler/jit/xla_cluster_util_test.cc b/tensorflow/compiler/jit/xla_cluster_util_test.cc index 68fb4da134e..cbaac719f2e 100644 --- a/tensorflow/compiler/jit/xla_cluster_util_test.cc +++ b/tensorflow/compiler/jit/xla_cluster_util_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/ops/control_flow_ops_internal.h" #include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/framework/function_testlib.h" #include "tensorflow/core/framework/graph_to_functiondef.h" #include "tensorflow/core/graph/algorithm.h" @@ -44,7 +45,7 @@ TEST(CreateCycleDetectionGraph, ConnectivityThroughEnterExitRegion) { FixupSourceAndSinkEdges(root.graph()); GraphCycles cycles; - TF_ASSERT_OK(CreateCycleDetectionGraph(root.graph(), &cycles)); + TF_ASSERT_OK(CreateCycleDetectionGraph(root.graph(), &cycles).status()); EXPECT_FALSE(cycles.ContractEdge(a.node()->id(), b.node()->id())); } @@ -63,10 +64,33 @@ TEST(CreateCycleDetectionGraph, ConnectivityThroughMultipleEnterExitRegions) { FixupSourceAndSinkEdges(root.graph()); GraphCycles cycles; - TF_ASSERT_OK(CreateCycleDetectionGraph(root.graph(), &cycles)); + TF_ASSERT_OK(CreateCycleDetectionGraph(root.graph(), &cycles).status()); EXPECT_FALSE(cycles.ContractEdge(a.node()->id(), b.node()->id())); } +TEST(CreateCycleDetectionGraph, ReachingEnterExit) { + // TODO(b/127521408): We can lift this limitation with some work. + Scope root = Scope::NewRootScope().ExitOnError(); + + Output a = ops::Const(root.WithOpName("a"), Input::Initializer(0.0)); + Output enter_0 = + ops::internal::Enter(root.WithOpName("enter_0"), a, "frame_0"); + Output exit_0 = ops::internal::Exit(root.WithOpName("exit_0"), enter_0); + + Output add = ops::Add(root.WithOpName("add"), exit_0, exit_0); + + Output enter_1 = + ops::internal::Enter(root.WithOpName("enter_1"), add, "frame_0"); + Output exit_1 = ops::internal::Exit(root.WithOpName("exit_1"), enter_1); + + FixupSourceAndSinkEdges(root.graph()); + + GraphCycles cycles; + TF_ASSERT_OK_AND_ASSIGN(bool ok, + CreateCycleDetectionGraph(root.graph(), &cycles)); + EXPECT_FALSE(ok); +} + void CheckPickDeviceResult(absl::string_view expected_result, bool allow_mixing_unknown_and_cpu, absl::Span inputs) { diff --git a/tensorflow/compiler/jit/xla_fusion_optimizer.cc b/tensorflow/compiler/jit/xla_fusion_optimizer.cc index bc0db558d8d..a2a06f57698 100644 --- a/tensorflow/compiler/jit/xla_fusion_optimizer.cc +++ b/tensorflow/compiler/jit/xla_fusion_optimizer.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/compiler/jit/graphcycles/graphcycles.h" #include "tensorflow/compiler/jit/union_find.h" #include "tensorflow/compiler/jit/xla_cluster_util.h" +#include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/common_runtime/shape_refiner.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/graph/graph_constructor.h" @@ -208,7 +209,12 @@ Status XlaFusionOptimizer::Optimize(grappler::Cluster* cluster, } GraphCycles cycles; - TF_RETURN_IF_ERROR(CreateCycleDetectionGraph(&graph, &cycles)); + TF_ASSIGN_OR_RETURN(bool cycle_detection_graph_ok, + CreateCycleDetectionGraph(&graph, &cycles)); + if (!cycle_detection_graph_ok) { + return Status::OK(); + } + TF_RETURN_IF_ERROR(AdjustCycleDetectionGraphForResourceOps( &graph, &graph.flib_def(), /*resource_ops_to_ignore=*/{}, &cycles));