Fix bug in tf.gradients.

This is fixing in a bug where we'd incorrectly match up gradients
tensors with the corresponding forward-pass tensors due to filtering
out captured EagerTensors. This could lead to errors about incompatible
gradient shapes or, even worse, the wrong gradient value.

PiperOrigin-RevId: 235271326
This commit is contained in:
Skye Wanderman-Milne 2019-02-22 15:35:37 -08:00 committed by TensorFlower Gardener
parent 8be0d24cc9
commit 135fee1685
3 changed files with 41 additions and 10 deletions

View File

@ -2595,6 +2595,22 @@ class ControlFlowTest(test.TestCase):
self.evaluate(variables.global_variables_initializer())
self.assertAllClose(216.0, g[0])
def testWhileGrad_EagerResourceVariable(self):
with context.eager_mode():
a = resource_variable_ops.ResourceVariable(
np.ones([2, 2], dtype=np.float32))
v = constant_op.constant(1.0)
@eager_function.defun
def fn():
r = control_flow_ops.while_loop(
lambda i, _: i < 2,
lambda i, x: (i + 1, x * math_ops.reduce_sum(a) * v),
[0, 1.0])[1]
return gradients_impl.gradients(r, [v])[0]
self.assertEqual(self.evaluate(fn()), 32.)
def testWhileGrad_ResourceVarInFunctionCall(self):
@def_function.function

View File

@ -587,11 +587,12 @@ class FunctionGradientsTest(test_util.TensorFlowTestCase):
def Foo():
x = constant_op.constant(10.0, name="x")
y = math_ops.multiply(x, c, name="y")
z = math_ops.multiply(y, 3.0, name="z")
# Regression test for b/122564611.
z = math_ops.multiply(c, y, name="z")
g = gradients_impl.gradients(z, x)
return g[0]
self.assertEqual(Foo().numpy(), 6.0)
self.assertEqual(Foo().numpy(), 4.0)
class StopGradientTest(test_util.TensorFlowTestCase):

View File

@ -478,14 +478,28 @@ def _MaybeCaptured(t):
return t
# TODO(skyewm): plumbing xs through everywhere is ugly, consider making
# _GradientsHelper a class with xs as a member variable.
def _NonEagerInputs(op, xs):
"""Returns the inputs of op, crossing closure boundaries where necessary.
Does not return any captured EagerTensors, i.e., the number of tensors
returned may be less than than the actual number of inputs.
Args:
op: Operation
xs: list of Tensors we are differentiating w.r.t.
Returns:
A list of tensors. The tensors may be from multiple Graph/FuncGraphs if op
is in a FuncGraph and has captured inputs.
"""
return [t for t in _Inputs(op, xs) if not isinstance(t, ops.EagerTensor)]
# TODO(skyewm): plumbing xs through everywhere is ugly, consider making
# _GradientsHelper a class with xs as a member variable.
def _Inputs(op, xs):
"""Returns the inputs of op, crossing closure boundaries where necessary.
Args:
op: Operation
xs: list of Tensors we are differentiating w.r.t.
@ -504,8 +518,6 @@ def _NonEagerInputs(op, xs):
# direct input to op.
if t not in xs:
t = _MaybeCaptured(t)
# Skip captured eager inputs.
if isinstance(t, ops.EagerTensor): continue
inputs.append(t)
return inputs
else:
@ -736,9 +748,10 @@ def _GradientsHelper(ys,
else:
# If no grad_fn is defined or none of out_grads is available,
# just propagate a list of None backwards.
in_grads = [None] * len(_NonEagerInputs(op, xs))
for i, (t_in, in_grad) in enumerate(zip(_NonEagerInputs(op, xs),
in_grads)):
in_grads = [None] * len(_Inputs(op, xs))
# Note: we don't filter out eager inputs here because the inputs need to
# line up with in_grads.
for i, (t_in, in_grad) in enumerate(zip(_Inputs(op, xs), in_grads)):
if in_grad is not None:
if (isinstance(in_grad, ops.Tensor) and
t_in.dtype != dtypes.resource):
@ -751,6 +764,7 @@ def _GradientsHelper(ys,
"Original input shape: %s. "
"Calculated input gradient shape: %s" %
(op.name, i, t_in.shape, in_grad.shape))
if not isinstance(t_in, ops.EagerTensor):
_SetGrad(grads, t_in, in_grad)
if loop_state:
loop_state.ExitGradWhileContext(op, before=False)