Add 'src_graph' argument to gradients_impl._GradientsHelper.
This allows the gradient graph to be built in a _FuncGraph separate from the forward graph (a _FuncGraph is necessary to capture needed tensors from the forward graph. It's up to the capturing logic what how to feed the forward tensors to the gradient graph). PiperOrigin-RevId: 197230736
This commit is contained in:
parent
f3ae367618
commit
92cdc99c98
@ -1192,20 +1192,18 @@ class ControlFlowState(object):
|
||||
to backprop.
|
||||
"""
|
||||
loop_exits = []
|
||||
for _, grad_state in self._map.items():
|
||||
# pylint: disable=protected-access
|
||||
for grad_state in self._map.values():
|
||||
for y in grad_state.forward_loop_exits:
|
||||
if pending_count[y.op._id] == 0:
|
||||
if pending_count[y.op] == 0:
|
||||
grad_state.pending_exits_count -= 1
|
||||
if y.op._id not in to_ops_set:
|
||||
if y.op not in to_ops_set:
|
||||
grad_state.unused_exits.append(y)
|
||||
if grad_state.pending_exits_count == 0:
|
||||
loop_exits.extend(grad_state.unused_exits)
|
||||
# Need to include Enters in backprop for higher-order gradients.
|
||||
for y in grad_state.forward_context.loop_enters:
|
||||
if pending_count[y.op._id] == 0:
|
||||
pending_count[y.op._id] = 1
|
||||
# pylint: enable=protected-access
|
||||
if pending_count[y.op] == 0:
|
||||
pending_count[y.op] = 1
|
||||
return loop_exits
|
||||
|
||||
def EnterGradWhileContext(self, op, before):
|
||||
@ -1243,8 +1241,8 @@ class ControlFlowState(object):
|
||||
|
||||
# We need to include all exits of a loop for backprop.
|
||||
for loop_exit in grad_state.forward_loop_exits:
|
||||
if not between_ops[loop_exit.op._id]:
|
||||
between_ops[loop_exit.op._id] = True
|
||||
if loop_exit.op not in between_ops:
|
||||
between_ops.add(loop_exit.op)
|
||||
between_op_list.append(loop_exit.op)
|
||||
|
||||
def ZerosLikeForExit(self, val):
|
||||
|
@ -112,14 +112,14 @@ def _MarkReachedOps(from_ops, reached_ops):
|
||||
|
||||
Args:
|
||||
from_ops: list of Operations.
|
||||
reached_ops: list of booleans, indexed by operation id.
|
||||
reached_ops: set of Operations.
|
||||
"""
|
||||
queue = collections.deque()
|
||||
queue.extend(from_ops)
|
||||
while queue:
|
||||
op = queue.popleft()
|
||||
if not reached_ops[op._id]:
|
||||
reached_ops[op._id] = True
|
||||
if op not in reached_ops:
|
||||
reached_ops.add(op)
|
||||
for output in op.outputs:
|
||||
if _IsBackpropagatable(output):
|
||||
queue.extend(output.consumers())
|
||||
@ -130,7 +130,7 @@ def _GatherInputs(to_ops, reached_ops):
|
||||
|
||||
Args:
|
||||
to_ops: list of Operations.
|
||||
reached_ops: list of booleans, indexed by operation id.
|
||||
reached_ops: set of Operations.
|
||||
|
||||
Returns:
|
||||
The list of all inputs of to_ops that are in reached_ops.
|
||||
@ -142,58 +142,57 @@ def _GatherInputs(to_ops, reached_ops):
|
||||
while queue:
|
||||
op = queue.popleft()
|
||||
# We are interested in this op.
|
||||
if reached_ops[op._id]:
|
||||
if op in reached_ops:
|
||||
inputs.append(op)
|
||||
# Clear the boolean so we won't add the inputs again.
|
||||
reached_ops[op._id] = False
|
||||
reached_ops.remove(op)
|
||||
for inp in op.inputs:
|
||||
queue.append(inp.op)
|
||||
return inputs
|
||||
|
||||
|
||||
def _PendingCount(graph, to_ops, from_ops, colocate_gradients_with_ops):
|
||||
def _PendingCount(to_ops, from_ops, colocate_gradients_with_ops):
|
||||
"""Initialize the pending count for ops between two lists of Operations.
|
||||
|
||||
'pending_count[op._id]' indicates the number of backprop inputs
|
||||
'pending_count[op]' indicates the number of backprop inputs
|
||||
to this operation.
|
||||
|
||||
Args:
|
||||
graph: a Graph.
|
||||
to_ops: list of Operations.
|
||||
from_ops: list of Operations.
|
||||
colocate_gradients_with_ops: Python bool. See docstring of gradients().
|
||||
|
||||
Returns:
|
||||
A tuple containing: (1) the subset of to_ops ids reachable from from_ops
|
||||
by a path of zero or more backpropagatable tensors, (2) a list of integers
|
||||
indexed by operation id, indicating the number of backprop inputs to this
|
||||
operation, and (3) a ControlFlowState object which is not None if the ops
|
||||
between from_ops and to_ops contain control flow loops.
|
||||
A tuple containing: (1) the subset of to_ops reachable from from_ops by a
|
||||
path of zero or more backpropagatable tensors, (2) a mapping from operation
|
||||
to the number of backprop inputs to that op, and (3) a ControlFlowState
|
||||
object which is not None if the ops between from_ops and to_ops contain
|
||||
control flow loops.
|
||||
"""
|
||||
# Mark reachable ops from from_ops.
|
||||
reached_ops = [False] * (graph._last_id + 1)
|
||||
reached_ops = set()
|
||||
_MarkReachedOps(from_ops, reached_ops)
|
||||
# reached_ops[X] iff X is reachable from from_ops by a path of zero or more
|
||||
# X in reached_ops iff X is reachable from from_ops by a path of zero or more
|
||||
# backpropagatable tensors.
|
||||
|
||||
reachable_to_ops = set(op._id for op in to_ops if reached_ops[op._id]) # pylint: disable=protected-access
|
||||
reachable_to_ops = set(op for op in to_ops if op in reached_ops)
|
||||
|
||||
# Mark between ops.
|
||||
between_ops = [False] * (graph._last_id + 1)
|
||||
between_ops = set()
|
||||
between_op_list = []
|
||||
queue = collections.deque()
|
||||
queue.extend(to_ops)
|
||||
while queue:
|
||||
op = queue.popleft()
|
||||
# We are interested in this op.
|
||||
if reached_ops[op._id]:
|
||||
between_ops[op._id] = True
|
||||
if op in reached_ops:
|
||||
between_ops.add(op)
|
||||
between_op_list.append(op)
|
||||
# Clear the boolean so we won't add the inputs again.
|
||||
reached_ops[op._id] = False
|
||||
reached_ops.remove(op)
|
||||
for inp in op.inputs:
|
||||
queue.append(inp.op)
|
||||
# between_ops[X] iff X is on a path of zero or more backpropagatable tensors
|
||||
# X in between_ops iff X is on a path of zero or more backpropagatable tensors
|
||||
# between from_ops and to_ops
|
||||
|
||||
# 'loop_state' is None if there are no while loops.
|
||||
@ -201,11 +200,11 @@ def _PendingCount(graph, to_ops, from_ops, colocate_gradients_with_ops):
|
||||
between_op_list, between_ops, colocate_gradients_with_ops)
|
||||
|
||||
# Initialize pending count for between ops.
|
||||
pending_count = [0] * (graph._last_id + 1)
|
||||
pending_count = collections.defaultdict(int)
|
||||
for op in between_op_list:
|
||||
for x in op.inputs:
|
||||
if between_ops[x.op._id]:
|
||||
pending_count[x.op._id] += 1
|
||||
if x.op in between_ops:
|
||||
pending_count[x.op] += 1
|
||||
|
||||
return reachable_to_ops, pending_count, loop_state
|
||||
|
||||
@ -331,15 +330,15 @@ def _StopOps(from_ops, stop_gradient_ops, pending_count):
|
||||
should stop. Operations in the returned set will not be differentiated.
|
||||
This set is defined as the subset of `from_ops` containing ops that have
|
||||
no predecessor in `from_ops`. `pending_count` is the result of
|
||||
`_PendingCount(g, xs, from_ops)`. An 'op' has predecessors in `from_ops`
|
||||
iff pending_count[op._id] > 0.
|
||||
`_PendingCount(xs, from_ops)`. An 'op' has predecessors in `from_ops`
|
||||
iff pending_count[op] > 0.
|
||||
|
||||
In addition, none of `stop_gradient_ops` will be differentiated.
|
||||
|
||||
Args:
|
||||
from_ops: list of Operations.
|
||||
stop_gradient_ops: list of Operations never to backprop through.
|
||||
pending_count: List of integers, indexed by operation id.
|
||||
pending_count: mapping from operation to number of backprop inputs.
|
||||
|
||||
Returns:
|
||||
The set of operations.
|
||||
@ -348,12 +347,12 @@ def _StopOps(from_ops, stop_gradient_ops, pending_count):
|
||||
for op in from_ops:
|
||||
is_stop_op = True
|
||||
for inp in op.inputs:
|
||||
if pending_count[inp.op._id] > 0:
|
||||
if pending_count[inp.op] > 0:
|
||||
is_stop_op = False
|
||||
break
|
||||
if is_stop_op:
|
||||
stop_ops.add(op._id)
|
||||
stop_ops.update(op._id for op in stop_gradient_ops) # pylint: disable=protected-access
|
||||
stop_ops.add(op)
|
||||
stop_ops.update(op for op in stop_gradient_ops)
|
||||
return stop_ops
|
||||
|
||||
|
||||
@ -375,9 +374,7 @@ def _SymGrad(op, out_grads):
|
||||
f.name = op.type
|
||||
for k in op.node_def.attr:
|
||||
f.attr[k].CopyFrom(op.node_def.attr[k])
|
||||
# pylint: disable=protected-access
|
||||
in_grads = functional_ops.symbolic_gradient(input=f_in, Tout=f_types, f=f)
|
||||
# pylint: enable=protected-access
|
||||
return in_grads
|
||||
|
||||
|
||||
@ -535,13 +532,23 @@ def gradients(ys,
|
||||
gate_gradients, aggregation_method, stop_gradients)
|
||||
|
||||
|
||||
def _GradientsHelper(ys, xs, grad_ys, name, colocate_gradients_with_ops,
|
||||
gate_gradients, aggregation_method, stop_gradients):
|
||||
def _GradientsHelper(ys,
|
||||
xs,
|
||||
grad_ys=None,
|
||||
name="gradients",
|
||||
colocate_gradients_with_ops=False,
|
||||
gate_gradients=False,
|
||||
aggregation_method=None,
|
||||
stop_gradients=None,
|
||||
src_graph=None):
|
||||
"""Implementation of gradients()."""
|
||||
if context.executing_eagerly():
|
||||
raise RuntimeError("tf.gradients not supported when eager execution "
|
||||
"is enabled. Use tf.contrib.eager.GradientTape "
|
||||
"instead.")
|
||||
if src_graph is None:
|
||||
src_graph = ops.get_default_graph()
|
||||
|
||||
ys = _AsList(ys)
|
||||
xs = _AsList(xs)
|
||||
stop_gradients = [] if stop_gradients is None else _AsList(stop_gradients)
|
||||
@ -581,7 +588,7 @@ def _GradientsHelper(ys, xs, grad_ys, name, colocate_gradients_with_ops,
|
||||
from_ops = [t.op for t in xs]
|
||||
stop_gradient_ops = [t.op for t in stop_gradients]
|
||||
reachable_to_ops, pending_count, loop_state = _PendingCount(
|
||||
ops.get_default_graph(), to_ops, from_ops, colocate_gradients_with_ops)
|
||||
to_ops, from_ops, colocate_gradients_with_ops)
|
||||
|
||||
# Iterate over the collected ops.
|
||||
#
|
||||
@ -603,12 +610,10 @@ def _GradientsHelper(ys, xs, grad_ys, name, colocate_gradients_with_ops,
|
||||
for op in to_ops:
|
||||
# 'ready' handles the case where one output gradient relies on
|
||||
# another output's gradient.
|
||||
# pylint: disable=protected-access
|
||||
ready = (pending_count[op._id] == 0)
|
||||
if ready and op._id not in to_ops_set and op._id in reachable_to_ops:
|
||||
to_ops_set.add(op._id)
|
||||
ready = (pending_count[op] == 0)
|
||||
if ready and op not in to_ops_set and op in reachable_to_ops:
|
||||
to_ops_set.add(op)
|
||||
queue.append(op)
|
||||
# pylint: enable=protected-access
|
||||
|
||||
if loop_state:
|
||||
loop_exits = loop_state.ProcessUnusedLoopExits(pending_count, to_ops_set)
|
||||
@ -632,12 +637,12 @@ def _GradientsHelper(ys, xs, grad_ys, name, colocate_gradients_with_ops,
|
||||
grad_fn = None
|
||||
func_call = None
|
||||
# pylint: disable=protected-access
|
||||
is_func_call = ops.get_default_graph()._is_function(op.type)
|
||||
is_func_call = src_graph._is_function(op.type)
|
||||
# pylint: enable=protected-access
|
||||
has_out_grads = any(isinstance(g, ops.Tensor) or g for g in out_grads)
|
||||
if has_out_grads and (op._id not in stop_ops):
|
||||
if has_out_grads and (op not in stop_ops):
|
||||
if is_func_call:
|
||||
func_call = ops.get_default_graph()._get_function(op.type)
|
||||
func_call = src_graph._get_function(op.type) # pylint: disable=protected-access
|
||||
# Note that __defun is not set if the graph is
|
||||
# imported. If it's set, we prefer to access the original
|
||||
# defun.
|
||||
@ -687,7 +692,7 @@ def _GradientsHelper(ys, xs, grad_ys, name, colocate_gradients_with_ops,
|
||||
out_grads[i] = control_flow_ops.ZerosLikeOutsideLoop(op, i)
|
||||
with ops.name_scope(op.name + "_grad"):
|
||||
# pylint: disable=protected-access
|
||||
with ops.get_default_graph()._original_op(op):
|
||||
with src_graph._original_op(op):
|
||||
# pylint: enable=protected-access
|
||||
if grad_fn:
|
||||
# If grad_fn was found, do not use SymbolicGradient even for
|
||||
@ -754,13 +759,10 @@ def _HasAnyNotNoneGrads(grads, op):
|
||||
def _UpdatePendingAndEnqueueReady(grads, op, queue, pending_count, loop_state):
|
||||
"""Update pending count for the inputs of op and enqueue ready ops."""
|
||||
for x in op.inputs:
|
||||
# pylint: disable=protected-access
|
||||
pending_count[x.op._id] -= 1
|
||||
ready = (pending_count[x.op._id] == 0)
|
||||
pending_count[x.op] -= 1
|
||||
ready = (pending_count[x.op] == 0)
|
||||
if loop_state and not ready:
|
||||
ready = (
|
||||
pending_count[x.op._id] > 0 and control_flow_util.IsLoopSwitch(x.op))
|
||||
# pylint: enable=protected-access
|
||||
ready = pending_count[x.op] > 0 and control_flow_util.IsLoopSwitch(x.op)
|
||||
if ready:
|
||||
if control_flow_util.IsLoopExit(x.op):
|
||||
# if x is an exit without real gradient, defer processing them.
|
||||
|
@ -57,11 +57,10 @@ from tensorflow.python.ops.nn_ops import bias_add
|
||||
from tensorflow.python.platform import googletest
|
||||
|
||||
|
||||
def _OpsBetween(graph, to_ops, from_ops):
|
||||
def _OpsBetween(to_ops, from_ops):
|
||||
"""Build the list of operations between two lists of Operations.
|
||||
|
||||
Args:
|
||||
graph: a Graph.
|
||||
to_ops: list of Operations.
|
||||
from_ops: list of Operations.
|
||||
|
||||
@ -72,13 +71,12 @@ def _OpsBetween(graph, to_ops, from_ops):
|
||||
TODO(touts): Think about returning an empty list if from_ops are not
|
||||
reachable from to_ops. Presently it returns to_ops in that case.
|
||||
"""
|
||||
# List of booleans, indexed by operation id, indicating if
|
||||
# an op is reached from the output of "input_ops".
|
||||
reached_ops = [False] * (graph._last_id + 1)
|
||||
# Ops that are reachable from the output of "input_ops".
|
||||
reached_ops = set()
|
||||
# We only care to reach up to "output_ops" so we mark the
|
||||
# output ops as reached to avoid recursing past them.
|
||||
for op in to_ops:
|
||||
reached_ops[op._id] = True
|
||||
reached_ops.add(op)
|
||||
gradients_impl._MarkReachedOps(from_ops, reached_ops)
|
||||
between_ops = gradients_impl._GatherInputs(to_ops, reached_ops)
|
||||
between_ops.sort(key=lambda x: -x._id)
|
||||
@ -95,18 +93,18 @@ class GradientsTest(test_util.TensorFlowTestCase):
|
||||
self.assertEquals(self._OpNames(ops1), self._OpNames(ops2))
|
||||
|
||||
def testOpsBetweenSimple(self):
|
||||
with ops.Graph().as_default() as g:
|
||||
with ops.Graph().as_default():
|
||||
t1 = constant(1.0)
|
||||
t2 = constant(2.0)
|
||||
t3 = array_ops.stack([t1, t2])
|
||||
# Full graph
|
||||
self._assertOpListEqual([t3.op, t2.op, t1.op],
|
||||
_OpsBetween(g, [t3.op], [t1.op, t2.op]))
|
||||
_OpsBetween([t3.op], [t1.op, t2.op]))
|
||||
# Only t1, t3.
|
||||
self._assertOpListEqual([t3.op, t1.op], _OpsBetween(g, [t3.op], [t1.op]))
|
||||
self._assertOpListEqual([t3.op, t1.op], _OpsBetween([t3.op], [t1.op]))
|
||||
|
||||
def testOpsBetweenUnreachable(self):
|
||||
with ops.Graph().as_default() as g:
|
||||
with ops.Graph().as_default():
|
||||
t1 = constant(1.0)
|
||||
t2 = constant(2.0)
|
||||
_ = array_ops.stack([t1, t2])
|
||||
@ -114,10 +112,10 @@ class GradientsTest(test_util.TensorFlowTestCase):
|
||||
t5 = constant(2.0)
|
||||
t6 = array_ops.stack([t4, t5])
|
||||
# Elements of to_ops are always listed.
|
||||
self._assertOpListEqual([t6.op], _OpsBetween(g, [t6.op], [t1.op]))
|
||||
self._assertOpListEqual([t6.op], _OpsBetween([t6.op], [t1.op]))
|
||||
|
||||
def testOpsBetweenCut(self):
|
||||
with ops.Graph().as_default() as g:
|
||||
with ops.Graph().as_default():
|
||||
t1 = constant(1.0)
|
||||
t2 = constant(2.0)
|
||||
t3 = array_ops.stack([t1, t2])
|
||||
@ -126,10 +124,10 @@ class GradientsTest(test_util.TensorFlowTestCase):
|
||||
t6 = constant([2.0])
|
||||
t7 = array_ops.concat([t5, t6], 0)
|
||||
self._assertOpListEqual([t7.op, t5.op, t4.op],
|
||||
_OpsBetween(g, [t7.op], [t4.op]))
|
||||
_OpsBetween([t7.op], [t4.op]))
|
||||
|
||||
def testOpsBetweenCycle(self):
|
||||
with ops.Graph().as_default() as g:
|
||||
with ops.Graph().as_default():
|
||||
t1 = constant(1.0)
|
||||
t2 = constant(2.0)
|
||||
t3 = array_ops.stack([t1, t2])
|
||||
@ -138,11 +136,11 @@ class GradientsTest(test_util.TensorFlowTestCase):
|
||||
t6 = array_ops.concat([t4, t5], 0)
|
||||
t7 = array_ops.concat([t6, t3], 0)
|
||||
self._assertOpListEqual([t6.op, t4.op, t3.op],
|
||||
_OpsBetween(g, [t6.op], [t3.op]))
|
||||
_OpsBetween([t6.op], [t3.op]))
|
||||
self._assertOpListEqual([t7.op, t6.op, t5.op, t4.op, t3.op, t1.op],
|
||||
_OpsBetween(g, [t7.op], [t1.op, t5.op]))
|
||||
_OpsBetween([t7.op], [t1.op, t5.op]))
|
||||
self._assertOpListEqual([t6.op, t5.op, t4.op, t3.op, t2.op],
|
||||
_OpsBetween(g, [t6.op], [t2.op, t5.op]))
|
||||
_OpsBetween([t6.op], [t2.op, t5.op]))
|
||||
|
||||
def testGradients(self):
|
||||
with ops.Graph().as_default():
|
||||
|
Loading…
x
Reference in New Issue
Block a user