Add a workaround to the cycle detection algorithm for certain unusual graphs
Our cycle detection graph has special logic to deal with frames and NextIteration nodes and this logic does not do the right thing in certain cases. The motivation for the current scheme is mentioned in CL 137756517 but I think it goes wrong for graphs like: digraph { Enter_0 [label="Enter('frame')"] Enter_1 [label="Enter('frame')"] Exit_0 [label="Exit"] Exit_1 [label="Exit"] SRC -> Enter_0 Enter_0 -> Exit_0 Exit_0 -> A A -> Enter_1 Enter_1 -> Exit_1 } for which we try to create the cycle detection graph: digraph { SRC -> "FakeNode(frame)" "FakeNode(frame)" -> A A -> "FakeNode(frame)" "FakeNode(frame)" -> Exit } which, of course, is cyclic. I'm not sure how frequent this is in practice, so for now I'll add a workaround that avoids auto-clustering in this situation. But we may want to eventually implement a more principled fix, possibly based on some from of SCC detection. PiperOrigin-RevId: 237154773
This commit is contained in:
parent
10b1a5c578
commit
58e052bd77
@ -560,11 +560,13 @@ cc_library(
|
||||
":resource_operation_safety_analysis",
|
||||
"//tensorflow/compiler/jit/graphcycles",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:framework_bounds_check",
|
||||
"//tensorflow/core:graph",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
@ -692,6 +694,7 @@ tf_cc_test(
|
||||
"//tensorflow/compiler/jit/kernels:xla_ops",
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:framework_internal",
|
||||
@ -715,6 +718,7 @@ cc_library(
|
||||
":union_find",
|
||||
":xla_cluster_util",
|
||||
"//tensorflow/compiler/jit/graphcycles",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/core:core_cpu_base",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/grappler:grappler_item",
|
||||
|
@ -768,7 +768,8 @@ Status DeadnessAnalysisImpl::GetInputPreds(
|
||||
auto it = predicate_map_.find(InputEdgeToTensorId(in_edge));
|
||||
if (it == predicate_map_.end()) {
|
||||
GraphCycles graph_cycles;
|
||||
TF_RETURN_IF_ERROR(CreateCycleDetectionGraph(&graph_, &graph_cycles));
|
||||
TF_RETURN_IF_ERROR(
|
||||
CreateCycleDetectionGraph(&graph_, &graph_cycles).status());
|
||||
|
||||
// If we didn't return with an error above then the graph is probably
|
||||
// fine and we have a bug in deadness analysis.
|
||||
|
@ -1098,7 +1098,11 @@ Status MarkForCompilationPass::RunImpl(
|
||||
}
|
||||
|
||||
GraphCycles cycles;
|
||||
TF_RETURN_IF_ERROR(CreateCycleDetectionGraph(graph, &cycles));
|
||||
TF_ASSIGN_OR_RETURN(bool cycle_detection_graph_ok,
|
||||
CreateCycleDetectionGraph(graph, &cycles));
|
||||
if (!cycle_detection_graph_ok) {
|
||||
return Status::OK();
|
||||
}
|
||||
TF_RETURN_IF_ERROR(AdjustCycleDetectionGraphForResourceOps(
|
||||
graph, options.flib_def, IgnoreResourceOpForSafetyAnalysis, &cycles));
|
||||
|
||||
|
@ -111,7 +111,8 @@ bool HasForwardedRefInput(const Node& node) {
|
||||
return false;
|
||||
}
|
||||
|
||||
Status CreateCycleDetectionGraph(const Graph* graph, GraphCycles* cycles) {
|
||||
xla::StatusOr<bool> CreateCycleDetectionGraph(const Graph* graph,
|
||||
GraphCycles* cycles) {
|
||||
for (int i = 0; i < graph->num_node_ids(); ++i) {
|
||||
// We rely on the node IDs in the cycle detection graph being consecutive
|
||||
// integers starting from 0.
|
||||
@ -174,9 +175,11 @@ Status CreateCycleDetectionGraph(const Graph* graph, GraphCycles* cycles) {
|
||||
}
|
||||
|
||||
if (!cycles->InsertEdge(src, dst)) {
|
||||
return errors::Internal(
|
||||
"Cycle detected when adding ", src_type, "->", dst_type,
|
||||
" edge: ", DescribeCycle(cycles, *graph, src, dst));
|
||||
// TODO(b/127521408): We can probably handle this situation with a more
|
||||
// sophisticated SCC based algorithm, but for now we bail out.
|
||||
VLOG(1) << "Cycle detected when adding " << src_type << "->" << dst_type
|
||||
<< " edge: " << DescribeCycle(cycles, *graph, src, dst);
|
||||
return false;
|
||||
}
|
||||
// Drop the original edge.
|
||||
continue;
|
||||
@ -194,7 +197,8 @@ Status CreateCycleDetectionGraph(const Graph* graph, GraphCycles* cycles) {
|
||||
DescribeCycle(cycles, *graph, edge->src()->id(), edge->dst()->id()));
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
absl::optional<absl::string_view> GetXlaClusterForNode(const Node& node) {
|
||||
|
@ -20,8 +20,10 @@ limitations under the License.
|
||||
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/core/common_runtime/optimization_registry.h"
|
||||
#include "tensorflow/core/graph/algorithm.h"
|
||||
#include "tensorflow/stream_executor/lib/statusor.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
@ -53,7 +55,11 @@ bool HasForwardedRefInput(const Node& node);
|
||||
// Creates a graph representation to enable cycle detection when clustering.
|
||||
// This representation handles loops in graph by disconnecting each loop from
|
||||
// the enclosing graph.
|
||||
Status CreateCycleDetectionGraph(const Graph* graph, GraphCycles* cycles);
|
||||
//
|
||||
// Returns true for success and false for valid graphs that we can't handle yet
|
||||
// (b/127521408).
|
||||
xla::StatusOr<bool> CreateCycleDetectionGraph(const Graph* graph,
|
||||
GraphCycles* cycles);
|
||||
|
||||
// Returns the XLA cluster in which `node` is placed if it is in an XLA cluster,
|
||||
// otherwise returns nullopt.
|
||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||
#include "tensorflow/cc/framework/ops.h"
|
||||
#include "tensorflow/cc/ops/control_flow_ops_internal.h"
|
||||
#include "tensorflow/cc/ops/standard_ops.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/core/framework/function_testlib.h"
|
||||
#include "tensorflow/core/framework/graph_to_functiondef.h"
|
||||
#include "tensorflow/core/graph/algorithm.h"
|
||||
@ -44,7 +45,7 @@ TEST(CreateCycleDetectionGraph, ConnectivityThroughEnterExitRegion) {
|
||||
FixupSourceAndSinkEdges(root.graph());
|
||||
|
||||
GraphCycles cycles;
|
||||
TF_ASSERT_OK(CreateCycleDetectionGraph(root.graph(), &cycles));
|
||||
TF_ASSERT_OK(CreateCycleDetectionGraph(root.graph(), &cycles).status());
|
||||
EXPECT_FALSE(cycles.ContractEdge(a.node()->id(), b.node()->id()));
|
||||
}
|
||||
|
||||
@ -63,10 +64,33 @@ TEST(CreateCycleDetectionGraph, ConnectivityThroughMultipleEnterExitRegions) {
|
||||
FixupSourceAndSinkEdges(root.graph());
|
||||
|
||||
GraphCycles cycles;
|
||||
TF_ASSERT_OK(CreateCycleDetectionGraph(root.graph(), &cycles));
|
||||
TF_ASSERT_OK(CreateCycleDetectionGraph(root.graph(), &cycles).status());
|
||||
EXPECT_FALSE(cycles.ContractEdge(a.node()->id(), b.node()->id()));
|
||||
}
|
||||
|
||||
TEST(CreateCycleDetectionGraph, ReachingEnterExit) {
|
||||
// TODO(b/127521408): We can lift this limitation with some work.
|
||||
Scope root = Scope::NewRootScope().ExitOnError();
|
||||
|
||||
Output a = ops::Const(root.WithOpName("a"), Input::Initializer(0.0));
|
||||
Output enter_0 =
|
||||
ops::internal::Enter(root.WithOpName("enter_0"), a, "frame_0");
|
||||
Output exit_0 = ops::internal::Exit(root.WithOpName("exit_0"), enter_0);
|
||||
|
||||
Output add = ops::Add(root.WithOpName("add"), exit_0, exit_0);
|
||||
|
||||
Output enter_1 =
|
||||
ops::internal::Enter(root.WithOpName("enter_1"), add, "frame_0");
|
||||
Output exit_1 = ops::internal::Exit(root.WithOpName("exit_1"), enter_1);
|
||||
|
||||
FixupSourceAndSinkEdges(root.graph());
|
||||
|
||||
GraphCycles cycles;
|
||||
TF_ASSERT_OK_AND_ASSIGN(bool ok,
|
||||
CreateCycleDetectionGraph(root.graph(), &cycles));
|
||||
EXPECT_FALSE(ok);
|
||||
}
|
||||
|
||||
void CheckPickDeviceResult(absl::string_view expected_result,
|
||||
bool allow_mixing_unknown_and_cpu,
|
||||
absl::Span<const absl::string_view> inputs) {
|
||||
|
@ -26,6 +26,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
|
||||
#include "tensorflow/compiler/jit/union_find.h"
|
||||
#include "tensorflow/compiler/jit/xla_cluster_util.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/core/common_runtime/shape_refiner.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/graph/graph_constructor.h"
|
||||
@ -208,7 +209,12 @@ Status XlaFusionOptimizer::Optimize(grappler::Cluster* cluster,
|
||||
}
|
||||
|
||||
GraphCycles cycles;
|
||||
TF_RETURN_IF_ERROR(CreateCycleDetectionGraph(&graph, &cycles));
|
||||
TF_ASSIGN_OR_RETURN(bool cycle_detection_graph_ok,
|
||||
CreateCycleDetectionGraph(&graph, &cycles));
|
||||
if (!cycle_detection_graph_ok) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(AdjustCycleDetectionGraphForResourceOps(
|
||||
&graph, &graph.flib_def(), /*resource_ops_to_ignore=*/{}, &cycles));
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user