Fix bug in tf.gradients.
This is fixing in a bug where we'd incorrectly match up gradients tensors with the corresponding forward-pass tensors due to filtering out captured EagerTensors. This could lead to errors about incompatible gradient shapes or, even worse, the wrong gradient value. PiperOrigin-RevId: 235271326
This commit is contained in:
parent
8be0d24cc9
commit
135fee1685
@ -2595,6 +2595,22 @@ class ControlFlowTest(test.TestCase):
|
|||||||
self.evaluate(variables.global_variables_initializer())
|
self.evaluate(variables.global_variables_initializer())
|
||||||
self.assertAllClose(216.0, g[0])
|
self.assertAllClose(216.0, g[0])
|
||||||
|
|
||||||
|
def testWhileGrad_EagerResourceVariable(self):
|
||||||
|
with context.eager_mode():
|
||||||
|
a = resource_variable_ops.ResourceVariable(
|
||||||
|
np.ones([2, 2], dtype=np.float32))
|
||||||
|
v = constant_op.constant(1.0)
|
||||||
|
|
||||||
|
@eager_function.defun
|
||||||
|
def fn():
|
||||||
|
r = control_flow_ops.while_loop(
|
||||||
|
lambda i, _: i < 2,
|
||||||
|
lambda i, x: (i + 1, x * math_ops.reduce_sum(a) * v),
|
||||||
|
[0, 1.0])[1]
|
||||||
|
return gradients_impl.gradients(r, [v])[0]
|
||||||
|
|
||||||
|
self.assertEqual(self.evaluate(fn()), 32.)
|
||||||
|
|
||||||
def testWhileGrad_ResourceVarInFunctionCall(self):
|
def testWhileGrad_ResourceVarInFunctionCall(self):
|
||||||
|
|
||||||
@def_function.function
|
@def_function.function
|
||||||
|
@ -587,11 +587,12 @@ class FunctionGradientsTest(test_util.TensorFlowTestCase):
|
|||||||
def Foo():
|
def Foo():
|
||||||
x = constant_op.constant(10.0, name="x")
|
x = constant_op.constant(10.0, name="x")
|
||||||
y = math_ops.multiply(x, c, name="y")
|
y = math_ops.multiply(x, c, name="y")
|
||||||
z = math_ops.multiply(y, 3.0, name="z")
|
# Regression test for b/122564611.
|
||||||
|
z = math_ops.multiply(c, y, name="z")
|
||||||
g = gradients_impl.gradients(z, x)
|
g = gradients_impl.gradients(z, x)
|
||||||
return g[0]
|
return g[0]
|
||||||
|
|
||||||
self.assertEqual(Foo().numpy(), 6.0)
|
self.assertEqual(Foo().numpy(), 4.0)
|
||||||
|
|
||||||
|
|
||||||
class StopGradientTest(test_util.TensorFlowTestCase):
|
class StopGradientTest(test_util.TensorFlowTestCase):
|
||||||
|
@ -478,14 +478,28 @@ def _MaybeCaptured(t):
|
|||||||
return t
|
return t
|
||||||
|
|
||||||
|
|
||||||
# TODO(skyewm): plumbing xs through everywhere is ugly, consider making
|
|
||||||
# _GradientsHelper a class with xs as a member variable.
|
|
||||||
def _NonEagerInputs(op, xs):
|
def _NonEagerInputs(op, xs):
|
||||||
"""Returns the inputs of op, crossing closure boundaries where necessary.
|
"""Returns the inputs of op, crossing closure boundaries where necessary.
|
||||||
|
|
||||||
Does not return any captured EagerTensors, i.e., the number of tensors
|
Does not return any captured EagerTensors, i.e., the number of tensors
|
||||||
returned may be less than than the actual number of inputs.
|
returned may be less than than the actual number of inputs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
op: Operation
|
||||||
|
xs: list of Tensors we are differentiating w.r.t.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of tensors. The tensors may be from multiple Graph/FuncGraphs if op
|
||||||
|
is in a FuncGraph and has captured inputs.
|
||||||
|
"""
|
||||||
|
return [t for t in _Inputs(op, xs) if not isinstance(t, ops.EagerTensor)]
|
||||||
|
|
||||||
|
|
||||||
|
# TODO(skyewm): plumbing xs through everywhere is ugly, consider making
|
||||||
|
# _GradientsHelper a class with xs as a member variable.
|
||||||
|
def _Inputs(op, xs):
|
||||||
|
"""Returns the inputs of op, crossing closure boundaries where necessary.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
op: Operation
|
op: Operation
|
||||||
xs: list of Tensors we are differentiating w.r.t.
|
xs: list of Tensors we are differentiating w.r.t.
|
||||||
@ -504,8 +518,6 @@ def _NonEagerInputs(op, xs):
|
|||||||
# direct input to op.
|
# direct input to op.
|
||||||
if t not in xs:
|
if t not in xs:
|
||||||
t = _MaybeCaptured(t)
|
t = _MaybeCaptured(t)
|
||||||
# Skip captured eager inputs.
|
|
||||||
if isinstance(t, ops.EagerTensor): continue
|
|
||||||
inputs.append(t)
|
inputs.append(t)
|
||||||
return inputs
|
return inputs
|
||||||
else:
|
else:
|
||||||
@ -736,9 +748,10 @@ def _GradientsHelper(ys,
|
|||||||
else:
|
else:
|
||||||
# If no grad_fn is defined or none of out_grads is available,
|
# If no grad_fn is defined or none of out_grads is available,
|
||||||
# just propagate a list of None backwards.
|
# just propagate a list of None backwards.
|
||||||
in_grads = [None] * len(_NonEagerInputs(op, xs))
|
in_grads = [None] * len(_Inputs(op, xs))
|
||||||
for i, (t_in, in_grad) in enumerate(zip(_NonEagerInputs(op, xs),
|
# Note: we don't filter out eager inputs here because the inputs need to
|
||||||
in_grads)):
|
# line up with in_grads.
|
||||||
|
for i, (t_in, in_grad) in enumerate(zip(_Inputs(op, xs), in_grads)):
|
||||||
if in_grad is not None:
|
if in_grad is not None:
|
||||||
if (isinstance(in_grad, ops.Tensor) and
|
if (isinstance(in_grad, ops.Tensor) and
|
||||||
t_in.dtype != dtypes.resource):
|
t_in.dtype != dtypes.resource):
|
||||||
@ -751,6 +764,7 @@ def _GradientsHelper(ys,
|
|||||||
"Original input shape: %s. "
|
"Original input shape: %s. "
|
||||||
"Calculated input gradient shape: %s" %
|
"Calculated input gradient shape: %s" %
|
||||||
(op.name, i, t_in.shape, in_grad.shape))
|
(op.name, i, t_in.shape, in_grad.shape))
|
||||||
|
if not isinstance(t_in, ops.EagerTensor):
|
||||||
_SetGrad(grads, t_in, in_grad)
|
_SetGrad(grads, t_in, in_grad)
|
||||||
if loop_state:
|
if loop_state:
|
||||||
loop_state.ExitGradWhileContext(op, before=False)
|
loop_state.ExitGradWhileContext(op, before=False)
|
||||||
|
Loading…
Reference in New Issue
Block a user