Expand CompositeTensors in GradientTape.watch.
For built-in and user-defined CompositeTensors this is useful to be able to watch the composite without having to manually pick specific tensors within it to watch. PiperOrigin-RevId: 313211503 Change-Id: I16a3fa178aa39a4e06d9b35e9fe40f06b10adcac
This commit is contained in:
parent
6aece71ebf
commit
51504ec873
@ -882,7 +882,7 @@ class GradientTape(object):
|
||||
Raises:
|
||||
ValueError: if it encounters something that is not a tensor.
|
||||
"""
|
||||
for t in nest.flatten(tensor):
|
||||
for t in nest.flatten(tensor, expand_composites=True):
|
||||
if not (_pywrap_utils.IsTensor(t) or _pywrap_utils.IsVariable(t)):
|
||||
raise ValueError("Passed in object of type {}, not tf.Tensor".format(
|
||||
type(t)))
|
||||
|
@ -32,6 +32,7 @@ from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors_impl
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.framework import test_util
|
||||
@ -48,6 +49,7 @@ from tensorflow.python.ops import nn_grad # pylint: disable=unused-import
|
||||
from tensorflow.python.ops import nn_ops
|
||||
from tensorflow.python.ops import random_ops
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
from tensorflow.python.ops import sparse_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.training import training
|
||||
|
||||
@ -1484,6 +1486,19 @@ class BackpropTest(test.TestCase, parameterized.TestCase):
|
||||
with self.assertRaisesRegexp(ValueError, 'ndarray'):
|
||||
g.watch(np.array(1.))
|
||||
|
||||
def testWatchComposite(self):
|
||||
"""Test that tape.watch expands composites and watches component Tensors."""
|
||||
with backprop.GradientTape() as t:
|
||||
values = constant_op.constant([1.0, 2.0], dtypes.float32)
|
||||
s = sparse_tensor.SparseTensor(
|
||||
indices=[[0, 0], [1, 2]],
|
||||
values=values,
|
||||
dense_shape=[3, 4])
|
||||
t.watch(s)
|
||||
z = sparse_ops.sparse_reduce_sum_v2(s)
|
||||
result = t.gradient(z, values)
|
||||
self.assertAllEqual(result, [1.0, 1.0])
|
||||
|
||||
def testWatchedVariablesAfterNonPersistentGradientCall(self):
|
||||
with backprop.GradientTape(persistent=False) as tape:
|
||||
x = resource_variable_ops.ResourceVariable(1.0)
|
||||
|
Loading…
x
Reference in New Issue
Block a user