Add a workaround to the cycle detection algorithm for certain unusual graphs

Our cycle detection graph has special logic to deal with frames and
NextIteration nodes and this logic does not do the right thing in certain cases.
The motivation for the current scheme is mentioned in CL 137756517 but I think
it goes wrong for graphs like:

digraph {
  Enter_0 [label="Enter('frame')"]
  Enter_1 [label="Enter('frame')"]
  Exit_0 [label="Exit"]
  Exit_1 [label="Exit"]

  SRC -> Enter_0
  Enter_0 -> Exit_0
  Exit_0 -> A
  A -> Enter_1
  Enter_1 -> Exit_1
}

for which we try to create the cycle detection graph:

digraph {
  SRC -> "FakeNode(frame)"
  "FakeNode(frame)" -> A
  A -> "FakeNode(frame)"
  "FakeNode(frame)" -> Exit
}

which, of course, is cyclic.

I'm not sure how frequent this is in practice, so for now I'll add a workaround
that avoids auto-clustering in this situation.  But we may want to eventually
implement a more principled fix, possibly based on some from of SCC detection.

PiperOrigin-RevId: 237154773
This commit is contained in:
Sanjoy Das 2019-03-06 17:43:27 -08:00 committed by TensorFlower Gardener
parent 10b1a5c578
commit 58e052bd77
7 changed files with 60 additions and 11 deletions

View File

