From e321ae5e32859b459879706fea931f4b352be7f6 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Sat, 6 May 2017 09:14:16 -0800 Subject: [PATCH] Fix bug when inlining function calls inside a control-flow construct. Add control dependencies from a control frame's pivot node onto all operators within the frame. Function inlining may introduce extra nodes in the graph, and those newly introduced nodes must depend on the enclosing control frame's pivot node. However, since the inlining transformation only looks locally at the function call node, it has no way to know that it should do this. There are two possible solutions: either (a) teach the function inliner about control flow frames, or (b) add a control dependency from the pivot node to function call node and have the inliner propagate that dependency. This change does the latter. When inlining a function call, propagate control dependencies from the function call node onto each inlined operator with no data dependencies, and onto any nested function calls or SymbolicGradients. Change: 155289198 --- tensorflow/core/common_runtime/function.cc | 76 +++++++++++------ .../core/common_runtime/function_test.cc | 84 +++++++++++++++++++ tensorflow/python/framework/function_test.py | 42 ++++++++++ tensorflow/python/ops/control_flow_ops.py | 10 +++ 4 files changed, 188 insertions(+), 24 deletions(-) diff --git a/tensorflow/core/common_runtime/function.cc b/tensorflow/core/common_runtime/function.cc index 3644279b920..4be22f82606 100644 --- a/tensorflow/core/common_runtime/function.cc +++ b/tensorflow/core/common_runtime/function.cc @@ -829,7 +829,8 @@ static bool ValidateInlining(const Node* node, const FunctionBody* fbody) { // Given a "caller" in "graph", which is a function call of a function // to "fbody". Replaces the "caller" with fbody->graph and connects // edges properly. -static void InlineFunctionBody(Graph* g, Node* caller, +static void InlineFunctionBody(const FunctionLibraryDefinition& flib_def, + Graph* g, Node* caller, const FunctionBody* fbody) { if (!ValidateInlining(caller, fbody)) { LOG(WARNING) << "Inlining mismatch: " << caller->DebugString() << " vs. " @@ -837,6 +838,23 @@ static void InlineFunctionBody(Graph* g, Node* caller, return; } + // Input edges. For data edges coming into "caller", we first compute the + // : for the i-th input in "inputs". + // If "caller" has any input control dependencies, we add a NoOp + // node "input_control_node", which depends on "caller"'s control inputs. + std::vector inputs(caller->num_inputs()); + Node* input_control_node = nullptr; + for (const Edge* e : caller->in_edges()) { + if (e->IsControlEdge()) { + if (input_control_node == nullptr) { + input_control_node = AddNoOp(g); + } + g->AddControlEdge(e->src(), input_control_node); + } else { + inputs[e->dst_input()] = {e->src(), e->src_output()}; + } + } + // Duplicate fbody->graph into 'g'. First, we copy the nodes of // fbody->graph into 'g' except the source and sink nodes. We copy // edges among nodes in 'fbody->graph'. @@ -850,8 +868,35 @@ static void InlineFunctionBody(Graph* g, Node* caller, CHECK(n->IsOp()); NodeDef ndef = n->def(); ndef.set_name(strings::StrCat(caller->name(), "/", ndef.name())); - node_map[n->id()] = g->AddNode(ndef, &s); + Node* clone = g->AddNode(ndef, &s); TF_CHECK_OK(s); + node_map[n->id()] = clone; + + // If there is an input control node, and one of: + // a) the node has no data or control inputs, or + // b) the node is a function call or SymbolicGradient, + // then add a control edge from the input control node to the clone. + // + // We must not execute any nodes if the original function call would not + // have executed. This is especially critical when the function call is + // inside a control-flow construct like tf.cond(). Case (a) ensures that + // such nodes do not run. + // + // The purpose of case (b) is to ensure that instances of case (a) created + // by further inlining steps also receive the control dependency. + if (input_control_node) { + bool has_inputs = false; + for (const Edge* e : n->in_edges()) { + if (!e->src()->IsSource()) { + has_inputs = true; + break; + } + } + if (!has_inputs || flib_def.Find(clone->type_string()) != nullptr || + clone->type_string() == "SymbolicGradient") { + g->AddControlEdge(input_control_node, clone); + } + } } for (const Edge* e : fbody->graph->edges()) { if (e->src()->IsSource() || e->src()->IsSink() || e->dst()->IsSource() || @@ -865,29 +910,12 @@ static void InlineFunctionBody(Graph* g, Node* caller, // Connect input edges. // - // For data edges coming into "caller", we first compute the - // : for the i-th input in "inputs". We create one - // Identity node for each input. Then, we connect inputs[i] to to - // the i-th identity node added. The nodes that previously connects - // to the j-th output of i-th arg node are reconnected to th i-th + // We create one Identity node for each input. Then, we connect inputs[i] to + // the i-th identity node added. The nodes that previously connected + // to the j-th output of i-th arg node are reconnected to the i-th // identity node. // - // If "caller" has any input control dependencies, we add a NoOp - // node "input_control_node". This "input_control_node" depends on - // what "caller" depends on, and the added identity nodes depend on - // "input_control_node". - std::vector inputs(caller->num_inputs()); - Node* input_control_node = nullptr; - for (const Edge* e : caller->in_edges()) { - if (e->IsControlEdge()) { - if (input_control_node == nullptr) { - input_control_node = AddNoOp(g); - } - g->AddControlEdge(e->src(), input_control_node); - } else { - inputs[e->dst_input()] = {e->src(), e->src_output()}; - } - } + // The added identity nodes depend on "input_control_node". for (std::size_t i = 0; i < fbody->arg_nodes.size(); ++i) { Node* arg = node_map[fbody->arg_nodes[i]->id()]; Node* n = AddIdentity(g, inputs[i]); @@ -982,7 +1010,7 @@ bool ExpandInlineFunctions(FunctionLibraryRuntime* lib, Graph* graph) { candidates.push_back({node, fbody}); } for (const auto& p : candidates) { - InlineFunctionBody(graph, p.first, p.second); + InlineFunctionBody(*fld, graph, p.first, p.second); } return !candidates.empty(); } diff --git a/tensorflow/core/common_runtime/function_test.cc b/tensorflow/core/common_runtime/function_test.cc index af1ff6aec03..dfa1ed8a7e4 100644 --- a/tensorflow/core/common_runtime/function_test.cc +++ b/tensorflow/core/common_runtime/function_test.cc @@ -391,6 +391,90 @@ TEST_F(FunctionLibraryRuntimeTest, ExpandInlineFunctions) { } } +// Verifies that control dependencies on the caller are added as control +// dependencies on any function calls created by inlining. +TEST_F(FunctionLibraryRuntimeTest, ExpandInlineFunctionsWithControlDeps) { + Init({test::function::XTimesTwo(), test::function::XTimesFour()}); + + std::unique_ptr g(new Graph(OpRegistry::Global())); + { + Scope s = Scope::NewRootScope(); + TF_ASSERT_OK(s.graph()->AddFunctionLibrary(fdef_lib_)); + auto a = ops::_Arg(s.WithOpName("a"), DT_FLOAT, 0); + auto c = ops::NoOp(s.WithOpName("c")); + auto b = Call(&s, "b", "XTimesFour", {a}); + s.graph()->AddControlEdge(c.operation.node(), b.node()); + auto ret = ops::_Retval(s.WithOpName("b_RetVal"), b, 0); + TF_ASSERT_OK(s.ToGraph(g.get())); + } + + ExpandInlineFunctions(lib_.get(), g.get()); + { + Scope s = Scope::NewRootScope(); + TF_ASSERT_OK(s.graph()->AddFunctionLibrary(fdef_lib_)); + auto a = ops::_Arg(s.WithOpName("a"), DT_FLOAT, 0); + auto c = ops::NoOp(s.WithOpName("c")); + auto func0 = + ops::NoOp(s.WithOpName("Func/_0").WithControlDependencies({c})); + auto func1 = ops::Identity( + s.WithOpName("Func/_1").WithControlDependencies({func0}), a); + auto b_x2 = Call(&s, "b/x2", "XTimesTwo", {func1}); + s.graph()->AddControlEdge(func0.operation.node(), b_x2.node()); + auto b_y = Call(&s, "b/y", "XTimesTwo", {b_x2}); + s.graph()->AddControlEdge(func0.operation.node(), b_y.node()); + auto func2 = ops::Identity(s.WithOpName("Func/_2"), b_y); + auto ret = ops::_Retval(s.WithOpName("b_RetVal"), func2, 0); + GraphDef expected; + TF_ASSERT_OK(s.ToGraphDef(&expected)); + + GraphDef actual; + g->ToGraphDef(&actual); + TF_EXPECT_GRAPH_EQ(expected, actual); + } + + ExpandInlineFunctions(lib_.get(), g.get()); + { + Scope s = Scope::NewRootScope(); + TF_ASSERT_OK(s.graph()->AddFunctionLibrary(fdef_lib_)); + auto a = ops::_Arg(s.WithOpName("a"), DT_FLOAT, 0); + auto c = ops::NoOp(s.WithOpName("c")); + auto func0 = + ops::NoOp(s.WithOpName("Func/_0").WithControlDependencies({c})); + auto func1 = ops::Identity( + s.WithOpName("Func/_1").WithControlDependencies({func0}), a); + + auto func3 = + ops::NoOp(s.WithOpName("Func/_3").WithControlDependencies({func0})); + auto func4 = ops::Identity( + s.WithOpName("Func/_4").WithControlDependencies({func3}), func1); + auto b_x2_two = ops::Const( + s.WithOpName("b/x2/two").WithControlDependencies({func3}), 2LL); + auto b_x2_scale = ops::Cast(s.WithOpName("b/x2/scale"), b_x2_two, DT_FLOAT); + auto b_x2_y = ops::Mul(s.WithOpName("b/x2/y"), func4, b_x2_scale); + auto func5 = ops::Identity(s.WithOpName("Func/_5"), b_x2_y); + + auto func6 = + ops::NoOp(s.WithOpName("Func/_6").WithControlDependencies({func0})); + auto func7 = ops::Identity( + s.WithOpName("Func/_7").WithControlDependencies({func6}), func5); + auto b_y_two = ops::Const( + s.WithOpName("b/y/two").WithControlDependencies({func6}), 2LL); + auto b_y_scale = ops::Cast(s.WithOpName("b/y/scale"), b_y_two, DT_FLOAT); + auto b_y_y = ops::Mul(s.WithOpName("b/y/y"), func7, b_y_scale); + auto func8 = ops::Identity(s.WithOpName("Func/_8"), b_y_y); + + auto func2 = ops::Identity(s.WithOpName("Func/_2"), func8); + auto ret = ops::_Retval(s.WithOpName("b_RetVal"), func2, 0); + + GraphDef expected; + TF_ASSERT_OK(s.ToGraphDef(&expected)); + + GraphDef actual; + g->ToGraphDef(&actual); + TF_EXPECT_GRAPH_EQ(expected, actual); + } +} + TEST_F(FunctionLibraryRuntimeTest, OptimizeGraph) { Init({test::function::XTimesTwo(), test::function::XTimesFour(), test::function::XTimes16()}); diff --git a/tensorflow/python/framework/function_test.py b/tensorflow/python/framework/function_test.py index 39f00e52169..416ab263afc 100644 --- a/tensorflow/python/framework/function_test.py +++ b/tensorflow/python/framework/function_test.py @@ -324,6 +324,48 @@ class FunctionTest(test.TestCase): "assertion"): _ = MyFn(100.0).eval() + def testControlFlowStrictness(self): + """Inlined functions must not execute in a untaken control flow branch.""" + + @function.Defun(dtypes.int32) + def AssertFail(x): + # Assertion that always fails and does not have a data dependency on `x`. + assert_false = control_flow_ops.Assert(False, [42]) + with ops.control_dependencies([assert_false]): + return array_ops.identity(x) + + with ops.device("CPU"): + pred = array_ops.placeholder(dtypes.bool) + x = array_ops.placeholder(dtypes.int32) + cond = control_flow_ops.cond(pred, lambda: x + 1, lambda: AssertFail(x)) + # pylint: disable=unnecessary-lambda + loop = control_flow_ops.while_loop(lambda y: pred, + lambda y: AssertFail(y), [x]) + # pylint: enable=unnecessary-lambda + + # Enables inlining. + config = config_pb2.ConfigProto(graph_options=config_pb2.GraphOptions( + optimizer_options=config_pb2.OptimizerOptions( + opt_level=config_pb2.OptimizerOptions.L0, + do_common_subexpression_elimination=True, + do_function_inlining=True, + do_constant_folding=True))) + + with session.Session(config=config) as sess: + # Since the 'False' branch is not taken, the assertion should not fire. + self.assertEqual(4, sess.run(cond, {pred: True, x: 3})) + + # The assertion should still fire if the False branch is taken. + with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, + "assertion"): + sess.run(cond, {pred: False, x: 3}) + + # Similarly for loops. + self.assertEqual(3, sess.run(loop, {pred: False, x: 3})) + with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, + "assertion"): + sess.run(loop, {pred: True, x: 3}) + def testVar(self): @function.Defun(dtypes.float32) diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py index 91be9ddbd78..713cb65a40d 100644 --- a/tensorflow/python/ops/control_flow_ops.py +++ b/tensorflow/python/ops/control_flow_ops.py @@ -1622,6 +1622,11 @@ class CondContext(ControlFlowContext): # pylint: enable=protected-access for x in op.outputs: self._values.add(x.name) + # pylint: disable=protected-access + if op.graph._is_function(op.type) or op.type == "SymbolicGradient": + op._add_control_input(self._pivot.op) + # pylint: enable=protected-access + if self._outer_context or not IsLoopExit(op): op.graph.prevent_fetching(op) @@ -2147,8 +2152,13 @@ class WhileContext(ControlFlowContext): def _MaybeAddControlDependency(self, op): """Add a control input to the op if it only depends on loop invariants.""" def _IsOpFree(op): + """Determines if `op` needs a control dependency.""" if op.control_inputs: return False + # pylint: disable=protected-access + if op.graph._is_function(op.type) or op.type == "SymbolicGradient": + return True + # pylint: enable=protected-access for x in op.inputs: if not _IsLoopConstantEnter(x.op): return False