diff --git a/tensorflow/core/common_runtime/function.cc b/tensorflow/core/common_runtime/function.cc index 3644279b920..4be22f82606 100644 --- a/tensorflow/core/common_runtime/function.cc +++ b/tensorflow/core/common_runtime/function.cc @@ -829,7 +829,8 @@ static bool ValidateInlining(const Node* node, const FunctionBody* fbody) { // Given a "caller" in "graph", which is a function call of a function // to "fbody". Replaces the "caller" with fbody->graph and connects // edges properly. -static void InlineFunctionBody(Graph* g, Node* caller, +static void InlineFunctionBody(const FunctionLibraryDefinition& flib_def, + Graph* g, Node* caller, const FunctionBody* fbody) { if (!ValidateInlining(caller, fbody)) { LOG(WARNING) << "Inlining mismatch: " << caller->DebugString() << " vs. " @@ -837,6 +838,23 @@ static void InlineFunctionBody(Graph* g, Node* caller, return; } + // Input edges. For data edges coming into "caller", we first compute the + // : for the i-th input in "inputs". + // If "caller" has any input control dependencies, we add a NoOp + // node "input_control_node", which depends on "caller"'s control inputs. + std::vector inputs(caller->num_inputs()); + Node* input_control_node = nullptr; + for (const Edge* e : caller->in_edges()) { + if (e->IsControlEdge()) { + if (input_control_node == nullptr) { + input_control_node = AddNoOp(g); + } + g->AddControlEdge(e->src(), input_control_node); + } else { + inputs[e->dst_input()] = {e->src(), e->src_output()}; + } + } + // Duplicate fbody->graph into 'g'. First, we copy the nodes of // fbody->graph into 'g' except the source and sink nodes. We copy // edges among nodes in 'fbody->graph'. @@ -850,8 +868,35 @@ static void InlineFunctionBody(Graph* g, Node* caller, CHECK(n->IsOp()); NodeDef ndef = n->def(); ndef.set_name(strings::StrCat(caller->name(), "/", ndef.name())); - node_map[n->id()] = g->AddNode(ndef, &s); + Node* clone = g->AddNode(ndef, &s); TF_CHECK_OK(s); + node_map[n->id()] = clone; + + // If there is an input control node, and one of: + // a) the node has no data or control inputs, or + // b) the node is a function call or SymbolicGradient, + // then add a control edge from the input control node to the clone. + // + // We must not execute any nodes if the original function call would not + // have executed. This is especially critical when the function call is + // inside a control-flow construct like tf.cond(). Case (a) ensures that + // such nodes do not run. + // + // The purpose of case (b) is to ensure that instances of case (a) created + // by further inlining steps also receive the control dependency. + if (input_control_node) { + bool has_inputs = false; + for (const Edge* e : n->in_edges()) { + if (!e->src()->IsSource()) { + has_inputs = true; + break; + } + } + if (!has_inputs || flib_def.Find(clone->type_string()) != nullptr || + clone->type_string() == "SymbolicGradient") { + g->AddControlEdge(input_control_node, clone); + } + } } for (const Edge* e : fbody->graph->edges()) { if (e->src()->IsSource() || e->src()->IsSink() || e->dst()->IsSource() || @@ -865,29 +910,12 @@ static void InlineFunctionBody(Graph* g, Node* caller, // Connect input edges. // - // For data edges coming into "caller", we first compute the - // : for the i-th input in "inputs". We create one - // Identity node for each input. Then, we connect inputs[i] to to - // the i-th identity node added. The nodes that previously connects - // to the j-th output of i-th arg node are reconnected to th i-th + // We create one Identity node for each input. Then, we connect inputs[i] to + // the i-th identity node added. The nodes that previously connected + // to the j-th output of i-th arg node are reconnected to the i-th // identity node. // - // If "caller" has any input control dependencies, we add a NoOp - // node "input_control_node". This "input_control_node" depends on - // what "caller" depends on, and the added identity nodes depend on - // "input_control_node". - std::vector inputs(caller->num_inputs()); - Node* input_control_node = nullptr; - for (const Edge* e : caller->in_edges()) { - if (e->IsControlEdge()) { - if (input_control_node == nullptr) { - input_control_node = AddNoOp(g); - } - g->AddControlEdge(e->src(), input_control_node); - } else { - inputs[e->dst_input()] = {e->src(), e->src_output()}; - } - } + // The added identity nodes depend on "input_control_node". for (std::size_t i = 0; i < fbody->arg_nodes.size(); ++i) { Node* arg = node_map[fbody->arg_nodes[i]->id()]; Node* n = AddIdentity(g, inputs[i]); @@ -982,7 +1010,7 @@ bool ExpandInlineFunctions(FunctionLibraryRuntime* lib, Graph* graph) { candidates.push_back({node, fbody}); } for (const auto& p : candidates) { - InlineFunctionBody(graph, p.first, p.second); + InlineFunctionBody(*fld, graph, p.first, p.second); } return !candidates.empty(); } diff --git a/tensorflow/core/common_runtime/function_test.cc b/tensorflow/core/common_runtime/function_test.cc index af1ff6aec03..dfa1ed8a7e4 100644 --- a/tensorflow/core/common_runtime/function_test.cc +++ b/tensorflow/core/common_runtime/function_test.cc @@ -391,6 +391,90 @@ TEST_F(FunctionLibraryRuntimeTest, ExpandInlineFunctions) { } } +// Verifies that control dependencies on the caller are added as control +// dependencies on any function calls created by inlining. +TEST_F(FunctionLibraryRuntimeTest, ExpandInlineFunctionsWithControlDeps) { + Init({test::function::XTimesTwo(), test::function::XTimesFour()}); + + std::unique_ptr g(new Graph(OpRegistry::Global())); + { + Scope s = Scope::NewRootScope(); + TF_ASSERT_OK(s.graph()->AddFunctionLibrary(fdef_lib_)); + auto a = ops::_Arg(s.WithOpName("a"), DT_FLOAT, 0); + auto c = ops::NoOp(s.WithOpName("c")); + auto b = Call(&s, "b", "XTimesFour", {a}); + s.graph()->AddControlEdge(c.operation.node(), b.node()); + auto ret = ops::_Retval(s.WithOpName("b_RetVal"), b, 0); + TF_ASSERT_OK(s.ToGraph(g.get())); + } + + ExpandInlineFunctions(lib_.get(), g.get()); + { + Scope s = Scope::NewRootScope(); + TF_ASSERT_OK(s.graph()->AddFunctionLibrary(fdef_lib_)); + auto a = ops::_Arg(s.WithOpName("a"), DT_FLOAT, 0); + auto c = ops::NoOp(s.WithOpName("c")); + auto func0 = + ops::NoOp(s.WithOpName("Func/_0").WithControlDependencies({c})); + auto func1 = ops::Identity( + s.WithOpName("Func/_1").WithControlDependencies({func0}), a); + auto b_x2 = Call(&s, "b/x2", "XTimesTwo", {func1}); + s.graph()->AddControlEdge(func0.operation.node(), b_x2.node()); + auto b_y = Call(&s, "b/y", "XTimesTwo", {b_x2}); + s.graph()->AddControlEdge(func0.operation.node(), b_y.node()); + auto func2 = ops::Identity(s.WithOpName("Func/_2"), b_y); + auto ret = ops::_Retval(s.WithOpName("b_RetVal"), func2, 0); + GraphDef expected; + TF_ASSERT_OK(s.ToGraphDef(&expected)); + + GraphDef actual; + g->ToGraphDef(&actual); + TF_EXPECT_GRAPH_EQ(expected, actual); + } + + ExpandInlineFunctions(lib_.get(), g.get()); + { + Scope s = Scope::NewRootScope(); + TF_ASSERT_OK(s.graph()->AddFunctionLibrary(fdef_lib_)); + auto a = ops::_Arg(s.WithOpName("a"), DT_FLOAT, 0); + auto c = ops::NoOp(s.WithOpName("c")); + auto func0 = + ops::NoOp(s.WithOpName("Func/_0").WithControlDependencies({c})); + auto func1 = ops::Identity( + s.WithOpName("Func/_1").WithControlDependencies({func0}), a); + + auto func3 = + ops::NoOp(s.WithOpName("Func/_3").WithControlDependencies({func0})); + auto func4 = ops::Identity( + s.WithOpName("Func/_4").WithControlDependencies({func3}), func1); + auto b_x2_two = ops::Const( + s.WithOpName("b/x2/two").WithControlDependencies({func3}), 2LL); + auto b_x2_scale = ops::Cast(s.WithOpName("b/x2/scale"), b_x2_two, DT_FLOAT); + auto b_x2_y = ops::Mul(s.WithOpName("b/x2/y"), func4, b_x2_scale); + auto func5 = ops::Identity(s.WithOpName("Func/_5"), b_x2_y); + + auto func6 = + ops::NoOp(s.WithOpName("Func/_6").WithControlDependencies({func0})); + auto func7 = ops::Identity( + s.WithOpName("Func/_7").WithControlDependencies({func6}), func5); + auto b_y_two = ops::Const( + s.WithOpName("b/y/two").WithControlDependencies({func6}), 2LL); + auto b_y_scale = ops::Cast(s.WithOpName("b/y/scale"), b_y_two, DT_FLOAT); + auto b_y_y = ops::Mul(s.WithOpName("b/y/y"), func7, b_y_scale); + auto func8 = ops::Identity(s.WithOpName("Func/_8"), b_y_y); + + auto func2 = ops::Identity(s.WithOpName("Func/_2"), func8); + auto ret = ops::_Retval(s.WithOpName("b_RetVal"), func2, 0); + + GraphDef expected; + TF_ASSERT_OK(s.ToGraphDef(&expected)); + + GraphDef actual; + g->ToGraphDef(&actual); + TF_EXPECT_GRAPH_EQ(expected, actual); + } +} + TEST_F(FunctionLibraryRuntimeTest, OptimizeGraph) { Init({test::function::XTimesTwo(), test::function::XTimesFour(), test::function::XTimes16()}); diff --git a/tensorflow/python/framework/function_test.py b/tensorflow/python/framework/function_test.py index 39f00e52169..416ab263afc 100644 --- a/tensorflow/python/framework/function_test.py +++ b/tensorflow/python/framework/function_test.py @@ -324,6 +324,48 @@ class FunctionTest(test.TestCase): "assertion"): _ = MyFn(100.0).eval() + def testControlFlowStrictness(self): + """Inlined functions must not execute in a untaken control flow branch.""" + + @function.Defun(dtypes.int32) + def AssertFail(x): + # Assertion that always fails and does not have a data dependency on `x`. + assert_false = control_flow_ops.Assert(False, [42]) + with ops.control_dependencies([assert_false]): + return array_ops.identity(x) + + with ops.device("CPU"): + pred = array_ops.placeholder(dtypes.bool) + x = array_ops.placeholder(dtypes.int32) + cond = control_flow_ops.cond(pred, lambda: x + 1, lambda: AssertFail(x)) + # pylint: disable=unnecessary-lambda + loop = control_flow_ops.while_loop(lambda y: pred, + lambda y: AssertFail(y), [x]) + # pylint: enable=unnecessary-lambda + + # Enables inlining. + config = config_pb2.ConfigProto(graph_options=config_pb2.GraphOptions( + optimizer_options=config_pb2.OptimizerOptions( + opt_level=config_pb2.OptimizerOptions.L0, + do_common_subexpression_elimination=True, + do_function_inlining=True, + do_constant_folding=True))) + + with session.Session(config=config) as sess: + # Since the 'False' branch is not taken, the assertion should not fire. + self.assertEqual(4, sess.run(cond, {pred: True, x: 3})) + + # The assertion should still fire if the False branch is taken. + with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, + "assertion"): + sess.run(cond, {pred: False, x: 3}) + + # Similarly for loops. + self.assertEqual(3, sess.run(loop, {pred: False, x: 3})) + with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, + "assertion"): + sess.run(loop, {pred: True, x: 3}) + def testVar(self): @function.Defun(dtypes.float32) diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py index 91be9ddbd78..713cb65a40d 100644 --- a/tensorflow/python/ops/control_flow_ops.py +++ b/tensorflow/python/ops/control_flow_ops.py @@ -1622,6 +1622,11 @@ class CondContext(ControlFlowContext): # pylint: enable=protected-access for x in op.outputs: self._values.add(x.name) + # pylint: disable=protected-access + if op.graph._is_function(op.type) or op.type == "SymbolicGradient": + op._add_control_input(self._pivot.op) + # pylint: enable=protected-access + if self._outer_context or not IsLoopExit(op): op.graph.prevent_fetching(op) @@ -2147,8 +2152,13 @@ class WhileContext(ControlFlowContext): def _MaybeAddControlDependency(self, op): """Add a control input to the op if it only depends on loop invariants.""" def _IsOpFree(op): + """Determines if `op` needs a control dependency.""" if op.control_inputs: return False + # pylint: disable=protected-access + if op.graph._is_function(op.type) or op.type == "SymbolicGradient": + return True + # pylint: enable=protected-access for x in op.inputs: if not _IsLoopConstantEnter(x.op): return False