Expand CompositeTensors in GradientTape.watch.

For built-in and user-defined CompositeTensors this is useful to be able to watch the composite without having to manually pick specific tensors within it to watch.

PiperOrigin-RevId: 313211503
Change-Id: I16a3fa178aa39a4e06d9b35e9fe40f06b10adcac
This commit is contained in:
RJ Skerry-Ryan 2020-05-26 09:57:48 -07:00 committed by TensorFlower Gardener
parent 6aece71ebf
commit 51504ec873
2 changed files with 16 additions and 1 deletions

View File

@ -882,7 +882,7 @@ class GradientTape(object):
Raises:
ValueError: if it encounters something that is not a tensor.
"""
for t in nest.flatten(tensor):
for t in nest.flatten(tensor, expand_composites=True):
if not (_pywrap_utils.IsTensor(t) or _pywrap_utils.IsVariable(t)):
raise ValueError("Passed in object of type {}, not tf.Tensor".format(
type(t)))

View File

@ -32,6 +32,7 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
from tensorflow.python.framework import test_util
@ -48,6 +49,7 @@ from tensorflow.python.ops import nn_grad # pylint: disable=unused-import
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import sparse_ops
from tensorflow.python.ops import variables
from tensorflow.python.training import training
@ -1484,6 +1486,19 @@ class BackpropTest(test.TestCase, parameterized.TestCase):
with self.assertRaisesRegexp(ValueError, 'ndarray'):
g.watch(np.array(1.))
def testWatchComposite(self):
"""Test that tape.watch expands composites and watches component Tensors."""
with backprop.GradientTape() as t:
values = constant_op.constant([1.0, 2.0], dtypes.float32)
s = sparse_tensor.SparseTensor(
indices=[[0, 0], [1, 2]],
values=values,
dense_shape=[3, 4])
t.watch(s)
z = sparse_ops.sparse_reduce_sum_v2(s)
result = t.gradient(z, values)
self.assertAllEqual(result, [1.0, 1.0])
def testWatchedVariablesAfterNonPersistentGradientCall(self):
with backprop.GradientTape(persistent=False) as tape:
x = resource_variable_ops.ResourceVariable(1.0)