diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py index 982ead7e945..20c03178ea1 100644 --- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py +++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py @@ -2595,6 +2595,22 @@ class ControlFlowTest(test.TestCase): self.evaluate(variables.global_variables_initializer()) self.assertAllClose(216.0, g[0]) + def testWhileGrad_EagerResourceVariable(self): + with context.eager_mode(): + a = resource_variable_ops.ResourceVariable( + np.ones([2, 2], dtype=np.float32)) + v = constant_op.constant(1.0) + + @eager_function.defun + def fn(): + r = control_flow_ops.while_loop( + lambda i, _: i < 2, + lambda i, x: (i + 1, x * math_ops.reduce_sum(a) * v), + [0, 1.0])[1] + return gradients_impl.gradients(r, [v])[0] + + self.assertEqual(self.evaluate(fn()), 32.) + def testWhileGrad_ResourceVarInFunctionCall(self): @def_function.function diff --git a/tensorflow/python/ops/gradients_test.py b/tensorflow/python/ops/gradients_test.py index 9caffa3ea8e..9b24473f57a 100644 --- a/tensorflow/python/ops/gradients_test.py +++ b/tensorflow/python/ops/gradients_test.py @@ -587,11 +587,12 @@ class FunctionGradientsTest(test_util.TensorFlowTestCase): def Foo(): x = constant_op.constant(10.0, name="x") y = math_ops.multiply(x, c, name="y") - z = math_ops.multiply(y, 3.0, name="z") + # Regression test for b/122564611. + z = math_ops.multiply(c, y, name="z") g = gradients_impl.gradients(z, x) return g[0] - self.assertEqual(Foo().numpy(), 6.0) + self.assertEqual(Foo().numpy(), 4.0) class StopGradientTest(test_util.TensorFlowTestCase): diff --git a/tensorflow/python/ops/gradients_util.py b/tensorflow/python/ops/gradients_util.py index 64c199ad29d..af46101726d 100644 --- a/tensorflow/python/ops/gradients_util.py +++ b/tensorflow/python/ops/gradients_util.py @@ -478,14 +478,28 @@ def _MaybeCaptured(t): return t -# TODO(skyewm): plumbing xs through everywhere is ugly, consider making -# _GradientsHelper a class with xs as a member variable. def _NonEagerInputs(op, xs): """Returns the inputs of op, crossing closure boundaries where necessary. Does not return any captured EagerTensors, i.e., the number of tensors returned may be less than than the actual number of inputs. + Args: + op: Operation + xs: list of Tensors we are differentiating w.r.t. + + Returns: + A list of tensors. The tensors may be from multiple Graph/FuncGraphs if op + is in a FuncGraph and has captured inputs. + """ + return [t for t in _Inputs(op, xs) if not isinstance(t, ops.EagerTensor)] + + +# TODO(skyewm): plumbing xs through everywhere is ugly, consider making +# _GradientsHelper a class with xs as a member variable. +def _Inputs(op, xs): + """Returns the inputs of op, crossing closure boundaries where necessary. + Args: op: Operation xs: list of Tensors we are differentiating w.r.t. @@ -504,8 +518,6 @@ def _NonEagerInputs(op, xs): # direct input to op. if t not in xs: t = _MaybeCaptured(t) - # Skip captured eager inputs. - if isinstance(t, ops.EagerTensor): continue inputs.append(t) return inputs else: @@ -736,9 +748,10 @@ def _GradientsHelper(ys, else: # If no grad_fn is defined or none of out_grads is available, # just propagate a list of None backwards. - in_grads = [None] * len(_NonEagerInputs(op, xs)) - for i, (t_in, in_grad) in enumerate(zip(_NonEagerInputs(op, xs), - in_grads)): + in_grads = [None] * len(_Inputs(op, xs)) + # Note: we don't filter out eager inputs here because the inputs need to + # line up with in_grads. + for i, (t_in, in_grad) in enumerate(zip(_Inputs(op, xs), in_grads)): if in_grad is not None: if (isinstance(in_grad, ops.Tensor) and t_in.dtype != dtypes.resource): @@ -751,7 +764,8 @@ def _GradientsHelper(ys, "Original input shape: %s. " "Calculated input gradient shape: %s" % (op.name, i, t_in.shape, in_grad.shape)) - _SetGrad(grads, t_in, in_grad) + if not isinstance(t_in, ops.EagerTensor): + _SetGrad(grads, t_in, in_grad) if loop_state: loop_state.ExitGradWhileContext(op, before=False)