Fix bug in tf.gradients.
This is fixing in a bug where we'd incorrectly match up gradients tensors with the corresponding forward-pass tensors due to filtering out captured EagerTensors. This could lead to errors about incompatible gradient shapes or, even worse, the wrong gradient value. PiperOrigin-RevId: 235271326
This commit is contained in:
parent
8be0d24cc9
commit
135fee1685
@ -2595,6 +2595,22 @@ class ControlFlowTest(test.TestCase):
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
self.assertAllClose(216.0, g[0])
|
||||
|
||||
def testWhileGrad_EagerResourceVariable(self):
|
||||
with context.eager_mode():
|
||||
a = resource_variable_ops.ResourceVariable(
|
||||
np.ones([2, 2], dtype=np.float32))
|
||||
v = constant_op.constant(1.0)
|
||||
|
||||
@eager_function.defun
|
||||
def fn():
|
||||
r = control_flow_ops.while_loop(
|
||||
lambda i, _: i < 2,
|
||||
lambda i, x: (i + 1, x * math_ops.reduce_sum(a) * v),
|
||||
[0, 1.0])[1]
|
||||
return gradients_impl.gradients(r, [v])[0]
|
||||
|
||||
self.assertEqual(self.evaluate(fn()), 32.)
|
||||
|
||||
def testWhileGrad_ResourceVarInFunctionCall(self):
|
||||
|
||||
@def_function.function
|
||||
|
@ -587,11 +587,12 @@ class FunctionGradientsTest(test_util.TensorFlowTestCase):
|
||||
def Foo():
|
||||
x = constant_op.constant(10.0, name="x")
|
||||
y = math_ops.multiply(x, c, name="y")
|
||||
z = math_ops.multiply(y, 3.0, name="z")
|
||||
# Regression test for b/122564611.
|
||||
z = math_ops.multiply(c, y, name="z")
|
||||
g = gradients_impl.gradients(z, x)
|
||||
return g[0]
|
||||
|
||||
self.assertEqual(Foo().numpy(), 6.0)
|
||||
self.assertEqual(Foo().numpy(), 4.0)
|
||||
|
||||
|
||||
class StopGradientTest(test_util.TensorFlowTestCase):
|
||||
|
@ -478,14 +478,28 @@ def _MaybeCaptured(t):
|
||||
return t
|
||||
|
||||
|
||||
# TODO(skyewm): plumbing xs through everywhere is ugly, consider making
|
||||
# _GradientsHelper a class with xs as a member variable.
|
||||
def _NonEagerInputs(op, xs):
|
||||
"""Returns the inputs of op, crossing closure boundaries where necessary.
|
||||
|
||||
Does not return any captured EagerTensors, i.e., the number of tensors
|
||||
returned may be less than than the actual number of inputs.
|
||||
|
||||
Args:
|
||||
op: Operation
|
||||
xs: list of Tensors we are differentiating w.r.t.
|
||||
|
||||
Returns:
|
||||
A list of tensors. The tensors may be from multiple Graph/FuncGraphs if op
|
||||
is in a FuncGraph and has captured inputs.
|
||||
"""
|
||||
return [t for t in _Inputs(op, xs) if not isinstance(t, ops.EagerTensor)]
|
||||
|
||||
|
||||
# TODO(skyewm): plumbing xs through everywhere is ugly, consider making
|
||||
# _GradientsHelper a class with xs as a member variable.
|
||||
def _Inputs(op, xs):
|
||||
"""Returns the inputs of op, crossing closure boundaries where necessary.
|
||||
|
||||
Args:
|
||||
op: Operation
|
||||
xs: list of Tensors we are differentiating w.r.t.
|
||||
@ -504,8 +518,6 @@ def _NonEagerInputs(op, xs):
|
||||
# direct input to op.
|
||||
if t not in xs:
|
||||
t = _MaybeCaptured(t)
|
||||
# Skip captured eager inputs.
|
||||
if isinstance(t, ops.EagerTensor): continue
|
||||
inputs.append(t)
|
||||
return inputs
|
||||
else:
|
||||
@ -736,9 +748,10 @@ def _GradientsHelper(ys,
|
||||
else:
|
||||
# If no grad_fn is defined or none of out_grads is available,
|
||||
# just propagate a list of None backwards.
|
||||
in_grads = [None] * len(_NonEagerInputs(op, xs))
|
||||
for i, (t_in, in_grad) in enumerate(zip(_NonEagerInputs(op, xs),
|
||||
in_grads)):
|
||||
in_grads = [None] * len(_Inputs(op, xs))
|
||||
# Note: we don't filter out eager inputs here because the inputs need to
|
||||
# line up with in_grads.
|
||||
for i, (t_in, in_grad) in enumerate(zip(_Inputs(op, xs), in_grads)):
|
||||
if in_grad is not None:
|
||||
if (isinstance(in_grad, ops.Tensor) and
|
||||
t_in.dtype != dtypes.resource):
|
||||
@ -751,6 +764,7 @@ def _GradientsHelper(ys,
|
||||
"Original input shape: %s. "
|
||||
"Calculated input gradient shape: %s" %
|
||||
(op.name, i, t_in.shape, in_grad.shape))
|
||||
if not isinstance(t_in, ops.EagerTensor):
|
||||
_SetGrad(grads, t_in, in_grad)
|
||||
if loop_state:
|
||||
loop_state.ExitGradWhileContext(op, before=False)
|
||||
|
Loading…
Reference in New Issue
Block a user