Add grad_pass_through utility to custom_gradients.

grad_pass_through will let user wrap any op (including non differentiable ops such as assign operations) and produce a new op where gradients can 'pass through' to the inputs.
An example use case is defining an approximate 'differentiable' moving avarage, where in the forward pass we use the current value of the moving average, and in the backward pass we only use last value fed to the op.

PiperOrigin-RevId: 257770003
This commit is contained in:
A. Unique TensorFlower 2019-07-12 02:44:28 -07:00 committed by TensorFlower Gardener
parent 6bcff21ad8
commit cbe94fd66e
5 changed files with 125 additions and 1 deletions

View File

@ -3920,6 +3920,7 @@ cuda_py_test(
":framework_test_lib",
":functional_ops",
":gradients",
":init_ops",
":list_ops",
":math_grad",
":math_ops",
@ -3927,6 +3928,7 @@ cuda_py_test(
":nn_ops",
":platform_test",
":state_grad",
":state_ops",
":tensor_array_grad",
":tensor_array_ops",
":test_ops",

View File

@ -74,7 +74,7 @@ def copy_handle_data(source_t, target_t):
shapes, types = zip(*[(pair.shape, pair.dtype)
for pair in handle_data.shape_and_type])
ranks = [len(s.dim) if not s.unknown_rank else -1 for s in shapes]
shapes = [[d.size for d in s.dim]
shapes = [[d.size for d in s.dim] # pylint: disable=g-complex-comprehension
if not s.unknown_rank else None for s in shapes]
pywrap_tensorflow.TF_GraphSetOutputHandleShapesAndTypes_wrapper(
target_t._op._graph._c_graph, # pylint: disable=protected-access
@ -394,3 +394,56 @@ def recompute_grad(f):
return result, grad
return inner
@tf_export("grad_pass_through")
def grad_pass_through(f):
"""Creates a grad-pass-through op with the forward behavior provided in f.
Use this function to wrap any op, maintaining its behavior in the forward
pass, but replacing the original op in the backward graph with an identity.
For example:
```python
x = tf.Variable(1.0, name="x")
z = tf.Variable(3.0, name="z")
with tf.GradientTape() as tape:
# y will evaluate to 9.0
y = tf.grad_pass_through(x.assign)(z**2)
# grads will evaluate to 6.0
grads = tape.gradient(y, z)
```
Another example is a 'differentiable' moving average approximation, where
gradients are allowed to flow into the last value fed to the moving average,
but the moving average is still used for the forward pass:
```python
x = ... # Some scalar value
# A moving average object, we don't need to know how this is implemented
moving_average = MovingAverage()
with backprop.GradientTape() as tape:
# mavg_x will evaluate to the current running average value
mavg_x = tf.grad_pass_through(moving_average)(x)
grads = tape.gradient(mavg_x, x) # grads will evaluate to 1.0
```
Args:
f: function `f(*x)` that returns a `Tensor` or nested structure of `Tensor`
outputs.
Returns:
A function `h(x)` which returns the same values as `f(x)` and whose
gradients are the same as those of an identity function.
"""
@custom_gradient
def _grad_pass_through_op(*args, **kwargs):
def grad(*args, **kwargs):
variables = kwargs.get("variables")
if variables is not None:
# Variables involved in the wrapped op will not receive gradients.
return args, [None] * len(variables)
return args
return f(*args, **kwargs), grad
return tf_decorator.make_decorator(f, _grad_pass_through_op)

View File

@ -44,12 +44,14 @@ from tensorflow.python.ops import data_flow_ops # pylint: disable=unused-import
from tensorflow.python.ops import functional_ops # pylint: disable=unused-import
from tensorflow.python.ops import gradients
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import list_ops
from tensorflow.python.ops import math_grad # pylint: disable=unused-import
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_grad # pylint: disable=unused-import
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import state_grad # pylint: disable=unused-import
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import tensor_array_grad # pylint: disable=unused-import
from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.ops import unconnected_gradients
@ -1389,5 +1391,64 @@ class VariablesGradientTest(test_util.TensorFlowTestCase):
self.assertAllClose(g, g_re)
class GradPassThroughTest(test_util.TensorFlowTestCase):
@test_util.run_v1_only("b/120545219")
def test_gradients_v1(self):
x = variable_scope.get_variable(
name="x", shape=(), initializer=init_ops.constant_initializer(1.0),
use_resource=True)
z = variable_scope.get_variable(
name="z", shape=(), initializer=init_ops.constant_initializer(3.0),
use_resource=True)
# Verify that assign op is not differentiable
y = state_ops.assign(x, z**2)
grads = gradients.gradients(y, z)
self.assertIsNone(grads[0])
# Verify that when the (non differentiable) assign op is wrapped with
# grad_pass_through, gradients are correctly forwarded to the inputs.
# Form an input as quadratic function of variable z and check that the
# gradient of output wrt to z is correct.
y = custom_gradient.grad_pass_through(
lambda v: state_ops.assign(x, v))(z**2)
grads = gradients.gradients(y, z)
with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
self.assertAllClose(grads[0].eval(), 6.0)
# Verify that variables involved in the wrapped op do not receive gradients.
y = custom_gradient.grad_pass_through(lambda v: x * v)(z)
grads = gradients.gradients(y, x)
self.assertIsNone(grads[0])
@test_util.run_v2_only
def test_gradients_v2(self):
x = variables.Variable(1.0, name="x")
z = variables.Variable(3.0, name="z")
# Verify that assign op is not differentiable
with backprop.GradientTape() as tape:
y = x.assign(z**2)
grads = tape.gradient(y, z)
self.assertIsNone(grads)
# Verify that when the (non differentiable) assign op is wrapped with
# grad_pass_through, gradients are correctly forwarded to the inputs.
# Form an input as quadratic function of variable z and check that the
# gradient of output wrt to z is correct.
with backprop.GradientTape() as tape:
y = custom_gradient.grad_pass_through(x.assign)(z**2)
grads = tape.gradient(y, z)
self.assertAllClose(grads, 6.0)
# Verify that variables involved in the wrapped op do not receive gradients.
with backprop.GradientTape() as tape:
y = custom_gradient.grad_pass_through(lambda v: x * v)(z)
grads = tape.gradient(y, x)
self.assertIsNone(grads)
if __name__ == "__main__":
googletest.main()

View File

@ -1388,6 +1388,10 @@ tf_module {
name: "global_variables_initializer"
argspec: "args=[], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "grad_pass_through"
argspec: "args=[\'f\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "gradients"
argspec: "args=[\'ys\', \'xs\', \'grad_ys\', \'name\', \'colocate_gradients_with_ops\', \'gate_gradients\', \'aggregation_method\', \'stop_gradients\', \'unconnected_gradients\'], varargs=None, keywords=None, defaults=[\'None\', \'gradients\', \'False\', \'False\', \'None\', \'None\', \'UnconnectedGradients.NONE\'], "

View File

@ -664,6 +664,10 @@ tf_module {
name: "get_static_value"
argspec: "args=[\'tensor\', \'partial\'], varargs=None, keywords=None, defaults=[\'False\'], "
}
member_method {
name: "grad_pass_through"
argspec: "args=[\'f\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "gradients"
argspec: "args=[\'ys\', \'xs\', \'grad_ys\', \'name\', \'gate_gradients\', \'aggregation_method\', \'stop_gradients\', \'unconnected_gradients\'], varargs=None, keywords=None, defaults=[\'None\', \'gradients\', \'False\', \'None\', \'None\', \'UnconnectedGradients.NONE\'], "