From cbe94fd66eb2b77932a8d2f8ab04cdcac6ba7645 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 12 Jul 2019 02:44:28 -0700 Subject: [PATCH] Add grad_pass_through utility to custom_gradients. grad_pass_through will let user wrap any op (including non differentiable ops such as assign operations) and produce a new op where gradients can 'pass through' to the inputs. An example use case is defining an approximate 'differentiable' moving avarage, where in the forward pass we use the current value of the moving average, and in the backward pass we only use last value fed to the op. PiperOrigin-RevId: 257770003 --- tensorflow/python/BUILD | 2 + tensorflow/python/ops/custom_gradient.py | 55 ++++++++++++++++- tensorflow/python/ops/gradients_test.py | 61 +++++++++++++++++++ .../tools/api/golden/v1/tensorflow.pbtxt | 4 ++ .../tools/api/golden/v2/tensorflow.pbtxt | 4 ++ 5 files changed, 125 insertions(+), 1 deletion(-) diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 93f43ab338a..15e447efede 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -3920,6 +3920,7 @@ cuda_py_test( ":framework_test_lib", ":functional_ops", ":gradients", + ":init_ops", ":list_ops", ":math_grad", ":math_ops", @@ -3927,6 +3928,7 @@ cuda_py_test( ":nn_ops", ":platform_test", ":state_grad", + ":state_ops", ":tensor_array_grad", ":tensor_array_ops", ":test_ops", diff --git a/tensorflow/python/ops/custom_gradient.py b/tensorflow/python/ops/custom_gradient.py index e2bbdc7f788..12b4feb68e5 100644 --- a/tensorflow/python/ops/custom_gradient.py +++ b/tensorflow/python/ops/custom_gradient.py @@ -74,7 +74,7 @@ def copy_handle_data(source_t, target_t): shapes, types = zip(*[(pair.shape, pair.dtype) for pair in handle_data.shape_and_type]) ranks = [len(s.dim) if not s.unknown_rank else -1 for s in shapes] - shapes = [[d.size for d in s.dim] + shapes = [[d.size for d in s.dim] # pylint: disable=g-complex-comprehension if not s.unknown_rank else None for s in shapes] pywrap_tensorflow.TF_GraphSetOutputHandleShapesAndTypes_wrapper( target_t._op._graph._c_graph, # pylint: disable=protected-access @@ -394,3 +394,56 @@ def recompute_grad(f): return result, grad return inner + + +@tf_export("grad_pass_through") +def grad_pass_through(f): + """Creates a grad-pass-through op with the forward behavior provided in f. + + Use this function to wrap any op, maintaining its behavior in the forward + pass, but replacing the original op in the backward graph with an identity. + For example: + + ```python + x = tf.Variable(1.0, name="x") + z = tf.Variable(3.0, name="z") + + with tf.GradientTape() as tape: + # y will evaluate to 9.0 + y = tf.grad_pass_through(x.assign)(z**2) + # grads will evaluate to 6.0 + grads = tape.gradient(y, z) + ``` + + Another example is a 'differentiable' moving average approximation, where + gradients are allowed to flow into the last value fed to the moving average, + but the moving average is still used for the forward pass: + + ```python + x = ... # Some scalar value + # A moving average object, we don't need to know how this is implemented + moving_average = MovingAverage() + with backprop.GradientTape() as tape: + # mavg_x will evaluate to the current running average value + mavg_x = tf.grad_pass_through(moving_average)(x) + grads = tape.gradient(mavg_x, x) # grads will evaluate to 1.0 + ``` + + Args: + f: function `f(*x)` that returns a `Tensor` or nested structure of `Tensor` + outputs. + + Returns: + A function `h(x)` which returns the same values as `f(x)` and whose + gradients are the same as those of an identity function. + """ + @custom_gradient + def _grad_pass_through_op(*args, **kwargs): + def grad(*args, **kwargs): + variables = kwargs.get("variables") + if variables is not None: + # Variables involved in the wrapped op will not receive gradients. + return args, [None] * len(variables) + return args + return f(*args, **kwargs), grad + return tf_decorator.make_decorator(f, _grad_pass_through_op) diff --git a/tensorflow/python/ops/gradients_test.py b/tensorflow/python/ops/gradients_test.py index 1b6dcd35bf3..be98f2a6279 100644 --- a/tensorflow/python/ops/gradients_test.py +++ b/tensorflow/python/ops/gradients_test.py @@ -44,12 +44,14 @@ from tensorflow.python.ops import data_flow_ops # pylint: disable=unused-import from tensorflow.python.ops import functional_ops # pylint: disable=unused-import from tensorflow.python.ops import gradients from tensorflow.python.ops import gradients_impl +from tensorflow.python.ops import init_ops from tensorflow.python.ops import list_ops from tensorflow.python.ops import math_grad # pylint: disable=unused-import from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_grad # pylint: disable=unused-import from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import state_grad # pylint: disable=unused-import +from tensorflow.python.ops import state_ops from tensorflow.python.ops import tensor_array_grad # pylint: disable=unused-import from tensorflow.python.ops import tensor_array_ops from tensorflow.python.ops import unconnected_gradients @@ -1389,5 +1391,64 @@ class VariablesGradientTest(test_util.TensorFlowTestCase): self.assertAllClose(g, g_re) +class GradPassThroughTest(test_util.TensorFlowTestCase): + + @test_util.run_v1_only("b/120545219") + def test_gradients_v1(self): + x = variable_scope.get_variable( + name="x", shape=(), initializer=init_ops.constant_initializer(1.0), + use_resource=True) + z = variable_scope.get_variable( + name="z", shape=(), initializer=init_ops.constant_initializer(3.0), + use_resource=True) + + # Verify that assign op is not differentiable + y = state_ops.assign(x, z**2) + grads = gradients.gradients(y, z) + self.assertIsNone(grads[0]) + + # Verify that when the (non differentiable) assign op is wrapped with + # grad_pass_through, gradients are correctly forwarded to the inputs. + # Form an input as quadratic function of variable z and check that the + # gradient of output wrt to z is correct. + y = custom_gradient.grad_pass_through( + lambda v: state_ops.assign(x, v))(z**2) + grads = gradients.gradients(y, z) + with self.cached_session() as sess: + sess.run(variables.global_variables_initializer()) + self.assertAllClose(grads[0].eval(), 6.0) + + # Verify that variables involved in the wrapped op do not receive gradients. + y = custom_gradient.grad_pass_through(lambda v: x * v)(z) + grads = gradients.gradients(y, x) + self.assertIsNone(grads[0]) + + @test_util.run_v2_only + def test_gradients_v2(self): + x = variables.Variable(1.0, name="x") + z = variables.Variable(3.0, name="z") + + # Verify that assign op is not differentiable + with backprop.GradientTape() as tape: + y = x.assign(z**2) + grads = tape.gradient(y, z) + self.assertIsNone(grads) + + # Verify that when the (non differentiable) assign op is wrapped with + # grad_pass_through, gradients are correctly forwarded to the inputs. + # Form an input as quadratic function of variable z and check that the + # gradient of output wrt to z is correct. + with backprop.GradientTape() as tape: + y = custom_gradient.grad_pass_through(x.assign)(z**2) + grads = tape.gradient(y, z) + self.assertAllClose(grads, 6.0) + + # Verify that variables involved in the wrapped op do not receive gradients. + with backprop.GradientTape() as tape: + y = custom_gradient.grad_pass_through(lambda v: x * v)(z) + grads = tape.gradient(y, x) + self.assertIsNone(grads) + + if __name__ == "__main__": googletest.main() diff --git a/tensorflow/tools/api/golden/v1/tensorflow.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.pbtxt index 50b3b399a9f..178daad4a2a 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.pbtxt @@ -1388,6 +1388,10 @@ tf_module { name: "global_variables_initializer" argspec: "args=[], varargs=None, keywords=None, defaults=None" } + member_method { + name: "grad_pass_through" + argspec: "args=[\'f\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "gradients" argspec: "args=[\'ys\', \'xs\', \'grad_ys\', \'name\', \'colocate_gradients_with_ops\', \'gate_gradients\', \'aggregation_method\', \'stop_gradients\', \'unconnected_gradients\'], varargs=None, keywords=None, defaults=[\'None\', \'gradients\', \'False\', \'False\', \'None\', \'None\', \'UnconnectedGradients.NONE\'], " diff --git a/tensorflow/tools/api/golden/v2/tensorflow.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.pbtxt index 6d15fa0c841..33c4610d97b 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.pbtxt @@ -664,6 +664,10 @@ tf_module { name: "get_static_value" argspec: "args=[\'tensor\', \'partial\'], varargs=None, keywords=None, defaults=[\'False\'], " } + member_method { + name: "grad_pass_through" + argspec: "args=[\'f\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "gradients" argspec: "args=[\'ys\', \'xs\', \'grad_ys\', \'name\', \'gate_gradients\', \'aggregation_method\', \'stop_gradients\', \'unconnected_gradients\'], varargs=None, keywords=None, defaults=[\'None\', \'gradients\', \'False\', \'None\', \'None\', \'UnconnectedGradients.NONE\'], "