Set up the generated gradient while loops properly so that None
is propagated when using tf.stop_gradient.
Change: 144108792
This commit is contained in:
parent
6a661a027b
commit
28fd303882
tensorflow/python
@ -2069,6 +2069,22 @@ class ControlFlowTest(test.TestCase):
|
||||
grad_theta_stopped = array_ops.stop_gradient(grad_theta)
|
||||
gradients_impl.gradients(grad_theta_stopped, theta)
|
||||
|
||||
def testStopGradOnWhileGrad(self):
|
||||
with self.test_session():
|
||||
x = constant_op.constant(2.0, name="x")
|
||||
y = constant_op.constant(2.0, name="y")
|
||||
|
||||
c = lambda x: math_ops.less(x, 100.0)
|
||||
b = lambda x: math_ops.mul(x, y)
|
||||
rx = control_flow_ops.while_loop(c, b, [x])
|
||||
|
||||
rg = gradients_impl.gradients(rx, y)[0]
|
||||
rg = array_ops.stop_gradient(rg)
|
||||
r = math_ops.add(math_ops.square(y), rx)
|
||||
r = math_ops.add(r, rg)
|
||||
r = gradients_impl.gradients(r, y)[0]
|
||||
self.assertEqual(388.0, r.eval())
|
||||
|
||||
def testOneValueCond(self):
|
||||
with self.test_session():
|
||||
c = array_ops.placeholder(dtypes.int32, shape=[])
|
||||
|
@ -203,7 +203,7 @@ def _EnterGrad(op, grad):
|
||||
# Skip gradient computation, if the attribute `back_prop` is false.
|
||||
return grad
|
||||
if grad_ctxt.grad_state is None:
|
||||
# Pass the gradient grough if we are not in a gradient while context.
|
||||
# Pass the gradient through if we are not in a gradient while context.
|
||||
return grad
|
||||
if op.get_attr("is_constant"):
|
||||
# Add a gradient accumulator for each loop invariant.
|
||||
@ -216,6 +216,7 @@ def _EnterGrad(op, grad):
|
||||
raise TypeError("Type %s not supported" % type(grad))
|
||||
else:
|
||||
result = exit(grad)
|
||||
grad_ctxt.loop_exits.append(result)
|
||||
grad_ctxt.ExitResult([result])
|
||||
return result
|
||||
|
||||
|
@ -703,6 +703,7 @@ class GradLoopState(object):
|
||||
self._switch_map = {}
|
||||
self._unused_exits = []
|
||||
self._deferred_exits = []
|
||||
self._forward_loop_exits = list(forward_ctxt.loop_exits)
|
||||
self._pending_exits_count = len(forward_ctxt.loop_exits)
|
||||
|
||||
self._outer_grad_state = outer_grad_state
|
||||
@ -820,6 +821,11 @@ class GradLoopState(object):
|
||||
"""The list of "deferred" exits."""
|
||||
return self._deferred_exits
|
||||
|
||||
@property
|
||||
def forward_loop_exits(self):
|
||||
"""The list of exits of the forward loop."""
|
||||
return self._forward_loop_exits
|
||||
|
||||
@property
|
||||
def pending_exits_count(self):
|
||||
"""The number of exits we expect to see but haven't."""
|
||||
@ -1059,8 +1065,8 @@ class ControlFlowState(object):
|
||||
to backprop.
|
||||
"""
|
||||
loop_exits = []
|
||||
for forward_ctxt, grad_state in self._map.items():
|
||||
for y in forward_ctxt.loop_exits:
|
||||
for _, grad_state in self._map.items():
|
||||
for y in grad_state.forward_loop_exits:
|
||||
# pylint: disable=protected-access
|
||||
if pending_count[y.op._id] == 0:
|
||||
grad_state.pending_exits_count -= 1
|
||||
@ -1105,7 +1111,7 @@ class ControlFlowState(object):
|
||||
self._map[forward_ctxt] = grad_state
|
||||
|
||||
# We need to include all exits of a loop for backprop.
|
||||
for loop_exit in forward_ctxt.loop_exits:
|
||||
for loop_exit in grad_state.forward_loop_exits:
|
||||
if not between_ops[loop_exit.op._id]:
|
||||
between_ops[loop_exit.op._id] = True
|
||||
between_op_list.append(loop_exit.op)
|
||||
@ -2119,6 +2125,7 @@ class WhileContext(ControlFlowContext):
|
||||
merge_n.op._update_input(1, next_n)
|
||||
|
||||
total_iterations = exit(switch_n[0], name="f_count")
|
||||
self.loop_exits.append(total_iterations)
|
||||
self.ExitResult([total_iterations])
|
||||
self.Exit()
|
||||
return total_iterations, next_n
|
||||
@ -2163,6 +2170,7 @@ class WhileContext(ControlFlowContext):
|
||||
merge_count.op._update_input(1, next_count)
|
||||
|
||||
final_zero = exit(switch_count[0], name="b_count")
|
||||
self.loop_exits.append(final_zero)
|
||||
if outer_grad_state is not None:
|
||||
# Force the stack pops of i-th execution of an inner loop to be ordered
|
||||
# before the pops of (i+1)-th execution of the same inner loop.
|
||||
@ -2244,6 +2252,7 @@ class WhileContext(ControlFlowContext):
|
||||
merge_acc.op._update_input(1, next_acc) # pylint: disable=protected-access
|
||||
|
||||
acc_result = exit(switch_acc_false, name="b_acc")
|
||||
self.loop_exits.append(acc_result)
|
||||
self.ExitResult([acc_result])
|
||||
return acc_result
|
||||
|
||||
@ -2320,6 +2329,7 @@ class WhileContext(ControlFlowContext):
|
||||
xm.op._update_input(1, xn) # pylint: disable=protected-access
|
||||
|
||||
acc_exits = [exit(x[0], name="b_acc") for x in switch_acc]
|
||||
self.loop_exits.extend(acc_exits)
|
||||
|
||||
self.ExitResult(acc_exits)
|
||||
return ops.IndexedSlices(
|
||||
|
Loading…
Reference in New Issue
Block a user