Add 'src_graph' argument to gradients_impl._GradientsHelper.

This allows the gradient graph to be built in a _FuncGraph separate
from the forward graph (a _FuncGraph is necessary to capture needed
tensors from the forward graph. It's up to the capturing logic what
how to feed the forward tensors to the gradient graph).

PiperOrigin-RevId: 197230736
This commit is contained in:
Skye Wanderman-Milne 2018-05-18 18:29:54 -07:00 committed by TensorFlower Gardener
parent f3ae367618
commit 92cdc99c98
3 changed files with 75 additions and 77 deletions

View File

@ -1192,20 +1192,18 @@ class ControlFlowState(object):
to backprop.
"""
loop_exits = []
for _, grad_state in self._map.items():
# pylint: disable=protected-access
for grad_state in self._map.values():
for y in grad_state.forward_loop_exits:
if pending_count[y.op._id] == 0:
if pending_count[y.op] == 0:
grad_state.pending_exits_count -= 1
if y.op._id not in to_ops_set:
if y.op not in to_ops_set:
grad_state.unused_exits.append(y)
if grad_state.pending_exits_count == 0:
loop_exits.extend(grad_state.unused_exits)
# Need to include Enters in backprop for higher-order gradients.
for y in grad_state.forward_context.loop_enters:
if pending_count[y.op._id] == 0:
pending_count[y.op._id] = 1
# pylint: enable=protected-access
if pending_count[y.op] == 0:
pending_count[y.op] = 1
return loop_exits
def EnterGradWhileContext(self, op, before):
@ -1243,8 +1241,8 @@ class ControlFlowState(object):
# We need to include all exits of a loop for backprop.
for loop_exit in grad_state.forward_loop_exits:
if not between_ops[loop_exit.op._id]:
between_ops[loop_exit.op._id] = True
if loop_exit.op not in between_ops:
between_ops.add(loop_exit.op)
between_op_list.append(loop_exit.op)
def ZerosLikeForExit(self, val):

View File

@ -112,14 +112,14 @@ def _MarkReachedOps(from_ops, reached_ops):
Args:
from_ops: list of Operations.
reached_ops: list of booleans, indexed by operation id.
reached_ops: set of Operations.
"""
queue = collections.deque()
queue.extend(from_ops)
while queue:
op = queue.popleft()
if not reached_ops[op._id]:
reached_ops[op._id] = True
if op not in reached_ops:
reached_ops.add(op)
for output in op.outputs:
if _IsBackpropagatable(output):
queue.extend(output.consumers())
@ -130,7 +130,7 @@ def _GatherInputs(to_ops, reached_ops):
Args:
to_ops: list of Operations.
reached_ops: list of booleans, indexed by operation id.
reached_ops: set of Operations.
Returns:
The list of all inputs of to_ops that are in reached_ops.
@ -142,58 +142,57 @@ def _GatherInputs(to_ops, reached_ops):
while queue:
op = queue.popleft()
# We are interested in this op.
if reached_ops[op._id]:
if op in reached_ops:
inputs.append(op)
# Clear the boolean so we won't add the inputs again.
reached_ops[op._id] = False
reached_ops.remove(op)
for inp in op.inputs:
queue.append(inp.op)
return inputs
def _PendingCount(graph, to_ops, from_ops, colocate_gradients_with_ops):
def _PendingCount(to_ops, from_ops, colocate_gradients_with_ops):
"""Initialize the pending count for ops between two lists of Operations.
'pending_count[op._id]' indicates the number of backprop inputs
'pending_count[op]' indicates the number of backprop inputs
to this operation.
Args:
graph: a Graph.
to_ops: list of Operations.
from_ops: list of Operations.
colocate_gradients_with_ops: Python bool. See docstring of gradients().
Returns:
A tuple containing: (1) the subset of to_ops ids reachable from from_ops
by a path of zero or more backpropagatable tensors, (2) a list of integers
indexed by operation id, indicating the number of backprop inputs to this
operation, and (3) a ControlFlowState object which is not None if the ops
between from_ops and to_ops contain control flow loops.
A tuple containing: (1) the subset of to_ops reachable from from_ops by a
path of zero or more backpropagatable tensors, (2) a mapping from operation
to the number of backprop inputs to that op, and (3) a ControlFlowState
object which is not None if the ops between from_ops and to_ops contain
control flow loops.
"""
# Mark reachable ops from from_ops.
reached_ops = [False] * (graph._last_id + 1)
reached_ops = set()
_MarkReachedOps(from_ops, reached_ops)
# reached_ops[X] iff X is reachable from from_ops by a path of zero or more
# X in reached_ops iff X is reachable from from_ops by a path of zero or more
# backpropagatable tensors.
reachable_to_ops = set(op._id for op in to_ops if reached_ops[op._id]) # pylint: disable=protected-access
reachable_to_ops = set(op for op in to_ops if op in reached_ops)
# Mark between ops.
between_ops = [False] * (graph._last_id + 1)
between_ops = set()
between_op_list = []
queue = collections.deque()
queue.extend(to_ops)
while queue:
op = queue.popleft()
# We are interested in this op.
if reached_ops[op._id]:
between_ops[op._id] = True
if op in reached_ops:
between_ops.add(op)
between_op_list.append(op)
# Clear the boolean so we won't add the inputs again.
reached_ops[op._id] = False
reached_ops.remove(op)
for inp in op.inputs:
queue.append(inp.op)
# between_ops[X] iff X is on a path of zero or more backpropagatable tensors
# X in between_ops iff X is on a path of zero or more backpropagatable tensors
# between from_ops and to_ops
# 'loop_state' is None if there are no while loops.
@ -201,11 +200,11 @@ def _PendingCount(graph, to_ops, from_ops, colocate_gradients_with_ops):
between_op_list, between_ops, colocate_gradients_with_ops)
# Initialize pending count for between ops.
pending_count = [0] * (graph._last_id + 1)
pending_count = collections.defaultdict(int)
for op in between_op_list:
for x in op.inputs:
if between_ops[x.op._id]:
pending_count[x.op._id] += 1
if x.op in between_ops:
pending_count[x.op] += 1
return reachable_to_ops, pending_count, loop_state
@ -331,15 +330,15 @@ def _StopOps(from_ops, stop_gradient_ops, pending_count):
should stop. Operations in the returned set will not be differentiated.
This set is defined as the subset of `from_ops` containing ops that have
no predecessor in `from_ops`. `pending_count` is the result of
`_PendingCount(g, xs, from_ops)`. An 'op' has predecessors in `from_ops`
iff pending_count[op._id] > 0.
`_PendingCount(xs, from_ops)`. An 'op' has predecessors in `from_ops`
iff pending_count[op] > 0.
In addition, none of `stop_gradient_ops` will be differentiated.
Args:
from_ops: list of Operations.
stop_gradient_ops: list of Operations never to backprop through.
pending_count: List of integers, indexed by operation id.
pending_count: mapping from operation to number of backprop inputs.
Returns:
The set of operations.
@ -348,12 +347,12 @@ def _StopOps(from_ops, stop_gradient_ops, pending_count):
for op in from_ops:
is_stop_op = True
for inp in op.inputs:
if pending_count[inp.op._id] > 0:
if pending_count[inp.op] > 0:
is_stop_op = False
break
if is_stop_op:
stop_ops.add(op._id)
stop_ops.update(op._id for op in stop_gradient_ops) # pylint: disable=protected-access
stop_ops.add(op)
stop_ops.update(op for op in stop_gradient_ops)
return stop_ops
@ -375,9 +374,7 @@ def _SymGrad(op, out_grads):
f.name = op.type
for k in op.node_def.attr:
f.attr[k].CopyFrom(op.node_def.attr[k])
# pylint: disable=protected-access
in_grads = functional_ops.symbolic_gradient(input=f_in, Tout=f_types, f=f)
# pylint: enable=protected-access
return in_grads
@ -535,13 +532,23 @@ def gradients(ys,
gate_gradients, aggregation_method, stop_gradients)
def _GradientsHelper(ys, xs, grad_ys, name, colocate_gradients_with_ops,
gate_gradients, aggregation_method, stop_gradients):
def _GradientsHelper(ys,
xs,
grad_ys=None,
name="gradients",
colocate_gradients_with_ops=False,
gate_gradients=False,
aggregation_method=None,
stop_gradients=None,
src_graph=None):
"""Implementation of gradients()."""
if context.executing_eagerly():
raise RuntimeError("tf.gradients not supported when eager execution "
"is enabled. Use tf.contrib.eager.GradientTape "
"instead.")
if src_graph is None:
src_graph = ops.get_default_graph()
ys = _AsList(ys)
xs = _AsList(xs)
stop_gradients = [] if stop_gradients is None else _AsList(stop_gradients)
@ -581,7 +588,7 @@ def _GradientsHelper(ys, xs, grad_ys, name, colocate_gradients_with_ops,
from_ops = [t.op for t in xs]
stop_gradient_ops = [t.op for t in stop_gradients]
reachable_to_ops, pending_count, loop_state = _PendingCount(
ops.get_default_graph(), to_ops, from_ops, colocate_gradients_with_ops)
to_ops, from_ops, colocate_gradients_with_ops)
# Iterate over the collected ops.
#
@ -603,12 +610,10 @@ def _GradientsHelper(ys, xs, grad_ys, name, colocate_gradients_with_ops,
for op in to_ops:
# 'ready' handles the case where one output gradient relies on
# another output's gradient.
# pylint: disable=protected-access
ready = (pending_count[op._id] == 0)
if ready and op._id not in to_ops_set and op._id in reachable_to_ops:
to_ops_set.add(op._id)
ready = (pending_count[op] == 0)
if ready and op not in to_ops_set and op in reachable_to_ops:
to_ops_set.add(op)
queue.append(op)
# pylint: enable=protected-access
if loop_state:
loop_exits = loop_state.ProcessUnusedLoopExits(pending_count, to_ops_set)
@ -632,12 +637,12 @@ def _GradientsHelper(ys, xs, grad_ys, name, colocate_gradients_with_ops,
grad_fn = None
func_call = None
# pylint: disable=protected-access
is_func_call = ops.get_default_graph()._is_function(op.type)
is_func_call = src_graph._is_function(op.type)
# pylint: enable=protected-access
has_out_grads = any(isinstance(g, ops.Tensor) or g for g in out_grads)
if has_out_grads and (op._id not in stop_ops):
if has_out_grads and (op not in stop_ops):
if is_func_call:
func_call = ops.get_default_graph()._get_function(op.type)
func_call = src_graph._get_function(op.type) # pylint: disable=protected-access
# Note that __defun is not set if the graph is
# imported. If it's set, we prefer to access the original
# defun.
@ -687,7 +692,7 @@ def _GradientsHelper(ys, xs, grad_ys, name, colocate_gradients_with_ops,
out_grads[i] = control_flow_ops.ZerosLikeOutsideLoop(op, i)
with ops.name_scope(op.name + "_grad"):
# pylint: disable=protected-access
with ops.get_default_graph()._original_op(op):
with src_graph._original_op(op):
# pylint: enable=protected-access
if grad_fn:
# If grad_fn was found, do not use SymbolicGradient even for
@ -754,13 +759,10 @@ def _HasAnyNotNoneGrads(grads, op):
def _UpdatePendingAndEnqueueReady(grads, op, queue, pending_count, loop_state):
"""Update pending count for the inputs of op and enqueue ready ops."""
for x in op.inputs:
# pylint: disable=protected-access
pending_count[x.op._id] -= 1
ready = (pending_count[x.op._id] == 0)
pending_count[x.op] -= 1
ready = (pending_count[x.op] == 0)
if loop_state and not ready:
ready = (
pending_count[x.op._id] > 0 and control_flow_util.IsLoopSwitch(x.op))
# pylint: enable=protected-access
ready = pending_count[x.op] > 0 and control_flow_util.IsLoopSwitch(x.op)
if ready:
if control_flow_util.IsLoopExit(x.op):
# if x is an exit without real gradient, defer processing them.

View File

@ -57,11 +57,10 @@ from tensorflow.python.ops.nn_ops import bias_add
from tensorflow.python.platform import googletest
def _OpsBetween(graph, to_ops, from_ops):
def _OpsBetween(to_ops, from_ops):
"""Build the list of operations between two lists of Operations.
Args:
graph: a Graph.
to_ops: list of Operations.
from_ops: list of Operations.
@ -72,13 +71,12 @@ def _OpsBetween(graph, to_ops, from_ops):
TODO(touts): Think about returning an empty list if from_ops are not
reachable from to_ops. Presently it returns to_ops in that case.
"""
# List of booleans, indexed by operation id, indicating if
# an op is reached from the output of "input_ops".
reached_ops = [False] * (graph._last_id + 1)
# Ops that are reachable from the output of "input_ops".
reached_ops = set()
# We only care to reach up to "output_ops" so we mark the
# output ops as reached to avoid recursing past them.
for op in to_ops:
reached_ops[op._id] = True
reached_ops.add(op)
gradients_impl._MarkReachedOps(from_ops, reached_ops)
between_ops = gradients_impl._GatherInputs(to_ops, reached_ops)
between_ops.sort(key=lambda x: -x._id)
@ -95,18 +93,18 @@ class GradientsTest(test_util.TensorFlowTestCase):
self.assertEquals(self._OpNames(ops1), self._OpNames(ops2))
def testOpsBetweenSimple(self):
with ops.Graph().as_default() as g:
with ops.Graph().as_default():
t1 = constant(1.0)
t2 = constant(2.0)
t3 = array_ops.stack([t1, t2])
# Full graph
self._assertOpListEqual([t3.op, t2.op, t1.op],
_OpsBetween(g, [t3.op], [t1.op, t2.op]))
_OpsBetween([t3.op], [t1.op, t2.op]))
# Only t1, t3.
self._assertOpListEqual([t3.op, t1.op], _OpsBetween(g, [t3.op], [t1.op]))
self._assertOpListEqual([t3.op, t1.op], _OpsBetween([t3.op], [t1.op]))
def testOpsBetweenUnreachable(self):
with ops.Graph().as_default() as g:
with ops.Graph().as_default():
t1 = constant(1.0)
t2 = constant(2.0)
_ = array_ops.stack([t1, t2])
@ -114,10 +112,10 @@ class GradientsTest(test_util.TensorFlowTestCase):
t5 = constant(2.0)
t6 = array_ops.stack([t4, t5])
# Elements of to_ops are always listed.
self._assertOpListEqual([t6.op], _OpsBetween(g, [t6.op], [t1.op]))
self._assertOpListEqual([t6.op], _OpsBetween([t6.op], [t1.op]))
def testOpsBetweenCut(self):
with ops.Graph().as_default() as g:
with ops.Graph().as_default():
t1 = constant(1.0)
t2 = constant(2.0)
t3 = array_ops.stack([t1, t2])
@ -126,10 +124,10 @@ class GradientsTest(test_util.TensorFlowTestCase):
t6 = constant([2.0])
t7 = array_ops.concat([t5, t6], 0)
self._assertOpListEqual([t7.op, t5.op, t4.op],
_OpsBetween(g, [t7.op], [t4.op]))
_OpsBetween([t7.op], [t4.op]))
def testOpsBetweenCycle(self):
with ops.Graph().as_default() as g:
with ops.Graph().as_default():
t1 = constant(1.0)
t2 = constant(2.0)
t3 = array_ops.stack([t1, t2])
@ -138,11 +136,11 @@ class GradientsTest(test_util.TensorFlowTestCase):
t6 = array_ops.concat([t4, t5], 0)
t7 = array_ops.concat([t6, t3], 0)
self._assertOpListEqual([t6.op, t4.op, t3.op],
_OpsBetween(g, [t6.op], [t3.op]))
_OpsBetween([t6.op], [t3.op]))
self._assertOpListEqual([t7.op, t6.op, t5.op, t4.op, t3.op, t1.op],
_OpsBetween(g, [t7.op], [t1.op, t5.op]))
_OpsBetween([t7.op], [t1.op, t5.op]))
self._assertOpListEqual([t6.op, t5.op, t4.op, t3.op, t2.op],
_OpsBetween(g, [t6.op], [t2.op, t5.op]))
_OpsBetween([t6.op], [t2.op, t5.op]))
def testGradients(self):
with ops.Graph().as_default():