From 51504ec873c6d670253e106e325fd8ba965dcf0c Mon Sep 17 00:00:00 2001 From: RJ Skerry-Ryan Date: Tue, 26 May 2020 09:57:48 -0700 Subject: [PATCH] Expand CompositeTensors in GradientTape.watch. For built-in and user-defined CompositeTensors this is useful to be able to watch the composite without having to manually pick specific tensors within it to watch. PiperOrigin-RevId: 313211503 Change-Id: I16a3fa178aa39a4e06d9b35e9fe40f06b10adcac --- tensorflow/python/eager/backprop.py | 2 +- tensorflow/python/eager/backprop_test.py | 15 +++++++++++++++ 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/tensorflow/python/eager/backprop.py b/tensorflow/python/eager/backprop.py index 7a3dce7db4e..dc7bb7c4b11 100644 --- a/tensorflow/python/eager/backprop.py +++ b/tensorflow/python/eager/backprop.py @@ -882,7 +882,7 @@ class GradientTape(object): Raises: ValueError: if it encounters something that is not a tensor. """ - for t in nest.flatten(tensor): + for t in nest.flatten(tensor, expand_composites=True): if not (_pywrap_utils.IsTensor(t) or _pywrap_utils.IsVariable(t)): raise ValueError("Passed in object of type {}, not tf.Tensor".format( type(t))) diff --git a/tensorflow/python/eager/backprop_test.py b/tensorflow/python/eager/backprop_test.py index b28aaa3a626..a0f98fc0a44 100644 --- a/tensorflow/python/eager/backprop_test.py +++ b/tensorflow/python/eager/backprop_test.py @@ -32,6 +32,7 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors_impl from tensorflow.python.framework import ops +from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util from tensorflow.python.framework import test_util @@ -48,6 +49,7 @@ from tensorflow.python.ops import nn_grad # pylint: disable=unused-import from tensorflow.python.ops import nn_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.ops import sparse_ops from tensorflow.python.ops import variables from tensorflow.python.training import training @@ -1484,6 +1486,19 @@ class BackpropTest(test.TestCase, parameterized.TestCase): with self.assertRaisesRegexp(ValueError, 'ndarray'): g.watch(np.array(1.)) + def testWatchComposite(self): + """Test that tape.watch expands composites and watches component Tensors.""" + with backprop.GradientTape() as t: + values = constant_op.constant([1.0, 2.0], dtypes.float32) + s = sparse_tensor.SparseTensor( + indices=[[0, 0], [1, 2]], + values=values, + dense_shape=[3, 4]) + t.watch(s) + z = sparse_ops.sparse_reduce_sum_v2(s) + result = t.gradient(z, values) + self.assertAllEqual(result, [1.0, 1.0]) + def testWatchedVariablesAfterNonPersistentGradientCall(self): with backprop.GradientTape(persistent=False) as tape: x = resource_variable_ops.ResourceVariable(1.0)