diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 93f43ab338a..15e447efede 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -3920,6 +3920,7 @@ cuda_py_test( ":framework_test_lib", ":functional_ops", ":gradients", + ":init_ops", ":list_ops", ":math_grad", ":math_ops", @@ -3927,6 +3928,7 @@ cuda_py_test( ":nn_ops", ":platform_test", ":state_grad", + ":state_ops", ":tensor_array_grad", ":tensor_array_ops", ":test_ops", diff --git a/tensorflow/python/ops/custom_gradient.py b/tensorflow/python/ops/custom_gradient.py index e2bbdc7f788..12b4feb68e5 100644 --- a/tensorflow/python/ops/custom_gradient.py +++ b/tensorflow/python/ops/custom_gradient.py @@ -74,7 +74,7 @@ def copy_handle_data(source_t, target_t): shapes, types = zip(*[(pair.shape, pair.dtype) for pair in handle_data.shape_and_type]) ranks = [len(s.dim) if not s.unknown_rank else -1 for s in shapes] - shapes = [[d.size for d in s.dim] + shapes = [[d.size for d in s.dim] # pylint: disable=g-complex-comprehension if not s.unknown_rank else None for s in shapes] pywrap_tensorflow.TF_GraphSetOutputHandleShapesAndTypes_wrapper( target_t._op._graph._c_graph, # pylint: disable=protected-access @@ -394,3 +394,56 @@ def recompute_grad(f): return result, grad return inner + + +@tf_export("grad_pass_through") +def grad_pass_through(f): + """Creates a grad-pass-through op with the forward behavior provided in f. + + Use this function to wrap any op, maintaining its behavior in the forward + pass, but replacing the original op in the backward graph with an identity. + For example: + + ```python + x = tf.Variable(1.0, name="x") + z = tf.Variable(3.0, name="z") + + with tf.GradientTape() as tape: + # y will evaluate to 9.0 + y = tf.grad_pass_through(x.assign)(z**2) + # grads will evaluate to 6.0 + grads = tape.gradient(y, z) + ``` + + Another example is a 'differentiable' moving average approximation, where + gradients are allowed to flow into the last value fed to the moving average, + but the moving average is still used for the forward pass: + + ```python + x = ... # Some scalar value + # A moving average object, we don't need to know how this is implemented + moving_average = MovingAverage() + with backprop.GradientTape() as tape: + # mavg_x will evaluate to the current running average value + mavg_x = tf.grad_pass_through(moving_average)(x) + grads = tape.gradient(mavg_x, x) # grads will evaluate to 1.0 + ``` + + Args: + f: function `f(*x)` that returns a `Tensor` or nested structure of `Tensor` + outputs. + + Returns: + A function `h(x)` which returns the same values as `f(x)` and whose + gradients are the same as those of an identity function. + """ + @custom_gradient + def _grad_pass_through_op(*args, **kwargs): + def grad(*args, **kwargs): + variables = kwargs.get("variables") + if variables is not None: + # Variables involved in the wrapped op will not receive gradients. + return args, [None] * len(variables) + return args + return f(*args, **kwargs), grad + return tf_decorator.make_decorator(f, _grad_pass_through_op) diff --git a/tensorflow/python/ops/gradients_test.py b/tensorflow/python/ops/gradients_test.py index 1b6dcd35bf3..be98f2a6279 100644 --- a/tensorflow/python/ops/gradients_test.py +++ b/tensorflow/python/ops/gradients_test.py @@ -44,12 +44,14 @@ from tensorflow.python.ops import data_flow_ops # pylint: disable=unused-import from tensorflow.python.ops import functional_ops # pylint: disable=unused-import from tensorflow.python.ops import gradients from tensorflow.python.ops import gradients_impl +from tensorflow.python.ops import init_ops from tensorflow.python.ops import list_ops from tensorflow.python.ops import math_grad # pylint: disable=unused-import from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_grad # pylint: disable=unused-import from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import state_grad # pylint: disable=unused-import +from tensorflow.python.ops import state_ops from tensorflow.python.ops import tensor_array_grad # pylint: disable=unused-import from tensorflow.python.ops import tensor_array_ops from tensorflow.python.ops import unconnected_gradients @@ -1389,5 +1391,64 @@ class VariablesGradientTest(test_util.TensorFlowTestCase): self.assertAllClose(g, g_re) +class GradPassThroughTest(test_util.TensorFlowTestCase): + + @test_util.run_v1_only("b/120545219") + def test_gradients_v1(self): + x = variable_scope.get_variable( + name="x", shape=(), initializer=init_ops.constant_initializer(1.0), + use_resource=True) + z = variable_scope.get_variable( + name="z", shape=(), initializer=init_ops.constant_initializer(3.0), + use_resource=True) + + # Verify that assign op is not differentiable + y = state_ops.assign(x, z**2) + grads = gradients.gradients(y, z) + self.assertIsNone(grads[0]) + + # Verify that when the (non differentiable) assign op is wrapped with + # grad_pass_through, gradients are correctly forwarded to the inputs. + # Form an input as quadratic function of variable z and check that the + # gradient of output wrt to z is correct. + y = custom_gradient.grad_pass_through( + lambda v: state_ops.assign(x, v))(z**2) + grads = gradients.gradients(y, z) + with self.cached_session() as sess: + sess.run(variables.global_variables_initializer()) + self.assertAllClose(grads[0].eval(), 6.0) + + # Verify that variables involved in the wrapped op do not receive gradients. + y = custom_gradient.grad_pass_through(lambda v: x * v)(z) + grads = gradients.gradients(y, x) + self.assertIsNone(grads[0]) + + @test_util.run_v2_only + def test_gradients_v2(self): + x = variables.Variable(1.0, name="x") + z = variables.Variable(3.0, name="z") + + # Verify that assign op is not differentiable + with backprop.GradientTape() as tape: + y = x.assign(z**2) + grads = tape.gradient(y, z) + self.assertIsNone(grads) + + # Verify that when the (non differentiable) assign op is wrapped with + # grad_pass_through, gradients are correctly forwarded to the inputs. + # Form an input as quadratic function of variable z and check that the + # gradient of output wrt to z is correct. + with backprop.GradientTape() as tape: + y = custom_gradient.grad_pass_through(x.assign)(z**2) + grads = tape.gradient(y, z) + self.assertAllClose(grads, 6.0) + + # Verify that variables involved in the wrapped op do not receive gradients. + with backprop.GradientTape() as tape: + y = custom_gradient.grad_pass_through(lambda v: x * v)(z) + grads = tape.gradient(y, x) + self.assertIsNone(grads) + + if __name__ == "__main__": googletest.main() diff --git a/tensorflow/tools/api/golden/v1/tensorflow.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.pbtxt index 50b3b399a9f..178daad4a2a 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.pbtxt @@ -1388,6 +1388,10 @@ tf_module { name: "global_variables_initializer" argspec: "args=[], varargs=None, keywords=None, defaults=None" } + member_method { + name: "grad_pass_through" + argspec: "args=[\'f\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "gradients" argspec: "args=[\'ys\', \'xs\', \'grad_ys\', \'name\', \'colocate_gradients_with_ops\', \'gate_gradients\', \'aggregation_method\', \'stop_gradients\', \'unconnected_gradients\'], varargs=None, keywords=None, defaults=[\'None\', \'gradients\', \'False\', \'False\', \'None\', \'None\', \'UnconnectedGradients.NONE\'], " diff --git a/tensorflow/tools/api/golden/v2/tensorflow.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.pbtxt index 6d15fa0c841..33c4610d97b 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.pbtxt @@ -664,6 +664,10 @@ tf_module { name: "get_static_value" argspec: "args=[\'tensor\', \'partial\'], varargs=None, keywords=None, defaults=[\'False\'], " } + member_method { + name: "grad_pass_through" + argspec: "args=[\'f\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "gradients" argspec: "args=[\'ys\', \'xs\', \'grad_ys\', \'name\', \'gate_gradients\', \'aggregation_method\', \'stop_gradients\', \'unconnected_gradients\'], varargs=None, keywords=None, defaults=[\'None\', \'gradients\', \'False\', \'None\', \'None\', \'UnconnectedGradients.NONE\'], "