diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index d0db9cc3f50..3f46f713673 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -560,11 +560,13 @@ cc_library( ":resource_operation_safety_analysis", "//tensorflow/compiler/jit/graphcycles", "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", "//tensorflow/core:core_cpu", "//tensorflow/core:framework", "//tensorflow/core:framework_bounds_check", "//tensorflow/core:graph", "//tensorflow/core:protos_all_cc", + "//tensorflow/stream_executor/lib", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", @@ -692,6 +694,7 @@ tf_cc_test( "//tensorflow/compiler/jit/kernels:xla_ops", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla/kernels:xla_ops", + "//tensorflow/compiler/xla:status_macros", "//tensorflow/core:core_cpu", "//tensorflow/core:framework", "//tensorflow/core:framework_internal", @@ -715,6 +718,7 @@ cc_library( ":union_find", ":xla_cluster_util", "//tensorflow/compiler/jit/graphcycles", + "//tensorflow/compiler/xla:status_macros", "//tensorflow/core:core_cpu_base", "//tensorflow/core:protos_all_cc", "//tensorflow/core/grappler:grappler_item", diff --git a/tensorflow/compiler/jit/deadness_analysis.cc b/tensorflow/compiler/jit/deadness_analysis.cc index 4397eea9af2..4856301cef4 100644 --- a/tensorflow/compiler/jit/deadness_analysis.cc +++ b/tensorflow/compiler/jit/deadness_analysis.cc @@ -768,7 +768,8 @@ Status DeadnessAnalysisImpl::GetInputPreds( auto it = predicate_map_.find(InputEdgeToTensorId(in_edge)); if (it == predicate_map_.end()) { GraphCycles graph_cycles; - TF_RETURN_IF_ERROR(CreateCycleDetectionGraph(&graph_, &graph_cycles)); + TF_RETURN_IF_ERROR( + CreateCycleDetectionGraph(&graph_, &graph_cycles).status()); // If we didn't return with an error above then the graph is probably // fine and we have a bug in deadness analysis. diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index bbf5d77db14..c193386afcc 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -1098,7 +1098,11 @@ Status MarkForCompilationPass::RunImpl( } GraphCycles cycles; - TF_RETURN_IF_ERROR(CreateCycleDetectionGraph(graph, &cycles)); + TF_ASSIGN_OR_RETURN(bool cycle_detection_graph_ok, + CreateCycleDetectionGraph(graph, &cycles)); + if (!cycle_detection_graph_ok) { + return Status::OK(); + } TF_RETURN_IF_ERROR(AdjustCycleDetectionGraphForResourceOps( graph, options.flib_def, IgnoreResourceOpForSafetyAnalysis, &cycles)); diff --git a/tensorflow/compiler/jit/xla_cluster_util.cc b/tensorflow/compiler/jit/xla_cluster_util.cc index eaa7015768c..cb8ac06207e 100644 --- a/tensorflow/compiler/jit/xla_cluster_util.cc +++ b/tensorflow/compiler/jit/xla_cluster_util.cc @@ -111,7 +111,8 @@ bool HasForwardedRefInput(const Node& node) { return false; } -Status CreateCycleDetectionGraph(const Graph* graph, GraphCycles* cycles) { +xla::StatusOr CreateCycleDetectionGraph(const Graph* graph, + GraphCycles* cycles) { for (int i = 0; i < graph->num_node_ids(); ++i) { // We rely on the node IDs in the cycle detection graph being consecutive // integers starting from 0. @@ -174,9 +175,11 @@ Status CreateCycleDetectionGraph(const Graph* graph, GraphCycles* cycles) { } if (!cycles->InsertEdge(src, dst)) { - return errors::Internal( - "Cycle detected when adding ", src_type, "->", dst_type, - " edge: ", DescribeCycle(cycles, *graph, src, dst)); + // TODO(b/127521408): We can probably handle this situation with a more + // sophisticated SCC based algorithm, but for now we bail out. + VLOG(1) << "Cycle detected when adding " << src_type << "->" << dst_type + << " edge: " << DescribeCycle(cycles, *graph, src, dst); + return false; } // Drop the original edge. continue; @@ -194,7 +197,8 @@ Status CreateCycleDetectionGraph(const Graph* graph, GraphCycles* cycles) { DescribeCycle(cycles, *graph, edge->src()->id(), edge->dst()->id())); } } - return Status::OK(); + + return true; } absl::optional GetXlaClusterForNode(const Node& node) { diff --git a/tensorflow/compiler/jit/xla_cluster_util.h b/tensorflow/compiler/jit/xla_cluster_util.h index ddca0aaeabb..af01e1d3023 100644 --- a/tensorflow/compiler/jit/xla_cluster_util.h +++ b/tensorflow/compiler/jit/xla_cluster_util.h @@ -20,8 +20,10 @@ limitations under the License. #include "absl/types/optional.h" #include "tensorflow/compiler/jit/graphcycles/graphcycles.h" +#include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/core/common_runtime/optimization_registry.h" #include "tensorflow/core/graph/algorithm.h" +#include "tensorflow/stream_executor/lib/statusor.h" namespace tensorflow { @@ -53,7 +55,11 @@ bool HasForwardedRefInput(const Node& node); // Creates a graph representation to enable cycle detection when clustering. // This representation handles loops in graph by disconnecting each loop from // the enclosing graph. -Status CreateCycleDetectionGraph(const Graph* graph, GraphCycles* cycles); +// +// Returns true for success and false for valid graphs that we can't handle yet +// (b/127521408). +xla::StatusOr CreateCycleDetectionGraph(const Graph* graph, + GraphCycles* cycles); // Returns the XLA cluster in which `node` is placed if it is in an XLA cluster, // otherwise returns nullopt. diff --git a/tensorflow/compiler/jit/xla_cluster_util_test.cc b/tensorflow/compiler/jit/xla_cluster_util_test.cc index 68fb4da134e..cbaac719f2e 100644 --- a/tensorflow/compiler/jit/xla_cluster_util_test.cc +++ b/tensorflow/compiler/jit/xla_cluster_util_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/ops/control_flow_ops_internal.h" #include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/framework/function_testlib.h" #include "tensorflow/core/framework/graph_to_functiondef.h" #include "tensorflow/core/graph/algorithm.h" @@ -44,7 +45,7 @@ TEST(CreateCycleDetectionGraph, ConnectivityThroughEnterExitRegion) { FixupSourceAndSinkEdges(root.graph()); GraphCycles cycles; - TF_ASSERT_OK(CreateCycleDetectionGraph(root.graph(), &cycles)); + TF_ASSERT_OK(CreateCycleDetectionGraph(root.graph(), &cycles).status()); EXPECT_FALSE(cycles.ContractEdge(a.node()->id(), b.node()->id())); } @@ -63,10 +64,33 @@ TEST(CreateCycleDetectionGraph, ConnectivityThroughMultipleEnterExitRegions) { FixupSourceAndSinkEdges(root.graph()); GraphCycles cycles; - TF_ASSERT_OK(CreateCycleDetectionGraph(root.graph(), &cycles)); + TF_ASSERT_OK(CreateCycleDetectionGraph(root.graph(), &cycles).status()); EXPECT_FALSE(cycles.ContractEdge(a.node()->id(), b.node()->id())); } +TEST(CreateCycleDetectionGraph, ReachingEnterExit) { + // TODO(b/127521408): We can lift this limitation with some work. + Scope root = Scope::NewRootScope().ExitOnError(); + + Output a = ops::Const(root.WithOpName("a"), Input::Initializer(0.0)); + Output enter_0 = + ops::internal::Enter(root.WithOpName("enter_0"), a, "frame_0"); + Output exit_0 = ops::internal::Exit(root.WithOpName("exit_0"), enter_0); + + Output add = ops::Add(root.WithOpName("add"), exit_0, exit_0); + + Output enter_1 = + ops::internal::Enter(root.WithOpName("enter_1"), add, "frame_0"); + Output exit_1 = ops::internal::Exit(root.WithOpName("exit_1"), enter_1); + + FixupSourceAndSinkEdges(root.graph()); + + GraphCycles cycles; + TF_ASSERT_OK_AND_ASSIGN(bool ok, + CreateCycleDetectionGraph(root.graph(), &cycles)); + EXPECT_FALSE(ok); +} + void CheckPickDeviceResult(absl::string_view expected_result, bool allow_mixing_unknown_and_cpu, absl::Span inputs) { diff --git a/tensorflow/compiler/jit/xla_fusion_optimizer.cc b/tensorflow/compiler/jit/xla_fusion_optimizer.cc index bc0db558d8d..a2a06f57698 100644 --- a/tensorflow/compiler/jit/xla_fusion_optimizer.cc +++ b/tensorflow/compiler/jit/xla_fusion_optimizer.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/compiler/jit/graphcycles/graphcycles.h" #include "tensorflow/compiler/jit/union_find.h" #include "tensorflow/compiler/jit/xla_cluster_util.h" +#include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/common_runtime/shape_refiner.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/graph/graph_constructor.h" @@ -208,7 +209,12 @@ Status XlaFusionOptimizer::Optimize(grappler::Cluster* cluster, } GraphCycles cycles; - TF_RETURN_IF_ERROR(CreateCycleDetectionGraph(&graph, &cycles)); + TF_ASSIGN_OR_RETURN(bool cycle_detection_graph_ok, + CreateCycleDetectionGraph(&graph, &cycles)); + if (!cycle_detection_graph_ok) { + return Status::OK(); + } + TF_RETURN_IF_ERROR(AdjustCycleDetectionGraphForResourceOps( &graph, &graph.flib_def(), /*resource_ops_to_ignore=*/{}, &cycles));