Convert Python Variable objects to tensors in custom_gradient, which allows nested custom_gradient functions. This allows a custom_gradient wrapped function to call through to another custom_gradient wrapped function.

PiperOrigin-RevId: 237295007
This commit is contained in:
A. Unique TensorFlower 2019-03-07 12:07:16 -08:00 committed by TensorFlower Gardener
parent 509b632a8a
commit 7dd20b844c
2 changed files with 39 additions and 0 deletions

View File

@ -238,6 +238,9 @@ def _graph_mode_decorator(f, *args, **kwargs):
original_tensors = all_tensors
with ops.get_default_graph().gradient_override_map({"IdentityN": name}):
all_tensors = array_ops.identity_n(all_tensors)
original_tensors = [ops.convert_to_tensor(x) for x in original_tensors]
# Propagate handle data for happier shape inference for resource variables.
for i, t in enumerate(original_tensors):
if t.dtype == dtypes.resource and hasattr(t, "_handle_data"):

View File

@ -1033,6 +1033,42 @@ class CustomGradientTest(test_util.TensorFlowTestCase):
self.assertAllEqual(g.eval(), [2.0])
self.assertAllEqual(g.eval(feed_dict={conditional: False}), [3.0])
def testRecursiveCustomGradient(self):
@custom_gradient.custom_gradient
def F(x):
out = core_layers.dense(x, 3, use_bias=False)
def Grad(out_grad, variables=None): # pylint: disable=redefined-outer-name
self.assertEqual(1, len(variables))
grads = gradients.gradients(out, [x, variables[0]], grad_ys=out_grad)
return grads[0], [array_ops.ones((4, 3))]
return out, Grad
@custom_gradient.custom_gradient
def DoubleF(x):
out = F(x)
def Grad(out_grad, variables=None): # pylint: disable=redefined-outer-name
self.assertEqual(1, len(variables))
grads = gradients.gradients(out, [x, variables[0]], grad_ys=out_grad)
return grads[0], [array_ops.ones((4, 3))]
return out, Grad
with ops.Graph().as_default():
x = array_ops.ones((2, 4))
with variable_scope.variable_scope("f", use_resource=True) as vs:
y = DoubleF(x)
all_vars = vs.global_variables()
assert len(all_vars) == 1
grads = gradients.gradients(y, [x, all_vars[0]])
for g in grads:
self.assertIsNotNone(g)
with session.Session() as sess:
self.evaluate(variables.global_variables_initializer())
dw = sess.run(math_ops.reduce_sum(grads[1]))
self.assertEqual(12., dw)
class AggregateIndexedSlicesGradientsTest(test_util.TensorFlowTestCase):