Convert Python Variable objects to tensors in custom_gradient, which allows nested custom_gradient functions. This allows a custom_gradient wrapped function to call through to another custom_gradient wrapped function.
PiperOrigin-RevId: 237295007
This commit is contained in:
parent
509b632a8a
commit
7dd20b844c
@ -238,6 +238,9 @@ def _graph_mode_decorator(f, *args, **kwargs):
|
|||||||
original_tensors = all_tensors
|
original_tensors = all_tensors
|
||||||
with ops.get_default_graph().gradient_override_map({"IdentityN": name}):
|
with ops.get_default_graph().gradient_override_map({"IdentityN": name}):
|
||||||
all_tensors = array_ops.identity_n(all_tensors)
|
all_tensors = array_ops.identity_n(all_tensors)
|
||||||
|
|
||||||
|
original_tensors = [ops.convert_to_tensor(x) for x in original_tensors]
|
||||||
|
|
||||||
# Propagate handle data for happier shape inference for resource variables.
|
# Propagate handle data for happier shape inference for resource variables.
|
||||||
for i, t in enumerate(original_tensors):
|
for i, t in enumerate(original_tensors):
|
||||||
if t.dtype == dtypes.resource and hasattr(t, "_handle_data"):
|
if t.dtype == dtypes.resource and hasattr(t, "_handle_data"):
|
||||||
|
@ -1033,6 +1033,42 @@ class CustomGradientTest(test_util.TensorFlowTestCase):
|
|||||||
self.assertAllEqual(g.eval(), [2.0])
|
self.assertAllEqual(g.eval(), [2.0])
|
||||||
self.assertAllEqual(g.eval(feed_dict={conditional: False}), [3.0])
|
self.assertAllEqual(g.eval(feed_dict={conditional: False}), [3.0])
|
||||||
|
|
||||||
|
def testRecursiveCustomGradient(self):
|
||||||
|
@custom_gradient.custom_gradient
|
||||||
|
def F(x):
|
||||||
|
out = core_layers.dense(x, 3, use_bias=False)
|
||||||
|
|
||||||
|
def Grad(out_grad, variables=None): # pylint: disable=redefined-outer-name
|
||||||
|
self.assertEqual(1, len(variables))
|
||||||
|
grads = gradients.gradients(out, [x, variables[0]], grad_ys=out_grad)
|
||||||
|
return grads[0], [array_ops.ones((4, 3))]
|
||||||
|
|
||||||
|
return out, Grad
|
||||||
|
|
||||||
|
@custom_gradient.custom_gradient
|
||||||
|
def DoubleF(x):
|
||||||
|
out = F(x)
|
||||||
|
|
||||||
|
def Grad(out_grad, variables=None): # pylint: disable=redefined-outer-name
|
||||||
|
self.assertEqual(1, len(variables))
|
||||||
|
grads = gradients.gradients(out, [x, variables[0]], grad_ys=out_grad)
|
||||||
|
return grads[0], [array_ops.ones((4, 3))]
|
||||||
|
|
||||||
|
return out, Grad
|
||||||
|
with ops.Graph().as_default():
|
||||||
|
x = array_ops.ones((2, 4))
|
||||||
|
with variable_scope.variable_scope("f", use_resource=True) as vs:
|
||||||
|
y = DoubleF(x)
|
||||||
|
all_vars = vs.global_variables()
|
||||||
|
assert len(all_vars) == 1
|
||||||
|
grads = gradients.gradients(y, [x, all_vars[0]])
|
||||||
|
for g in grads:
|
||||||
|
self.assertIsNotNone(g)
|
||||||
|
with session.Session() as sess:
|
||||||
|
self.evaluate(variables.global_variables_initializer())
|
||||||
|
dw = sess.run(math_ops.reduce_sum(grads[1]))
|
||||||
|
self.assertEqual(12., dw)
|
||||||
|
|
||||||
|
|
||||||
class AggregateIndexedSlicesGradientsTest(test_util.TensorFlowTestCase):
|
class AggregateIndexedSlicesGradientsTest(test_util.TensorFlowTestCase):
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user