Fix bug when inlining function calls inside a control-flow construct.

Add control dependencies from a control frame's pivot node onto all operators within the frame. Function inlining may introduce extra nodes in the graph, and those newly introduced nodes must depend on the enclosing control frame's pivot node. However, since the inlining transformation only looks locally at the function call node, it has no way to know that it should do this.

There are two possible solutions: either (a) teach the function inliner about control flow frames, or (b) add a control dependency from the pivot node to function call node and have the inliner propagate that dependency. This change does the latter.

When inlining a function call, propagate control dependencies from the function call node onto each inlined operator with no data dependencies, and onto any nested function calls or SymbolicGradients.
Change: 155289198
This commit is contained in:
Peter Hawkins 2017-05-06 09:14:16 -08:00 committed by TensorFlower Gardener
parent 1182c93e3e
commit e321ae5e32
4 changed files with 188 additions and 24 deletions

View File

@ -829,7 +829,8 @@ static bool ValidateInlining(const Node* node, const FunctionBody* fbody) {
// Given a "caller" in "graph", which is a function call of a function
// to "fbody". Replaces the "caller" with fbody->graph and connects
// edges properly.
static void InlineFunctionBody(Graph* g, Node* caller,
static void InlineFunctionBody(const FunctionLibraryDefinition& flib_def,
Graph* g, Node* caller,
const FunctionBody* fbody) {
if (!ValidateInlining(caller, fbody)) {
LOG(WARNING) << "Inlining mismatch: " << caller->DebugString() << " vs. "
@ -837,6 +838,23 @@ static void InlineFunctionBody(Graph* g, Node* caller,
return;
}
// Input edges. For data edges coming into "caller", we first compute the
// <src>:<src_output> for the i-th input in "inputs".
// If "caller" has any input control dependencies, we add a NoOp
// node "input_control_node", which depends on "caller"'s control inputs.
std::vector<Endpoint> inputs(caller->num_inputs());
Node* input_control_node = nullptr;
for (const Edge* e : caller->in_edges()) {
if (e->IsControlEdge()) {
if (input_control_node == nullptr) {
input_control_node = AddNoOp(g);
}
g->AddControlEdge(e->src(), input_control_node);
} else {
inputs[e->dst_input()] = {e->src(), e->src_output()};
}
}
// Duplicate fbody->graph into 'g'. First, we copy the nodes of
// fbody->graph into 'g' except the source and sink nodes. We copy
// edges among nodes in 'fbody->graph'.
@ -850,8 +868,35 @@ static void InlineFunctionBody(Graph* g, Node* caller,
CHECK(n->IsOp());
NodeDef ndef = n->def();
ndef.set_name(strings::StrCat(caller->name(), "/", ndef.name()));
node_map[n->id()] = g->AddNode(ndef, &s);
Node* clone = g->AddNode(ndef, &s);
TF_CHECK_OK(s);
node_map[n->id()] = clone;
// If there is an input control node, and one of:
// a) the node has no data or control inputs, or
// b) the node is a function call or SymbolicGradient,
// then add a control edge from the input control node to the clone.
//
// We must not execute any nodes if the original function call would not
// have executed. This is especially critical when the function call is
// inside a control-flow construct like tf.cond(). Case (a) ensures that
// such nodes do not run.
//
// The purpose of case (b) is to ensure that instances of case (a) created
// by further inlining steps also receive the control dependency.
if (input_control_node) {
bool has_inputs = false;
for (const Edge* e : n->in_edges()) {
if (!e->src()->IsSource()) {
has_inputs = true;
break;
}
}
if (!has_inputs || flib_def.Find(clone->type_string()) != nullptr ||
clone->type_string() == "SymbolicGradient") {
g->AddControlEdge(input_control_node, clone);
}
}
}
for (const Edge* e : fbody->graph->edges()) {
if (e->src()->IsSource() || e->src()->IsSink() || e->dst()->IsSource() ||
@ -865,29 +910,12 @@ static void InlineFunctionBody(Graph* g, Node* caller,
// Connect input edges.
//
// For data edges coming into "caller", we first compute the
// <src>:<src_output> for the i-th input in "inputs". We create one
// Identity node for each input. Then, we connect inputs[i] to to
// the i-th identity node added. The nodes that previously connects
// to the j-th output of i-th arg node are reconnected to th i-th
// We create one Identity node for each input. Then, we connect inputs[i] to
// the i-th identity node added. The nodes that previously connected
// to the j-th output of i-th arg node are reconnected to the i-th
// identity node.
//
// If "caller" has any input control dependencies, we add a NoOp
// node "input_control_node". This "input_control_node" depends on
// what "caller" depends on, and the added identity nodes depend on
// "input_control_node".
std::vector<Endpoint> inputs(caller->num_inputs());
Node* input_control_node = nullptr;
for (const Edge* e : caller->in_edges()) {
if (e->IsControlEdge()) {
if (input_control_node == nullptr) {
input_control_node = AddNoOp(g);
}
g->AddControlEdge(e->src(), input_control_node);
} else {
inputs[e->dst_input()] = {e->src(), e->src_output()};
}
}
// The added identity nodes depend on "input_control_node".
for (std::size_t i = 0; i < fbody->arg_nodes.size(); ++i) {
Node* arg = node_map[fbody->arg_nodes[i]->id()];
Node* n = AddIdentity(g, inputs[i]);
@ -982,7 +1010,7 @@ bool ExpandInlineFunctions(FunctionLibraryRuntime* lib, Graph* graph) {
candidates.push_back({node, fbody});
}
for (const auto& p : candidates) {
InlineFunctionBody(graph, p.first, p.second);
InlineFunctionBody(*fld, graph, p.first, p.second);
}
return !candidates.empty();
}

View File

@ -391,6 +391,90 @@ TEST_F(FunctionLibraryRuntimeTest, ExpandInlineFunctions) {
}
}
// Verifies that control dependencies on the caller are added as control
// dependencies on any function calls created by inlining.
TEST_F(FunctionLibraryRuntimeTest, ExpandInlineFunctionsWithControlDeps) {
Init({test::function::XTimesTwo(), test::function::XTimesFour()});
std::unique_ptr<Graph> g(new Graph(OpRegistry::Global()));
{
Scope s = Scope::NewRootScope();
TF_ASSERT_OK(s.graph()->AddFunctionLibrary(fdef_lib_));
auto a = ops::_Arg(s.WithOpName("a"), DT_FLOAT, 0);
auto c = ops::NoOp(s.WithOpName("c"));
auto b = Call(&s, "b", "XTimesFour", {a});
s.graph()->AddControlEdge(c.operation.node(), b.node());
auto ret = ops::_Retval(s.WithOpName("b_RetVal"), b, 0);
TF_ASSERT_OK(s.ToGraph(g.get()));
}
ExpandInlineFunctions(lib_.get(), g.get());
{
Scope s = Scope::NewRootScope();
TF_ASSERT_OK(s.graph()->AddFunctionLibrary(fdef_lib_));
auto a = ops::_Arg(s.WithOpName("a"), DT_FLOAT, 0);
auto c = ops::NoOp(s.WithOpName("c"));
auto func0 =
ops::NoOp(s.WithOpName("Func/_0").WithControlDependencies({c}));
auto func1 = ops::Identity(
s.WithOpName("Func/_1").WithControlDependencies({func0}), a);
auto b_x2 = Call(&s, "b/x2", "XTimesTwo", {func1});
s.graph()->AddControlEdge(func0.operation.node(), b_x2.node());
auto b_y = Call(&s, "b/y", "XTimesTwo", {b_x2});
s.graph()->AddControlEdge(func0.operation.node(), b_y.node());
auto func2 = ops::Identity(s.WithOpName("Func/_2"), b_y);
auto ret = ops::_Retval(s.WithOpName("b_RetVal"), func2, 0);
GraphDef expected;
TF_ASSERT_OK(s.ToGraphDef(&expected));
GraphDef actual;
g->ToGraphDef(&actual);
TF_EXPECT_GRAPH_EQ(expected, actual);
}
ExpandInlineFunctions(lib_.get(), g.get());
{
Scope s = Scope::NewRootScope();
TF_ASSERT_OK(s.graph()->AddFunctionLibrary(fdef_lib_));
auto a = ops::_Arg(s.WithOpName("a"), DT_FLOAT, 0);
auto c = ops::NoOp(s.WithOpName("c"));
auto func0 =
ops::NoOp(s.WithOpName("Func/_0").WithControlDependencies({c}));
auto func1 = ops::Identity(
s.WithOpName("Func/_1").WithControlDependencies({func0}), a);
auto func3 =
ops::NoOp(s.WithOpName("Func/_3").WithControlDependencies({func0}));
auto func4 = ops::Identity(
s.WithOpName("Func/_4").WithControlDependencies({func3}), func1);
auto b_x2_two = ops::Const(
s.WithOpName("b/x2/two").WithControlDependencies({func3}), 2LL);
auto b_x2_scale = ops::Cast(s.WithOpName("b/x2/scale"), b_x2_two, DT_FLOAT);
auto b_x2_y = ops::Mul(s.WithOpName("b/x2/y"), func4, b_x2_scale);
auto func5 = ops::Identity(s.WithOpName("Func/_5"), b_x2_y);
auto func6 =
ops::NoOp(s.WithOpName("Func/_6").WithControlDependencies({func0}));
auto func7 = ops::Identity(
s.WithOpName("Func/_7").WithControlDependencies({func6}), func5);
auto b_y_two = ops::Const(
s.WithOpName("b/y/two").WithControlDependencies({func6}), 2LL);
auto b_y_scale = ops::Cast(s.WithOpName("b/y/scale"), b_y_two, DT_FLOAT);
auto b_y_y = ops::Mul(s.WithOpName("b/y/y"), func7, b_y_scale);
auto func8 = ops::Identity(s.WithOpName("Func/_8"), b_y_y);
auto func2 = ops::Identity(s.WithOpName("Func/_2"), func8);
auto ret = ops::_Retval(s.WithOpName("b_RetVal"), func2, 0);
GraphDef expected;
TF_ASSERT_OK(s.ToGraphDef(&expected));
GraphDef actual;
g->ToGraphDef(&actual);
TF_EXPECT_GRAPH_EQ(expected, actual);
}
}
TEST_F(FunctionLibraryRuntimeTest, OptimizeGraph) {
Init({test::function::XTimesTwo(), test::function::XTimesFour(),
test::function::XTimes16()});

View File

@ -324,6 +324,48 @@ class FunctionTest(test.TestCase):
"assertion"):
_ = MyFn(100.0).eval()
def testControlFlowStrictness(self):
"""Inlined functions must not execute in a untaken control flow branch."""
@function.Defun(dtypes.int32)
def AssertFail(x):
# Assertion that always fails and does not have a data dependency on `x`.
assert_false = control_flow_ops.Assert(False, [42])
with ops.control_dependencies([assert_false]):
return array_ops.identity(x)
with ops.device("CPU"):
pred = array_ops.placeholder(dtypes.bool)
x = array_ops.placeholder(dtypes.int32)
cond = control_flow_ops.cond(pred, lambda: x + 1, lambda: AssertFail(x))
# pylint: disable=unnecessary-lambda
loop = control_flow_ops.while_loop(lambda y: pred,
lambda y: AssertFail(y), [x])
# pylint: enable=unnecessary-lambda
# Enables inlining.
config = config_pb2.ConfigProto(graph_options=config_pb2.GraphOptions(
optimizer_options=config_pb2.OptimizerOptions(
opt_level=config_pb2.OptimizerOptions.L0,
do_common_subexpression_elimination=True,
do_function_inlining=True,
do_constant_folding=True)))
with session.Session(config=config) as sess:
# Since the 'False' branch is not taken, the assertion should not fire.
self.assertEqual(4, sess.run(cond, {pred: True, x: 3}))
# The assertion should still fire if the False branch is taken.
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
"assertion"):
sess.run(cond, {pred: False, x: 3})
# Similarly for loops.
self.assertEqual(3, sess.run(loop, {pred: False, x: 3}))
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
"assertion"):
sess.run(loop, {pred: True, x: 3})
def testVar(self):
@function.Defun(dtypes.float32)

View File

@ -1622,6 +1622,11 @@ class CondContext(ControlFlowContext):
# pylint: enable=protected-access
for x in op.outputs:
self._values.add(x.name)
# pylint: disable=protected-access
if op.graph._is_function(op.type) or op.type == "SymbolicGradient":
op._add_control_input(self._pivot.op)
# pylint: enable=protected-access
if self._outer_context or not IsLoopExit(op):
op.graph.prevent_fetching(op)
@ -2147,8 +2152,13 @@ class WhileContext(ControlFlowContext):
def _MaybeAddControlDependency(self, op):
"""Add a control input to the op if it only depends on loop invariants."""
def _IsOpFree(op):
"""Determines if `op` needs a control dependency."""
if op.control_inputs:
return False
# pylint: disable=protected-access
if op.graph._is_function(op.type) or op.type == "SymbolicGradient":
return True
# pylint: enable=protected-access
for x in op.inputs:
if not _IsLoopConstantEnter(x.op):
return False