@ -560,11 +560,13 @@ cc_library(
":resource_operation_safety_analysis",
"//tensorflow/compiler/jit/graphcycles",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:framework_bounds_check",
"//tensorflow/core:graph",
"//tensorflow/core:protos_all_cc",
"//tensorflow/stream_executor/lib",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
@ -692,6 +694,7 @@ tf_cc_test(
"//tensorflow/compiler/jit/kernels:xla_ops",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal",
@ -715,6 +718,7 @@ cc_library(
":union_find",
":xla_cluster_util",
"//tensorflow/compiler/jit/graphcycles",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/core:core_cpu_base",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/grappler:grappler_item",

View File

@ -768,7 +768,8 @@ Status DeadnessAnalysisImpl::GetInputPreds(
auto it = predicate_map_.find(InputEdgeToTensorId(in_edge));
if (it == predicate_map_.end()) {
GraphCycles graph_cycles;
TF_RETURN_IF_ERROR(CreateCycleDetectionGraph(&graph_, &graph_cycles));
TF_RETURN_IF_ERROR(
CreateCycleDetectionGraph(&graph_, &graph_cycles).status());
// If we didn't return with an error above then the graph is probably
// fine and we have a bug in deadness analysis.

View File

@ -1098,7 +1098,11 @@ Status MarkForCompilationPass::RunImpl(
}
GraphCycles cycles;
TF_RETURN_IF_ERROR(CreateCycleDetectionGraph(graph, &cycles));
TF_ASSIGN_OR_RETURN(bool cycle_detection_graph_ok,
CreateCycleDetectionGraph(graph, &cycles));
if (!cycle_detection_graph_ok) {
return Status::OK();
}
TF_RETURN_IF_ERROR(AdjustCycleDetectionGraphForResourceOps(
graph, options.flib_def, IgnoreResourceOpForSafetyAnalysis, &cycles));

View File

@ -111,7 +111,8 @@ bool HasForwardedRefInput(const Node& node) {
return false;
}
Status CreateCycleDetectionGraph(const Graph* graph, GraphCycles* cycles) {
xla::StatusOr<bool> CreateCycleDetectionGraph(const Graph* graph,
GraphCycles* cycles) {
for (int i = 0; i < graph->num_node_ids(); ++i) {
// We rely on the node IDs in the cycle detection graph being consecutive
// integers starting from 0.
@ -174,9 +175,11 @@ Status CreateCycleDetectionGraph(const Graph* graph, GraphCycles* cycles) {
}
if (!cycles->InsertEdge(src, dst)) {
return errors::Internal(
"Cycle detected when adding ", src_type, "->", dst_type,
" edge: ", DescribeCycle(cycles, *graph, src, dst));
// TODO(b/127521408): We can probably handle this situation with a more
// sophisticated SCC based algorithm, but for now we bail out.
VLOG(1) << "Cycle detected when adding " << src_type << "->" << dst_type
<< " edge: " << DescribeCycle(cycles, *graph, src, dst);
return false;
}
// Drop the original edge.
continue;
@ -194,7 +197,8 @@ Status CreateCycleDetectionGraph(const Graph* graph, GraphCycles* cycles) {
DescribeCycle(cycles, *graph, edge->src()->id(), edge->dst()->id()));
}
}
return Status::OK();
return true;
}
absl::optional<absl::string_view> GetXlaClusterForNode(const Node& node) {

View File

@ -20,8 +20,10 @@ limitations under the License.
#include "absl/types/optional.h"
#include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/core/common_runtime/optimization_registry.h"
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/stream_executor/lib/statusor.h"
namespace tensorflow {
@ -53,7 +55,11 @@ bool HasForwardedRefInput(const Node& node);
// Creates a graph representation to enable cycle detection when clustering.
// This representation handles loops in graph by disconnecting each loop from
// the enclosing graph.
Status CreateCycleDetectionGraph(const Graph* graph, GraphCycles* cycles);
//
// Returns true for success and false for valid graphs that we can't handle yet
// (b/127521408).
xla::StatusOr<bool> CreateCycleDetectionGraph(const Graph* graph,
GraphCycles* cycles);
// Returns the XLA cluster in which `node` is placed if it is in an XLA cluster,
// otherwise returns nullopt.

View File

@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/cc/framework/ops.h"
#include "tensorflow/cc/ops/control_flow_ops_internal.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/core/framework/function_testlib.h"
#include "tensorflow/core/framework/graph_to_functiondef.h"
#include "tensorflow/core/graph/algorithm.h"
@ -44,7 +45,7 @@ TEST(CreateCycleDetectionGraph, ConnectivityThroughEnterExitRegion) {
FixupSourceAndSinkEdges(root.graph());
GraphCycles cycles;
TF_ASSERT_OK(CreateCycleDetectionGraph(root.graph(), &cycles));
TF_ASSERT_OK(CreateCycleDetectionGraph(root.graph(), &cycles).status());
EXPECT_FALSE(cycles.ContractEdge(a.node()->id(), b.node()->id()));
}
@ -63,10 +64,33 @@ TEST(CreateCycleDetectionGraph, ConnectivityThroughMultipleEnterExitRegions) {
FixupSourceAndSinkEdges(root.graph());
GraphCycles cycles;
TF_ASSERT_OK(CreateCycleDetectionGraph(root.graph(), &cycles));
TF_ASSERT_OK(CreateCycleDetectionGraph(root.graph(), &cycles).status());
EXPECT_FALSE(cycles.ContractEdge(a.node()->id(), b.node()->id()));
}
TEST(CreateCycleDetectionGraph, ReachingEnterExit) {
// TODO(b/127521408): We can lift this limitation with some work.
Scope root = Scope::NewRootScope().ExitOnError();
Output a = ops::Const(root.WithOpName("a"), Input::Initializer(0.0));
Output enter_0 =
ops::internal::Enter(root.WithOpName("enter_0"), a, "frame_0");
Output exit_0 = ops::internal::Exit(root.WithOpName("exit_0"), enter_0);
Output add = ops::Add(root.WithOpName("add"), exit_0, exit_0);
Output enter_1 =
ops::internal::Enter(root.WithOpName("enter_1"), add, "frame_0");
Output exit_1 = ops::internal::Exit(root.WithOpName("exit_1"), enter_1);
FixupSourceAndSinkEdges(root.graph());
GraphCycles cycles;
TF_ASSERT_OK_AND_ASSIGN(bool ok,
CreateCycleDetectionGraph(root.graph(), &cycles));
EXPECT_FALSE(ok);
}
void CheckPickDeviceResult(absl::string_view expected_result,
bool allow_mixing_unknown_and_cpu,
absl::Span<const absl::string_view> inputs) {

View File

@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
#include "tensorflow/compiler/jit/union_find.h"
#include "tensorflow/compiler/jit/xla_cluster_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/core/common_runtime/shape_refiner.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/graph/graph_constructor.h"
@ -208,7 +209,12 @@ Status XlaFusionOptimizer::Optimize(grappler::Cluster* cluster,
}
GraphCycles cycles;
TF_RETURN_IF_ERROR(CreateCycleDetectionGraph(&graph, &cycles));
TF_ASSIGN_OR_RETURN(bool cycle_detection_graph_ok,
CreateCycleDetectionGraph(&graph, &cycles));
if (!cycle_detection_graph_ok) {
return Status::OK();
}
TF_RETURN_IF_ERROR(AdjustCycleDetectionGraphForResourceOps(
&graph, &graph.flib_def(), /*resource_ops_to_ignore=*/{}, &cycles));